From d26c18e25c493ca55add9713ed151081c8de7ecf Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 23 Apr 2025 12:40:05 -0700 Subject: [PATCH] fix token type --- fs/config.go | 2 +- fs/ggml/ggml.go | 21 ++++++++++++++++----- model/models/gemma2/model.go | 2 +- model/models/gemma3/model.go | 2 +- model/models/gemma3/model_text.go | 2 +- model/models/llama/model.go | 2 +- model/models/mistral3/model_text.go | 2 +- model/models/mllama/model.go | 2 +- model/models/mllama/model_text.go | 8 ++++---- model/models/mllama/model_vision.go | 8 ++++---- model/process_text.go | 2 +- model/process_text_spm_test.go | 6 +++--- model/process_text_test.go | 2 +- 13 files changed, 36 insertions(+), 25 deletions(-) diff --git a/fs/config.go b/fs/config.go index bc5bfa550..89a1b134c 100644 --- a/fs/config.go +++ b/fs/config.go @@ -8,6 +8,6 @@ type Config interface { Bool(string, ...bool) bool Strings(string, ...[]string) []string - Uints(string, ...[]uint32) []uint32 + Ints(string, ...[]int32) []int32 Floats(string, ...[]float32) []float32 } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index fcb69decc..51cae50ab 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -108,6 +108,10 @@ func (kv KV) Strings(key string, defaultValue ...[]string) []string { return keyValue(kv, key, &array[string]{}).values } +func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 { + return keyValue(kv, key, &array[int32]{}).values +} + func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 { return keyValue(kv, key, &array[uint32]{}).values } @@ -124,11 +128,18 @@ func (kv KV) OllamaEngineRequired() bool { } type valueTypes interface { - string | uint32 | uint64 | float32 | bool | - *array[string] | *array[uint32] | *array[uint64] | *array[float32] | *array[bool] + uint8 | int8 | uint16 | int16 | + uint32 | int32 | uint64 | int64 | + string | float32 | float64 | bool } -func keyValue[T valueTypes](kv KV, key string, defaultValue ...T) T { +type arrayValueTypes interface { + *array[uint8] | *array[int8] | *array[uint16] | *array[int16] | + *array[uint32] | *array[int32] | *array[uint64] | *array[int64] | + *array[string] | *array[float32] | *array[float64] | *array[bool] +} + +func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) T { if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") { key = kv.Architecture() + "." + key } @@ -450,9 +461,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri case "mllama": var visionTokens, tiles uint64 = 1601, 4 - crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") + crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers") for i := range kv { - if slices.Contains(crossAttentionLayers, uint32(i)) { + if slices.Contains(crossAttentionLayers, int32(i)) { kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * 4 * // sizeof(float32) visionTokens * diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 752cb5cc2..d418f6827 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -42,7 +42,7 @@ func New(c fs.Config) (model.Model, error) { &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index cef058e2a..bf396b6a0 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) { &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), EOS: int32(1), diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 2d7bb20a7..c1e843d8f 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -49,7 +49,7 @@ func newTextModel(c fs.Config) *TextModel { &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), }, diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 68980dd76..3e5a54278 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -41,7 +41,7 @@ func New(c fs.Config) (model.Model, error) { c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index c256cbf17..1bf72acd8 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -152,7 +152,7 @@ func NewTextModel(c fs.Config) (*TextModel, error) { c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index a0fc6b693..149876c9c 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -43,7 +43,7 @@ func New(c fs.Config) (model.Model, error) { c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), - Types: c.Uints("tokenizer.ggml.token_type"), + Types: c.Ints("tokenizer.ggml.token_type"), Merges: c.Strings("tokenizer.ggml.merges"), BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 261897c33..490eb696c 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -177,7 +177,7 @@ type TextDecoder struct { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { layerType := selfAttentionLayer - if slices.Contains(opts.crossAttentionLayers, uint32(i)) { + if slices.Contains(opts.crossAttentionLayers, int32(i)) { layerType = crossAttentionLayer } @@ -202,7 +202,7 @@ type TextModelOptions struct { eps, ropeBase, ropeScale float32 ropeDim uint32 - crossAttentionLayers []uint32 + crossAttentionLayers []int32 } type TextModel struct { @@ -225,7 +225,7 @@ func newTextModel(c fs.Config) *TextModel { var decoderLayers []TextDecoderLayer for i := range c.Uint("block_count") { var textDecoderLayer TextDecoderLayer - if slices.Contains(c.Uints("attention.cross_attention_layers"), i) { + if slices.Contains(c.Ints("attention.cross_attention_layers"), int32(i)) { textDecoderLayer = &TextCrossAttentionDecoderLayer{} } else { textDecoderLayer = &TextSelfAttentionDecoderLayer{} @@ -244,7 +244,7 @@ func newTextModel(c fs.Config) *TextModel { ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count"), - crossAttentionLayers: c.Uints("attention.cross_attention_layers"), + crossAttentionLayers: c.Ints("attention.cross_attention_layers"), }, } } diff --git a/model/models/mllama/model_vision.go b/model/models/mllama/model_vision.go index 8b10bde88..bd3d150a3 100644 --- a/model/models/mllama/model_vision.go +++ b/model/models/mllama/model_vision.go @@ -96,10 +96,10 @@ type VisionEncoder struct { Layers []VisionEncoderLayer } -func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []uint32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) { +func (e *VisionEncoder) Forward(ctx ml.Context, hiddenState ml.Tensor, intermediateLayersIndices []int32, opts *VisionModelOptions) (ml.Tensor, []ml.Tensor) { var intermediateHiddenStates []ml.Tensor for i, layer := range e.Layers { - if slices.Contains(intermediateLayersIndices, uint32(i)) { + if slices.Contains(intermediateLayersIndices, int32(i)) { intermediateHiddenStates = append(intermediateHiddenStates, hiddenState.Reshape(ctx, append([]int{1}, hiddenState.Shape()...)...)) } @@ -154,7 +154,7 @@ type VisionModelOptions struct { imageSize, patchSize int eps float32 - intermediateLayersIndices []uint32 + intermediateLayersIndices []int32 } type VisionModel struct { @@ -229,7 +229,7 @@ func newVisionModel(c fs.Config) *VisionModel { eps: c.Float("vision.attention.layer_norm_epsilon"), - intermediateLayersIndices: c.Uints("vision.intermediate_layers_indices"), + intermediateLayersIndices: c.Ints("vision.intermediate_layers_indices"), }, } } diff --git a/model/process_text.go b/model/process_text.go index ce0b2d98a..90b220a2e 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -37,7 +37,7 @@ type TextProcessor interface { type Vocabulary struct { Values []string - Types []uint32 + Types []int32 Scores []float32 Merges []string diff --git a/model/process_text_spm_test.go b/model/process_text_spm_test.go index 4813333ee..50ac26787 100644 --- a/model/process_text_spm_test.go +++ b/model/process_text_spm_test.go @@ -35,9 +35,9 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel { sentencepiece.ModelProto_SentencePiece_CONTROL, sentencepiece.ModelProto_SentencePiece_UNUSED, sentencepiece.ModelProto_SentencePiece_BYTE: - v.Types = append(v.Types, uint32(t)) + v.Types = append(v.Types, int32(t)) default: - tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL) + tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL) // todo parse the special tokens file // - this will roundtrip correctly but the and // tokens aren't processed @@ -124,7 +124,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) { "<0xC3>", "<0xA3>", }, - Types: []uint32{ + Types: []int32{ TOKEN_TYPE_NORMAL, TOKEN_TYPE_BYTE, TOKEN_TYPE_BYTE, diff --git a/model/process_text_test.go b/model/process_text_test.go index f48303212..7e310b56e 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -28,7 +28,7 @@ func llama(t testing.TB) BytePairEncoding { t.Fatal(err) } - types := make([]uint32, len(vocab)) + types := make([]int32, len(vocab)) tokens := make([]string, len(vocab)) for token, id := range vocab { tokens[id] = token