mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
feat: add new Ollama engine using ggml through cgo This change introduces a new way to run pretrained models. It introduces 3 high level interfaces and a bunch of smaller helper interfaces to facilitate this. - `model.Model` defines the interface for a model architecture. Models such as `llama` and `mllama`, which are provided as examples, can implement the model's forward propagation in the `Forward` method. This method will be called to generate completions. This interface can be found in `model/model.go` - `ml.Backend` defines the interface for a backend tensor library, in this case `ggml`. Among other things, a Backend is responsible for loading a pretrained model into hardware (GPU, CPU, etc) and providing an interface for Models to access loaded tensors. This interface can be found in `ml/backend.go` - `ml.Tensor` defines the interface for a tensor and tensor operations This is the first implementation of the new engine. Follow up PRs will implement more features: - non-greedy sampling (#8410) - integration with Ollama and KV caching (#8301) - more model support (#9080) with more coming soon Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
313 lines
6.4 KiB
Go
313 lines
6.4 KiB
Go
package model
|
|
|
|
import (
|
|
"cmp"
|
|
"iter"
|
|
"log/slog"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
|
|
)
|
|
|
|
type Special int32
|
|
|
|
const (
|
|
SpecialBOS Special = iota
|
|
SpecialEOS
|
|
)
|
|
|
|
type TextProcessor interface {
|
|
Encode(string) ([]int32, error)
|
|
Decode([]int32) (string, error)
|
|
Is(uint32, Special) bool
|
|
}
|
|
|
|
type Vocabulary struct {
|
|
Values []string
|
|
Types []uint32
|
|
Scores []uint32
|
|
Merges []string
|
|
|
|
BOS, EOS uint32
|
|
|
|
specialOnce sync.Once
|
|
special []string
|
|
|
|
valuesOnce sync.Once
|
|
values map[string]int32
|
|
|
|
mergeOnce sync.Once
|
|
merge map[string]int32
|
|
}
|
|
|
|
func (v *Vocabulary) Is(id uint32, special Special) bool {
|
|
switch special {
|
|
case SpecialBOS:
|
|
return id == v.BOS
|
|
case SpecialEOS:
|
|
return id == v.EOS
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (v *Vocabulary) Encode(s string) int32 {
|
|
v.valuesOnce.Do(func() {
|
|
v.values = make(map[string]int32, len(v.Values))
|
|
for i, value := range v.Values {
|
|
v.values[value] = int32(i)
|
|
}
|
|
})
|
|
|
|
if id, ok := v.values[s]; ok {
|
|
return id
|
|
}
|
|
|
|
return -1
|
|
}
|
|
|
|
func (v *Vocabulary) Decode(id int32) string {
|
|
return v.Values[id]
|
|
}
|
|
|
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
|
v.specialOnce.Do(func() {
|
|
for i := range v.Values {
|
|
if v.Types[i] == 3 {
|
|
v.special = append(v.special, v.Values[i])
|
|
}
|
|
}
|
|
})
|
|
|
|
return v.special
|
|
}
|
|
|
|
func (v *Vocabulary) Merge(left, right string) int {
|
|
v.mergeOnce.Do(func() {
|
|
v.merge = make(map[string]int32, len(v.Merges))
|
|
for i, merge := range v.Merges {
|
|
v.merge[merge] = int32(i)
|
|
}
|
|
})
|
|
|
|
if id, ok := v.merge[left+" "+right]; ok {
|
|
return int(id)
|
|
}
|
|
|
|
return -1
|
|
}
|
|
|
|
type BytePairEncoding struct {
|
|
pre *regexp2.Regexp
|
|
vocab *Vocabulary
|
|
}
|
|
|
|
func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding {
|
|
return BytePairEncoding{
|
|
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
vocab: vocab,
|
|
}
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Is(id uint32, special Special) bool {
|
|
return bpe.vocab.Is(id, special)
|
|
}
|
|
|
|
func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
|
|
return func(yield func(string) bool) {
|
|
for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) {
|
|
if !yield(m.String()) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// fragment is a string fragment and their corresponding token IDs
|
|
type fragment struct {
|
|
value string
|
|
ids []int32
|
|
}
|
|
|
|
// pair is a pair of runes and its rank
|
|
type pair struct {
|
|
a, b int
|
|
rank int
|
|
value string
|
|
}
|
|
|
|
type merge struct {
|
|
p, n int
|
|
runes []rune
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
|
fragments := []fragment{{value: s}}
|
|
for _, special := range bpe.vocab.SpecialVocabulary() {
|
|
// TODO: process special tokens concurrently
|
|
id := bpe.vocab.Encode(special)
|
|
for i := 0; i < len(fragments); i++ {
|
|
frag := fragments[i]
|
|
if len(frag.ids) > 0 {
|
|
continue
|
|
}
|
|
|
|
var middle []fragment
|
|
switch i := strings.Index(frag.value, special); {
|
|
case i < 0:
|
|
middle = append(middle, frag)
|
|
case i > 0:
|
|
middle = append(middle, fragment{value: frag.value[:i]})
|
|
fallthrough
|
|
default:
|
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
|
if rest := frag.value[i+len(special):]; rest != "" {
|
|
middle = append(middle, fragment{value: rest})
|
|
}
|
|
}
|
|
|
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
|
}
|
|
}
|
|
|
|
var ids []int32
|
|
for _, frag := range fragments {
|
|
if len(frag.ids) > 0 {
|
|
ids = append(ids, frag.ids...)
|
|
slog.Debug("encoded", "text", frag.value, "ids", frag.ids, "special", true)
|
|
continue
|
|
}
|
|
|
|
for split := range bpe.split(frag.value) {
|
|
// TODO: process splits concurrently
|
|
var sb strings.Builder
|
|
for _, b := range []byte(split) {
|
|
r := rune(b)
|
|
switch {
|
|
case r == 0x00ad:
|
|
r = 0x0143
|
|
case r <= 0x0020:
|
|
r = r + 0x0100
|
|
case r >= 0x007e && r <= 0x00a0:
|
|
r = r + 0x00a2
|
|
}
|
|
|
|
sb.WriteRune(r)
|
|
}
|
|
|
|
// short circuit if the fragment is in the vocabulary
|
|
if id := bpe.vocab.Encode(sb.String()); id >= 0 {
|
|
ids = append(ids, id)
|
|
slog.Debug("encoded", "text", sb.String(), "ids", []int32{id})
|
|
continue
|
|
}
|
|
|
|
runes := []rune(sb.String())
|
|
merges := make([]merge, len(runes))
|
|
for r := range runes {
|
|
merges[r] = merge{
|
|
p: r - 1,
|
|
n: r + 1,
|
|
runes: []rune{runes[r]},
|
|
}
|
|
}
|
|
|
|
pairwise := func(a, b int) *pair {
|
|
if a < 0 || b >= len(runes) {
|
|
return nil
|
|
}
|
|
|
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
|
rank := bpe.vocab.Merge(left, right)
|
|
if rank < 0 {
|
|
return nil
|
|
}
|
|
|
|
return &pair{
|
|
a: a,
|
|
b: b,
|
|
rank: rank,
|
|
value: left + right,
|
|
}
|
|
}
|
|
|
|
pairs := heap.NewWith(func(i, j *pair) int {
|
|
return cmp.Compare(i.rank, j.rank)
|
|
})
|
|
|
|
for i := range len(runes) - 1 {
|
|
if pair := pairwise(i, i+1); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
}
|
|
|
|
for !pairs.Empty() {
|
|
pair, _ := pairs.Pop()
|
|
|
|
left, right := merges[pair.a], merges[pair.b]
|
|
if len(left.runes) == 0 || len(right.runes) == 0 ||
|
|
string(left.runes)+string(right.runes) != pair.value {
|
|
continue
|
|
}
|
|
|
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
|
merges[pair.b].runes = nil
|
|
|
|
merges[pair.a].n = right.n
|
|
if right.n < len(merges) {
|
|
merges[right.n].p = pair.a
|
|
}
|
|
|
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
|
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
pairs.Push(pair)
|
|
}
|
|
}
|
|
|
|
for _, merge := range merges {
|
|
if len(merge.runes) > 0 {
|
|
// TODO: handle the edge case where the rune isn't in the vocabulary
|
|
if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
ids = append(ids, id)
|
|
slog.Debug("encoded", "text", string(merge.runes), "ids", []int32{id})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|
var sb strings.Builder
|
|
for _, id := range ids {
|
|
for _, r := range bpe.vocab.Decode(id) {
|
|
switch {
|
|
case r == 0x0100:
|
|
// this produces 0x00 aka NULL
|
|
continue
|
|
case r == 0x0143:
|
|
r = 0x00ad
|
|
case r > 0x0100 && r <= 0x0120:
|
|
r = r - 0x0100
|
|
case r > 0x0120 && r <= 0x0142:
|
|
r = r - 0x00a2
|
|
}
|
|
|
|
// NOTE: not using WriteRune here because it writes the UTF-8
|
|
// encoding of the rune which is _not_ what we want
|
|
if err := sb.WriteByte(byte(r)); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
}
|
|
|
|
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
|
return sb.String(), nil
|
|
}
|