diff --git a/Dockerfile.debian b/Dockerfile.debian index 70714f624..8513390df 100644 --- a/Dockerfile.debian +++ b/Dockerfile.debian @@ -1,6 +1,23 @@ +FROM rust:1.70.0-bullseye AS rust_build + +WORKDIR / + +RUN apt-get update && \ + apt-get install -y -q \ + build-essential \ + curl \ + git \ + make + +RUN git clone https://github.com/daulet/tokenizers.git /tokenizer && \ + cd /tokenizer && \ + cargo build --release && \ + cp target/release/libtokenizers.a /tokenizer/libtokenizers.a + FROM docker.io/golang:1.24-bookworm AS build ARG BUILD_VERSION +ARG ONNXRUNTIME_VERSION=1.18.1 WORKDIR /go/src/crowdsec @@ -24,7 +41,36 @@ RUN apt-get update && \ COPY . . -RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ +COPY --from=rust_build /tokenizer/libtokenizers.a /usr/local/lib/ + +# INSTALL ONNXRUNTIME +RUN cd /tmp && \ + wget -O onnxruntime.tgz https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION}.tgz && \ + tar -C /tmp -xvf onnxruntime.tgz && \ + mv onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION} onnxruntime && \ + rm -rf onnxruntime.tgz && \ + cp -R onnxruntime/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib && \ + cp onnxruntime/include/*.h /usr/local/include && \ + rm -rf onnxruntime + +RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib/libonnxruntime.so + +RUN ls -la /usr/local/include +RUN ls -la /usr/local/lib + +RUN ldconfig + +# Test if linking works with a simple program +RUN echo "#include " > test.c && \ + echo "int main() { return 0; }" >> test.c && \ + gcc test.c -L/usr/local/lib -lonnxruntime -o test_executable && ./test_executable + +RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=0 \ + CGO_CFLAGS="-D_LARGEFILE64_SOURCE -I/usr/local/include" \ + CGO_CPPFLAGS="-I/usr/local/include" \ + CGO_LDFLAGS="-L/usr/local/lib -lstdc++ -lonnxruntime /usr/local/lib/libtokenizers.a -ldl -lm" \ + LIBRARY_PATH="/usr/local/lib" \ + LD_LIBRARY_PATH="/usr/local/lib" && \ cd crowdsec-v* && \ ./wizard.sh --docker-mode && \ cd - >/dev/null && \ @@ -37,6 +83,8 @@ RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \ FROM docker.io/debian:bookworm-slim AS slim +ARG ONNXRUNTIME_VERSION=1.18.1 + ENV DEBIAN_FRONTEND=noninteractive ENV DEBCONF_NOWARNINGS="yes" @@ -58,6 +106,15 @@ COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/l COPY --from=build /etc/crowdsec /staging/etc/crowdsec COPY --from=build /go/src/crowdsec/docker/docker_start.sh / COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml + +# Note Copying this since can't build statically yet +COPY --from=build /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} +COPY --from=build /usr/local/lib/libtokenizers.a /usr/lib/libtokenizers.a +RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so +COPY --from=build /usr/local/lib/libre2.* /usr/lib/ + +RUN ls -la /usr/lib + RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml && \ yq eval -i ".plugin_config.group = \"nogroup\"" /staging/etc/crowdsec/config.yaml diff --git a/Makefile b/Makefile index 3a04f174c..699b9d296 100644 --- a/Makefile +++ b/Makefile @@ -60,11 +60,9 @@ bool = $(if $(filter $(call lc, $1),1 yes true),1,0) #-------------------------------------- # -# Define MAKE_FLAGS and LD_OPTS for the sub-makefiles in cmd/ +# Define LD_OPTS for the sub-makefiles in cmd/ # -MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)" - LD_OPTS_VARS= \ -X 'github.com/crowdsecurity/go-cs-lib/version.Version=$(BUILD_VERSION)' \ -X 'github.com/crowdsecurity/go-cs-lib/version.BuildDate=$(BUILD_TIMESTAMP)' \ @@ -116,6 +114,14 @@ ifneq (,$(RE2_CHECK)) endif endif +#-------------------------------------- +# +# List of required build-time dependencies + +DEPS_DIR := $(CURDIR)/build/deps + +DEPS_FILES := + #-------------------------------------- # # Handle optional components and build profiles, to save space on the final binaries. @@ -142,7 +148,8 @@ COMPONENTS := \ datasource_s3 \ datasource_syslog \ datasource_wineventlog \ - cscli_setup + cscli_setup \ + mlsupport comma := , space := $(empty) $(empty) @@ -152,6 +159,9 @@ space := $(empty) $(empty) # keep only datasource-file EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS))) +# ml-support requires pre-built static libraries and weights 20MB +EXCLUDE_DEFAULT := mlsupport + # example # EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3 @@ -160,8 +170,10 @@ BUILD_PROFILE ?= default # Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set ifeq ($(BUILD_PROFILE),minimal) EXCLUDE ?= $(EXCLUDE_MINIMAL) -else ifneq ($(BUILD_PROFILE),default) -$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default) +else ifeq ($(BUILD_PROFILE),default) +EXCLUDE ?= $(EXCLUDE_DEFAULT) +else ifneq ($(BUILD_PROFILE),full) +$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default, full) endif # Create list of excluded components from the EXCLUDE variable @@ -179,6 +191,24 @@ ifneq ($(COMPONENT_TAGS),) GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS)) endif +ifeq ($(filter mlsupport,$(EXCLUDE_LIST)),) + $(info mlsupport is included) + # Set additional variables when mlsupport is included +ifneq ($(call bool,$(BUILD_RE2_WASM)),1) + $(error for now, the flag BUILD_RE2_WASM is required for mlsupport) +endif + CGO_CPPFLAGS := -I$(DEPS_DIR)/src/onnxruntime/include/onnxruntime/core/session + CGO_LDFLAGS := -L$(DEPS_DIR)/libs-lstdc++ -lonnxruntime -dl -lm + LIBRARY_PATH := $(DEPS_DIR)/lib + DEPS_FILES += $(DEPS_DIR)/lib/libtokenizers.a + DEPS_FILES += $(DEPS_DIR)/lib/libonnxruntime.a + DEPS_FILES += $(DEPS_DIR)/src/onnxruntime +else + CGO_CPPFLAGS := + CGO_LDFLAGS := + LIBRARY_PATH := +endif + #-------------------------------------- ifeq ($(call bool,$(BUILD_STATIC)),1) @@ -208,7 +238,7 @@ endif #-------------------------------------- .PHONY: build -build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins +build: build-info download-deps crowdsec cscli plugins ## Build crowdsec, cscli and plugins .PHONY: build-info build-info: ## Print build information @@ -235,6 +265,29 @@ endif .PHONY: all all: clean test build ## Clean, test and build (requires localstack) + +.PHONY: download-deps +download-deps: $(DEPS_FILES) + +$(DEPS_DIR)/lib/libtokenizers.a: + curl --fail -L --output $@ --create-dirs \ + https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libtokenizers.a \ + +$(DEPS_DIR)/lib/libonnxruntime.a: + curl --fail -L --output $@ --create-dirs \ + https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libonnxruntime.a \ + +$(DEPS_DIR)/src/onnxruntime: + git clone --depth 1 https://github.com/microsoft/onnxruntime $(DEPS_DIR)/src/onnxruntime -b v1.19.2 + +# Full list of flags that are passed down to the sub-makefiles in cmd/ + +MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)" CGO_CPPFLAGS="$(CGO_CPPFLAGS)" LIBRARY_PATH="$(LIBRARY_PATH)" + +.PHONY: clean-deps +clean-deps: + @$(RM) -r $(DEPS_DIR) + .PHONY: plugins plugins: ## Build notification plugins @$(foreach plugin,$(PLUGINS), \ @@ -260,7 +313,7 @@ clean-rpm: @$(RM) -r rpm/SRPMS .PHONY: clean -clean: clean-debian clean-rpm bats-clean ## Remove build artifacts +clean: clean-debian clean-rpm clean-deps bats-clean ## Remove build artifacts @$(MAKE) -C $(CROWDSEC_FOLDER) clean $(MAKE_FLAGS) @$(MAKE) -C $(CSCLI_FOLDER) clean $(MAKE_FLAGS) @$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR) diff --git a/go.mod b/go.mod index 931ca9e7e..865d150b0 100644 --- a/go.mod +++ b/go.mod @@ -25,8 +25,10 @@ require ( github.com/creack/pty v1.1.21 // indirect github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 github.com/crowdsecurity/go-cs-lib v0.0.19 + github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4 github.com/crowdsecurity/grokky v0.2.2 github.com/crowdsecurity/machineid v1.0.2 + github.com/daulet/tokenizers v0.9.0 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc github.com/dghubble/sling v1.4.2 github.com/distribution/reference v0.6.0 // indirect diff --git a/go.sum b/go.sum index e9a0c8b49..8796ff234 100644 --- a/go.sum +++ b/go.sum @@ -113,10 +113,14 @@ github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk= github.com/crowdsecurity/go-cs-lib v0.0.19 h1:wA4O8hGrEntTGn7eZTJqnQ3mrAje5JvQAj8DNbe5IZg= github.com/crowdsecurity/go-cs-lib v0.0.19/go.mod h1:hz2FOHFXc0vWzH78uxo2VebtPQ9Snkbdzy3TMA20tVQ= +github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4 h1:CwzISIxoKp0dJLrJJIlhvQPuzirpS9QH07guxK5LIeg= +github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4/go.mod h1:YfyL16lx2wA8Z6t/TG1x1/FBngOIpuCuo7nM/FSuP54= github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4= github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM= github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc= github.com/crowdsecurity/machineid v1.0.2/go.mod h1:XWUSlnS0R0+u/JK5ulidwlbceNT3ZOCKteoVQEn6Luo= +github.com/daulet/tokenizers v0.9.0 h1:PSjFUGeuhqb3C0GKP9hdvtHvJ6L1AZceV+0nYGACtCk= +github.com/daulet/tokenizers v0.9.0/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/cwversion/component/component.go b/pkg/cwversion/component/component.go index 2c6374e4b..1de72880a 100644 --- a/pkg/cwversion/component/component.go +++ b/pkg/cwversion/component/component.go @@ -19,10 +19,11 @@ var Built = map[string]bool{ "datasource_loki": false, "datasource_s3": false, "datasource_syslog": false, - "datasource_wineventlog": false, "datasource_victorialogs": false, - "datasource_http": false, + "datasource_wineventlog": false, + "datasource_http": false, "cscli_setup": false, + "mlsupport": false, } func Register(name string) { diff --git a/pkg/exprhelpers/anomalydetection.go b/pkg/exprhelpers/anomalydetection.go new file mode 100644 index 000000000..6282073a9 --- /dev/null +++ b/pkg/exprhelpers/anomalydetection.go @@ -0,0 +1,57 @@ +//go:build !no_mlsupport + +package exprhelpers + +import ( + "errors" + "fmt" + "log" + + "github.com/crowdsecurity/crowdsec/pkg/cwversion/component" + "github.com/crowdsecurity/crowdsec/pkg/ml" +) + +var robertaInferencePipeline *ml.RobertaClassificationInferencePipeline + +//nolint:gochecknoinits +func init() { + component.Register("mlsupport") +} + +func InitRobertaInferencePipeline(modelBundlePath string) error { + var err error + + fmt.Println("Initializing Roberta Inference Pipeline") + + robertaInferencePipeline, err = ml.NewRobertaInferencePipeline(modelBundlePath) + if err != nil { + return err + } + if robertaInferencePipeline == nil { + fmt.Println("Failed to initialize Roberta Inference Pipeline") + } + + return nil +} + +func IsAnomalous(params ...any) (any, error) { + verb, ok1 := params[0].(string) + httpPath, ok2 := params[1].(string) + + if !ok1 || !ok2 { + return nil, errors.New("parameters must be strings") + } + + text := verb + " " + httpPath + log.Println("Verb : ", verb) + log.Println("HTTP Path : ", httpPath) + log.Println("Text to analyze for Anomaly: ", text) + + if robertaInferencePipeline == nil { + return nil, errors.New("Roberta Inference Pipeline not properly initialized") + } + + result, err := robertaInferencePipeline.PredictLabel(text) + boolean_label := result == 1 + return boolean_label, err +} diff --git a/pkg/exprhelpers/anomalydetection_stub.go b/pkg/exprhelpers/anomalydetection_stub.go new file mode 100644 index 000000000..48029026a --- /dev/null +++ b/pkg/exprhelpers/anomalydetection_stub.go @@ -0,0 +1,29 @@ +//go:build no_mlsupport + +package exprhelpers + +import ( + "errors" + "fmt" +) + +var robertaInferencePipeline *RobertaInferencePipelineStub + +type RobertaInferencePipelineStub struct{} + +func InitRobertaInferencePipeline(modelBundlePath string) error { + fmt.Println("Stub: InitRobertaInferencePipeline called with no ML support") + return nil +} + +func IsAnomalous(params ...any) (any, error) { + _, ok1 := params[0].(string) + _, ok2 := params[1].(string) + + if !ok1 || !ok2 { + return nil, errors.New("parameters must be strings") + } + fmt.Println("IsAnomalous: InitRobertaInferencePipeline called with no ML support") + + return false, nil +} diff --git a/pkg/exprhelpers/anomalydetection_test.go b/pkg/exprhelpers/anomalydetection_test.go new file mode 100644 index 000000000..6ff005753 --- /dev/null +++ b/pkg/exprhelpers/anomalydetection_test.go @@ -0,0 +1,54 @@ +package exprhelpers + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAnomalyDetection(t *testing.T) { + tests := []struct { + name string + params []any + expectResult any + err error + }{ + { + name: "Empty verb and path", + params: []any{"", "hello"}, + expectResult: false, + err: nil, + }, + { + name: "Empty verb", + params: []any{"", "/somepath"}, + expectResult: false, + err: nil, + }, + { + name: "Empty path", + params: []any{"GET", ""}, + expectResult: true, + err: nil, + }, + { + name: "Valid verb and path", + params: []any{"GET", "/somepath"}, + expectResult: false, + err: nil, + }, + } + + tarFilePath := "tests/anomaly_detection_bundle_test.tar" + + if err := InitRobertaInferencePipeline(tarFilePath); err != nil { + t.Fatalf("failed to initialize RobertaInferencePipeline: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, _ := IsAnomalous(tt.params...) + assert.Equal(t, tt.expectResult, result) + }) + } +} diff --git a/pkg/exprhelpers/expr_lib.go b/pkg/exprhelpers/expr_lib.go index e0d7f6d97..9d385e886 100644 --- a/pkg/exprhelpers/expr_lib.go +++ b/pkg/exprhelpers/expr_lib.go @@ -509,6 +509,13 @@ var exprFuncs = []exprCustomFunc{ new(func(string) *net.IPNet), }, }, + { + name: "IsAnomalous", + function: IsAnomalous, + signature: []interface{}{ + new(func(string, string) (bool, error)), + }, + }, { name: "JA4H", function: JA4H, @@ -518,6 +525,6 @@ var exprFuncs = []exprCustomFunc{ }, } -//go 1.20 "CutPrefix": strings.CutPrefix, +//go 1.20 "CutPrefix": strings.CutPrefix, //go 1.20 "CutSuffix": strings.CutSuffix, //"Cut": strings.Cut, -> returns more than 2 values, not supported by expr diff --git a/pkg/exprhelpers/helpers.go b/pkg/exprhelpers/helpers.go index 6c99c53dd..4b215cf85 100644 --- a/pkg/exprhelpers/helpers.go +++ b/pkg/exprhelpers/helpers.go @@ -38,9 +38,10 @@ import ( ) var ( - dataFile map[string][]string - dataFileRegex map[string][]*regexp.Regexp - dataFileRe2 map[string][]*re2.Regexp + dataFile map[string][]string + dataFileRegex map[string][]*regexp.Regexp + dataFileRe2 map[string][]*re2.Regexp + mlRobertaModelFiles map[string]struct{} ) // This is used to (optionally) cache regexp results for RegexpInFile operations @@ -125,6 +126,7 @@ func Init(databaseClient *database.Client) error { dataFile = make(map[string][]string) dataFileRegex = make(map[string][]*regexp.Regexp) dataFileRe2 = make(map[string][]*re2.Regexp) + mlRobertaModelFiles = make(map[string]struct{}) dbClient = databaseClient XMLCacheInit() @@ -205,6 +207,16 @@ func FileInit(fileFolder string, filename string, fileType string) error { filepath := filepath.Join(fileFolder, filename) + if fileType == "ml_roberta_model" { + err := InitRobertaInferencePipeline(filepath) + if err != nil { + log.Errorf("unable to init roberta model : %s", err) + return err + } + mlRobertaModelFiles[filename] = struct{}{} + return nil + } + file, err := os.Open(filepath) if err != nil { return err @@ -300,6 +312,8 @@ func existsInFileMaps(filename string, ftype string) (bool, error) { } case "string": _, ok = dataFile[filename] + case "ml_roberta_model": + _, ok = dataFile[filename] default: err = fmt.Errorf("unknown data type '%s' for : '%s'", ftype, filename) } diff --git a/pkg/exprhelpers/tests/anomaly_detection_bundle_test.tar b/pkg/exprhelpers/tests/anomaly_detection_bundle_test.tar new file mode 100644 index 000000000..fddc62786 Binary files /dev/null and b/pkg/exprhelpers/tests/anomaly_detection_bundle_test.tar differ diff --git a/pkg/ml/onnx.go b/pkg/ml/onnx.go new file mode 100644 index 000000000..5ef335602 --- /dev/null +++ b/pkg/ml/onnx.go @@ -0,0 +1,101 @@ +//go:build !no_mlsupport +package ml + +import ( + "fmt" + + onnxruntime "github.com/crowdsecurity/go-onnxruntime" +) + +type OrtSession struct { + ORTSession *onnxruntime.ORTSession + ORTEnv *onnxruntime.ORTEnv + ORTSessionOptions *onnxruntime.ORTSessionOptions +} + +func NewOrtSession(modelPath string) (*OrtSession, error) { + ortEnv := onnxruntime.NewORTEnv(onnxruntime.ORT_LOGGING_LEVEL_ERROR, "development") + if ortEnv == nil { + return nil, fmt.Errorf("failed to create ORT environment") + } + + ortSessionOptions := onnxruntime.NewORTSessionOptions() + if ortSessionOptions == nil { + ortEnv.Close() + return nil, fmt.Errorf("failed to create ORT session options") + } + + fmt.Println("Creating ORT session from model path:", modelPath) + + session, err := onnxruntime.NewORTSession(ortEnv, modelPath, ortSessionOptions) + if err != nil { + fmt.Println("Error creating ORT session") + ortEnv.Close() + ortSessionOptions.Close() + return nil, err + } + + fmt.Println("Done creating ORT session") + + return &OrtSession{ + ORTSession: session, + ORTEnv: ortEnv, + ORTSessionOptions: ortSessionOptions, + }, nil +} + +func (ort *OrtSession) Predict(inputs []onnxruntime.TensorValue) ([]onnxruntime.TensorValue, error) { + res, err := ort.ORTSession.Predict(inputs) + if err != nil { + return nil, err + } + + return res, nil +} + +func (ort *OrtSession) PredictLabel(inputIds []onnxruntime.TensorValue) (int, error) { + res, err := ort.Predict(inputIds) + if err != nil { + return 0, err + } + + label, err := PredicitonToLabel(res) + if err != nil { + return 0, err + } + + return label, nil +} + +func GetTensorValue(input []int64, shape []int64) onnxruntime.TensorValue { + return onnxruntime.TensorValue{ + Shape: shape, + Value: input, + } +} + +func PredicitonToLabel(res []onnxruntime.TensorValue) (int, error) { + if len(res) != 1 { + return 0, fmt.Errorf("expected one output tensor, got %d", len(res)) + } + + outputTensor := res[0] + + maxIndex := 0 // Assuming the output tensor shape is [1 2], and we want to find the index of the max value + maxProb := outputTensor.Value.([]float32)[0] // Assuming the values are float32 + + for i, value := range outputTensor.Value.([]float32) { + if value > maxProb { + maxProb = value + maxIndex = i + } + } + + return maxIndex, nil +} + +func (os *OrtSession) Close() { + os.ORTSession.Close() + os.ORTEnv.Close() + os.ORTSessionOptions.Close() +} diff --git a/pkg/ml/robertapipeline.go b/pkg/ml/robertapipeline.go new file mode 100644 index 000000000..63d44cb3e --- /dev/null +++ b/pkg/ml/robertapipeline.go @@ -0,0 +1,153 @@ +//go:build !no_ml_support + +package ml + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + onnxruntime "github.com/crowdsecurity/go-onnxruntime" +) + +type RobertaClassificationInferencePipeline struct { + inputShape []int64 + tokenizer *Tokenizer + ortSession *OrtSession +} + +var bundleFileList = []string{ + "model.onnx", + "tokenizer.json", + "tokenizer_config.json", +} + +func NewRobertaInferencePipeline(bundleFilePath string) (*RobertaClassificationInferencePipeline, error) { + tempDir, err := os.MkdirTemp("", "crowdsec_roberta_model_assets") + if err != nil { + return nil, fmt.Errorf("could not create temp directory: %v", err) + } + + if err := extractTarFile(bundleFilePath, tempDir); err != nil { + os.RemoveAll(tempDir) + return nil, fmt.Errorf("failed to extract tar file: %v", err) + } + + outputDir := filepath.Join(tempDir, strings.Split(filepath.Base(bundleFilePath), ".tar")[0]) + + for _, file := range bundleFileList { + if _, err := os.Stat(filepath.Join(outputDir, file)); os.IsNotExist(err) { + os.RemoveAll(tempDir) + return nil, fmt.Errorf("missing required file: %s, in %s", file, outputDir) + } + } + + ortSession, err := NewOrtSession(filepath.Join(outputDir, "model.onnx")) + if err != nil { + os.RemoveAll(tempDir) + return nil, err + } + + tokenizer, err := NewTokenizer(outputDir) + if err != nil { + ortSession.Close() + os.RemoveAll(tempDir) + return nil, err + } + + inputShape := []int64{1, int64(tokenizer.modelMaxLength)} + + if err := os.RemoveAll(tempDir); err != nil { + ortSession.Close() + tokenizer.Close() + return nil, fmt.Errorf("could not remove temp directory: %v", err) + } + + return &RobertaClassificationInferencePipeline{ + inputShape: inputShape, + tokenizer: tokenizer, + ortSession: ortSession, + }, nil +} + +func (r *RobertaClassificationInferencePipeline) Close() { + r.tokenizer.Close() + r.ortSession.Close() +} + +func (pipeline *RobertaClassificationInferencePipeline) PredictLabel(text string) (int, error) { + options := EncodeOptions{ + AddSpecialTokens: true, + PadToMaxLength: true, // TODO:= ONNX Input formats leads to segfault without this + ReturnAttentionMask: true, + Truncate: true, + } + + ids, _, attentionMask, err := pipeline.tokenizer.Encode(text, options) + if err != nil { + fmt.Println(err) + fmt.Println("Error encoding text") + return 0, err + } + + label, err := pipeline.ortSession.PredictLabel([]onnxruntime.TensorValue{ + GetTensorValue(ids, pipeline.inputShape), + GetTensorValue(attentionMask, pipeline.inputShape), + }) + if err != nil { + fmt.Println(err) + return 0, err + } + + return label, nil +} + +func extractTarFile(tarFilePath, outputDir string) error { + file, err := os.Open(tarFilePath) + if err != nil { + return fmt.Errorf("could not open tar file: %v", err) + } + defer file.Close() + + tarReader := tar.NewReader(file) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("error reading tar file: %v", err) + } + + targetPath := filepath.Join(outputDir, header.Name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, 0755); err != nil { + return fmt.Errorf("could not create directory %s: %v", targetPath, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return fmt.Errorf("could not create directory %s: %v", filepath.Dir(targetPath), err) + } + + outFile, err := os.Create(targetPath) + if err != nil { + return fmt.Errorf("could not create file %s: %v", targetPath, err) + } + + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return fmt.Errorf("could not copy data to file %s: %v", targetPath, err) + } + outFile.Close() + default: + fmt.Printf("Unsupported file type in tar: %s\n", header.Name) + } + } + return nil +} diff --git a/pkg/ml/robertapipeline_test.go b/pkg/ml/robertapipeline_test.go new file mode 100644 index 000000000..f97dc4dac --- /dev/null +++ b/pkg/ml/robertapipeline_test.go @@ -0,0 +1,112 @@ +package ml + +import ( + "fmt" + "log" + "runtime" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func BenchmarkPredictLabel(b *testing.B) { + log.Println("Starting benchmark for PredictLabel") + + tarFilePath := "tests/anomaly_detection_bundle_test.tar" + + pipeline, err := NewRobertaInferencePipeline(tarFilePath) + if err != nil { + b.Fatalf("NewRobertaInferencePipeline returned error: %v", err) + } + defer pipeline.Close() + + text := "POST /" + + b.ResetTimer() + startTime := time.Now() + + for n := 0; n < b.N; n++ { + if n%1000 == 0 { + log.Printf("Running iteration %d", n) + } + + _, err := pipeline.PredictLabel(text) + if err != nil { + b.Fatalf("Prediction failed: %v", err) + } + } + + var memStart runtime.MemStats + runtime.ReadMemStats(&memStart) + + for n := 0; n < b.N; n++ { + _, err := pipeline.PredictLabel(text) + if err != nil { + b.Fatalf("Prediction failed: %v", err) + } + } + + b.StopTimer() + + var memEnd runtime.MemStats + runtime.ReadMemStats(&memEnd) + + totalAlloc := memEnd.TotalAlloc - memStart.TotalAlloc + allocPerOp := totalAlloc / uint64(b.N) + + totalTime := time.Since(startTime) + log.Printf("Total benchmark time: %s\n", totalTime) + log.Printf("Average time per prediction: %s\n", totalTime/time.Duration(b.N)) + log.Printf("Number of operations: %d\n", b.N) + log.Printf("Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds()) + log.Printf("Memory allocated per operation: %d bytes\n", allocPerOp) + log.Printf("Total memory allocated: %d bytes\n", totalAlloc) + + fmt.Printf("Benchmark Results:\n") + fmt.Printf(" Total time: %s\n", totalTime) + fmt.Printf(" Average time per operation: %s\n", totalTime/time.Duration(b.N)) + fmt.Printf(" Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds()) + fmt.Printf(" Memory allocated per operation: %d bytes\n", allocPerOp) + fmt.Printf(" Total memory allocated: %d bytes\n", totalAlloc) +} +func TestPredictLabel(t *testing.T) { + tests := []struct { + name string + text string + expectedID int + label int + }{ + { + name: "Malicious request", + text: "GET /lib/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php?", + expectedID: 0, + label: 1, + }, + { + name: "Safe request", + text: "GET /online/_ui/responsive/theme-miglog/img/header+Navigation/icon-delivery.svg", + expectedID: 0, + label: 0, + }, + } + + tarFilePath := "tests/anomaly_detection_bundle_test.tar" + + pipeline, err := NewRobertaInferencePipeline(tarFilePath) + if err != nil { + t.Fatalf("NewRobertaInferencePipeline returned error: %v", err) + } + defer pipeline.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prediction, err := pipeline.PredictLabel(tt.text) + if err != nil { + t.Errorf("PredictLabel returned error: %v", err) + } + + assert.Equal(t, tt.label, prediction, "Predicted label does not match the expected label") + }) + } +} diff --git a/pkg/ml/tests/anomaly_detection_bundle_test.tar b/pkg/ml/tests/anomaly_detection_bundle_test.tar new file mode 100644 index 000000000..fddc62786 Binary files /dev/null and b/pkg/ml/tests/anomaly_detection_bundle_test.tar differ diff --git a/pkg/ml/tests/tokenizer.json b/pkg/ml/tests/tokenizer.json new file mode 100644 index 000000000..8332802d1 --- /dev/null +++ b/pkg/ml/tests/tokenizer.json @@ -0,0 +1,3830 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": true, + "special": true + }, + { + "id": 4, + "content": "", + "single_word": false, + "lstrip": true, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true, + "use_regex": true + }, + "post_processor": { + "type": "RobertaProcessing", + "sep": [ + "", + 2 + ], + "cls": [ + "", + 0 + ], + "trim_offsets": true, + "add_prefix_space": false + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true, + "use_regex": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": "", + "end_of_word_suffix": "", + "fuse_unk": false, + "byte_fallback": false, + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "": 4, + "!": 5, + "\"": 6, + "#": 7, + "$": 8, + "%": 9, + "&": 10, + "'": 11, + "(": 12, + ")": 13, + "*": 14, + "+": 15, + ",": 16, + "-": 17, + ".": 18, + "/": 19, + "0": 20, + "1": 21, + "2": 22, + "3": 23, + "4": 24, + "5": 25, + "6": 26, + "7": 27, + "8": 28, + "9": 29, + ":": 30, + ";": 31, + "<": 32, + "=": 33, + ">": 34, + "?": 35, + "@": 36, + "A": 37, + "B": 38, + "C": 39, + "D": 40, + "E": 41, + "F": 42, + "G": 43, + "H": 44, + "I": 45, + "J": 46, + "K": 47, + "L": 48, + "M": 49, + "N": 50, + "O": 51, + "P": 52, + "Q": 53, + "R": 54, + "S": 55, + "T": 56, + "U": 57, + "V": 58, + "W": 59, + "X": 60, + "Y": 61, + "Z": 62, + "[": 63, + "\\": 64, + "]": 65, + "^": 66, + "_": 67, + "`": 68, + "a": 69, + "b": 70, + "c": 71, + "d": 72, + "e": 73, + "f": 74, + "g": 75, + "h": 76, + "i": 77, + "j": 78, + "k": 79, + "l": 80, + "m": 81, + "n": 82, + "o": 83, + "p": 84, + "q": 85, + "r": 86, + "s": 87, + "t": 88, + "u": 89, + "v": 90, + "w": 91, + "x": 92, + "y": 93, + "z": 94, + "{": 95, + "|": 96, + "}": 97, + "~": 98, + "¡": 99, + "¢": 100, + "£": 101, + "¤": 102, + "¥": 103, + "¦": 104, + "§": 105, + "¨": 106, + "©": 107, + "ª": 108, + "«": 109, + "¬": 110, + "®": 111, + "¯": 112, + "°": 113, + "±": 114, + "²": 115, + "³": 116, + "´": 117, + "µ": 118, + "¶": 119, + "·": 120, + "¸": 121, + "¹": 122, + "º": 123, + "»": 124, + "¼": 125, + "½": 126, + "¾": 127, + "¿": 128, + "À": 129, + "Á": 130, + "Â": 131, + "Ã": 132, + "Ä": 133, + "Å": 134, + "Æ": 135, + "Ç": 136, + "È": 137, + "É": 138, + "Ê": 139, + "Ë": 140, + "Ì": 141, + "Í": 142, + "Î": 143, + "Ï": 144, + "Ð": 145, + "Ñ": 146, + "Ò": 147, + "Ó": 148, + "Ô": 149, + "Õ": 150, + "Ö": 151, + "×": 152, + "Ø": 153, + "Ù": 154, + "Ú": 155, + "Û": 156, + "Ü": 157, + "Ý": 158, + "Þ": 159, + "ß": 160, + "à": 161, + "á": 162, + "â": 163, + "ã": 164, + "ä": 165, + "å": 166, + "æ": 167, + "ç": 168, + "è": 169, + "é": 170, + "ê": 171, + "ë": 172, + "ì": 173, + "í": 174, + "î": 175, + "ï": 176, + "ð": 177, + "ñ": 178, + "ò": 179, + "ó": 180, + "ô": 181, + "õ": 182, + "ö": 183, + "÷": 184, + "ø": 185, + "ù": 186, + "ú": 187, + "û": 188, + "ü": 189, + "ý": 190, + "þ": 191, + "ÿ": 192, + "Ā": 193, + "ā": 194, + "Ă": 195, + "ă": 196, + "Ą": 197, + "ą": 198, + "Ć": 199, + "ć": 200, + "Ĉ": 201, + "ĉ": 202, + "Ċ": 203, + "ċ": 204, + "Č": 205, + "č": 206, + "Ď": 207, + "ď": 208, + "Đ": 209, + "đ": 210, + "Ē": 211, + "ē": 212, + "Ĕ": 213, + "ĕ": 214, + "Ė": 215, + "ė": 216, + "Ę": 217, + "ę": 218, + "Ě": 219, + "ě": 220, + "Ĝ": 221, + "ĝ": 222, + "Ğ": 223, + "ğ": 224, + "Ġ": 225, + "ġ": 226, + "Ģ": 227, + "ģ": 228, + "Ĥ": 229, + "ĥ": 230, + "Ħ": 231, + "ħ": 232, + "Ĩ": 233, + "ĩ": 234, + "Ī": 235, + "ī": 236, + "Ĭ": 237, + "ĭ": 238, + "Į": 239, + "į": 240, + "İ": 241, + "ı": 242, + "IJ": 243, + "ij": 244, + "Ĵ": 245, + "ĵ": 246, + "Ķ": 247, + "ķ": 248, + "ĸ": 249, + "Ĺ": 250, + "ĺ": 251, + "Ļ": 252, + "ļ": 253, + "Ľ": 254, + "ľ": 255, + "Ŀ": 256, + "ŀ": 257, + "Ł": 258, + "ł": 259, + "Ń": 260, + "22": 261, + "re": 262, + "Ġ/": 263, + "on": 264, + "ET": 265, + "GET": 266, + "20": 267, + "te": 268, + "se": 269, + "¿½": 270, + "�": 271, + "Ġ1": 272, + "in": 273, + "AA": 274, + "am": 275, + "le": 276, + "ro": 277, + "st": 278, + "��": 279, + "ti": 280, + "amp": 281, + "en": 282, + "co": 283, + "ref": 284, + "er": 285, + "ge": 286, + "00": 287, + "al": 288, + "66": 289, + "or": 290, + "an": 291, + "ht": 292, + "at": 293, + "ar": 294, + "con": 295, + "ww": 296, + "17": 297, + "id": 298, + "29": 299, + "28": 300, + "25": 301, + "ss": 302, + "80": 303, + "lo": 304, + "tp": 305, + "$%": 306, + "ad": 307, + "10": 308, + "age": 309, + "js": 310, + "..": 311, + "23": 312, + "pg": 313, + "com": 314, + "27": 315, + "19": 316, + "http": 317, + "he": 318, + "16": 319, + "de": 320, + "di": 321, + "ul": 322, + "24": 323, + ":/": 324, + "ce": 325, + "li": 326, + "OM": 327, + "to": 328, + "ct": 329, + ":-": 330, + "30": 331, + "OS": 332, + "����": 333, + "me": 334, + "nt": 335, + "://": 336, + "pro": 337, + "ff": 338, + "15": 339, + "ne": 340, + "wp": 341, + "ROM": 342, + "https": 343, + "26": 344, + "AAAA": 345, + "ty": 346, + "ri": 347, + "13": 348, + "ac": 349, + "row": 350, + "ho": 351, + "\":": 352, + "POS": 353, + "14": 354, + "im": 355, + "brow": 356, + "browse": 357, + "il": 358, + "as": 359, + "pl": 360, + "ab": 361, + "ur": 362, + "tion": 363, + "conte": 364, + "18": 365, + "tr": 366, + "Ġ0": 367, + "99": 368, + "oo": 369, + "ate": 370, + "si": 371, + "ap": 372, + "up": 373, + "www": 374, + "po": 375, + "40": 376, + "TI": 377, + "Ġ×": 378, + "fi": 379, + "vi": 380, + "FROM": 381, + "90": 382, + ",\"": 383, + "<": 671, + "sp": 672, + "sta": 673, + "59": 674, + "bQ": 675, + "AK": 676, + "In": 677, + "�%": 678, + "ex": 679, + "31": 680, + "AS": 681, + "eo": 682, + "png": 683, + "gB": 684, + "AF": 685, + "tid": 686, + "nl": 687, + "AL": 688, + "cur": 689, + "ol": 690, + "plug": 691, + "Ġ//": 692, + "bb": 693, + "42": 694, + ".%": 695, + "ated": 696, + "BQ": 697, + "UN": 698, + "-%": 699, + "51": 700, + "ack": 701, + "86": 702, + "OD": 703, + "tmp": 704, + "app": 705, + "mi": 706, + "type": 707, + "mK": 708, + "39": 709, + "sk": 710, + "cy": 711, + "alert": 712, + "06": 713, + "view": 714, + "zi": 715, + "AN": 716, + "asset": 717, + "54": 718, + "art": 719, + "{\"": 720, + "71": 721, + "rf": 722, + "etc": 723, + "2022": 724, + "Ġ24": 725, + "br": 726, + "VIE": 727, + "sO": 728, + "VIEW": 729, + "plugins": 730, + "CH": 731, + "76": 732, + "TF": 733, + "+/": 734, + "by": 735, + "file": 736, + "74": 737, + "+-": 738, + "ffff": 739, + "OTE": 740, + "RER": 741, + "ds": 742, + "name": 743, + "SEC": 744, + "REF": 745, + "AM": 746, + "koi": 747, + "static": 748, + "assets": 749, + "AP": 750, + "TID": 751, + "TYPE": 752, + "pre": 753, + "event": 754, + "CJ": 755, + "created": 756, + "ITION": 757, + "POSITION": 758, + "SECTION": 759, + "000": 760, + "ERRER": 761, + "REFERRER": 762, + "jn": 763, + "pgsess": 764, + "cmd": 765, + "PROM": 766, + "OTED": 767, + "PROMOTED": 768, + "user": 769, + "qf": 770, + "CB": 771, + "version": 772, + "dd": 773, + "is": 774, + "Aw": 775, + "dex": 776, + "Mz": 777, + "Ġ9": 778, + "medi": 779, + "'%": 780, + "go": 781, + "84": 782, + "ubcat": 783, + "*?": 784, + "07": 785, + "ment": 786, + "Bsubcat": 787, + "out": 788, + "ZW": 789, + "/$%": 790, + "Ġ-": 791, + "INFO": 792, + "pb": 793, + "NE": 794, + "dit": 795, + "html": 796, + "IF": 797, + "YW": 798, + "LE": 799, + "ba": 800, + "cs": 801, + "201": 802, + "*?*?": 803, + "DICT": 804, + "ir": 805, + "Lj": 806, + "85": 807, + "Afalse": 808, + "bu": 809, + "comm": 810, + "79": 811, + "os": 812, + "ĠS": 813, + "confirm": 814, + "bd": 815, + "EF": 816, + "end": 817, + "index": 818, + "Ġ1080": 819, + "lect": 820, + "Ij": 821, + "mpt": 822, + "fc": 823, + "vb": 824, + "Ac": 825, + "Im": 826, + "Nj": 827, + "04": 828, + "prompt": 829, + "CQ": 830, + "ci": 831, + "low": 832, + "aW": 833, + "je": 834, + "87": 835, + "89": 836, + "wd": 837, + "Ii": 838, + ")\">": 839, + "Ah": 840, + "Vi": 841, + "uv": 842, + "RJ": 843, + "AJ": 844, + "81": 845, + "ps": 846, + "form": 847, + "fd": 848, + "Da": 849, + "Mj": 850, + "lin": 851, + "la": 852, + "hp": 853, + "ĠC": 854, + "Ġ4": 855, + "Au": 856, + "bi": 857, + "ky": 858, + "CF": 859, + "Eg": 860, + "ZX": 861, + "px": 862, + "ft": 863, + "bn": 864, + "Ax": 865, + "SI": 866, + "MC": 867, + "all": 868, + "sm": 869, + "drag": 870, + "HJ": 871, + "wa": 872, + "69": 873, + "YB": 874, + "TM": 875, + "news": 876, + ":$%": 877, + "1660": 878, + "uk": 879, + "wn": 880, + "json": 881, + "82": 882, + "yJ": 883, + "cm": 884, + "vid": 885, + "dm": 886, + "../../../../": 887, + "CE": 888, + "ls": 889, + "media": 890, + "Id": 891, + "cc": 892, + "oa": 893, + "Bh": 894, + "AU": 895, + "bg": 896, + "mes": 897, + "ath": 898, + "ר": 899, + "DQ": 900, + "ces": 901, + "mZ": 902, + "pp": 903, + "kt": 904, + "RR": 905, + "hi": 906, + "Cg": 907, + "ud": 908, + "61": 909, + "ax": 910, + "sl": 911, + "hb": 912, + "my": 913, + "iL": 914, + "QU": 915, + "My": 916, + "Ġ1660": 917, + "zL": 918, + "BE": 919, + "AO": 920, + "WU": 921, + "ks": 922, + "ni": 923, + "Re": 924, + "da": 925, + "-_": 926, + "ry": 927, + "df": 928, + "per": 929, + "Sy": 930, + "AV": 931, + "auto": 932, + "bl": 933, + "Ġid": 934, + "IQ": 935, + "72": 936, + "62": 937, + "za": 938, + "gE": 939, + "ui": 940, + "sho": 941, + "svg": 942, + "tran": 943, + "nte": 944, + "Zm": 945, + "win": 946, + "jw": 947, + "gAAAA": 948, + "St": 949, + "ea": 950, + "uF": 951, + "Og": 952, + "onte": 953, + "ll": 954, + "OR": 955, + "AT": 956, + "KJ": 957, + "HE": 958, + "Ġconte": 959, + "ditable": 960, + "nteditable": 961, + "Ġcontenteditable": 962, + "GF": 963, + "cess": 964, + "sw": 965, + "Oj": 966, + "Si": 967, + "GB": 968, + "Om": 969, + "QQ": 970, + "use": 971, + "ŀ×": 972, + "AH": 973, + "TU": 974, + "VF": 975, + "cr": 976, + "aw": 977, + "Nl": 978, + "DE": 979, + "py": 980, + "Mi": 981, + "No": 982, + "Oi": 983, + "Ro": 984, + "bc": 985, + "Ġ|": 986, + "ak": 987, + "CC": 988, + "pf": 989, + "MS": 990, + "GZ": 991, + "Uy": 992, + "pc": 993, + "url": 994, + "ly": 995, + "video": 996, + "Ex": 997, + "REEN": 998, + "TJ": 999, + "chr": 1000, + "bf": 1001, + "GV": 1002, + "ude": 1003, + "AZ": 1004, + "cB": 1005, + "Co": 1006, + "mar": 1007, + "SCREEN": 1008, + "ze": 1009, + "cu": 1010, + "ai": 1011, + "boo": 1012, + "Ym": 1013, + "Gg": 1014, + "cC": 1015, + "ĠP": 1016, + "pagead": 1017, + "HM": 1018, + "ow": 1019, + "dl": 1020, + ".\\": 1021, + "ĠGoogle": 1022, + "gr": 1023, + "dy": 1024, + "Anull": 1025, + "pv": 1026, + "YX": 1027, + "lower": 1028, + "ren": 1029, + "ache": 1030, + "SS": 1031, + "Nz": 1032, + "Ig": 1033, + "FB": 1034, + "NT": 1035, + "hCsO": 1036, + "mKTy": 1037, + "bQhCsO": 1038, + "ĠGooglemKTy": 1039, + "ĠGooglemKTybQhCsO": 1040, + "LT": 1041, + "we": 1042, + "nc": 1043, + "Az": 1044, + "Gl": 1045, + "Ho": 1046, + "MB": 1047, + "Fi": 1048, + "Ai": 1049, + "xw": 1050, + "db": 1051, + "Wl": 1052, + "fs": 1053, + "Blower": 1054, + "XB": 1055, + "rent": 1056, + "IV": 1057, + "Ny": 1058, + "tit": 1059, + "Tb": 1060, + "clude": 1061, + "WS": 1062, + "304": 1063, + "google": 1064, + "gt": 1065, + "GR": 1066, + "ef": 1067, + "heck": 1068, + "Il": 1069, + "Ġ1040": 1070, + "Is": 1071, + "253": 1072, + "\"}": 1073, + "wv": 1074, + "AX": 1075, + "Ay": 1076, + "gAAAAAB": 1077, + "GH": 1078, + "gif": 1079, + "arch": 1080, + "WF": 1081, + "CM": 1082, + "׾": 1083, + "path": 1084, + "PR": 1085, + "KQ": 1086, + "cgi": 1087, + "=/": 1088, + "az": 1089, + "OT": 1090, + "De": 1091, + "uZ": 1092, + "CN": 1093, + "nQ": 1094, + "uc": 1095, + "WE": 1096, + "next": 1097, + "Yz": 1098, + "CG": 1099, + "wo": 1100, + "Li": 1101, + "Cj": 1102, + "Ġt": 1103, + "VV": 1104, + "Av": 1105, + "LB": 1106, + "UF": 1107, + "themes": 1108, + "ting": 1109, + "play": 1110, + "htm": 1111, + "ord": 1112, + "Ad": 1113, + "ĠON": 1114, + "jh": 1115, + "300": 1116, + "UE": 1117, + "ZS": 1118, + "DS": 1119, + ".\\.\\": 1120, + "IB": 1121, + "XR": 1122, + "LL": 1123, + "ka": 1124, + "ery": 1125, + "YQ": 1126, + "unt": 1127, + "ME": 1128, + "gd": 1129, + "AW": 1130, + "Ġevent": 1131, + "src": 1132, + "gn": 1133, + "HV": 1134, + "red": 1135, + "200": 1136, + "UR": 1137, + "HR": 1138, + "Fk": 1139, + "MR": 1140, + "xy": 1141, + "rd": 1142, + "MG": 1143, + "pu": 1144, + "ND": 1145, + "Ġen": 1146, + "×Ķ": 1147, + "/*": 1148, + "./.": 1149, + "IP": 1150, + "Mk": 1151, + "Mozi": 1152, + "sd": 1153, + "MP": 1154, + "mV": 1155, + "Zj": 1156, + "sole": 1157, + "pd": 1158, + "HB": 1159, + "hy": 1160, + "ix": 1161, + "WQ": 1162, + "880": 1163, + "NF": 1164, + "OU": 1165, + "EV": 1166, + "NS": 1167, + "100": 1168, + "vo": 1169, + "cchr": 1170, + "console": 1171, + "ini": 1172, + "rb": 1173, + "xx": 1174, + "Ne": 1175, + "UJ": 1176, + "Aj": 1177, + "gy": 1178, + "UK": 1179, + "GQ": 1180, + "ob": 1181, + ")> <", + "s p", + "st a", + "5 9", + "b Q", + "A K", + "I n", + "� %", + "e x", + "3 1", + "A S", + "e o", + "pn g", + "g B", + "A F", + "ti d", + "n l", + "A L", + "c ur", + "o l", + "pl ug", + "Ġ/ /", + "b b", + "4 2", + ". %", + "ate d", + "B Q", + "U N", + "- %", + "5 1", + "ac k", + "8 6", + "O D", + "tm p", + "ap p", + "m i", + "ty pe", + "m K", + "3 9", + "s k", + "c y", + "ale rt", + "0 6", + "vi ew", + "z i", + "A N", + "as set", + "5 4", + "ar t", + "{ \"", + "7 1", + "r f", + "et c", + "20 22", + "Ġ 24", + "b r", + "V IE", + "s O", + "VIE W", + "plug ins", + "C H", + "7 6", + "T F", + "+ /", + "b y", + "fi le", + "7 4", + "+ -", + "ff ff", + "O TE", + "RE R", + "d s", + "n ame", + "SE C", + "RE F", + "A M", + "ko i", + "sta tic", + "asset s", + "A P", + "TI D", + "TY PE", + "p re", + "ev ent", + "C J", + "cre ated", + "I TION", + "POS ITION", + "SEC TION", + "00 0", + "ER RER", + "REF ERRER", + "j n", + "pg sess", + "c md", + "P ROM", + "OTE D", + "PROM OTED", + "u ser", + "q f", + "C B", + "ver sion", + "d d", + "i s", + "A w", + "de x", + "M z", + "Ġ 9", + "me di", + "' %", + "g o", + "8 4", + "ub cat", + "* ?", + "0 7", + "m ent", + "Bs ubcat", + "o ut", + "Z W", + "/ $%", + "Ġ -", + "IN FO", + "p b", + "N E", + "di t", + "ht ml", + "I F", + "Y W", + "L E", + "b a", + "c s", + "20 1", + "*? *?", + "DI CT", + "i r", + "L j", + "8 5", + "A false", + "b u", + "com m", + "7 9", + "o s", + "Ġ S", + "confi rm", + "b d", + "E F", + "en d", + "in dex", + "Ġ10 80", + "le ct", + "I j", + "m pt", + "f c", + "v b", + "A c", + "I m", + "N j", + "0 4", + "pro mpt", + "C Q", + "c i", + "lo w", + "a W", + "j e", + "8 7", + "8 9", + "w d", + "I i", + ")\" >", + "A h", + "V i", + "u v", + "R J", + "A J", + "8 1", + "p s", + "for m", + "f d", + "D a", + "M j", + "l in", + "l a", + "h p", + "Ġ C", + "Ġ 4", + "A u", + "b i", + "k y", + "C F", + "E g", + "Z X", + "p x", + "f t", + "b n", + "A x", + "S I", + "M C", + "al l", + "s m", + "dr ag", + "H J", + "w a", + "6 9", + "Y B", + "T M", + "ne ws", + ": $%", + "166 0", + "u k", + "w n", + "js on", + "8 2", + "y J", + "c m", + "v id", + "d m", + "../../ ../../", + "C E", + "l s", + "medi a", + "I d", + "c c", + "o a", + "B h", + "A U", + "b g", + "me s", + "at h", + "× ¨", + "D Q", + "ce s", + "m Z", + "p p", + "k t", + "R R", + "h i", + "C g", + "u d", + "6 1", + "a x", + "s l", + "h b", + "m y", + "i L", + "Q U", + "M y", + "Ġ166 0", + "z L", + "B E", + "A O", + "W U", + "k s", + "n i", + "R e", + "d a", + "- _", + "r y", + "d f", + "p er", + "S y", + "A V", + "au to", + "b l", + "Ġ id", + "I Q", + "7 2", + "6 2", + "z a", + "g E", + "u i", + "s ho", + "sv g", + "tr an", + "n te", + "Z m", + "w in", + "j w", + "g AAAA", + "S t", + "e a", + "u F", + "O g", + "on te", + "l l", + "O R", + "A T", + "K J", + "H E", + "Ġ conte", + "dit able", + "nte ditable", + "Ġconte nteditable", + "G F", + "ce ss", + "s w", + "O j", + "S i", + "G B", + "O m", + "Q Q", + "u se", + "ŀ ×", + "A H", + "T U", + "V F", + "c r", + "a w", + "N l", + "D E", + "p y", + "M i", + "N o", + "O i", + "R o", + "b c", + "Ġ |", + "a k", + "C C", + "p f", + "M S", + "G Z", + "U y", + "p c", + "ur l", + "l y", + "vid eo", + "E x", + "RE EN", + "T J", + "ch r", + "b f", + "G V", + "u de", + "A Z", + "c B", + "C o", + "m ar", + "SC REEN", + "z e", + "c u", + "a i", + "b oo", + "Y m", + "G g", + "c C", + "Ġ P", + "page ad", + "H M", + "o w", + "d l", + ". \\", + "ĠG oogle", + "g r", + "d y", + "An ull", + "p v", + "Y X", + "low er", + "re n", + "ac he", + "S S", + "N z", + "I g", + "F B", + "N T", + "hC sO", + "mK Ty", + "bQ hCsO", + "ĠGoogle mKTy", + "ĠGooglemKTy bQhCsO", + "L T", + "w e", + "n c", + "A z", + "G l", + "H o", + "M B", + "F i", + "A i", + "x w", + "d b", + "W l", + "f s", + "B lower", + "X B", + "re nt", + "I V", + "N y", + "ti t", + "T b", + "cl ude", + "W S", + "30 4", + "g oogle", + "g t", + "G R", + "e f", + "he ck", + "I l", + "Ġ10 40", + "I s", + "25 3", + "\" }", + "w v", + "A X", + "A y", + "gAAAA AB", + "G H", + "gi f", + "ar ch", + "W F", + "C M", + "× ľ", + "p ath", + "P R", + "K Q", + "c gi", + "= /", + "a z", + "O T", + "D e", + "u Z", + "C N", + "n Q", + "u c", + "W E", + "ne xt", + "Y z", + "C G", + "w o", + "L i", + "C j", + "Ġ t", + "V V", + "A v", + "L B", + "U F", + "the mes", + "t ing", + "pl ay", + "ht m", + "or d", + "A d", + "Ġ ON", + "j h", + "3 00", + "U E", + "Z S", + "D S", + ".\\ .\\", + "I B", + "X R", + "L L", + "k a", + "er y", + "Y Q", + "u nt", + "M E", + "g d", + "A W", + "Ġ event", + "sr c", + "g n", + "H V", + "re d", + "20 0", + "U R", + "H R", + "F k", + "M R", + "x y", + "r d", + "M G", + "p u", + "N D", + "Ġ en", + "× Ķ", + "/ *", + ". /.", + "I P", + "M k", + "Mo zi", + "s d", + "M P", + "m V", + "Z j", + "so le", + "p d", + "H B", + "h y", + "i x", + "W Q", + "8 80", + "N F", + "O U", + "E V", + "N S", + "1 00", + "v o", + "c chr", + "con sole", + "in i", + "r b", + "x x", + "N e", + "U J", + "A j", + "g y", + "U K", + "G Q", + "o b", + ") >", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "4": { + "content": "", + "lstrip": true, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "bos_token": "", + "clean_up_tokenization_spaces": true, + "cls_token": "", + "eos_token": "", + "errors": "replace", + "mask_token": "", + "model_max_length": 512, + "pad_token": "", + "sep_token": "", + "tokenizer_class": "RobertaTokenizer", + "trim_offsets": true, + "unk_token": "" +} diff --git a/pkg/ml/tokenizer.go b/pkg/ml/tokenizer.go new file mode 100644 index 000000000..4a3cdc6ee --- /dev/null +++ b/pkg/ml/tokenizer.go @@ -0,0 +1,160 @@ +//go:build !no_mlsupport + +package ml + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + + tokenizers "github.com/daulet/tokenizers" +) + +type Tokenizer struct { + tk *tokenizers.Tokenizer + modelMaxLength int + padTokenID int + tokenizerClass string +} + +type tokenizerConfig struct { + ModelMaxLen int `json:"model_max_length"` + PadToken string `json:"pad_token"` + TokenizerClass string `json:"tokenizer_class"` + AddedTokenDecoder map[string]map[string]interface{} `json:"added_tokens_decoder"` +} + +func loadTokenizerConfig(filename string) (*tokenizerConfig, error) { + file, err := os.ReadFile(filename) + if err != nil { + fmt.Println("Error reading tokenizer config file") + return nil, err + } + config := &tokenizerConfig{} + if err := json.Unmarshal(file, config); err != nil { + fmt.Println("Error unmarshalling tokenizer config") + return nil, err + } + return config, nil +} + +func findTokenID(tokens map[string]map[string]interface{}, tokenContent string) int { + for key, value := range tokens { + if content, ok := value["content"]; ok && content == tokenContent { + if tokenID, err := strconv.Atoi(key); err == nil { + return tokenID + } + } + } + return -1 +} + +func NewTokenizer(datadir string) (*Tokenizer, error) { + defaultMaxLen := 512 + defaultPadTokenID := 1 + defaultTokenizerClass := "RobertaTokenizer" + + // check if tokenizer.json exists + tokenizerPath := filepath.Join(datadir, "tokenizer.json") + if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) { + return nil, fmt.Errorf("tokenizer.json not found in %s", datadir) + } + + tk, err := tokenizers.FromFile(tokenizerPath) + if err != nil { + return nil, err + } + + configFile := filepath.Join(datadir, "tokenizer_config.json") + config, err := loadTokenizerConfig(configFile) + if err != nil { + fmt.Println("Warning: Could not load tokenizer config, using default values.") + return &Tokenizer{ + tk: tk, + modelMaxLength: defaultMaxLen, + padTokenID: defaultPadTokenID, + tokenizerClass: defaultTokenizerClass, + }, nil + } + + // Use default values if any required config is missing + // modelMaxLen := 256 + modelMaxLen := config.ModelMaxLen + if modelMaxLen == 0 { + modelMaxLen = defaultMaxLen + } + + padTokenID := findTokenID(config.AddedTokenDecoder, config.PadToken) + if padTokenID == -1 { + padTokenID = defaultPadTokenID + } + + tokenizerClass := config.TokenizerClass + if tokenizerClass == "" { + tokenizerClass = defaultTokenizerClass + } + + return &Tokenizer{ + tk: tk, + modelMaxLength: modelMaxLen, + padTokenID: padTokenID, + tokenizerClass: tokenizerClass, + }, nil +} + +type EncodeOptions struct { + AddSpecialTokens bool + PadToMaxLength bool + ReturnAttentionMask bool + Truncate bool +} + +func (t *Tokenizer) Encode(text string, options EncodeOptions) ([]int64, []string, []int64, error) { + if t.tk == nil { + return nil, nil, nil, fmt.Errorf("tokenizer is not initialized") + } + + ids, tokens := t.tk.Encode(text, options.AddSpecialTokens) + + // Truncate to max length (right truncation) + if len(ids) > t.modelMaxLength && options.Truncate { + ids = ids[:t.modelMaxLength] + tokens = tokens[:t.modelMaxLength] + } + + //[]uint32 to []int64 + int64Ids := make([]int64, len(ids)) + for i, id := range ids { + int64Ids[i] = int64(id) + } + + // Padding to max length + if options.PadToMaxLength && len(int64Ids) < t.modelMaxLength { + paddingLength := t.modelMaxLength - len(int64Ids) + for i := 0; i < paddingLength; i++ { + int64Ids = append(int64Ids, int64(t.padTokenID)) + tokens = append(tokens, "") + } + } + + // Creating attention mask + var attentionMask []int64 + if options.ReturnAttentionMask { + attentionMask = make([]int64, len(int64Ids)) + for i := range attentionMask { + if int64Ids[i] != int64(t.padTokenID) { + attentionMask[i] = 1 + } else { + attentionMask[i] = 0 + } + } + } + + return int64Ids, tokens, attentionMask, nil +} + +func (t *Tokenizer) Close() { + t.tk.Close() +} diff --git a/pkg/ml/tokenizer_test.go b/pkg/ml/tokenizer_test.go new file mode 100644 index 000000000..ade3c6226 --- /dev/null +++ b/pkg/ml/tokenizer_test.go @@ -0,0 +1,105 @@ +//go:build !no_mlsupport + +package ml + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTokenize(t *testing.T) { + tests := []struct { + name string + inputText string + tokenizerPath string + encodeOptions EncodeOptions + expectedIds []int64 + expectedTokens []string + expectedMask []int64 + expectTruncation bool + }{ + { + name: "Tokenize 'this is some text'", + inputText: "this is some text", + tokenizerPath: "tests/small-champion-model", + encodeOptions: EncodeOptions{ + AddSpecialTokens: true, + PadToMaxLength: false, + ReturnAttentionMask: true, + Truncate: true, + }, + expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 268, 488, 2}, + expectedTokens: []string{"", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "te", "xt", ""}, + expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + }, + { + name: "Tokenize 'this is some new texts'", + inputText: "this is some new texts", + tokenizerPath: "tests/small-champion-model", + encodeOptions: EncodeOptions{ + AddSpecialTokens: true, + PadToMaxLength: false, + ReturnAttentionMask: true, + Truncate: true, + }, + expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 1959, 225, 268, 488, 87, 2}, + expectedTokens: []string{"", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "new", "Ġ", "te", "xt", "s", ""}, + expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + }, + } + + tokenizer, err := NewTokenizer("tests") + if err != nil { + t.Errorf("NewTokenizer returned error: %v", err) + return + } + defer tokenizer.Close() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ids, tokens, attentionMask, err := tokenizer.Encode(tt.inputText, tt.encodeOptions) + + if err != nil { + t.Errorf("Encode returned error: %v", err) + } + + assert.Equal(t, tt.expectedIds, ids, "IDs do not match") + assert.Equal(t, tt.expectedTokens, tokens, "Tokens do not match") + if tt.encodeOptions.ReturnAttentionMask { + assert.Equal(t, tt.expectedMask, attentionMask, "Attention mask does not match") + } + }) + } +} + +func TestTokenizeLongString(t *testing.T) { + var builder strings.Builder + for i := 0; i < 1024; i++ { + builder.WriteString("a") + } + longString := builder.String() + + tokenizer, err := NewTokenizer("tests") + if err != nil { + t.Errorf("NewTokenizer returned error: %v", err) + return + } + defer tokenizer.Close() + + encodeOptions := EncodeOptions{ + AddSpecialTokens: true, + PadToMaxLength: false, + ReturnAttentionMask: true, + Truncate: true, + } + + ids, tokens, _, err := tokenizer.Encode(longString, encodeOptions) + if err != nil { + t.Errorf("Encode returned error: %v", err) + } + + assert.Equal(t, 512, len(ids), "IDs length does not match for long string with truncation") + assert.Equal(t, 512, len(tokens), "IDs length does not match for long string with truncation") +}