package model import ( "log/slog" "os" "path/filepath" "slices" "testing" "google.golang.org/protobuf/proto" "github.com/ollama/ollama/convert/sentencepiece" ) func loadSentencePieceVocab(t *testing.T) SentencePieceModel { t.Helper() bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model")) if err != nil { t.Fatal(err) } var spm sentencepiece.ModelProto if err := proto.Unmarshal(bts, &spm); err != nil { t.Fatal(err) } var v Vocabulary for _, piece := range spm.GetPieces() { v.Values = append(v.Values, piece.GetPiece()) v.Scores = append(v.Scores, piece.GetScore()) switch t := piece.GetType(); t { case sentencepiece.ModelProto_SentencePiece_UNKNOWN, sentencepiece.ModelProto_SentencePiece_CONTROL, sentencepiece.ModelProto_SentencePiece_UNUSED, sentencepiece.ModelProto_SentencePiece_BYTE: v.Types = append(v.Types, int32(t)) default: tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) // todo parse the special tokens file // - this will roundtrip correctly but the and // tokens aren't processed v.Types = append(v.Types, tt) } } return NewSentencePieceModel(&v) } func TestSentencePieceEncode(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) slog.SetDefault(logger) tokenizer := loadSentencePieceVocab(t) t.Run("basic roundtrip", func(t *testing.T) { t.Parallel() cases := []string{ "hello", "hello ", "hello ", " hello", " hello ", " hello ", "hello world", "请考试我的软件!12345", "你好", "Hello 你好 world!", "Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?", "Multilingual: 你好 こんにちは Привет Hola مرحبا", "Numbers and symbols: 123456789 +- */", "Special tokens: text ", "Code snippets: func main() { fmt.Println(\"Hello World\") }", "Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " + "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " + "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.", } for _, want := range cases { ids, err := tokenizer.Encode(want, true) if err != nil { t.Fatal(err) } if got, err := tokenizer.Decode(ids); err != nil { t.Fatal(err) } else if got != want { t.Errorf("got %q, want %q [%#v]", got, want, ids) } } }) t.Run("special tokens", func(t *testing.T) { type candidate struct { token string ids []int32 } cases := []candidate{ {"", []int32{2}}, {"", []int32{1}}, } for _, want := range cases { ids, err := tokenizer.Encode(want.token, true) if err != nil { t.Fatal(err) } if !slices.Equal(ids, want.ids) { t.Errorf("got %#v, want %#v", ids, want.ids) } } }) } func TestSentencePieceModelDecodeByteTokens(t *testing.T) { vocab := &Vocabulary{ Values: []string{ "normal", "<0xEA>", "<0x41>", "<0xC3>", "<0xA3>", }, Types: []int32{ TOKEN_TYPE_NORMAL, TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE, }, Scores: []float32{0, 0, 0, 0, 0}, } spm := NewSentencePieceModel(vocab) tests := []struct { name string ids []int32 expected string }{ { name: "single byte token", ids: []int32{1}, expected: "\xea", }, { name: "ASCII byte token", ids: []int32{2}, expected: "A", }, { name: "multiple byte tokens forming UTF-8 character", ids: []int32{3, 4}, expected: "ã", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result, err := spm.Decode(tt.ids) if err != nil { t.Errorf("failed to decode token IDs %v: %v", tt.ids, err) } if result != tt.expected { t.Errorf("got %q, want %q", result, tt.expected) } }) } }