mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
160 lines
3.9 KiB
Go
160 lines
3.9 KiB
Go
//go:build !no_mlsupport
|
|
|
|
package ml
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
|
|
tokenizers "github.com/daulet/tokenizers"
|
|
)
|
|
|
|
type Tokenizer struct {
|
|
tk *tokenizers.Tokenizer
|
|
modelMaxLength int
|
|
padTokenID int
|
|
tokenizerClass string
|
|
}
|
|
|
|
type tokenizerConfig struct {
|
|
ModelMaxLen int `json:"model_max_length"`
|
|
PadToken string `json:"pad_token"`
|
|
TokenizerClass string `json:"tokenizer_class"`
|
|
AddedTokenDecoder map[string]map[string]interface{} `json:"added_tokens_decoder"`
|
|
}
|
|
|
|
func loadTokenizerConfig(filename string) (*tokenizerConfig, error) {
|
|
file, err := os.ReadFile(filename)
|
|
if err != nil {
|
|
fmt.Println("Error reading tokenizer config file")
|
|
return nil, err
|
|
}
|
|
config := &tokenizerConfig{}
|
|
if err := json.Unmarshal(file, config); err != nil {
|
|
fmt.Println("Error unmarshalling tokenizer config")
|
|
return nil, err
|
|
}
|
|
return config, nil
|
|
}
|
|
|
|
func findTokenID(tokens map[string]map[string]interface{}, tokenContent string) int {
|
|
for key, value := range tokens {
|
|
if content, ok := value["content"]; ok && content == tokenContent {
|
|
if tokenID, err := strconv.Atoi(key); err == nil {
|
|
return tokenID
|
|
}
|
|
}
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func NewTokenizer(datadir string) (*Tokenizer, error) {
|
|
defaultMaxLen := 512
|
|
defaultPadTokenID := 1
|
|
defaultTokenizerClass := "RobertaTokenizer"
|
|
|
|
// check if tokenizer.json exists
|
|
tokenizerPath := filepath.Join(datadir, "tokenizer.json")
|
|
if _, err := os.Stat(tokenizerPath); os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("tokenizer.json not found in %s", datadir)
|
|
}
|
|
|
|
tk, err := tokenizers.FromFile(tokenizerPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
configFile := filepath.Join(datadir, "tokenizer_config.json")
|
|
config, err := loadTokenizerConfig(configFile)
|
|
if err != nil {
|
|
fmt.Println("Warning: Could not load tokenizer config, using default values.")
|
|
return &Tokenizer{
|
|
tk: tk,
|
|
modelMaxLength: defaultMaxLen,
|
|
padTokenID: defaultPadTokenID,
|
|
tokenizerClass: defaultTokenizerClass,
|
|
}, nil
|
|
}
|
|
|
|
// Use default values if any required config is missing
|
|
// modelMaxLen := 256
|
|
modelMaxLen := config.ModelMaxLen
|
|
if modelMaxLen == 0 {
|
|
modelMaxLen = defaultMaxLen
|
|
}
|
|
|
|
padTokenID := findTokenID(config.AddedTokenDecoder, config.PadToken)
|
|
if padTokenID == -1 {
|
|
padTokenID = defaultPadTokenID
|
|
}
|
|
|
|
tokenizerClass := config.TokenizerClass
|
|
if tokenizerClass == "" {
|
|
tokenizerClass = defaultTokenizerClass
|
|
}
|
|
|
|
return &Tokenizer{
|
|
tk: tk,
|
|
modelMaxLength: modelMaxLen,
|
|
padTokenID: padTokenID,
|
|
tokenizerClass: tokenizerClass,
|
|
}, nil
|
|
}
|
|
|
|
type EncodeOptions struct {
|
|
AddSpecialTokens bool
|
|
PadToMaxLength bool
|
|
ReturnAttentionMask bool
|
|
Truncate bool
|
|
}
|
|
|
|
func (t *Tokenizer) Encode(text string, options EncodeOptions) ([]int64, []string, []int64, error) {
|
|
if t.tk == nil {
|
|
return nil, nil, nil, fmt.Errorf("tokenizer is not initialized")
|
|
}
|
|
|
|
ids, tokens := t.tk.Encode(text, options.AddSpecialTokens)
|
|
|
|
// Truncate to max length (right truncation)
|
|
if len(ids) > t.modelMaxLength && options.Truncate {
|
|
ids = ids[:t.modelMaxLength]
|
|
tokens = tokens[:t.modelMaxLength]
|
|
}
|
|
|
|
//[]uint32 to []int64
|
|
int64Ids := make([]int64, len(ids))
|
|
for i, id := range ids {
|
|
int64Ids[i] = int64(id)
|
|
}
|
|
|
|
// Padding to max length
|
|
if options.PadToMaxLength && len(int64Ids) < t.modelMaxLength {
|
|
paddingLength := t.modelMaxLength - len(int64Ids)
|
|
for i := 0; i < paddingLength; i++ {
|
|
int64Ids = append(int64Ids, int64(t.padTokenID))
|
|
tokens = append(tokens, "<pad>")
|
|
}
|
|
}
|
|
|
|
// Creating attention mask
|
|
var attentionMask []int64
|
|
if options.ReturnAttentionMask {
|
|
attentionMask = make([]int64, len(int64Ids))
|
|
for i := range attentionMask {
|
|
if int64Ids[i] != int64(t.padTokenID) {
|
|
attentionMask[i] = 1
|
|
} else {
|
|
attentionMask[i] = 0
|
|
}
|
|
}
|
|
}
|
|
|
|
return int64Ids, tokens, attentionMask, nil
|
|
}
|
|
|
|
func (t *Tokenizer) Close() {
|
|
t.tk.Close()
|
|
}
|