mirror of
https://github.com/ollama/ollama.git
synced 2025-05-13 03:16:47 +02:00
model: fix issues with spm tokenizer for Gemma 3 (#10081)
This commit is contained in:
parent
b42970063d
commit
b51e0f397c
5 changed files with 175 additions and 113 deletions
|
@ -38,7 +38,6 @@ const (
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|
|
@ -55,7 +55,6 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||||
func New(c ml.Config) (model.Model, error) {
|
func New(c ml.Config) (model.Model, error) {
|
||||||
m := Model{
|
m := Model{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|
|
@ -45,7 +45,6 @@ func newTextModel(c ml.Config) *TextModel {
|
||||||
|
|
||||||
m := TextModel{
|
m := TextModel{
|
||||||
SentencePieceModel: model.NewSentencePieceModel(
|
SentencePieceModel: model.NewSentencePieceModel(
|
||||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
||||||
&model.Vocabulary{
|
&model.Vocabulary{
|
||||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
|
|
@ -1,29 +1,23 @@
|
||||||
package model
|
package model
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"iter"
|
"container/heap"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/dlclark/regexp2"
|
|
||||||
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const spmWhitespaceSep = "▁"
|
const spmWhitespaceSep = "▁"
|
||||||
|
|
||||||
func replaceWhitespaceBySeperator(s string) string {
|
|
||||||
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
|
||||||
}
|
|
||||||
|
|
||||||
type SentencePieceModel struct {
|
type SentencePieceModel struct {
|
||||||
maxTokenLen int
|
maxTokenLen int
|
||||||
pre *regexp2.Regexp
|
|
||||||
vocab *Vocabulary
|
vocab *Vocabulary
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||||
|
|
||||||
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||||
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||||
|
|
||||||
counter := map[int]int{}
|
counter := map[int]int{}
|
||||||
|
@ -44,7 +38,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
||||||
|
|
||||||
return SentencePieceModel{
|
return SentencePieceModel{
|
||||||
maxTokenLen: maxTokenLen,
|
maxTokenLen: maxTokenLen,
|
||||||
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
||||||
vocab: vocab,
|
vocab: vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,20 +46,9 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||||
return spm.vocab.Is(id, special)
|
return spm.vocab.Is(id, special)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
|
||||||
return func(yield func(string) bool) {
|
|
||||||
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
|
||||||
if !yield(m.String()) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||||
// TODO: process special tokens concurrently
|
|
||||||
id := spm.vocab.Encode(special)
|
id := spm.vocab.Encode(special)
|
||||||
for i := 0; i < len(fragments); i++ {
|
for i := 0; i < len(fragments); i++ {
|
||||||
frag := fragments[i]
|
frag := fragments[i]
|
||||||
|
@ -91,7 +73,6 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
slog.Debug("fragments", "frags", fragments)
|
|
||||||
|
|
||||||
var ids []int32
|
var ids []int32
|
||||||
for _, frag := range fragments {
|
for _, frag := range fragments {
|
||||||
|
@ -100,105 +81,96 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for split := range spm.split(frag.value) {
|
text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
|
||||||
split = replaceWhitespaceBySeperator(split)
|
|
||||||
|
|
||||||
var sb strings.Builder
|
if id := spm.vocab.Encode(text); id >= 0 {
|
||||||
sb.Write([]byte(split))
|
ids = append(ids, id)
|
||||||
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
continue
|
||||||
ids = append(ids, id)
|
}
|
||||||
continue
|
|
||||||
|
q := &queue{}
|
||||||
|
heap.Init(q)
|
||||||
|
|
||||||
|
runes := []rune(text)
|
||||||
|
merges := make([]merge, len(runes))
|
||||||
|
for r := range runes {
|
||||||
|
merges[r] = merge{
|
||||||
|
p: r - 1,
|
||||||
|
n: r + 1,
|
||||||
|
runes: []rune{runes[r]},
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
runes := []rune(sb.String())
|
pairwise := func(a, b int) *candidate {
|
||||||
pq := queue.NewWith(func(a, b any) int {
|
if a < 0 || b >= len(runes) {
|
||||||
priA := a.(*candidate)
|
|
||||||
priB := b.(*candidate)
|
|
||||||
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
})
|
|
||||||
|
|
||||||
merges := make([]merge, len(runes))
|
|
||||||
for r := range runes {
|
|
||||||
merges[r] = merge{
|
|
||||||
p: r - 1,
|
|
||||||
n: r + 1,
|
|
||||||
runes: []rune{runes[r]},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("tokenizer", "merges", merges)
|
|
||||||
|
|
||||||
pairwise := func(a, b int) *candidate {
|
|
||||||
if a < 0 || b >= len(runes) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
left, right := string(merges[a].runes), string(merges[b].runes)
|
|
||||||
if id := spm.vocab.Encode(left + right); id >= 0 {
|
|
||||||
return &candidate{
|
|
||||||
a: a,
|
|
||||||
b: b,
|
|
||||||
score: spm.vocab.Scores[id],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range len(runes) - 1 {
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
||||||
if pair := pairwise(i, i+1); pair != nil {
|
if id := spm.vocab.Encode(left + right); id >= 0 {
|
||||||
pq.Enqueue(pair)
|
return &candidate{
|
||||||
|
a: a,
|
||||||
|
b: b,
|
||||||
|
score: spm.vocab.Scores[id],
|
||||||
|
size: len(left) + len(right),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pqv := pq.Values()
|
return nil
|
||||||
for _, v := range pqv {
|
}
|
||||||
e := v.(*candidate)
|
|
||||||
slog.Debug("candidate", "candidate", e)
|
for i := range len(runes) - 1 {
|
||||||
|
if pair := pairwise(i, i+1); pair != nil {
|
||||||
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for q.Len() > 0 {
|
||||||
|
pair := heap.Pop(q).(*candidate)
|
||||||
|
left, right := merges[pair.a], merges[pair.b]
|
||||||
|
|
||||||
|
if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for !pq.Empty() {
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
||||||
v, _ := pq.Dequeue()
|
merges[pair.b].runes = nil
|
||||||
pair := v.(*candidate)
|
merges[pair.a].n = right.n
|
||||||
left, right := merges[pair.a], merges[pair.b]
|
if right.n < len(merges) {
|
||||||
|
merges[right.n].p = pair.a
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("pair", "left", left, "right", right)
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
||||||
if len(left.runes) == 0 || len(right.runes) == 0 {
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
||||||
|
heap.Push(q, pair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, merge := range merges {
|
||||||
|
if token := string(merge.runes); token != "" {
|
||||||
|
id := spm.vocab.Encode(token)
|
||||||
|
|
||||||
|
if id >= 0 {
|
||||||
|
ids = append(ids, id)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
// Fallback to byte tokenization
|
||||||
continue
|
var result []int32
|
||||||
}
|
for _, b := range []byte(token) {
|
||||||
|
byteToken := fmt.Sprintf("<0x%02X>", b)
|
||||||
merges[pair.a].runes = append(left.runes, right.runes...)
|
unknownID := spm.vocab.Encode(byteToken)
|
||||||
merges[pair.b].runes = nil
|
if unknownID >= 0 {
|
||||||
merges[pair.a].n = right.n
|
result = append(result, unknownID)
|
||||||
if right.n < len(merges) {
|
|
||||||
merges[right.n].p = pair.a
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
||||||
pq.Enqueue(pair)
|
|
||||||
}
|
|
||||||
|
|
||||||
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
||||||
pq.Enqueue(pair)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
slog.Debug("merges", "merges", merges)
|
|
||||||
|
|
||||||
for _, merge := range merges {
|
|
||||||
if len(merge.runes) > 0 {
|
|
||||||
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
||||||
ids = append(ids, id)
|
|
||||||
} else {
|
} else {
|
||||||
slog.Debug("missing token", "token", string(merge.runes))
|
slog.Debug("unknown byte token", "byte", b, "token", byteToken)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ids = append(ids, result...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -229,6 +201,30 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
type candidate struct {
|
type candidate struct {
|
||||||
a, b int
|
a, b int
|
||||||
score float32
|
score float32
|
||||||
|
size int
|
||||||
|
}
|
||||||
|
|
||||||
|
type queue []*candidate
|
||||||
|
|
||||||
|
func (q queue) Len() int { return len(q) }
|
||||||
|
|
||||||
|
func (q queue) Less(i, j int) bool {
|
||||||
|
return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
|
||||||
|
|
||||||
|
func (q *queue) Push(x interface{}) {
|
||||||
|
item := x.(*candidate)
|
||||||
|
*q = append(*q, item)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *queue) Pop() interface{} {
|
||||||
|
old := *q
|
||||||
|
n := len(old)
|
||||||
|
item := old[n-1]
|
||||||
|
*q = old[0 : n-1]
|
||||||
|
return item
|
||||||
}
|
}
|
||||||
|
|
||||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||||
|
@ -236,11 +232,26 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
data := spm.vocab.Decode(id)
|
data := spm.vocab.Decode(id)
|
||||||
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
||||||
if _, err := sb.WriteString(data); err != nil {
|
|
||||||
return "", err
|
// For tokenizers that use byte tokens like "<0xEA>"
|
||||||
|
// convert them to the partial unicode character
|
||||||
|
// so they are buffered correctly by the runner instead
|
||||||
|
// of being sent back to the api as "<0xEA>"
|
||||||
|
if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
|
||||||
|
byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to parse hex byte: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sb.WriteByte(byte(byteVal)); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := sb.WriteString(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
|
||||||
return sb.String(), nil
|
return sb.String(), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,8 +25,6 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
preTokenizer := `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
|
|
||||||
|
|
||||||
var v Vocabulary
|
var v Vocabulary
|
||||||
|
|
||||||
for _, piece := range spm.GetPieces() {
|
for _, piece := range spm.GetPieces() {
|
||||||
|
@ -47,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewSentencePieceModel(preTokenizer, &v)
|
return NewSentencePieceModel(&v)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSentencePieceEncode(t *testing.T) {
|
func TestSentencePieceEncode(t *testing.T) {
|
||||||
|
@ -116,3 +114,59 @@ func TestSentencePieceEncode(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||||
|
vocab := &Vocabulary{
|
||||||
|
Values: []string{
|
||||||
|
"normal",
|
||||||
|
"<0xEA>",
|
||||||
|
"<0x41>",
|
||||||
|
"<0xC3>",
|
||||||
|
"<0xA3>",
|
||||||
|
},
|
||||||
|
Types: []uint32{
|
||||||
|
TOKEN_TYPE_NORMAL,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
TOKEN_TYPE_BYTE,
|
||||||
|
},
|
||||||
|
Scores: []float32{0, 0, 0, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
spm := NewSentencePieceModel(vocab)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ids []int32
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single byte token",
|
||||||
|
ids: []int32{1},
|
||||||
|
expected: "\xea",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ASCII byte token",
|
||||||
|
ids: []int32{2},
|
||||||
|
expected: "A",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple byte tokens forming UTF-8 character",
|
||||||
|
ids: []int32{3, 4},
|
||||||
|
expected: "ã",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := spm.Decode(tt.ids)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||||||
|
}
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("got %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue