mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
Merge 1300906ac7
into 505ad36dfd
This commit is contained in:
commit
f6bd0a56e1
19 changed files with 4811 additions and 15 deletions
|
@ -1,6 +1,23 @@
|
|||
FROM rust:1.70.0-bullseye AS rust_build
|
||||
|
||||
WORKDIR /
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y -q \
|
||||
build-essential \
|
||||
curl \
|
||||
git \
|
||||
make
|
||||
|
||||
RUN git clone https://github.com/daulet/tokenizers.git /tokenizer && \
|
||||
cd /tokenizer && \
|
||||
cargo build --release && \
|
||||
cp target/release/libtokenizers.a /tokenizer/libtokenizers.a
|
||||
|
||||
FROM docker.io/golang:1.24-bookworm AS build
|
||||
|
||||
ARG BUILD_VERSION
|
||||
ARG ONNXRUNTIME_VERSION=1.18.1
|
||||
|
||||
WORKDIR /go/src/crowdsec
|
||||
|
||||
|
@ -24,7 +41,36 @@ RUN apt-get update && \
|
|||
|
||||
COPY . .
|
||||
|
||||
RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \
|
||||
COPY --from=rust_build /tokenizer/libtokenizers.a /usr/local/lib/
|
||||
|
||||
# INSTALL ONNXRUNTIME
|
||||
RUN cd /tmp && \
|
||||
wget -O onnxruntime.tgz https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION}.tgz && \
|
||||
tar -C /tmp -xvf onnxruntime.tgz && \
|
||||
mv onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION} onnxruntime && \
|
||||
rm -rf onnxruntime.tgz && \
|
||||
cp -R onnxruntime/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib && \
|
||||
cp onnxruntime/include/*.h /usr/local/include && \
|
||||
rm -rf onnxruntime
|
||||
|
||||
RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib/libonnxruntime.so
|
||||
|
||||
RUN ls -la /usr/local/include
|
||||
RUN ls -la /usr/local/lib
|
||||
|
||||
RUN ldconfig
|
||||
|
||||
# Test if linking works with a simple program
|
||||
RUN echo "#include <onnxruntime_c_api.h>" > test.c && \
|
||||
echo "int main() { return 0; }" >> test.c && \
|
||||
gcc test.c -L/usr/local/lib -lonnxruntime -o test_executable && ./test_executable
|
||||
|
||||
RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=0 \
|
||||
CGO_CFLAGS="-D_LARGEFILE64_SOURCE -I/usr/local/include" \
|
||||
CGO_CPPFLAGS="-I/usr/local/include" \
|
||||
CGO_LDFLAGS="-L/usr/local/lib -lstdc++ -lonnxruntime /usr/local/lib/libtokenizers.a -ldl -lm" \
|
||||
LIBRARY_PATH="/usr/local/lib" \
|
||||
LD_LIBRARY_PATH="/usr/local/lib" && \
|
||||
cd crowdsec-v* && \
|
||||
./wizard.sh --docker-mode && \
|
||||
cd - >/dev/null && \
|
||||
|
@ -37,6 +83,8 @@ RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \
|
|||
|
||||
FROM docker.io/debian:bookworm-slim AS slim
|
||||
|
||||
ARG ONNXRUNTIME_VERSION=1.18.1
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV DEBCONF_NOWARNINGS="yes"
|
||||
|
||||
|
@ -58,6 +106,15 @@ COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/l
|
|||
COPY --from=build /etc/crowdsec /staging/etc/crowdsec
|
||||
COPY --from=build /go/src/crowdsec/docker/docker_start.sh /
|
||||
COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml
|
||||
|
||||
# Note Copying this since can't build statically yet
|
||||
COPY --from=build /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION}
|
||||
COPY --from=build /usr/local/lib/libtokenizers.a /usr/lib/libtokenizers.a
|
||||
RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so
|
||||
COPY --from=build /usr/local/lib/libre2.* /usr/lib/
|
||||
|
||||
RUN ls -la /usr/lib
|
||||
|
||||
RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml && \
|
||||
yq eval -i ".plugin_config.group = \"nogroup\"" /staging/etc/crowdsec/config.yaml
|
||||
|
||||
|
|
69
Makefile
69
Makefile
|
@ -60,11 +60,9 @@ bool = $(if $(filter $(call lc, $1),1 yes true),1,0)
|
|||
|
||||
#--------------------------------------
|
||||
#
|
||||
# Define MAKE_FLAGS and LD_OPTS for the sub-makefiles in cmd/
|
||||
# Define LD_OPTS for the sub-makefiles in cmd/
|
||||
#
|
||||
|
||||
MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)"
|
||||
|
||||
LD_OPTS_VARS= \
|
||||
-X 'github.com/crowdsecurity/go-cs-lib/version.Version=$(BUILD_VERSION)' \
|
||||
-X 'github.com/crowdsecurity/go-cs-lib/version.BuildDate=$(BUILD_TIMESTAMP)' \
|
||||
|
@ -116,6 +114,14 @@ ifneq (,$(RE2_CHECK))
|
|||
endif
|
||||
endif
|
||||
|
||||
#--------------------------------------
|
||||
#
|
||||
# List of required build-time dependencies
|
||||
|
||||
DEPS_DIR := $(CURDIR)/build/deps
|
||||
|
||||
DEPS_FILES :=
|
||||
|
||||
#--------------------------------------
|
||||
#
|
||||
# Handle optional components and build profiles, to save space on the final binaries.
|
||||
|
@ -142,7 +148,8 @@ COMPONENTS := \
|
|||
datasource_s3 \
|
||||
datasource_syslog \
|
||||
datasource_wineventlog \
|
||||
cscli_setup
|
||||
cscli_setup \
|
||||
mlsupport
|
||||
|
||||
comma := ,
|
||||
space := $(empty) $(empty)
|
||||
|
@ -152,6 +159,9 @@ space := $(empty) $(empty)
|
|||
# keep only datasource-file
|
||||
EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS)))
|
||||
|
||||
# ml-support requires pre-built static libraries and weights 20MB
|
||||
EXCLUDE_DEFAULT := mlsupport
|
||||
|
||||
# example
|
||||
# EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3
|
||||
|
||||
|
@ -160,8 +170,10 @@ BUILD_PROFILE ?= default
|
|||
# Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set
|
||||
ifeq ($(BUILD_PROFILE),minimal)
|
||||
EXCLUDE ?= $(EXCLUDE_MINIMAL)
|
||||
else ifneq ($(BUILD_PROFILE),default)
|
||||
$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default)
|
||||
else ifeq ($(BUILD_PROFILE),default)
|
||||
EXCLUDE ?= $(EXCLUDE_DEFAULT)
|
||||
else ifneq ($(BUILD_PROFILE),full)
|
||||
$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default, full)
|
||||
endif
|
||||
|
||||
# Create list of excluded components from the EXCLUDE variable
|
||||
|
@ -179,6 +191,24 @@ ifneq ($(COMPONENT_TAGS),)
|
|||
GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS))
|
||||
endif
|
||||
|
||||
ifeq ($(filter mlsupport,$(EXCLUDE_LIST)),)
|
||||
$(info mlsupport is included)
|
||||
# Set additional variables when mlsupport is included
|
||||
ifneq ($(call bool,$(BUILD_RE2_WASM)),1)
|
||||
$(error for now, the flag BUILD_RE2_WASM is required for mlsupport)
|
||||
endif
|
||||
CGO_CPPFLAGS := -I$(DEPS_DIR)/src/onnxruntime/include/onnxruntime/core/session
|
||||
CGO_LDFLAGS := -L$(DEPS_DIR)/libs-lstdc++ -lonnxruntime -dl -lm
|
||||
LIBRARY_PATH := $(DEPS_DIR)/lib
|
||||
DEPS_FILES += $(DEPS_DIR)/lib/libtokenizers.a
|
||||
DEPS_FILES += $(DEPS_DIR)/lib/libonnxruntime.a
|
||||
DEPS_FILES += $(DEPS_DIR)/src/onnxruntime
|
||||
else
|
||||
CGO_CPPFLAGS :=
|
||||
CGO_LDFLAGS :=
|
||||
LIBRARY_PATH :=
|
||||
endif
|
||||
|
||||
#--------------------------------------
|
||||
|
||||
ifeq ($(call bool,$(BUILD_STATIC)),1)
|
||||
|
@ -208,7 +238,7 @@ endif
|
|||
#--------------------------------------
|
||||
|
||||
.PHONY: build
|
||||
build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins
|
||||
build: build-info download-deps crowdsec cscli plugins ## Build crowdsec, cscli and plugins
|
||||
|
||||
.PHONY: build-info
|
||||
build-info: ## Print build information
|
||||
|
@ -235,6 +265,29 @@ endif
|
|||
.PHONY: all
|
||||
all: clean test build ## Clean, test and build (requires localstack)
|
||||
|
||||
|
||||
.PHONY: download-deps
|
||||
download-deps: $(DEPS_FILES)
|
||||
|
||||
$(DEPS_DIR)/lib/libtokenizers.a:
|
||||
curl --fail -L --output $@ --create-dirs \
|
||||
https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libtokenizers.a \
|
||||
|
||||
$(DEPS_DIR)/lib/libonnxruntime.a:
|
||||
curl --fail -L --output $@ --create-dirs \
|
||||
https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libonnxruntime.a \
|
||||
|
||||
$(DEPS_DIR)/src/onnxruntime:
|
||||
git clone --depth 1 https://github.com/microsoft/onnxruntime $(DEPS_DIR)/src/onnxruntime -b v1.19.2
|
||||
|
||||
# Full list of flags that are passed down to the sub-makefiles in cmd/
|
||||
|
||||
MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)" CGO_CPPFLAGS="$(CGO_CPPFLAGS)" LIBRARY_PATH="$(LIBRARY_PATH)"
|
||||
|
||||
.PHONY: clean-deps
|
||||
clean-deps:
|
||||
@$(RM) -r $(DEPS_DIR)
|
||||
|
||||
.PHONY: plugins
|
||||
plugins: ## Build notification plugins
|
||||
@$(foreach plugin,$(PLUGINS), \
|
||||
|
@ -260,7 +313,7 @@ clean-rpm:
|
|||
@$(RM) -r rpm/SRPMS
|
||||
|
||||
.PHONY: clean
|
||||
clean: clean-debian clean-rpm bats-clean ## Remove build artifacts
|
||||
clean: clean-debian clean-rpm clean-deps bats-clean ## Remove build artifacts
|
||||
@$(MAKE) -C $(CROWDSEC_FOLDER) clean $(MAKE_FLAGS)
|
||||
@$(MAKE) -C $(CSCLI_FOLDER) clean $(MAKE_FLAGS)
|
||||
@$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR)
|
||||
|
|
2
go.mod
2
go.mod
|
@ -25,8 +25,10 @@ require (
|
|||
github.com/creack/pty v1.1.21 // indirect
|
||||
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26
|
||||
github.com/crowdsecurity/go-cs-lib v0.0.19
|
||||
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4
|
||||
github.com/crowdsecurity/grokky v0.2.2
|
||||
github.com/crowdsecurity/machineid v1.0.2
|
||||
github.com/daulet/tokenizers v0.9.0
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
|
||||
github.com/dghubble/sling v1.4.2
|
||||
github.com/distribution/reference v0.6.0 // indirect
|
||||
|
|
4
go.sum
4
go.sum
|
@ -113,10 +113,14 @@ github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen
|
|||
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk=
|
||||
github.com/crowdsecurity/go-cs-lib v0.0.19 h1:wA4O8hGrEntTGn7eZTJqnQ3mrAje5JvQAj8DNbe5IZg=
|
||||
github.com/crowdsecurity/go-cs-lib v0.0.19/go.mod h1:hz2FOHFXc0vWzH78uxo2VebtPQ9Snkbdzy3TMA20tVQ=
|
||||
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4 h1:CwzISIxoKp0dJLrJJIlhvQPuzirpS9QH07guxK5LIeg=
|
||||
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4/go.mod h1:YfyL16lx2wA8Z6t/TG1x1/FBngOIpuCuo7nM/FSuP54=
|
||||
github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4=
|
||||
github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM=
|
||||
github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc=
|
||||
github.com/crowdsecurity/machineid v1.0.2/go.mod h1:XWUSlnS0R0+u/JK5ulidwlbceNT3ZOCKteoVQEn6Luo=
|
||||
github.com/daulet/tokenizers v0.9.0 h1:PSjFUGeuhqb3C0GKP9hdvtHvJ6L1AZceV+0nYGACtCk=
|
||||
github.com/daulet/tokenizers v0.9.0/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs=
|
||||
github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
|
|
|
@ -19,10 +19,11 @@ var Built = map[string]bool{
|
|||
"datasource_loki": false,
|
||||
"datasource_s3": false,
|
||||
"datasource_syslog": false,
|
||||
"datasource_wineventlog": false,
|
||||
"datasource_victorialogs": false,
|
||||
"datasource_http": false,
|
||||
"datasource_wineventlog": false,
|
||||
"datasource_http": false,
|
||||
"cscli_setup": false,
|
||||
"mlsupport": false,
|
||||
}
|
||||
|
||||
func Register(name string) {
|
||||
|
|
57
pkg/exprhelpers/anomalydetection.go
Normal file
57
pkg/exprhelpers/anomalydetection.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
//go:build !no_mlsupport
|
||||
|
||||
package exprhelpers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/crowdsecurity/crowdsec/pkg/cwversion/component"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/ml"
|
||||
)
|
||||
|
||||
var robertaInferencePipeline *ml.RobertaClassificationInferencePipeline
|
||||
|
||||
//nolint:gochecknoinits
|
||||
func init() {
|
||||
component.Register("mlsupport")
|
||||
}
|
||||
|
||||
func InitRobertaInferencePipeline(modelBundlePath string) error {
|
||||
var err error
|
||||
|
||||
fmt.Println("Initializing Roberta Inference Pipeline")
|
||||
|
||||
robertaInferencePipeline, err = ml.NewRobertaInferencePipeline(modelBundlePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if robertaInferencePipeline == nil {
|
||||
fmt.Println("Failed to initialize Roberta Inference Pipeline")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAnomalous(params ...any) (any, error) {
|
||||
verb, ok1 := params[0].(string)
|
||||
httpPath, ok2 := params[1].(string)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("parameters must be strings")
|
||||
}
|
||||
|
||||
text := verb + " " + httpPath
|
||||
log.Println("Verb : ", verb)
|
||||
log.Println("HTTP Path : ", httpPath)
|
||||
log.Println("Text to analyze for Anomaly: ", text)
|
||||
|
||||
if robertaInferencePipeline == nil {
|
||||
return nil, errors.New("Roberta Inference Pipeline not properly initialized")
|
||||
}
|
||||
|
||||
result, err := robertaInferencePipeline.PredictLabel(text)
|
||||
boolean_label := result == 1
|
||||
return boolean_label, err
|
||||
}
|
29
pkg/exprhelpers/anomalydetection_stub.go
Normal file
29
pkg/exprhelpers/anomalydetection_stub.go
Normal file
|
@ -0,0 +1,29 @@
|
|||
//go:build no_mlsupport
|
||||
|
||||
package exprhelpers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var robertaInferencePipeline *RobertaInferencePipelineStub
|
||||
|
||||
type RobertaInferencePipelineStub struct{}
|
||||
|
||||
func InitRobertaInferencePipeline(modelBundlePath string) error {
|
||||
fmt.Println("Stub: InitRobertaInferencePipeline called with no ML support")
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsAnomalous(params ...any) (any, error) {
|
||||
_, ok1 := params[0].(string)
|
||||
_, ok2 := params[1].(string)
|
||||
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("parameters must be strings")
|
||||
}
|
||||
fmt.Println("IsAnomalous: InitRobertaInferencePipeline called with no ML support")
|
||||
|
||||
return false, nil
|
||||
}
|
54
pkg/exprhelpers/anomalydetection_test.go
Normal file
54
pkg/exprhelpers/anomalydetection_test.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package exprhelpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAnomalyDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params []any
|
||||
expectResult any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Empty verb and path",
|
||||
params: []any{"", "hello"},
|
||||
expectResult: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty verb",
|
||||
params: []any{"", "/somepath"},
|
||||
expectResult: false,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Empty path",
|
||||
params: []any{"GET", ""},
|
||||
expectResult: true,
|
||||
err: nil,
|
||||
},
|
||||
{
|
||||
name: "Valid verb and path",
|
||||
params: []any{"GET", "/somepath"},
|
||||
expectResult: false,
|
||||
err: nil,
|
||||
},
|
||||
}
|
||||
|
||||
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
|
||||
|
||||
if err := InitRobertaInferencePipeline(tarFilePath); err != nil {
|
||||
t.Fatalf("failed to initialize RobertaInferencePipeline: %v", err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, _ := IsAnomalous(tt.params...)
|
||||
assert.Equal(t, tt.expectResult, result)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -509,6 +509,13 @@ var exprFuncs = []exprCustomFunc{
|
|||
new(func(string) *net.IPNet),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "IsAnomalous",
|
||||
function: IsAnomalous,
|
||||
signature: []interface{}{
|
||||
new(func(string, string) (bool, error)),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "JA4H",
|
||||
function: JA4H,
|
||||
|
@ -518,6 +525,6 @@ var exprFuncs = []exprCustomFunc{
|
|||
},
|
||||
}
|
||||
|
||||
//go 1.20 "CutPrefix": strings.CutPrefix,
|
||||
//go 1.20 "CutPrefix": strings.CutPrefix,
|
||||
//go 1.20 "CutSuffix": strings.CutSuffix,
|
||||
//"Cut": strings.Cut, -> returns more than 2 values, not supported by expr
|
||||
|
|
|
@ -38,9 +38,10 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
dataFile map[string][]string
|
||||
dataFileRegex map[string][]*regexp.Regexp
|
||||
dataFileRe2 map[string][]*re2.Regexp
|
||||
dataFile map[string][]string
|
||||
dataFileRegex map[string][]*regexp.Regexp
|
||||
dataFileRe2 map[string][]*re2.Regexp
|
||||
mlRobertaModelFiles map[string]struct{}
|
||||
)
|
||||
|
||||
// This is used to (optionally) cache regexp results for RegexpInFile operations
|
||||
|
@ -125,6 +126,7 @@ func Init(databaseClient *database.Client) error {
|
|||
dataFile = make(map[string][]string)
|
||||
dataFileRegex = make(map[string][]*regexp.Regexp)
|
||||
dataFileRe2 = make(map[string][]*re2.Regexp)
|
||||
mlRobertaModelFiles = make(map[string]struct{})
|
||||
dbClient = databaseClient
|
||||
|
||||
XMLCacheInit()
|
||||
|
@ -205,6 +207,16 @@ func FileInit(fileFolder string, filename string, fileType string) error {
|
|||
|
||||
filepath := filepath.Join(fileFolder, filename)
|
||||
|
||||
if fileType == "ml_roberta_model" {
|
||||
err := InitRobertaInferencePipeline(filepath)
|
||||
if err != nil {
|
||||
log.Errorf("unable to init roberta model : %s", err)
|
||||
return err
|
||||
}
|
||||
mlRobertaModelFiles[filename] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -300,6 +312,8 @@ func existsInFileMaps(filename string, ftype string) (bool, error) {
|
|||
}
|
||||
case "string":
|
||||
_, ok = dataFile[filename]
|
||||
case "ml_roberta_model":
|
||||
_, ok = dataFile[filename]
|
||||
default:
|
||||
err = fmt.Errorf("unknown data type '%s' for : '%s'", ftype, filename)
|
||||
}
|
||||
|
|
BIN
pkg/exprhelpers/tests/anomaly_detection_bundle_test.tar
Normal file
BIN
pkg/exprhelpers/tests/anomaly_detection_bundle_test.tar
Normal file
Binary file not shown.
101
pkg/ml/onnx.go
Normal file
101
pkg/ml/onnx.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
//go:build !no_mlsupport
|
||||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
onnxruntime "github.com/crowdsecurity/go-onnxruntime"
|
||||
)
|
||||
|
||||
type OrtSession struct {
|
||||
ORTSession *onnxruntime.ORTSession
|
||||
ORTEnv *onnxruntime.ORTEnv
|
||||
ORTSessionOptions *onnxruntime.ORTSessionOptions
|
||||
}
|
||||
|
||||
func NewOrtSession(modelPath string) (*OrtSession, error) {
|
||||
ortEnv := onnxruntime.NewORTEnv(onnxruntime.ORT_LOGGING_LEVEL_ERROR, "development")
|
||||
if ortEnv == nil {
|
||||
return nil, fmt.Errorf("failed to create ORT environment")
|
||||
}
|
||||
|
||||
ortSessionOptions := onnxruntime.NewORTSessionOptions()
|
||||
if ortSessionOptions == nil {
|
||||
ortEnv.Close()
|
||||
return nil, fmt.Errorf("failed to create ORT session options")
|
||||
}
|
||||
|
||||
fmt.Println("Creating ORT session from model path:", modelPath)
|
||||
|
||||
session, err := onnxruntime.NewORTSession(ortEnv, modelPath, ortSessionOptions)
|
||||
if err != nil {
|
||||
fmt.Println("Error creating ORT session")
|
||||
ortEnv.Close()
|
||||
ortSessionOptions.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println("Done creating ORT session")
|
||||
|
||||
return &OrtSession{
|
||||
ORTSession: session,
|
||||
ORTEnv: ortEnv,
|
||||
ORTSessionOptions: ortSessionOptions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ort *OrtSession) Predict(inputs []onnxruntime.TensorValue) ([]onnxruntime.TensorValue, error) {
|
||||
res, err := ort.ORTSession.Predict(inputs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (ort *OrtSession) PredictLabel(inputIds []onnxruntime.TensorValue) (int, error) {
|
||||
res, err := ort.Predict(inputIds)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
label, err := PredicitonToLabel(res)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func GetTensorValue(input []int64, shape []int64) onnxruntime.TensorValue {
|
||||
return onnxruntime.TensorValue{
|
||||
Shape: shape,
|
||||
Value: input,
|
||||
}
|
||||
}
|
||||
|
||||
func PredicitonToLabel(res []onnxruntime.TensorValue) (int, error) {
|
||||
if len(res) != 1 {
|
||||
return 0, fmt.Errorf("expected one output tensor, got %d", len(res))
|
||||
}
|
||||
|
||||
outputTensor := res[0]
|
||||
|
||||
maxIndex := 0 // Assuming the output tensor shape is [1 2], and we want to find the index of the max value
|
||||
maxProb := outputTensor.Value.([]float32)[0] // Assuming the values are float32
|
||||
|
||||
for i, value := range outputTensor.Value.([]float32) {
|
||||
if value > maxProb {
|
||||
maxProb = value
|
||||
maxIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
return maxIndex, nil
|
||||
}
|
||||
|
||||
func (os *OrtSession) Close() {
|
||||
os.ORTSession.Close()
|
||||
os.ORTEnv.Close()
|
||||
os.ORTSessionOptions.Close()
|
||||
}
|
153
pkg/ml/robertapipeline.go
Normal file
153
pkg/ml/robertapipeline.go
Normal file
|
@ -0,0 +1,153 @@
|
|||
//go:build !no_ml_support
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
onnxruntime "github.com/crowdsecurity/go-onnxruntime"
|
||||
)
|
||||
|
||||
type RobertaClassificationInferencePipeline struct {
|
||||
inputShape []int64
|
||||
tokenizer *Tokenizer
|
||||
ortSession *OrtSession
|
||||
}
|
||||
|
||||
var bundleFileList = []string{
|
||||
"model.onnx",
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
}
|
||||
|
||||
func NewRobertaInferencePipeline(bundleFilePath string) (*RobertaClassificationInferencePipeline, error) {
|
||||
tempDir, err := os.MkdirTemp("", "crowdsec_roberta_model_assets")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not create temp directory: %v", err)
|
||||
}
|
||||
|
||||
if err := extractTarFile(bundleFilePath, tempDir); err != nil {
|
||||
os.RemoveAll(tempDir)
|
||||
return nil, fmt.Errorf("failed to extract tar file: %v", err)
|
||||
}
|
||||
|
||||
outputDir := filepath.Join(tempDir, strings.Split(filepath.Base(bundleFilePath), ".tar")[0])
|
||||
|
||||
for _, file := range bundleFileList {
|
||||
if _, err := os.Stat(filepath.Join(outputDir, file)); os.IsNotExist(err) {
|
||||
os.RemoveAll(tempDir)
|
||||
return nil, fmt.Errorf("missing required file: %s, in %s", file, outputDir)
|
||||
}
|
||||
}
|
||||
|
||||
ortSession, err := NewOrtSession(filepath.Join(outputDir, "model.onnx"))
|
||||
if err != nil {
|
||||
os.RemoveAll(tempDir)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tokenizer, err := NewTokenizer(outputDir)
|
||||
if err != nil {
|
||||
ortSession.Close()
|
||||
os.RemoveAll(tempDir)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inputShape := []int64{1, int64(tokenizer.modelMaxLength)}
|
||||
|
||||
if err := os.RemoveAll(tempDir); err != nil {
|
||||
ortSession.Close()
|
||||
tokenizer.Close()
|
||||
return nil, fmt.Errorf("could not remove temp directory: %v", err)
|
||||
}
|
||||
|
||||
return &RobertaClassificationInferencePipeline{
|
||||
inputShape: inputShape,
|
||||
tokenizer: tokenizer,
|
||||
ortSession: ortSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *RobertaClassificationInferencePipeline) Close() {
|
||||
r.tokenizer.Close()
|
||||
r.ortSession.Close()
|
||||
}
|
||||
|
||||
func (pipeline *RobertaClassificationInferencePipeline) PredictLabel(text string) (int, error) {
|
||||
options := EncodeOptions{
|
||||
AddSpecialTokens: true,
|
||||
PadToMaxLength: true, // TODO:= ONNX Input formats leads to segfault without this
|
||||
ReturnAttentionMask: true,
|
||||
Truncate: true,
|
||||
}
|
||||
|
||||
ids, _, attentionMask, err := pipeline.tokenizer.Encode(text, options)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
fmt.Println("Error encoding text")
|
||||
return 0, err
|
||||
}
|
||||
|
||||
label, err := pipeline.ortSession.PredictLabel([]onnxruntime.TensorValue{
|
||||
GetTensorValue(ids, pipeline.inputShape),
|
||||
GetTensorValue(attentionMask, pipeline.inputShape),
|
||||
})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return label, nil
|
||||
}
|
||||
|
||||
func extractTarFile(tarFilePath, outputDir string) error {
|
||||
file, err := os.Open(tarFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open tar file: %v", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
tarReader := tar.NewReader(file)
|
||||
|
||||
for {
|
||||
header, err := tarReader.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading tar file: %v", err)
|
||||
}
|
||||
|
||||
targetPath := filepath.Join(outputDir, header.Name)
|
||||
|
||||
switch header.Typeflag {
|
||||
case tar.TypeDir:
|
||||
if err := os.MkdirAll(targetPath, 0755); err != nil {
|
||||
return fmt.Errorf("could not create directory %s: %v", targetPath, err)
|
||||
}
|
||||
case tar.TypeReg:
|
||||
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
|
||||
return fmt.Errorf("could not create directory %s: %v", filepath.Dir(targetPath), err)
|
||||
}
|
||||
|
||||
outFile, err := os.Create(targetPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create file %s: %v", targetPath, err)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
outFile.Close()
|
||||
return fmt.Errorf("could not copy data to file %s: %v", targetPath, err)
|
||||
}
|
||||
outFile.Close()
|
||||
default:
|
||||
fmt.Printf("Unsupported file type in tar: %s\n", header.Name)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
112
pkg/ml/robertapipeline_test.go
Normal file
112
pkg/ml/robertapipeline_test.go
Normal file
|
@ -0,0 +1,112 @@
|
|||
package ml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func BenchmarkPredictLabel(b *testing.B) {
|
||||
log.Println("Starting benchmark for PredictLabel")
|
||||
|
||||
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
|
||||
|
||||
pipeline, err := NewRobertaInferencePipeline(tarFilePath)
|
||||
if err != nil {
|
||||
b.Fatalf("NewRobertaInferencePipeline returned error: %v", err)
|
||||
}
|
||||
defer pipeline.Close()
|
||||
|
||||
text := "POST /"
|
||||
|
||||
b.ResetTimer()
|
||||
startTime := time.Now()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
if n%1000 == 0 {
|
||||
log.Printf("Running iteration %d", n)
|
||||
}
|
||||
|
||||
_, err := pipeline.PredictLabel(text)
|
||||
if err != nil {
|
||||
b.Fatalf("Prediction failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
var memStart runtime.MemStats
|
||||
runtime.ReadMemStats(&memStart)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, err := pipeline.PredictLabel(text)
|
||||
if err != nil {
|
||||
b.Fatalf("Prediction failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
var memEnd runtime.MemStats
|
||||
runtime.ReadMemStats(&memEnd)
|
||||
|
||||
totalAlloc := memEnd.TotalAlloc - memStart.TotalAlloc
|
||||
allocPerOp := totalAlloc / uint64(b.N)
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
log.Printf("Total benchmark time: %s\n", totalTime)
|
||||
log.Printf("Average time per prediction: %s\n", totalTime/time.Duration(b.N))
|
||||
log.Printf("Number of operations: %d\n", b.N)
|
||||
log.Printf("Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds())
|
||||
log.Printf("Memory allocated per operation: %d bytes\n", allocPerOp)
|
||||
log.Printf("Total memory allocated: %d bytes\n", totalAlloc)
|
||||
|
||||
fmt.Printf("Benchmark Results:\n")
|
||||
fmt.Printf(" Total time: %s\n", totalTime)
|
||||
fmt.Printf(" Average time per operation: %s\n", totalTime/time.Duration(b.N))
|
||||
fmt.Printf(" Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds())
|
||||
fmt.Printf(" Memory allocated per operation: %d bytes\n", allocPerOp)
|
||||
fmt.Printf(" Total memory allocated: %d bytes\n", totalAlloc)
|
||||
}
|
||||
func TestPredictLabel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
text string
|
||||
expectedID int
|
||||
label int
|
||||
}{
|
||||
{
|
||||
name: "Malicious request",
|
||||
text: "GET /lib/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php?",
|
||||
expectedID: 0,
|
||||
label: 1,
|
||||
},
|
||||
{
|
||||
name: "Safe request",
|
||||
text: "GET /online/_ui/responsive/theme-miglog/img/header+Navigation/icon-delivery.svg",
|
||||
expectedID: 0,
|
||||
label: 0,
|
||||
},
|
||||
}
|
||||
|
||||
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
|
||||
|
||||
pipeline, err := NewRobertaInferencePipeline(tarFilePath)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRobertaInferencePipeline returned error: %v", err)
|
||||
}
|
||||
defer pipeline.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
prediction, err := pipeline.PredictLabel(tt.text)
|
||||
if err != nil {
|
||||
t.Errorf("PredictLabel returned error: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.label, prediction, "Predicted label does not match the expected label")
|
||||
})
|
||||
}
|
||||
}
|
BIN
pkg/ml/tests/anomaly_detection_bundle_test.tar
Normal file
BIN
pkg/ml/tests/anomaly_detection_bundle_test.tar
Normal file
Binary file not shown.
3830
pkg/ml/tests/tokenizer.json
Normal file
3830
pkg/ml/tests/tokenizer.json
Normal file
File diff suppressed because it is too large
Load diff
57
pkg/ml/tests/tokenizer_config.json
Normal file
57
pkg/ml/tests/tokenizer_config.json
Normal file
|
@ -0,0 +1,57 @@
|
|||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "<s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"1": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"2": {
|
||||
"content": "</s>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"3": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"4": {
|
||||
"content": "<mask>",
|
||||
"lstrip": true,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"cls_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"errors": "replace",
|
||||
"mask_token": "<mask>",
|
||||
"model_max_length": 512,
|
||||
"pad_token": "<pad>",
|
||||
"sep_token": "</s>",
|
||||
"tokenizer_class": "RobertaTokenizer",
|
||||
"trim_offsets": true,
|
||||
"unk_token": "<unk>"
|
||||
}
|
160
pkg/ml/tokenizer.go
Normal file
160
pkg/ml/tokenizer.go
Normal file
|
@ -0,0 +1,160 @@
|
|||
//go:build !no_mlsupport
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
|
||||
tokenizers "github.com/daulet/tokenizers"
|
||||
)
|
||||
|
||||
type Tokenizer struct {
|
||||
tk *tokenizers.Tokenizer
|
||||
modelMaxLength int
|
||||
padTokenID int
|
||||
tokenizerClass string
|
||||
}
|
||||
|
||||
type tokenizerConfig struct {
|
||||
ModelMaxLen int `json:"model_max_length"`
|
||||
PadToken string `json:"pad_token"`
|
||||
TokenizerClass string `json:"tokenizer_class"`
|
||||
AddedTokenDecoder map[string]map[string]interface{} `json:"added_tokens_decoder"`
|
||||
}
|
||||
|
||||
func loadTokenizerConfig(filename string) (*tokenizerConfig, error) {
|
||||
file, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
fmt.Println("Error reading tokenizer config file")
|
||||
return nil, err
|
||||
}
|
||||
config := &tokenizerConfig{}
|
||||
if err := json.Unmarshal(file, config); err != nil {
|
||||
fmt.Println("Error unmarshalling tokenizer config")
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func findTokenID(tokens map[string]map[string]interface{}, tokenContent string) int {
|
||||
for key, value := range tokens {
|
||||
if content, ok := value["content"]; ok && content == tokenContent {
|
||||
if tokenID, err := strconv.Atoi(key); err == nil {
|
||||
return tokenID
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func NewTokenizer(datadir string) (*Tokenizer, error) {
|
||||
defaultMaxLen := 512
|
||||
defaultPadTokenID := 1
|
||||
defaultTokenizerClass := "RobertaTokenizer"
|
||||
|
||||
// check if tokenizer.json exists
|
||||
tokenizerPath := filepath.Join(datadir, "tokenizer.json")
|
||||
if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("tokenizer.json not found in %s", datadir)
|
||||
}
|
||||
|
||||
tk, err := tokenizers.FromFile(tokenizerPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
configFile := filepath.Join(datadir, "tokenizer_config.json")
|
||||
config, err := loadTokenizerConfig(configFile)
|
||||
if err != nil {
|
||||
fmt.Println("Warning: Could not load tokenizer config, using default values.")
|
||||
return &Tokenizer{
|
||||
tk: tk,
|
||||
modelMaxLength: defaultMaxLen,
|
||||
padTokenID: defaultPadTokenID,
|
||||
tokenizerClass: defaultTokenizerClass,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Use default values if any required config is missing
|
||||
// modelMaxLen := 256
|
||||
modelMaxLen := config.ModelMaxLen
|
||||
if modelMaxLen == 0 {
|
||||
modelMaxLen = defaultMaxLen
|
||||
}
|
||||
|
||||
padTokenID := findTokenID(config.AddedTokenDecoder, config.PadToken)
|
||||
if padTokenID == -1 {
|
||||
padTokenID = defaultPadTokenID
|
||||
}
|
||||
|
||||
tokenizerClass := config.TokenizerClass
|
||||
if tokenizerClass == "" {
|
||||
tokenizerClass = defaultTokenizerClass
|
||||
}
|
||||
|
||||
return &Tokenizer{
|
||||
tk: tk,
|
||||
modelMaxLength: modelMaxLen,
|
||||
padTokenID: padTokenID,
|
||||
tokenizerClass: tokenizerClass,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type EncodeOptions struct {
|
||||
AddSpecialTokens bool
|
||||
PadToMaxLength bool
|
||||
ReturnAttentionMask bool
|
||||
Truncate bool
|
||||
}
|
||||
|
||||
func (t *Tokenizer) Encode(text string, options EncodeOptions) ([]int64, []string, []int64, error) {
|
||||
if t.tk == nil {
|
||||
return nil, nil, nil, fmt.Errorf("tokenizer is not initialized")
|
||||
}
|
||||
|
||||
ids, tokens := t.tk.Encode(text, options.AddSpecialTokens)
|
||||
|
||||
// Truncate to max length (right truncation)
|
||||
if len(ids) > t.modelMaxLength && options.Truncate {
|
||||
ids = ids[:t.modelMaxLength]
|
||||
tokens = tokens[:t.modelMaxLength]
|
||||
}
|
||||
|
||||
//[]uint32 to []int64
|
||||
int64Ids := make([]int64, len(ids))
|
||||
for i, id := range ids {
|
||||
int64Ids[i] = int64(id)
|
||||
}
|
||||
|
||||
// Padding to max length
|
||||
if options.PadToMaxLength && len(int64Ids) < t.modelMaxLength {
|
||||
paddingLength := t.modelMaxLength - len(int64Ids)
|
||||
for i := 0; i < paddingLength; i++ {
|
||||
int64Ids = append(int64Ids, int64(t.padTokenID))
|
||||
tokens = append(tokens, "<pad>")
|
||||
}
|
||||
}
|
||||
|
||||
// Creating attention mask
|
||||
var attentionMask []int64
|
||||
if options.ReturnAttentionMask {
|
||||
attentionMask = make([]int64, len(int64Ids))
|
||||
for i := range attentionMask {
|
||||
if int64Ids[i] != int64(t.padTokenID) {
|
||||
attentionMask[i] = 1
|
||||
} else {
|
||||
attentionMask[i] = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return int64Ids, tokens, attentionMask, nil
|
||||
}
|
||||
|
||||
func (t *Tokenizer) Close() {
|
||||
t.tk.Close()
|
||||
}
|
105
pkg/ml/tokenizer_test.go
Normal file
105
pkg/ml/tokenizer_test.go
Normal file
|
@ -0,0 +1,105 @@
|
|||
//go:build !no_mlsupport
|
||||
|
||||
package ml
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTokenize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputText string
|
||||
tokenizerPath string
|
||||
encodeOptions EncodeOptions
|
||||
expectedIds []int64
|
||||
expectedTokens []string
|
||||
expectedMask []int64
|
||||
expectTruncation bool
|
||||
}{
|
||||
{
|
||||
name: "Tokenize 'this is some text'",
|
||||
inputText: "this is some text",
|
||||
tokenizerPath: "tests/small-champion-model",
|
||||
encodeOptions: EncodeOptions{
|
||||
AddSpecialTokens: true,
|
||||
PadToMaxLength: false,
|
||||
ReturnAttentionMask: true,
|
||||
Truncate: true,
|
||||
},
|
||||
expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 268, 488, 2},
|
||||
expectedTokens: []string{"<s>", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "te", "xt", "</s>"},
|
||||
expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
},
|
||||
{
|
||||
name: "Tokenize 'this is some new texts'",
|
||||
inputText: "this is some new texts",
|
||||
tokenizerPath: "tests/small-champion-model",
|
||||
encodeOptions: EncodeOptions{
|
||||
AddSpecialTokens: true,
|
||||
PadToMaxLength: false,
|
||||
ReturnAttentionMask: true,
|
||||
Truncate: true,
|
||||
},
|
||||
expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 1959, 225, 268, 488, 87, 2},
|
||||
expectedTokens: []string{"<s>", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "new", "Ġ", "te", "xt", "s", "</s>"},
|
||||
expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
},
|
||||
}
|
||||
|
||||
tokenizer, err := NewTokenizer("tests")
|
||||
if err != nil {
|
||||
t.Errorf("NewTokenizer returned error: %v", err)
|
||||
return
|
||||
}
|
||||
defer tokenizer.Close()
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ids, tokens, attentionMask, err := tokenizer.Encode(tt.inputText, tt.encodeOptions)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Encode returned error: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.expectedIds, ids, "IDs do not match")
|
||||
assert.Equal(t, tt.expectedTokens, tokens, "Tokens do not match")
|
||||
if tt.encodeOptions.ReturnAttentionMask {
|
||||
assert.Equal(t, tt.expectedMask, attentionMask, "Attention mask does not match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenizeLongString(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
for i := 0; i < 1024; i++ {
|
||||
builder.WriteString("a")
|
||||
}
|
||||
longString := builder.String()
|
||||
|
||||
tokenizer, err := NewTokenizer("tests")
|
||||
if err != nil {
|
||||
t.Errorf("NewTokenizer returned error: %v", err)
|
||||
return
|
||||
}
|
||||
defer tokenizer.Close()
|
||||
|
||||
encodeOptions := EncodeOptions{
|
||||
AddSpecialTokens: true,
|
||||
PadToMaxLength: false,
|
||||
ReturnAttentionMask: true,
|
||||
Truncate: true,
|
||||
}
|
||||
|
||||
ids, tokens, _, err := tokenizer.Encode(longString, encodeOptions)
|
||||
if err != nil {
|
||||
t.Errorf("Encode returned error: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, 512, len(ids), "IDs length does not match for long string with truncation")
|
||||
assert.Equal(t, 512, len(tokens), "IDs length does not match for long string with truncation")
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue