This commit is contained in:
David Lequin 2025-05-09 15:22:46 +00:00 committed by GitHub
commit f6bd0a56e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 4811 additions and 15 deletions

View file

@ -1,6 +1,23 @@
FROM rust:1.70.0-bullseye AS rust_build
WORKDIR /
RUN apt-get update && \
apt-get install -y -q \
build-essential \
curl \
git \
make
RUN git clone https://github.com/daulet/tokenizers.git /tokenizer && \
cd /tokenizer && \
cargo build --release && \
cp target/release/libtokenizers.a /tokenizer/libtokenizers.a
FROM docker.io/golang:1.24-bookworm AS build
ARG BUILD_VERSION
ARG ONNXRUNTIME_VERSION=1.18.1
WORKDIR /go/src/crowdsec
@ -24,7 +41,36 @@ RUN apt-get update && \
COPY . .
RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \
COPY --from=rust_build /tokenizer/libtokenizers.a /usr/local/lib/
# INSTALL ONNXRUNTIME
RUN cd /tmp && \
wget -O onnxruntime.tgz https://github.com/microsoft/onnxruntime/releases/download/v${ONNXRUNTIME_VERSION}/onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION}.tgz && \
tar -C /tmp -xvf onnxruntime.tgz && \
mv onnxruntime-linux-aarch64-${ONNXRUNTIME_VERSION} onnxruntime && \
rm -rf onnxruntime.tgz && \
cp -R onnxruntime/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib && \
cp onnxruntime/include/*.h /usr/local/include && \
rm -rf onnxruntime
RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/local/lib/libonnxruntime.so
RUN ls -la /usr/local/include
RUN ls -la /usr/local/lib
RUN ldconfig
# Test if linking works with a simple program
RUN echo "#include <onnxruntime_c_api.h>" > test.c && \
echo "int main() { return 0; }" >> test.c && \
gcc test.c -L/usr/local/lib -lonnxruntime -o test_executable && ./test_executable
RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=0 \
CGO_CFLAGS="-D_LARGEFILE64_SOURCE -I/usr/local/include" \
CGO_CPPFLAGS="-I/usr/local/include" \
CGO_LDFLAGS="-L/usr/local/lib -lstdc++ -lonnxruntime /usr/local/lib/libtokenizers.a -ldl -lm" \
LIBRARY_PATH="/usr/local/lib" \
LD_LIBRARY_PATH="/usr/local/lib" && \
cd crowdsec-v* && \
./wizard.sh --docker-mode && \
cd - >/dev/null && \
@ -37,6 +83,8 @@ RUN make clean release DOCKER_BUILD=1 BUILD_STATIC=1 && \
FROM docker.io/debian:bookworm-slim AS slim
ARG ONNXRUNTIME_VERSION=1.18.1
ENV DEBIAN_FRONTEND=noninteractive
ENV DEBCONF_NOWARNINGS="yes"
@ -58,6 +106,15 @@ COPY --from=build /go/bin/yq /usr/local/bin/crowdsec /usr/local/bin/cscli /usr/l
COPY --from=build /etc/crowdsec /staging/etc/crowdsec
COPY --from=build /go/src/crowdsec/docker/docker_start.sh /
COPY --from=build /go/src/crowdsec/docker/config.yaml /staging/etc/crowdsec/config.yaml
# Note Copying this since can't build statically yet
COPY --from=build /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION}
COPY --from=build /usr/local/lib/libtokenizers.a /usr/lib/libtokenizers.a
RUN ln -s /usr/local/lib/libonnxruntime.so.${ONNXRUNTIME_VERSION} /usr/lib/libonnxruntime.so
COPY --from=build /usr/local/lib/libre2.* /usr/lib/
RUN ls -la /usr/lib
RUN yq -n '.url="http://0.0.0.0:8080"' | install -m 0600 /dev/stdin /staging/etc/crowdsec/local_api_credentials.yaml && \
yq eval -i ".plugin_config.group = \"nogroup\"" /staging/etc/crowdsec/config.yaml

View file

@ -60,11 +60,9 @@ bool = $(if $(filter $(call lc, $1),1 yes true),1,0)
#--------------------------------------
#
# Define MAKE_FLAGS and LD_OPTS for the sub-makefiles in cmd/
# Define LD_OPTS for the sub-makefiles in cmd/
#
MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)"
LD_OPTS_VARS= \
-X 'github.com/crowdsecurity/go-cs-lib/version.Version=$(BUILD_VERSION)' \
-X 'github.com/crowdsecurity/go-cs-lib/version.BuildDate=$(BUILD_TIMESTAMP)' \
@ -116,6 +114,14 @@ ifneq (,$(RE2_CHECK))
endif
endif
#--------------------------------------
#
# List of required build-time dependencies
DEPS_DIR := $(CURDIR)/build/deps
DEPS_FILES :=
#--------------------------------------
#
# Handle optional components and build profiles, to save space on the final binaries.
@ -142,7 +148,8 @@ COMPONENTS := \
datasource_s3 \
datasource_syslog \
datasource_wineventlog \
cscli_setup
cscli_setup \
mlsupport
comma := ,
space := $(empty) $(empty)
@ -152,6 +159,9 @@ space := $(empty) $(empty)
# keep only datasource-file
EXCLUDE_MINIMAL := $(subst $(space),$(comma),$(filter-out datasource_file,,$(COMPONENTS)))
# ml-support requires pre-built static libraries and weights 20MB
EXCLUDE_DEFAULT := mlsupport
# example
# EXCLUDE_MEDIUM := datasource_kafka,datasource_kinesis,datasource_s3
@ -160,8 +170,10 @@ BUILD_PROFILE ?= default
# Set the EXCLUDE_LIST based on the chosen profile, unless EXCLUDE is already set
ifeq ($(BUILD_PROFILE),minimal)
EXCLUDE ?= $(EXCLUDE_MINIMAL)
else ifneq ($(BUILD_PROFILE),default)
$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default)
else ifeq ($(BUILD_PROFILE),default)
EXCLUDE ?= $(EXCLUDE_DEFAULT)
else ifneq ($(BUILD_PROFILE),full)
$(error Invalid build profile specified: $(BUILD_PROFILE). Valid profiles are: minimal, default, full)
endif
# Create list of excluded components from the EXCLUDE variable
@ -179,6 +191,24 @@ ifneq ($(COMPONENT_TAGS),)
GO_TAGS := $(GO_TAGS),$(subst $(space),$(comma),$(COMPONENT_TAGS))
endif
ifeq ($(filter mlsupport,$(EXCLUDE_LIST)),)
$(info mlsupport is included)
# Set additional variables when mlsupport is included
ifneq ($(call bool,$(BUILD_RE2_WASM)),1)
$(error for now, the flag BUILD_RE2_WASM is required for mlsupport)
endif
CGO_CPPFLAGS := -I$(DEPS_DIR)/src/onnxruntime/include/onnxruntime/core/session
CGO_LDFLAGS := -L$(DEPS_DIR)/libs-lstdc++ -lonnxruntime -dl -lm
LIBRARY_PATH := $(DEPS_DIR)/lib
DEPS_FILES += $(DEPS_DIR)/lib/libtokenizers.a
DEPS_FILES += $(DEPS_DIR)/lib/libonnxruntime.a
DEPS_FILES += $(DEPS_DIR)/src/onnxruntime
else
CGO_CPPFLAGS :=
CGO_LDFLAGS :=
LIBRARY_PATH :=
endif
#--------------------------------------
ifeq ($(call bool,$(BUILD_STATIC)),1)
@ -208,7 +238,7 @@ endif
#--------------------------------------
.PHONY: build
build: build-info crowdsec cscli plugins ## Build crowdsec, cscli and plugins
build: build-info download-deps crowdsec cscli plugins ## Build crowdsec, cscli and plugins
.PHONY: build-info
build-info: ## Print build information
@ -235,6 +265,29 @@ endif
.PHONY: all
all: clean test build ## Clean, test and build (requires localstack)
.PHONY: download-deps
download-deps: $(DEPS_FILES)
$(DEPS_DIR)/lib/libtokenizers.a:
curl --fail -L --output $@ --create-dirs \
https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libtokenizers.a \
$(DEPS_DIR)/lib/libonnxruntime.a:
curl --fail -L --output $@ --create-dirs \
https://github.com/crowdsecurity/packaging-onnx/releases/download/test/libonnxruntime.a \
$(DEPS_DIR)/src/onnxruntime:
git clone --depth 1 https://github.com/microsoft/onnxruntime $(DEPS_DIR)/src/onnxruntime -b v1.19.2
# Full list of flags that are passed down to the sub-makefiles in cmd/
MAKE_FLAGS = --no-print-directory GOARCH=$(GOARCH) GOOS=$(GOOS) RM="$(RM)" WIN_IGNORE_ERR="$(WIN_IGNORE_ERR)" CP="$(CP)" CPR="$(CPR)" MKDIR="$(MKDIR)" CGO_CPPFLAGS="$(CGO_CPPFLAGS)" LIBRARY_PATH="$(LIBRARY_PATH)"
.PHONY: clean-deps
clean-deps:
@$(RM) -r $(DEPS_DIR)
.PHONY: plugins
plugins: ## Build notification plugins
@$(foreach plugin,$(PLUGINS), \
@ -260,7 +313,7 @@ clean-rpm:
@$(RM) -r rpm/SRPMS
.PHONY: clean
clean: clean-debian clean-rpm bats-clean ## Remove build artifacts
clean: clean-debian clean-rpm clean-deps bats-clean ## Remove build artifacts
@$(MAKE) -C $(CROWDSEC_FOLDER) clean $(MAKE_FLAGS)
@$(MAKE) -C $(CSCLI_FOLDER) clean $(MAKE_FLAGS)
@$(RM) $(CROWDSEC_BIN) $(WIN_IGNORE_ERR)

2
go.mod
View file

@ -25,8 +25,10 @@ require (
github.com/creack/pty v1.1.21 // indirect
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26
github.com/crowdsecurity/go-cs-lib v0.0.19
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4
github.com/crowdsecurity/grokky v0.2.2
github.com/crowdsecurity/machineid v1.0.2
github.com/daulet/tokenizers v0.9.0
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc
github.com/dghubble/sling v1.4.2
github.com/distribution/reference v0.6.0 // indirect

4
go.sum
View file

@ -113,10 +113,14 @@ github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:r97WNVC30Uen
github.com/crowdsecurity/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:zpv7r+7KXwgVUZnUNjyP22zc/D7LKjyoY02weH2RBbk=
github.com/crowdsecurity/go-cs-lib v0.0.19 h1:wA4O8hGrEntTGn7eZTJqnQ3mrAje5JvQAj8DNbe5IZg=
github.com/crowdsecurity/go-cs-lib v0.0.19/go.mod h1:hz2FOHFXc0vWzH78uxo2VebtPQ9Snkbdzy3TMA20tVQ=
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4 h1:CwzISIxoKp0dJLrJJIlhvQPuzirpS9QH07guxK5LIeg=
github.com/crowdsecurity/go-onnxruntime v0.0.0-20240801073851-3fd7de0127b4/go.mod h1:YfyL16lx2wA8Z6t/TG1x1/FBngOIpuCuo7nM/FSuP54=
github.com/crowdsecurity/grokky v0.2.2 h1:yALsI9zqpDArYzmSSxfBq2dhYuGUTKMJq8KOEIAsuo4=
github.com/crowdsecurity/grokky v0.2.2/go.mod h1:33usDIYzGDsgX1kHAThCbseso6JuWNJXOzRQDGXHtWM=
github.com/crowdsecurity/machineid v1.0.2 h1:wpkpsUghJF8Khtmn/tg6GxgdhLA1Xflerh5lirI+bdc=
github.com/crowdsecurity/machineid v1.0.2/go.mod h1:XWUSlnS0R0+u/JK5ulidwlbceNT3ZOCKteoVQEn6Luo=
github.com/daulet/tokenizers v0.9.0 h1:PSjFUGeuhqb3C0GKP9hdvtHvJ6L1AZceV+0nYGACtCk=
github.com/daulet/tokenizers v0.9.0/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs=
github.com/davecgh/go-spew v0.0.0-20161028175848-04cdfd42973b/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View file

@ -19,10 +19,11 @@ var Built = map[string]bool{
"datasource_loki": false,
"datasource_s3": false,
"datasource_syslog": false,
"datasource_wineventlog": false,
"datasource_victorialogs": false,
"datasource_http": false,
"datasource_wineventlog": false,
"datasource_http": false,
"cscli_setup": false,
"mlsupport": false,
}
func Register(name string) {

View file

@ -0,0 +1,57 @@
//go:build !no_mlsupport
package exprhelpers
import (
"errors"
"fmt"
"log"
"github.com/crowdsecurity/crowdsec/pkg/cwversion/component"
"github.com/crowdsecurity/crowdsec/pkg/ml"
)
var robertaInferencePipeline *ml.RobertaClassificationInferencePipeline
//nolint:gochecknoinits
func init() {
component.Register("mlsupport")
}
func InitRobertaInferencePipeline(modelBundlePath string) error {
var err error
fmt.Println("Initializing Roberta Inference Pipeline")
robertaInferencePipeline, err = ml.NewRobertaInferencePipeline(modelBundlePath)
if err != nil {
return err
}
if robertaInferencePipeline == nil {
fmt.Println("Failed to initialize Roberta Inference Pipeline")
}
return nil
}
func IsAnomalous(params ...any) (any, error) {
verb, ok1 := params[0].(string)
httpPath, ok2 := params[1].(string)
if !ok1 || !ok2 {
return nil, errors.New("parameters must be strings")
}
text := verb + " " + httpPath
log.Println("Verb : ", verb)
log.Println("HTTP Path : ", httpPath)
log.Println("Text to analyze for Anomaly: ", text)
if robertaInferencePipeline == nil {
return nil, errors.New("Roberta Inference Pipeline not properly initialized")
}
result, err := robertaInferencePipeline.PredictLabel(text)
boolean_label := result == 1
return boolean_label, err
}

View file

@ -0,0 +1,29 @@
//go:build no_mlsupport
package exprhelpers
import (
"errors"
"fmt"
)
var robertaInferencePipeline *RobertaInferencePipelineStub
type RobertaInferencePipelineStub struct{}
func InitRobertaInferencePipeline(modelBundlePath string) error {
fmt.Println("Stub: InitRobertaInferencePipeline called with no ML support")
return nil
}
func IsAnomalous(params ...any) (any, error) {
_, ok1 := params[0].(string)
_, ok2 := params[1].(string)
if !ok1 || !ok2 {
return nil, errors.New("parameters must be strings")
}
fmt.Println("IsAnomalous: InitRobertaInferencePipeline called with no ML support")
return false, nil
}

View file

@ -0,0 +1,54 @@
package exprhelpers
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAnomalyDetection(t *testing.T) {
tests := []struct {
name string
params []any
expectResult any
err error
}{
{
name: "Empty verb and path",
params: []any{"", "hello"},
expectResult: false,
err: nil,
},
{
name: "Empty verb",
params: []any{"", "/somepath"},
expectResult: false,
err: nil,
},
{
name: "Empty path",
params: []any{"GET", ""},
expectResult: true,
err: nil,
},
{
name: "Valid verb and path",
params: []any{"GET", "/somepath"},
expectResult: false,
err: nil,
},
}
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
if err := InitRobertaInferencePipeline(tarFilePath); err != nil {
t.Fatalf("failed to initialize RobertaInferencePipeline: %v", err)
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, _ := IsAnomalous(tt.params...)
assert.Equal(t, tt.expectResult, result)
})
}
}

View file

@ -509,6 +509,13 @@ var exprFuncs = []exprCustomFunc{
new(func(string) *net.IPNet),
},
},
{
name: "IsAnomalous",
function: IsAnomalous,
signature: []interface{}{
new(func(string, string) (bool, error)),
},
},
{
name: "JA4H",
function: JA4H,
@ -518,6 +525,6 @@ var exprFuncs = []exprCustomFunc{
},
}
//go 1.20 "CutPrefix": strings.CutPrefix,
//go 1.20 "CutPrefix": strings.CutPrefix,
//go 1.20 "CutSuffix": strings.CutSuffix,
//"Cut": strings.Cut, -> returns more than 2 values, not supported by expr

View file

@ -38,9 +38,10 @@ import (
)
var (
dataFile map[string][]string
dataFileRegex map[string][]*regexp.Regexp
dataFileRe2 map[string][]*re2.Regexp
dataFile map[string][]string
dataFileRegex map[string][]*regexp.Regexp
dataFileRe2 map[string][]*re2.Regexp
mlRobertaModelFiles map[string]struct{}
)
// This is used to (optionally) cache regexp results for RegexpInFile operations
@ -125,6 +126,7 @@ func Init(databaseClient *database.Client) error {
dataFile = make(map[string][]string)
dataFileRegex = make(map[string][]*regexp.Regexp)
dataFileRe2 = make(map[string][]*re2.Regexp)
mlRobertaModelFiles = make(map[string]struct{})
dbClient = databaseClient
XMLCacheInit()
@ -205,6 +207,16 @@ func FileInit(fileFolder string, filename string, fileType string) error {
filepath := filepath.Join(fileFolder, filename)
if fileType == "ml_roberta_model" {
err := InitRobertaInferencePipeline(filepath)
if err != nil {
log.Errorf("unable to init roberta model : %s", err)
return err
}
mlRobertaModelFiles[filename] = struct{}{}
return nil
}
file, err := os.Open(filepath)
if err != nil {
return err
@ -300,6 +312,8 @@ func existsInFileMaps(filename string, ftype string) (bool, error) {
}
case "string":
_, ok = dataFile[filename]
case "ml_roberta_model":
_, ok = dataFile[filename]
default:
err = fmt.Errorf("unknown data type '%s' for : '%s'", ftype, filename)
}

101
pkg/ml/onnx.go Normal file
View file

@ -0,0 +1,101 @@
//go:build !no_mlsupport
package ml
import (
"fmt"
onnxruntime "github.com/crowdsecurity/go-onnxruntime"
)
type OrtSession struct {
ORTSession *onnxruntime.ORTSession
ORTEnv *onnxruntime.ORTEnv
ORTSessionOptions *onnxruntime.ORTSessionOptions
}
func NewOrtSession(modelPath string) (*OrtSession, error) {
ortEnv := onnxruntime.NewORTEnv(onnxruntime.ORT_LOGGING_LEVEL_ERROR, "development")
if ortEnv == nil {
return nil, fmt.Errorf("failed to create ORT environment")
}
ortSessionOptions := onnxruntime.NewORTSessionOptions()
if ortSessionOptions == nil {
ortEnv.Close()
return nil, fmt.Errorf("failed to create ORT session options")
}
fmt.Println("Creating ORT session from model path:", modelPath)
session, err := onnxruntime.NewORTSession(ortEnv, modelPath, ortSessionOptions)
if err != nil {
fmt.Println("Error creating ORT session")
ortEnv.Close()
ortSessionOptions.Close()
return nil, err
}
fmt.Println("Done creating ORT session")
return &OrtSession{
ORTSession: session,
ORTEnv: ortEnv,
ORTSessionOptions: ortSessionOptions,
}, nil
}
func (ort *OrtSession) Predict(inputs []onnxruntime.TensorValue) ([]onnxruntime.TensorValue, error) {
res, err := ort.ORTSession.Predict(inputs)
if err != nil {
return nil, err
}
return res, nil
}
func (ort *OrtSession) PredictLabel(inputIds []onnxruntime.TensorValue) (int, error) {
res, err := ort.Predict(inputIds)
if err != nil {
return 0, err
}
label, err := PredicitonToLabel(res)
if err != nil {
return 0, err
}
return label, nil
}
func GetTensorValue(input []int64, shape []int64) onnxruntime.TensorValue {
return onnxruntime.TensorValue{
Shape: shape,
Value: input,
}
}
func PredicitonToLabel(res []onnxruntime.TensorValue) (int, error) {
if len(res) != 1 {
return 0, fmt.Errorf("expected one output tensor, got %d", len(res))
}
outputTensor := res[0]
maxIndex := 0 // Assuming the output tensor shape is [1 2], and we want to find the index of the max value
maxProb := outputTensor.Value.([]float32)[0] // Assuming the values are float32
for i, value := range outputTensor.Value.([]float32) {
if value > maxProb {
maxProb = value
maxIndex = i
}
}
return maxIndex, nil
}
func (os *OrtSession) Close() {
os.ORTSession.Close()
os.ORTEnv.Close()
os.ORTSessionOptions.Close()
}

153
pkg/ml/robertapipeline.go Normal file
View file

@ -0,0 +1,153 @@
//go:build !no_ml_support
package ml
import (
"archive/tar"
"fmt"
"io"
"os"
"path/filepath"
"strings"
onnxruntime "github.com/crowdsecurity/go-onnxruntime"
)
type RobertaClassificationInferencePipeline struct {
inputShape []int64
tokenizer *Tokenizer
ortSession *OrtSession
}
var bundleFileList = []string{
"model.onnx",
"tokenizer.json",
"tokenizer_config.json",
}
func NewRobertaInferencePipeline(bundleFilePath string) (*RobertaClassificationInferencePipeline, error) {
tempDir, err := os.MkdirTemp("", "crowdsec_roberta_model_assets")
if err != nil {
return nil, fmt.Errorf("could not create temp directory: %v", err)
}
if err := extractTarFile(bundleFilePath, tempDir); err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to extract tar file: %v", err)
}
outputDir := filepath.Join(tempDir, strings.Split(filepath.Base(bundleFilePath), ".tar")[0])
for _, file := range bundleFileList {
if _, err := os.Stat(filepath.Join(outputDir, file)); os.IsNotExist(err) {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("missing required file: %s, in %s", file, outputDir)
}
}
ortSession, err := NewOrtSession(filepath.Join(outputDir, "model.onnx"))
if err != nil {
os.RemoveAll(tempDir)
return nil, err
}
tokenizer, err := NewTokenizer(outputDir)
if err != nil {
ortSession.Close()
os.RemoveAll(tempDir)
return nil, err
}
inputShape := []int64{1, int64(tokenizer.modelMaxLength)}
if err := os.RemoveAll(tempDir); err != nil {
ortSession.Close()
tokenizer.Close()
return nil, fmt.Errorf("could not remove temp directory: %v", err)
}
return &RobertaClassificationInferencePipeline{
inputShape: inputShape,
tokenizer: tokenizer,
ortSession: ortSession,
}, nil
}
func (r *RobertaClassificationInferencePipeline) Close() {
r.tokenizer.Close()
r.ortSession.Close()
}
func (pipeline *RobertaClassificationInferencePipeline) PredictLabel(text string) (int, error) {
options := EncodeOptions{
AddSpecialTokens: true,
PadToMaxLength: true, // TODO:= ONNX Input formats leads to segfault without this
ReturnAttentionMask: true,
Truncate: true,
}
ids, _, attentionMask, err := pipeline.tokenizer.Encode(text, options)
if err != nil {
fmt.Println(err)
fmt.Println("Error encoding text")
return 0, err
}
label, err := pipeline.ortSession.PredictLabel([]onnxruntime.TensorValue{
GetTensorValue(ids, pipeline.inputShape),
GetTensorValue(attentionMask, pipeline.inputShape),
})
if err != nil {
fmt.Println(err)
return 0, err
}
return label, nil
}
func extractTarFile(tarFilePath, outputDir string) error {
file, err := os.Open(tarFilePath)
if err != nil {
return fmt.Errorf("could not open tar file: %v", err)
}
defer file.Close()
tarReader := tar.NewReader(file)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar file: %v", err)
}
targetPath := filepath.Join(outputDir, header.Name)
switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(targetPath, 0755); err != nil {
return fmt.Errorf("could not create directory %s: %v", targetPath, err)
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil {
return fmt.Errorf("could not create directory %s: %v", filepath.Dir(targetPath), err)
}
outFile, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("could not create file %s: %v", targetPath, err)
}
if _, err := io.Copy(outFile, tarReader); err != nil {
outFile.Close()
return fmt.Errorf("could not copy data to file %s: %v", targetPath, err)
}
outFile.Close()
default:
fmt.Printf("Unsupported file type in tar: %s\n", header.Name)
}
}
return nil
}

View file

@ -0,0 +1,112 @@
package ml
import (
"fmt"
"log"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func BenchmarkPredictLabel(b *testing.B) {
log.Println("Starting benchmark for PredictLabel")
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
pipeline, err := NewRobertaInferencePipeline(tarFilePath)
if err != nil {
b.Fatalf("NewRobertaInferencePipeline returned error: %v", err)
}
defer pipeline.Close()
text := "POST /"
b.ResetTimer()
startTime := time.Now()
for n := 0; n < b.N; n++ {
if n%1000 == 0 {
log.Printf("Running iteration %d", n)
}
_, err := pipeline.PredictLabel(text)
if err != nil {
b.Fatalf("Prediction failed: %v", err)
}
}
var memStart runtime.MemStats
runtime.ReadMemStats(&memStart)
for n := 0; n < b.N; n++ {
_, err := pipeline.PredictLabel(text)
if err != nil {
b.Fatalf("Prediction failed: %v", err)
}
}
b.StopTimer()
var memEnd runtime.MemStats
runtime.ReadMemStats(&memEnd)
totalAlloc := memEnd.TotalAlloc - memStart.TotalAlloc
allocPerOp := totalAlloc / uint64(b.N)
totalTime := time.Since(startTime)
log.Printf("Total benchmark time: %s\n", totalTime)
log.Printf("Average time per prediction: %s\n", totalTime/time.Duration(b.N))
log.Printf("Number of operations: %d\n", b.N)
log.Printf("Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds())
log.Printf("Memory allocated per operation: %d bytes\n", allocPerOp)
log.Printf("Total memory allocated: %d bytes\n", totalAlloc)
fmt.Printf("Benchmark Results:\n")
fmt.Printf(" Total time: %s\n", totalTime)
fmt.Printf(" Average time per operation: %s\n", totalTime/time.Duration(b.N))
fmt.Printf(" Operations per second: %.2f ops/s\n", float64(b.N)/totalTime.Seconds())
fmt.Printf(" Memory allocated per operation: %d bytes\n", allocPerOp)
fmt.Printf(" Total memory allocated: %d bytes\n", totalAlloc)
}
func TestPredictLabel(t *testing.T) {
tests := []struct {
name string
text string
expectedID int
label int
}{
{
name: "Malicious request",
text: "GET /lib/vendor/phpunit/phpunit/src/Util/PHP/eval-stdin.php?",
expectedID: 0,
label: 1,
},
{
name: "Safe request",
text: "GET /online/_ui/responsive/theme-miglog/img/header+Navigation/icon-delivery.svg",
expectedID: 0,
label: 0,
},
}
tarFilePath := "tests/anomaly_detection_bundle_test.tar"
pipeline, err := NewRobertaInferencePipeline(tarFilePath)
if err != nil {
t.Fatalf("NewRobertaInferencePipeline returned error: %v", err)
}
defer pipeline.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
prediction, err := pipeline.PredictLabel(tt.text)
if err != nil {
t.Errorf("PredictLabel returned error: %v", err)
}
assert.Equal(t, tt.label, prediction, "Predicted label does not match the expected label")
})
}
}

Binary file not shown.

3830
pkg/ml/tests/tokenizer.json Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,57 @@
{
"add_prefix_space": false,
"added_tokens_decoder": {
"0": {
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"1": {
"content": "<pad>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"2": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"3": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false,
"special": true
},
"4": {
"content": "<mask>",
"lstrip": true,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"bos_token": "<s>",
"clean_up_tokenization_spaces": true,
"cls_token": "<s>",
"eos_token": "</s>",
"errors": "replace",
"mask_token": "<mask>",
"model_max_length": 512,
"pad_token": "<pad>",
"sep_token": "</s>",
"tokenizer_class": "RobertaTokenizer",
"trim_offsets": true,
"unk_token": "<unk>"
}

160
pkg/ml/tokenizer.go Normal file
View file

@ -0,0 +1,160 @@
//go:build !no_mlsupport
package ml
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strconv"
tokenizers "github.com/daulet/tokenizers"
)
type Tokenizer struct {
tk *tokenizers.Tokenizer
modelMaxLength int
padTokenID int
tokenizerClass string
}
type tokenizerConfig struct {
ModelMaxLen int `json:"model_max_length"`
PadToken string `json:"pad_token"`
TokenizerClass string `json:"tokenizer_class"`
AddedTokenDecoder map[string]map[string]interface{} `json:"added_tokens_decoder"`
}
func loadTokenizerConfig(filename string) (*tokenizerConfig, error) {
file, err := os.ReadFile(filename)
if err != nil {
fmt.Println("Error reading tokenizer config file")
return nil, err
}
config := &tokenizerConfig{}
if err := json.Unmarshal(file, config); err != nil {
fmt.Println("Error unmarshalling tokenizer config")
return nil, err
}
return config, nil
}
func findTokenID(tokens map[string]map[string]interface{}, tokenContent string) int {
for key, value := range tokens {
if content, ok := value["content"]; ok && content == tokenContent {
if tokenID, err := strconv.Atoi(key); err == nil {
return tokenID
}
}
}
return -1
}
func NewTokenizer(datadir string) (*Tokenizer, error) {
defaultMaxLen := 512
defaultPadTokenID := 1
defaultTokenizerClass := "RobertaTokenizer"
// check if tokenizer.json exists
tokenizerPath := filepath.Join(datadir, "tokenizer.json")
if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) {
return nil, fmt.Errorf("tokenizer.json not found in %s", datadir)
}
tk, err := tokenizers.FromFile(tokenizerPath)
if err != nil {
return nil, err
}
configFile := filepath.Join(datadir, "tokenizer_config.json")
config, err := loadTokenizerConfig(configFile)
if err != nil {
fmt.Println("Warning: Could not load tokenizer config, using default values.")
return &Tokenizer{
tk: tk,
modelMaxLength: defaultMaxLen,
padTokenID: defaultPadTokenID,
tokenizerClass: defaultTokenizerClass,
}, nil
}
// Use default values if any required config is missing
// modelMaxLen := 256
modelMaxLen := config.ModelMaxLen
if modelMaxLen == 0 {
modelMaxLen = defaultMaxLen
}
padTokenID := findTokenID(config.AddedTokenDecoder, config.PadToken)
if padTokenID == -1 {
padTokenID = defaultPadTokenID
}
tokenizerClass := config.TokenizerClass
if tokenizerClass == "" {
tokenizerClass = defaultTokenizerClass
}
return &Tokenizer{
tk: tk,
modelMaxLength: modelMaxLen,
padTokenID: padTokenID,
tokenizerClass: tokenizerClass,
}, nil
}
type EncodeOptions struct {
AddSpecialTokens bool
PadToMaxLength bool
ReturnAttentionMask bool
Truncate bool
}
func (t *Tokenizer) Encode(text string, options EncodeOptions) ([]int64, []string, []int64, error) {
if t.tk == nil {
return nil, nil, nil, fmt.Errorf("tokenizer is not initialized")
}
ids, tokens := t.tk.Encode(text, options.AddSpecialTokens)
// Truncate to max length (right truncation)
if len(ids) > t.modelMaxLength && options.Truncate {
ids = ids[:t.modelMaxLength]
tokens = tokens[:t.modelMaxLength]
}
//[]uint32 to []int64
int64Ids := make([]int64, len(ids))
for i, id := range ids {
int64Ids[i] = int64(id)
}
// Padding to max length
if options.PadToMaxLength && len(int64Ids) < t.modelMaxLength {
paddingLength := t.modelMaxLength - len(int64Ids)
for i := 0; i < paddingLength; i++ {
int64Ids = append(int64Ids, int64(t.padTokenID))
tokens = append(tokens, "<pad>")
}
}
// Creating attention mask
var attentionMask []int64
if options.ReturnAttentionMask {
attentionMask = make([]int64, len(int64Ids))
for i := range attentionMask {
if int64Ids[i] != int64(t.padTokenID) {
attentionMask[i] = 1
} else {
attentionMask[i] = 0
}
}
}
return int64Ids, tokens, attentionMask, nil
}
func (t *Tokenizer) Close() {
t.tk.Close()
}

105
pkg/ml/tokenizer_test.go Normal file
View file

@ -0,0 +1,105 @@
//go:build !no_mlsupport
package ml
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTokenize(t *testing.T) {
tests := []struct {
name string
inputText string
tokenizerPath string
encodeOptions EncodeOptions
expectedIds []int64
expectedTokens []string
expectedMask []int64
expectTruncation bool
}{
{
name: "Tokenize 'this is some text'",
inputText: "this is some text",
tokenizerPath: "tests/small-champion-model",
encodeOptions: EncodeOptions{
AddSpecialTokens: true,
PadToMaxLength: false,
ReturnAttentionMask: true,
Truncate: true,
},
expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 268, 488, 2},
expectedTokens: []string{"<s>", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "te", "xt", "</s>"},
expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
},
{
name: "Tokenize 'this is some new texts'",
inputText: "this is some new texts",
tokenizerPath: "tests/small-champion-model",
encodeOptions: EncodeOptions{
AddSpecialTokens: true,
PadToMaxLength: false,
ReturnAttentionMask: true,
Truncate: true,
},
expectedIds: []int64{0, 435, 774, 225, 774, 225, 501, 334, 225, 1959, 225, 268, 488, 87, 2},
expectedTokens: []string{"<s>", "th", "is", "Ġ", "is", "Ġ", "so", "me", "Ġ", "new", "Ġ", "te", "xt", "s", "</s>"},
expectedMask: []int64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
},
}
tokenizer, err := NewTokenizer("tests")
if err != nil {
t.Errorf("NewTokenizer returned error: %v", err)
return
}
defer tokenizer.Close()
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ids, tokens, attentionMask, err := tokenizer.Encode(tt.inputText, tt.encodeOptions)
if err != nil {
t.Errorf("Encode returned error: %v", err)
}
assert.Equal(t, tt.expectedIds, ids, "IDs do not match")
assert.Equal(t, tt.expectedTokens, tokens, "Tokens do not match")
if tt.encodeOptions.ReturnAttentionMask {
assert.Equal(t, tt.expectedMask, attentionMask, "Attention mask does not match")
}
})
}
}
func TestTokenizeLongString(t *testing.T) {
var builder strings.Builder
for i := 0; i < 1024; i++ {
builder.WriteString("a")
}
longString := builder.String()
tokenizer, err := NewTokenizer("tests")
if err != nil {
t.Errorf("NewTokenizer returned error: %v", err)
return
}
defer tokenizer.Close()
encodeOptions := EncodeOptions{
AddSpecialTokens: true,
PadToMaxLength: false,
ReturnAttentionMask: true,
Truncate: true,
}
ids, tokens, _, err := tokenizer.Encode(longString, encodeOptions)
if err != nil {
t.Errorf("Encode returned error: %v", err)
}
assert.Equal(t, 512, len(ids), "IDs length does not match for long string with truncation")
assert.Equal(t, 512, len(tokens), "IDs length does not match for long string with truncation")
}