diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index 572b886ed..000000000 --- a/cache/cache.go +++ /dev/null @@ -1,63 +0,0 @@ -package cache - -import ( - "github.com/ollama/ollama/ml" -) - -type Options struct { - Position int -} - -type Cache interface { - Sub(i int) Cache - Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) -} - -type Simple struct { - DType ml.DType - Capacity int - - keys, values []ml.Tensor -} - -func (c *Simple) Sub(i int) Cache { - if i >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...) - } - - return &Simple{ - keys: c.keys[i : i+1], - values: c.values[i : i+1], - Capacity: c.Capacity, - DType: c.DType, - } -} - -func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) { - if c.keys[0] == nil || c.values[0] == nil { - c.keys[0] = ctx.Zeros(c.DType, key.Dim(0)*key.Dim(1)*c.Capacity) - c.values[0] = ctx.Zeros(c.DType, value.Dim(0)*value.Dim(1)*c.Capacity) - } - - ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, key.Stride(2)*opts.Position, key.Dim(0)*key.Dim(1)*key.Dim(2)))) - ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, value.Stride(2)*opts.Position, value.Dim(0)*value.Dim(1)*value.Dim(2)))) - - n := min(c.Capacity, key.Dim(2)+opts.Position) - - key = c.keys[0].View(ctx, 0, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - n, - ) - - value = c.values[0].View(ctx, 0, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - n, - ) - - // TODO shift context if necessary - - return key, value -} diff --git a/cmd/cmd.go b/cmd/cmd.go index 17c607171..80ece4c60 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,9 +35,9 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llama" - "github.com/ollama/ollama/llama/runner" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -338,7 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - opts.MultiModal = len(info.ProjectorInfo) != 0 + // TODO(jessegross): We should either find another way to know if this is + // a vision model or remove the logic. Also consider that other modalities will + // need different behavior anyways. + opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine() opts.ParentModel = info.Details.ParentModel if interactive { diff --git a/cmd/runner/main.go b/cmd/runner/main.go index 34b0e9d21..fbfafc7ff 100644 --- a/cmd/runner/main.go +++ b/cmd/runner/main.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/ollama/ollama/llama/runner" + "github.com/ollama/ollama/runner" ) func main() { diff --git a/envconfig/config.go b/envconfig/config.go index 0ca3b64cd..fbd881ba7 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -165,6 +165,8 @@ var ( IntelGPU = Bool("OLLAMA_INTEL_GPU") // MultiUserCache optimizes prompt caching for multi-user scenarios MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") + // Enable the new Ollama engine + NewEngine = Bool("OLLAMA_NEW_ENGINE") ) func String(s string) func() string { @@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/kvcache/cache.go b/kvcache/cache.go new file mode 100644 index 000000000..5d8b2f9b5 --- /dev/null +++ b/kvcache/cache.go @@ -0,0 +1,54 @@ +package kvcache + +import ( + "errors" + + "github.com/ollama/ollama/ml" +) + +var ( + ErrKvCacheFull = errors.New("could not find a kv cache slot") + ErrNotSupported = errors.New("model does not support operation") +) + +type Cache interface { + // ** used by model implementations ** + + // SetLayer sets the active layer of the cache + SetLayer(layer int) + + // Get returns the history of key and value tensors plus a mask + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) + + // Put stores a batch of key and value in the cache + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Put(ctx ml.Context, key, value ml.Tensor) + + // ** cache management ** + + // Init sets up runtime parameters + Init(backend ml.Backend, dtype ml.DType, capacity int32) + + // Close closes the cache and frees resources associated with it + Close() + + // StartForward is called before the start of the model's forward pass. + // For each token in the coming batch, there must be a corresponding + // entry in positions and seqs. + StartForward(ctx ml.Context, positions []int32, seqs []int) error + + // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq + CopyPrefix(srcSeq, dstSeq int, len int32) + + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set + // endIndex to math.MaxInt32 to remove everything starting at beginIndex. + // + // If an error occurs, the entire context for the sequence should be + // removed by calling Remove(seq, 0, math.MaxInt32) + Remove(seq int, beginIndex, endIndex int32) error +} diff --git a/kvcache/causal.go b/kvcache/causal.go new file mode 100644 index 000000000..5d46f8d4c --- /dev/null +++ b/kvcache/causal.go @@ -0,0 +1,455 @@ +package kvcache + +import ( + "errors" + "fmt" + "log/slog" + "math" + "slices" + + "github.com/ollama/ollama/ml" +) + +type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) + +// Causal cache stores K and V tensors according to their position in the +// sequence. Returns the history and a mask for attending to past tokens +// +// The tensors are of shape embed dim, kv heads, batch size +// The mask is of shape history size, batch size +type Causal struct { + DType ml.DType + Capacity int32 + windowSize int32 + + // ** current forward pass ** + + // the active layer for Get and Put + curLayer int + + // starting location for data storage for this batch + curLoc int + + // size of the current batch + curBatchSize int + + // mask of the cache as used by this batch + curMask ml.Tensor + + // locations in the cache that are needed for this batch + curCellRange cellRange + + // ** cache metadata ** + + // for each possible location in the cache, stores the position and set of sequences + // that reference the data there + cells []cacheCell + + // maps from sequence to the range of locations where it is stored in the cache + cellRanges map[int]cellRange + + // ** cache data storage ** + + shiftFn shiftFn + backend ml.Backend + cacheCtx ml.Context + keys, values []ml.Tensor +} + +type cacheCell struct { + pos int32 + sequences []int +} + +type cellRange struct { + min int + max int +} + +func NewCausalCache(shift shiftFn) *Causal { + return &Causal{windowSize: math.MaxInt32, shiftFn: shift} +} + +func NewSWACache(windowSize int32, shift shiftFn) *Causal { + return &Causal{windowSize: windowSize, shiftFn: shift} +} + +func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + c.DType = dtype + c.Capacity = capacity + c.cells = make([]cacheCell, capacity) + c.cellRanges = make(map[int]cellRange) + c.backend = backend + c.cacheCtx = backend.NewContext() +} + +func (c *Causal) Close() { + c.cacheCtx.Close() +} + +func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + c.curBatchSize = len(positions) + + var err error + c.curLoc, err = c.findStartLoc() + if errors.Is(err, ErrKvCacheFull) { + c.defrag() + c.curLoc, err = c.findStartLoc() + } + if err != nil { + return err + } + + c.curCellRange = newRange() + for i, pos := range positions { + seq := seqs[i] + + c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} + + seqRange, ok := c.cellRanges[seq] + if !ok { + seqRange = newRange() + } + + if c.curLoc+i > seqRange.max { + seqRange.max = c.curLoc + i + } + if seqRange.max > c.curCellRange.max { + c.curCellRange.max = seqRange.max + } + + if c.curLoc+i < seqRange.min { + seqRange.min = c.curLoc + i + } + if seqRange.min < c.curCellRange.min { + c.curCellRange.min = seqRange.min + } + c.cellRanges[seq] = seqRange + } + + c.curMask, err = c.buildMask(ctx, positions, seqs) + + return err +} + +func newRange() cellRange { + return cellRange{ + min: math.MaxInt, + max: 0, + } +} + +// Find the first contiguous block of at least curBatchSize +func (c *Causal) findStartLoc() (int, error) { + var start, count int + for i := range c.cells { + if len(c.cells[i].sequences) == 0 { + count++ + if count >= c.curBatchSize { + return start, nil + } + } else { + start = i + 1 + count = 0 + } + } + + return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) +} + +// Builds a mask of history x batch indicating whether for each token in the batch the +// token in the history should apply. This is based on both the sequence and causality (the +// position of the history is not ahead of the token in the batch). +func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { + // TODO(jessegross): This does not do padding, which is required for flash attention + len := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, c.curBatchSize*len) + + for i := range c.curBatchSize { + for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { + if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || + c.cells[j].pos < positions[i]-c.windowSize { + mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + } + } + } + + return ctx.FromFloatSlice(mask, len, c.curBatchSize) +} + +func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { + for _, obj := range objs { + if obj == nil { + continue + } + + srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) + dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + + ctx.Forward(srcView.Copy(ctx, dstView)) + } +} + +func (c *Causal) defrag() { + slog.Debug("defragmenting kv cache") + + // Defrag strategy: + // - Search for empty holes at the beginning of the cache, + // filling them with active data starting at the end + // - If there are contiguous elements that need to be moved, + // combine them into a single operation by holding new moves + // until we see that the next one is non-contiguous + // - Fill up the context with the maximum number of operations it + // can hold then compute that and continue with a new context + // + // We could try to optimize placement by grouping blocks from + // the same sequences together but most likely the next forward + // pass will disrupt this anyways, so the real world benefit + // seems limited as this time. + + ctx := c.backend.NewContext() + + // For every move, 6 tensors are required per layer (2 views and a + // copy for each of k and v). + layers := 0 + for _, key := range c.keys { + if key == nil { + continue + } + layers++ + } + + maxMoves := ctx.MaxTensors() / (6 * layers) + moves := 0 + + var pendingSrc, pendingDst, pendingLen int + src := len(c.cells) - 1 + + for dst := 0; dst < src; dst++ { + if len(c.cells[dst].sequences) == 0 { + for ; src > dst; src-- { + if len(c.cells[src].sequences) != 0 { + c.cells[dst] = c.cells[src] + c.cells[src] = cacheCell{} + + if pendingLen > 0 { + if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen { + pendingSrc = src + pendingLen++ + break + } else { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + } + + pendingSrc = src + pendingDst = dst + pendingLen = 1 + + break + } + } + } + + if moves >= maxMoves { + ctx.Compute() + ctx.Close() + ctx = c.backend.NewContext() + + moves = 0 + } + } + + if pendingLen > 0 { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + + if moves > 0 { + ctx.Compute() + } + ctx.Close() + + // Reset range metadata + for seq := range c.cellRanges { + seqRange := newRange() + + for i, cell := range c.cells { + if slices.Contains(cell.sequences, seq) { + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[seq] = seqRange + } +} + +func (c *Causal) SetLayer(layer int) { + if layer >= len(c.keys) { + c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) + c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) + } + + c.curLayer = layer +} + +func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + key := c.keys[c.curLayer] + value := c.values[c.curLayer] + + key = key.View(ctx, key.Stride(2)*c.curCellRange.min, + key.Dim(0), key.Stride(1), + key.Dim(1), key.Stride(2), + c.curMask.Dim(0), + ) + + value = value.View(ctx, key.Stride(2)*c.curCellRange.min, + value.Dim(0), value.Stride(1), + value.Dim(1), value.Stride(2), + c.curMask.Dim(0), + ) + + return key, value, c.curMask +} + +func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { + if c.curBatchSize != key.Dim(2) { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + } + + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2)))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2)))) +} + +func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { + seqRange := newRange() + + for i := range c.cells { + // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end + if slices.Contains(c.cells[i].sequences, dstSeq) { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq }) + } + + if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len { + c.cells[i].sequences = append(c.cells[i].sequences, dstSeq) + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[dstSeq] = seqRange +} + +func (c *Causal) shift(seq int, beginIndex, offset int32) error { + if c.shiftFn == nil { + return ErrNotSupported + } + + ctx := c.backend.NewContext() + defer ctx.Close() + + seqRange := c.cellRanges[seq] + size := seqRange.max - seqRange.min + 1 + + offsets := make([]int32, size) + for i := range offsets { + cell := c.cells[seqRange.min+i] + + if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { + offsets[i] = offset + } + } + + kShift, err := ctx.FromIntSlice(offsets, len(offsets)) + if err != nil { + return err + } + + for i, key := range c.keys { + if key == nil { + continue + } + + key = key.View(ctx, key.Stride(2)*seqRange.min, + key.Dim(0), key.Stride(1), + key.Dim(1), key.Stride(2), + size, + ) + + roped, err := c.shiftFn(ctx, i, key, kShift) + if err != nil { + return err + } + + ctx.Forward(roped.Copy(ctx, key)) + } + + ctx.Compute() + + return nil +} + +func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + var offset int32 + if endIndex != math.MaxInt32 { + offset = beginIndex - endIndex + } + + seqRange := newRange() + + for i := range c.cells { + if slices.Contains(c.cells[i].sequences, seq) { + if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) + } else { + if c.cells[i].pos >= endIndex { + if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { + // TODO(jessegross): Need to be careful about data shared between sequences + return errors.New("shifting on cells shared by multiple sequences not yet implemented") + } + + c.cells[i].pos += offset + } + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + } + + if seqRange == newRange() { + delete(c.cellRanges, seq) + return nil + } + + c.cellRanges[seq] = seqRange + + if endIndex != math.MaxInt32 { + err := c.shift(seq, endIndex+offset, offset) + if err != nil { + return err + } + } + + return nil +} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go new file mode 100644 index 000000000..0b614df06 --- /dev/null +++ b/kvcache/causal_test.go @@ -0,0 +1,506 @@ +package kvcache + +import ( + "math" + "slices" + "testing" + + "github.com/ollama/ollama/ml" +) + +type testCase struct { + name string + in []float32 + inShape []int + seqs []int + pos []int32 + expected []float32 + expectedShape []int + expectedMask []float32 +} + +func TestStore(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + inShape: []int{2, 3, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + expectedShape: []int{2, 3, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{115, 215, 125, 225, 135, 235}, + inShape: []int{2, 3, 1}, + seqs: []int{0}, + pos: []int32{4}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, + expectedShape: []int{2, 3, 5}, + expectedMask: []float32{0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestSWA(t *testing.T) { + backend := &testBackend{} + cache := NewSWACache(1, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF32, 16) + + tests := []testCase{ + { + name: "SlidingWindow", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestSequences(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{2, 2}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestRemove(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.Add(ctx, shift), nil + }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + err := cache.Remove(0, 1, math.MaxInt32) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveEnd", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{1, 2}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) + + err = cache.Remove(0, 0, 1) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveMiddle", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{1, 2}, + expected: []float32{7, 8, 3, 4, 4}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestDefrag(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.Add(ctx, shift), nil + }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + inShape: []int{1, 1, 16}, + seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + expectedShape: []int{1, 1, 16}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + err := cache.Remove(0, 2, 4) + if err != nil { + panic(err) + } + + err = cache.Remove(0, 13, math.MaxInt32) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "Defrag", + in: []float32{17, 18, 19}, + inShape: []int{1, 1, 3}, + seqs: []int{0, 0, 0}, + pos: []int32{16, 17, 18}, + expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19}, + expectedShape: []int{1, 1, 16}, + expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestCopy(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + cache.CopyPrefix(0, 1, 2) + + tests = []testCase{ + { + name: "Copy", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{3, 4}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, test.pos, test.seqs) + if err != nil { + panic(err) + } + + cache.SetLayer(0) + tensor, _ := context.FromFloatSlice(test.in, test.inShape...) + cache.Put(context, tensor, tensor) + + out, _, mask := cache.Get(context) + + context.Forward(out) + context.Forward(mask) + context.Compute(out, mask) + + if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { + t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) + } + }) + } +} + +type testBackend struct{} + +func (b *testBackend) Config() ml.Config { + panic("not implemented") +} + +func (b *testBackend) Get(name string) ml.Tensor { + panic("not implemented") +} + +func (b *testBackend) NewContext() ml.Context { + return &testContext{} +} + +type testContext struct{} + +func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + total := 0 + + if len(shape) > 0 { + total = 1 + for _, s := range shape { + total *= s + } + } + + return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} +} + +func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { + t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) + + copy(t.data, s) + + return t, nil +} + +func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { + f := make([]float32, len(s)) + for i := range f { + f[i] = float32(s[i]) + } + + out, _ := c.FromFloatSlice(f, shape...) + out.(*testTensor).dtype = ml.DTypeI32 + + return out, nil +} + +func (c *testContext) Forward(ml.Tensor) {} + +func (c *testContext) Compute(...ml.Tensor) {} + +func (c *testContext) MaxTensors() int { + return 10 +} + +func (c *testContext) Close() {} + +type testTensor struct { + dtype ml.DType + elementSize int + data []float32 + shape []int +} + +func (t *testTensor) Dim(n int) int { + return t.shape[n] +} + +func (t *testTensor) Stride(n int) int { + stride := t.elementSize + for i := range n { + stride *= t.shape[i] + } + + return stride +} + +func (t *testTensor) Shape() []int { + return t.shape +} + +func (t *testTensor) DType() ml.DType { + return t.dtype +} + +func (t *testTensor) Bytes() []byte { + panic("not implemented") +} + +func (t *testTensor) Floats() []float32 { + out := make([]float32, len(t.data)) + copy(out, t.data) + return out +} + +func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) + + for i := range out.data { + out.data[i] = t.data[i] + t2.(*testTensor).data[i] + } + + return out +} + +func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { + offset /= t.elementSize + + var s []int + + switch len(shape) { + case 1: + s = []int{shape[0]} + case 5: + s = []int{shape[0], shape[2], shape[4]} + default: + panic("unsupported number of dimensions") + } + + context := &testContext{} + + view := context.Zeros(t.dtype, s...).(*testTensor) + view.data = t.data[offset : offset+len(view.data)] + + return view +} + +func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + copy(t2.(*testTensor).data, t.data) + return nil +} diff --git a/kvcache/encoder.go b/kvcache/encoder.go new file mode 100644 index 000000000..8a44c194b --- /dev/null +++ b/kvcache/encoder.go @@ -0,0 +1,97 @@ +package kvcache + +import ( + "github.com/ollama/ollama/ml" +) + +// Encoder cache stores K and V tensors that are position independent +// +// The tensors can be of any shape and will be returned as they were stored +// The mask is currently always nil +// +// Not currently safe for multiple sequences +type EncoderCache struct { + // ** current forward pass ** + + // the active layer for Get and Put + curLayer int + + // if something is stored during this pass, this + // will be the position (but there is no guarantee + // anything will be stored) + curPos int32 + + // ** cache metadata ** + + // was something stored in the cache? + encoderCached bool + + // position of the cached data + encoderPos int32 + + // ** cache data storage ** + + cacheCtx ml.Context + keys, values []ml.Tensor +} + +func NewEncoderCache() *EncoderCache { + return &EncoderCache{} +} + +func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + c.cacheCtx = backend.NewContext() +} + +func (c *EncoderCache) Close() { + c.cacheCtx.Close() +} + +func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + // The image is always in the first position + c.curPos = positions[0] + + return nil +} + +func (c *EncoderCache) SetLayer(layer int) { + if layer >= len(c.keys) { + c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) + c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) + } + + c.curLayer = layer +} + +func (c *EncoderCache) EncoderCached() bool { + return c.encoderCached +} + +func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.keys[c.curLayer], c.values[c.curLayer], nil +} + +func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.encoderPos = c.curPos + c.encoderCached = true + + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) + c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer])) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer])) +} + +func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { + panic("encoder cache does not support multiple sequences") +} + +func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { + if c.encoderPos >= beginIndex && c.encoderPos < endIndex { + c.encoderCached = false + } + + return nil +} diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go new file mode 100644 index 000000000..2d4c1089a --- /dev/null +++ b/kvcache/wrapper.go @@ -0,0 +1,93 @@ +package kvcache + +import ( + "math" + + "github.com/ollama/ollama/ml" +) + +// Wrapper cache is a container for multiple types of caches, +// such as for the encoding and decoding portions of a model. +type WrapperCache struct { + // caches we are wrapping + caches []Cache + + // cache to be used for this layer + curType int +} + +func NewWrapperCache(caches ...Cache) *WrapperCache { + return &WrapperCache{ + caches: caches, + } +} + +func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + for _, cache := range c.caches { + cache.Init(backend, dtype, capacity) + } +} + +func (c *WrapperCache) Close() { + for _, cache := range c.caches { + cache.Close() + } +} + +func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + for i, cache := range c.caches { + err := cache.StartForward(ctx, positions, seqs) + if err != nil { + // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail + for j := i - 1; j >= 0; j-- { + for k := range positions { + _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + } + } + return err + } + } + + c.curType = 0 + return nil +} + +func (c *WrapperCache) SetLayer(layer int) { + for _, cache := range c.caches { + cache.SetLayer(layer) + } +} + +func (c *WrapperCache) SetLayerType(layerType int) { + c.curType = layerType +} + +func (c *WrapperCache) UnderlyingCache() Cache { + return c.caches[c.curType] +} + +func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.caches[c.curType].Get(ctx) +} + +func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.caches[c.curType].Put(ctx, key, value) +} + +func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { + for _, cache := range c.caches { + cache.CopyPrefix(srcSeq, dstSeq, len) + } +} + +func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { + // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail + for _, cache := range c.caches { + err := cache.Remove(seq, beginIndex, endIndex) + if err != nil { + return err + } + } + + return nil +} diff --git a/llm/server.go b/llm/server.go index 765d7bb9b..50ba91f18 100644 --- a/llm/server.go +++ b/llm/server.go @@ -275,6 +275,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } finalParams := []string{"runner"} + if envconfig.NewEngine() { + finalParams = append(finalParams, "--ollama-engine") + } finalParams = append(finalParams, params...) finalParams = append(finalParams, "--port", strconv.Itoa(port)) diff --git a/ml/backend.go b/ml/backend.go index acfb67637..0e99ab5a8 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -50,6 +50,7 @@ type Context interface { Forward(Tensor) Compute(...Tensor) + MaxTensors() int Close() } @@ -118,7 +119,7 @@ type DumpOptions struct { Precision int } -func Dump(t Tensor, opts ...DumpOptions) string { +func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { if len(opts) < 1 { opts = append(opts, DumpOptions{ Items: 3, @@ -128,11 +129,17 @@ func Dump(t Tensor, opts ...DumpOptions) string { switch t.DType() { case DTypeF32: - return dump[[]float32](t, opts[0].Items, func(f float32) string { + return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { + return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) + }) + case DTypeF16: + f32 := ctx.Zeros(DTypeF32, t.Shape()...) + f32 = t.Copy(ctx, f32) + return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeI32: - return dump[[]int32](t, opts[0].Items, func(i int32) string { + return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string { return strconv.FormatInt(int64(i), 10) }) default: @@ -140,10 +147,10 @@ func Dump(t Tensor, opts ...DumpOptions) string { } } -func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { - bts := t.Bytes() - if bts == nil { - return "" +func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string { + if t.Bytes() == nil { + ctx.Forward(t) + ctx.Compute(t) } s := make(S, mul(t.Shape()...)) @@ -191,7 +198,8 @@ func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { type DType int const ( - DTypeF32 DType = iota + DTypeOther DType = iota + DTypeF32 + DTypeF16 DTypeI32 - DTypeOther ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 609570672..6a727a60c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -258,6 +258,10 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } +func (c *Context) MaxTensors() int { + return c.nodes +} + func shapeToGGML(shape []int) *C.int64_t { sh := make([]C.int64_t, len(shape)) for i, s := range shape { @@ -282,6 +286,8 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { switch dtype { case ml.DTypeF32: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) + case ml.DTypeF16: + t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeI32: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) default: @@ -389,6 +395,8 @@ func (t *Tensor) DType() ml.DType { switch t.t._type { case C.GGML_TYPE_F32: return ml.DTypeF32 + case C.GGML_TYPE_F16: + return ml.DTypeF16 case C.GGML_TYPE_I32: return ml.DTypeI32 default: @@ -580,9 +588,14 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi ropeFactors = &Tensor{} } + dequant := t.t + if C.ggml_is_quantized(t.t._type) { + dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) + } + return &Tensor{ t: C.ggml_rope_ext( - ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, + ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), 131072, // YaRN n_ctx_train ropeTypeNorm, // ROPE_TYPE_NORM diff --git a/model/model.go b/model/model.go index 4a86d7d60..8a8c9b297 100644 --- a/model/model.go +++ b/model/model.go @@ -1,6 +1,7 @@ package model import ( + "errors" "fmt" "image" _ "image/jpeg" @@ -15,102 +16,42 @@ import ( _ "golang.org/x/image/tiff" _ "golang.org/x/image/webp" - "github.com/ollama/ollama/cache" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" ) -type Cache struct { - cache.Cache - cache.Options -} - -func (c Cache) Sub(i int) Cache { - if c.Cache != nil { - return Cache{ - Cache: c.Cache.Sub(i), - Options: c.Options, - } - } - - return c -} - -func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) { - if c.Cache != nil { - return c.Cache.Put(ctx, key, value, opts) - } - - return key, value -} - type Options struct { - inputs []int32 - - Offset int + Inputs []int32 + Positions []int32 + Sequences []int + Outputs []int32 Images []image.Image - - Cache } -func (opts Options) Inputs() []int32 { - return opts.inputs[opts.Offset:] -} - -func (opts Options) Positions() []int32 { - positions := make([]int32, len(opts.inputs)-opts.Offset) - for i := range positions { - positions[i] = int32(opts.Offset + i) - } - - return positions -} - -type OptionsFunc func(Model, *Options) - -func WithInputIDs(ids []int32) OptionsFunc { - return func(m Model, opts *Options) { - opts.inputs = ids - } -} - -func WithOffset(offset int) OptionsFunc { - return func(m Model, opts *Options) { - opts.Offset = offset - opts.Cache.Position = offset - } -} - -func WithImage(img image.Image) OptionsFunc { - return func(m Model, opts *Options) { - opts.Images = append(opts.Images, img) - } -} - -func WithCache(c cache.Cache) OptionsFunc { - return func(m Model, opts *Options) { - opts.Cache = Cache{ - Cache: c, - Options: cache.Options{ - Position: opts.Offset, - }, - } - } +type config struct { + Cache kvcache.Cache } type Base struct { b ml.Backend + config } func (m *Base) Backend() ml.Backend { return m.b } +func (m *Base) Config() config { + return m.config +} + type Model interface { Forward(ml.Context, Options) (ml.Tensor, error) Backend() ml.Backend + Config() config } var models = make(map[string]func(ml.Config) (Model, error)) @@ -146,12 +87,14 @@ func New(s string) (Model, error) { return nil, err } + base := Base{b: b, config: m.Config()} + v := reflect.ValueOf(m) - v.Elem().Set(populateFields(b, v.Elem())) + v.Elem().Set(populateFields(base, v.Elem())) return m, nil } -func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { +func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() if t.Kind() == reflect.Struct { @@ -170,7 +113,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { } if tt == reflect.TypeOf((*Base)(nil)).Elem() { - vv.Set(reflect.ValueOf(Base{b: b})) + vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { var fn func([]Tag) [][]string fn = func(tags []Tag) (values [][]string) { @@ -196,21 +139,21 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { names := fn(tagsCopy) for _, name := range names { - if tensor := b.Get(strings.Join(name, ".")); tensor != nil { + if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { slog.Debug("found tensor", "", tensor) vv.Set(reflect.ValueOf(tensor)) break } } } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { - setPointer(b, vv, tagsCopy) + setPointer(base, vv, tagsCopy) } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) } } } @@ -228,7 +171,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { return v } -func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { +func setPointer(base Base, v reflect.Value, tags []Tag) { vv := v if v.Kind() == reflect.Interface { if v.IsNil() { @@ -243,7 +186,7 @@ func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { vv = reflect.New(v.Type().Elem()).Elem() } - if f := populateFields(b, vv, tags...); f.CanAddr() { + if f := populateFields(base, vv, tags...); f.CanAddr() { v.Set(f.Addr()) } } @@ -277,18 +220,27 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) { - var opts Options - for _, optsFunc := range optsFuncs { - optsFunc(m, &opts) +func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { + if len(opts.Positions) != len(opts.Sequences) { + return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) + } + + if len(opts.Positions) < 1 { + return nil, errors.New("batch size cannot be less than 1") + } + + cache := m.Config().Cache + if cache != nil { + err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + if err != nil { + return nil, err + } } - ctx := m.Backend().NewContext() t, err := m.Forward(ctx, opts) if err != nil { return nil, err } - defer ctx.Close() ctx.Forward(t) ctx.Compute(t) diff --git a/model/model_test.go b/model/model_test.go index 2ba12acde..02b8aa3c2 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) { var m fakeModel v := reflect.ValueOf(&m) - v.Elem().Set(populateFields(&fakeBackend{ + v.Elem().Set(populateFields(Base{b: &fakeBackend{ names: []string{ "input.weight", "blk.0.attn_q.weight", @@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) { "output_norm.weight", "output.weight", }, - }, v.Elem())) + }}, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, @@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) { m := fakeModel{} v := reflect.ValueOf(&m) - v.Elem().Set(populateFields(&fakeBackend{ + v.Elem().Set(populateFields(Base{b: &fakeBackend{ names: []string{ "input.weight", }, - }, v.Elem())) + }}, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6efcc9bb7..b2c5c2c7b 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -3,6 +3,7 @@ package llama import ( "math" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" @@ -28,7 +29,7 @@ type Model struct { } func New(c ml.Config) (model.Model, error) { - return &Model{ + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ @@ -49,7 +50,11 @@ func New(c ml.Config) (model.Model, error) { ropeScale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count"), }, - }, nil + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + + return &m, nil } type SelfAttention struct { @@ -59,7 +64,7 @@ type SelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -74,7 +79,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k, v = cache.Put(ctx, k, v, cache.Options) + cache.Put(ctx, k, v) + k, v, mask := cache.Get(ctx) q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -82,6 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten kq := k.MulmatFullPrec(ctx, q) kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) + kq = kq.Add(ctx, mask) kq = kq.Softmax(ctx) kqv := v.Mulmat(ctx, kq) @@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten return sa.Output.Forward(ctx, kqv) } +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil +} + type MLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` @@ -109,7 +120,7 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { - inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) + inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) + positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } @@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) for i, layer := range m.Layers { - hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options) + m.Cache.SetLayer(i) + hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.Output.Forward(ctx, hiddenState) - outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) + outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index e5b275b0b..a1460d940 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -1,6 +1,7 @@ package mllama import ( + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" @@ -18,8 +19,13 @@ type Model struct { ImageProcessor } +const ( + crossAttentionLayer = iota + selfAttentionLayer +) + func New(c ml.Config) (model.Model, error) { - return &Model{ + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ @@ -33,7 +39,11 @@ func New(c ml.Config) (model.Model, error) { ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), TextModel: newTextModel(c), - }, nil + } + + m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + + return &m, nil } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { @@ -73,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates) } - inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) + inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) + positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } // TODO: attention mask, cross attention mask - hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache) + hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)) - outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) + outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 2b05a60ea..1e48086a3 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -4,9 +4,9 @@ import ( "math" "slices" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model" ) type TextSelfAttention struct { @@ -16,7 +16,7 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -31,7 +31,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key, value = cache.Put(ctx, key, value, cache.Options) + cache.Put(ctx, key, value) + key, value, mask := cache.Get(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -39,11 +40,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas scores := key.MulmatFullPrec(ctx, query) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - - if mask != nil { - scores = scores.Add(ctx, mask) - } - + scores = scores.Add(ctx, mask) scores = scores.Softmax(ctx) attention := value.Mulmat(ctx, scores) @@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas return sa.Output.Forward(ctx, attention) } +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + // This will only get called for layers in the cache, which are just the self attention layers + return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil +} + type TextMLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` @@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct { MLP *TextMLP } -func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -94,23 +96,29 @@ type TextCrossAttention struct { Output *nn.Linear `gguf:"cross_attn_o_proj"` } -func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) query := ca.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - key := ca.Key.Forward(ctx, crossAttentionStates) - key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) - key = ca.KeyNorm.Forward(ctx, key, opts.eps) + var key, value ml.Tensor + if crossAttentionStates != nil { + numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) - value := ca.Value.Forward(ctx, crossAttentionStates) - value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + key = ca.Key.Forward(ctx, crossAttentionStates) + key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + key = ca.KeyNorm.Forward(ctx, key, opts.eps) - // TODO cache key, value + value = ca.Value.Forward(ctx, crossAttentionStates) + value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + + cache.Put(ctx, key, value) + } else { + key, value, _ = cache.Get(ctx) + } query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct { MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` } -func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, } type TextDecoderLayer interface { - Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor + Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor } type TextDecoder struct { Layers []TextDecoderLayer } -func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { - if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil { - hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts) + layerType := selfAttentionLayer + if slices.Contains(opts.crossAttentionLayers, uint32(i)) { + layerType = crossAttentionLayer + } + + cache.SetLayer(i) + cache.SetLayerType(layerType) + + if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) } } @@ -189,7 +205,7 @@ type TextModel struct { *TextModelOptions } -func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/llama/runner/README.md b/runner/README.md similarity index 100% rename from llama/runner/README.md rename to runner/README.md diff --git a/llama/runner/stop.go b/runner/common/stop.go similarity index 87% rename from llama/runner/stop.go rename to runner/common/stop.go index 8dcb08d33..3f27a286e 100644 --- a/llama/runner/stop.go +++ b/runner/common/stop.go @@ -1,10 +1,10 @@ -package runner +package common import ( "strings" ) -func findStop(sequence string, stops []string) (bool, string) { +func FindStop(sequence string, stops []string) (bool, string) { for _, stop := range stops { if strings.Contains(sequence, stop) { return true, stop @@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) { return false, "" } -func containsStopSuffix(sequence string, stops []string) bool { +func ContainsStopSuffix(sequence string, stops []string) bool { for _, stop := range stops { for i := 1; i <= len(stop); i++ { if strings.HasSuffix(sequence, stop[:i]) { @@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func truncateStop(pieces []string, stop string) ([]string, bool) { +func TruncateStop(pieces []string, stop string) ([]string, bool) { joined := strings.Join(pieces, "") index := strings.Index(joined, stop) @@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) { return result, tokenTruncated } -func incompleteUnicode(token string) bool { +func IncompleteUnicode(token string) bool { incomplete := false // check if there is incomplete UTF-8 character at the end diff --git a/llama/runner/stop_test.go b/runner/common/stop_test.go similarity index 96% rename from llama/runner/stop_test.go rename to runner/common/stop_test.go index 31dc161f3..8df267eb4 100644 --- a/llama/runner/stop_test.go +++ b/runner/common/stop_test.go @@ -1,4 +1,4 @@ -package runner +package common import ( "reflect" @@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, resultTrunc := truncateStop(tt.pieces, tt.stop) + result, resultTrunc := TruncateStop(tt.pieces, tt.stop) if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) } @@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := incompleteUnicode(tt.input) + result := IncompleteUnicode(tt.input) if result != tt.expected { t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected) } diff --git a/llama/runner/cache.go b/runner/llamarunner/cache.go similarity index 99% rename from llama/runner/cache.go rename to runner/llamarunner/cache.go index e8a2d2994..d29e94b6b 100644 --- a/llama/runner/cache.go +++ b/runner/llamarunner/cache.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "errors" diff --git a/llama/runner/cache_test.go b/runner/llamarunner/cache_test.go similarity index 99% rename from llama/runner/cache_test.go rename to runner/llamarunner/cache_test.go index 9c838ed33..c0656dd97 100644 --- a/llama/runner/cache_test.go +++ b/runner/llamarunner/cache_test.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "testing" diff --git a/llama/runner/image.go b/runner/llamarunner/image.go similarity index 99% rename from llama/runner/image.go rename to runner/llamarunner/image.go index c1932443c..e7e30a4d8 100644 --- a/llama/runner/image.go +++ b/runner/llamarunner/image.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "errors" diff --git a/llama/runner/image_test.go b/runner/llamarunner/image_test.go similarity index 99% rename from llama/runner/image_test.go rename to runner/llamarunner/image_test.go index d5c3bc1e2..2e1efaec8 100644 --- a/llama/runner/image_test.go +++ b/runner/llamarunner/image_test.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "reflect" diff --git a/llama/runner/runner.go b/runner/llamarunner/runner.go similarity index 98% rename from llama/runner/runner.go rename to runner/llamarunner/runner.go index 60ae88dac..93d6bfabe 100644 --- a/llama/runner/runner.go +++ b/runner/llamarunner/runner.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "context" @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/runner/common" ) // input is an element of the prompt to process, either @@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") - if ok, stop := findStop(sequence, seq.stop); ok { + if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) var tokenTruncated bool origLen := len(seq.pendingResponses) - seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) newLen := len(seq.pendingResponses) // Update the cache based on the tokens that will be returned: @@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if containsStopSuffix(sequence, seq.stop) { + if common.ContainsStopSuffix(sequence, seq.stop) { continue } - if incompleteUnicode(sequence) { + if common.IncompleteUnicode(sequence) { continue } @@ -885,9 +886,6 @@ func (s *Server) loadModel( } func Execute(args []string) error { - if args[0] == "runner" { - args = args[1:] - } fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") ppath := fs.String("mmproj", "", "Path to projector binary file") diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go new file mode 100644 index 000000000..e1fa98b1a --- /dev/null +++ b/runner/ollamarunner/cache.go @@ -0,0 +1,280 @@ +package ollamarunner + +import ( + "errors" + "fmt" + "log/slog" + "math" + "reflect" + "time" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" +) + +type InputCache struct { + // context window size (per slot) + numCtx int32 + + // does the cache store data or do we need to always send the full input? + // note that when enabled is false the underlying cache may either be nil + // or a non-nil dummy that doesn't actually store anything + enabled bool + + // individual KV caches + slots []InputCacheSlot + + // optimize cache eviction for multiple users + multiUserCache bool + + cache kvcache.Cache +} + +func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { + if kvSize/int32(numSlots) < 1 { + return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + } + + slots := make([]InputCacheSlot, numSlots) + + for i := range slots { + slots[i] = InputCacheSlot{ + Id: i, + Inputs: make([]input, 0), + } + } + + cache := model.Config().Cache + if cache != nil { + cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) + } + + return &InputCache{ + numCtx: kvSize / int32(numSlots), + enabled: cache != nil, + slots: slots, + multiUserCache: multiUserCache, + cache: cache, + }, nil +} + +func kvCacheTypeFromStr(s string) ml.DType { + switch s { + case "q8_0": + panic("kv cache quantization not yet implemented") + case "q4_0": + panic("kv cache quantization not yet implemented") + default: + return ml.DTypeF16 + } +} + +func (c *InputCache) Close() { + c.cache.Close() +} + +// Locking: Operations on InputCacheSlot (including finding one +// through LoadCacheSlot) require a lock to be be held that serializes +// these operations with each other and processBatch + +type InputCacheSlot struct { + // Index in the KV cache + Id int + + // Inputs that are stored in the KV cache + Inputs []input + + // is this cache actively being processed as part of a sequence? + InUse bool + + // last time this cache was used (as of start of processing) + lastUsed time.Time +} + +func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { + var slot *InputCacheSlot + var numPast int32 + var err error + + // In single-user scenarios, the longest cache slot works fine for getting good input + // cache hit rates and it keeps the footprint of the cache small, which improves throughput. + // For multiple users, the "best" cache slot produces better input cache hit rates + // at the cost of worse performance when we miss the input cache. + if !c.multiUserCache { + slot, numPast, err = c.findLongestCacheSlot(prompt) + } else { + slot, numPast, err = c.findBestCacheSlot(prompt) + } + if err != nil { + return nil, nil, err + } + + if !cachePrompt { + numPast = 0 + } + + slot.InUse = true + slot.lastUsed = time.Now() + + if numPast == int32(len(prompt)) { + // Leave one input to sample so we can get a response + numPast-- + } + + if c.cache != nil { + err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) + if err != nil { + // Some models don't support partial erasure + err = c.cache.Remove(slot.Id, 0, math.MaxInt32) + if err != nil { + return nil, nil, err + } + numPast = 0 + } + } + + slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt), + "used", numPast, "remaining", int32(len(prompt))-numPast) + + prompt = prompt[numPast:] + slot.Inputs = slot.Inputs[:numPast] + + return slot, prompt, nil +} + +func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { + longest := int32(-1) + var longestSlot *InputCacheSlot + + for i, s := range c.slots { + if s.InUse { + continue + } + + count := countCommonPrefix(s.Inputs, prompt) + if count > longest { + longest = count + longestSlot = &c.slots[i] + } + } + + if longestSlot == nil { + return nil, 0, errors.New("no available cache slots") + } + + return longestSlot, longest, nil +} + +func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { + oldest := time.Now() + var oldestSlot *InputCacheSlot + + longest := int32(-1) + var longestSlot *InputCacheSlot + + for i, s := range c.slots { + count := countCommonPrefix(s.Inputs, prompt) + if count > longest { + longest = count + longestSlot = &c.slots[i] + } + + if s.lastUsed.Compare(oldest) < 0 && !s.InUse { + oldest = s.lastUsed + oldestSlot = &c.slots[i] + } + } + + if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse { + return longestSlot, longest, nil + } + + if oldestSlot.InUse { + return nil, 0, errors.New("no available cache slots") + } + + if len(oldestSlot.Inputs) != 0 { + slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs), + "used", oldestSlot.lastUsed) + } + + if longest > 0 && longestSlot != oldestSlot { + slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", + len(longestSlot.Inputs)) + oldestSlot.Inputs = make([]input, longest) + copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) + if c.cache != nil { + c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) + } + } + + return oldestSlot, longest, nil +} + +func countCommonPrefix(a []input, b []input) int32 { + var count int32 + + for i := range a { + if i >= len(b) { + break + } + + if !reflect.DeepEqual(a[i], b[i]) { + break + } + + count++ + } + + return count +} + +func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) + + currentFree := c.numCtx - inputLen + discard := targetFree - currentFree + + if discard < 0 { + discard = 0 + } + + return discard +} + +// Frees up space in the KV cache by deleting the oldest half of history and shifting +// the newest half into that space (saving numKeep inputs at the beginning). +// +// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { + if numKeep >= c.numCtx { + return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) + } + + inputLen := int32(len(slot.Inputs)) + discard := c.ShiftDiscard(inputLen, numKeep) + + if discard <= 0 { + return nil + } + + slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), + "keep", numKeep, "discard", discard) + + // TODO (jessegross): KV cache removal can fail for certain types of models + if c.cache != nil { + err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) + if err != nil { + return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + } + } + + for i := numKeep + discard; i < inputLen; i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:inputLen-discard] + + return nil +} diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go new file mode 100644 index 000000000..99e67b4fa --- /dev/null +++ b/runner/ollamarunner/cache_test.go @@ -0,0 +1,291 @@ +package ollamarunner + +import ( + "image" + "testing" + "time" +) + +func TestCountCommon(t *testing.T) { + imgA := image.NewRGBA(image.Rect(0, 0, 100, 100)) + imgB := image.NewRGBA(image.Rect(0, 0, 50, 50)) + imgC := image.NewRGBA(image.Rect(50, 50, 100, 100)) + + tests := []struct { + name string + t1 []input + t2 []input + expected int32 + }{ + { + name: "Equal", + t1: []input{{token: 1}, {token: 2}, {token: 3}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 3, + }, + { + name: "Prefix", + t1: []input{{token: 1}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 1, + }, + { + name: "Image Prefix", + t1: []input{{image: imgA}}, + t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, + expected: 1, + }, + { + name: "Mixed", + t1: []input{{token: 1}, {image: imgA}}, + t2: []input{{token: 1}, {image: imgA}, {token: 5}}, + expected: 2, + }, + { + name: "Empty", + t1: []input{}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 0, + }, + { + name: "Both Empty", + t1: []input{}, + t2: []input{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := countCommonPrefix(tt.t1, tt.t2) + if result != tt.expected { + t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected) + } + }) + } +} + +func TestFindCacheSlot(t *testing.T) { + type expected struct { + result int + len int32 + } + + tests := []struct { + name string + cache InputCache + prompt []input + longest expected + best expected + }{ + { + name: "Empty", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 0, len: 0}, + }, + { + name: "Extend", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 2}, + best: expected{result: 1, len: 2}, + }, + { + name: "New", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 2}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "Fork", + cache: InputCache{ + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }, + }, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 1}, + best: expected{result: 1, len: 1}, + }, + { + name: "Evict", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 2}, {token: 3}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "In use", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 1}, + best: expected{result: 1, len: 2}, + }, + } + + for _, tt := range tests { + t.Run("Longest-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findLongestCacheSlot: err %v", err) + } else if result.Id != tt.longest.result || resultLen != tt.longest.len { + t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.longest.result, resultLen, tt.longest.len) + } + }) + } + + for _, tt := range tests { + t.Run("Best-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findBestCacheSlot: err %v", err) + } else if result.Id != tt.best.result || resultLen != tt.best.len { + t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.best.result, resultLen, tt.best.len) + } + }) + } +} + +func TestShiftDiscard(t *testing.T) { + tests := []struct { + name string + numCtx int32 + numKeep int32 + inputLen int32 + expected int32 + }{ + { + name: "Shift", + numCtx: 2048, + numKeep: 5, + inputLen: 2048, + expected: 1021, + }, + { + name: "Max Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 2048, + expected: 1, + }, + { + name: "No Keep", + numCtx: 2048, + numKeep: 0, + inputLen: 2048, + expected: 1024, + }, + { + name: "Truncate", + numCtx: 2048, + numKeep: 5, + inputLen: 5000, + expected: 3973, + }, + { + name: "Truncate Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 5000, + expected: 2953, + }, + { + name: "No Op", + numCtx: 2048, + numKeep: 5, + inputLen: 512, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := InputCache{numCtx: tt.numCtx} + result := c.ShiftDiscard(tt.inputLen, tt.numKeep) + if result != tt.expected { + t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go new file mode 100644 index 000000000..d5a3b3407 --- /dev/null +++ b/runner/ollamarunner/runner.go @@ -0,0 +1,945 @@ +package ollamarunner + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "image" + "log" + "log/slog" + "net" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "golang.org/x/sync/semaphore" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/runner/common" + "github.com/ollama/ollama/sample" + + _ "github.com/ollama/ollama/model/models" +) + +// input is an element of the prompt to process, either a token or an image +type input struct { + token int32 + + image image.Image +} + +type Sequence struct { + // batch index + iBatch int + + // prompt inputs left to evaluate + inputs []input + + // inputs that have been added to a batch but not yet submitted to Forward + pendingInputs []input + + // tokens that have been generated but not returned yet (e.g. for stop sequences) + pendingResponses []string + + // input cache being used by this sequence + cache *InputCacheSlot + + // channel to send responses over + responses chan string + + // channel to stop decoding (such as if the remote connection is closed) + quit chan bool + + // number of tokens to predict + numPredict int + + // set of samplers to run on generated logits + samplers []sample.Sampler + + // channel to send back the embedding if embedding only + embedding chan []float32 + + // stop sequences + stop []string + + // number of inputs to keep at the beginning when shifting context window + numKeep int32 + + // true if an embedding are to be returned instead of text generation + embeddingOnly bool + + doneReason string + + // Metrics + startProcessingTime time.Time + startGenerationTime time.Time + numPredicted int + numPromptInputs int +} + +type NewSequenceParams struct { + numPredict int + stop []string + numKeep int32 + samplers []sample.Sampler + embedding bool +} + +func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { + s.ready.Wait() + + startTime := time.Now() + + inputs, err := s.inputs(prompt, images) + if err != nil { + return nil, fmt.Errorf("failed to process inputs: %w", err) + } else if len(inputs) == 0 { + return nil, errors.New("no input provided") + } + + if params.numKeep < 0 { + params.numKeep = int32(len(inputs)) + } + + // Ensure that at least 1 input can be discarded during shift + params.numKeep = min(params.numKeep, s.cache.numCtx-1) + + if int32(len(inputs)) > s.cache.numCtx { + discard := int32(len(inputs)) - s.cache.numCtx + newInputs := inputs[:params.numKeep] + newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + + slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) + inputs = newInputs + } + + // TODO(jessegross): Ingest cached history for grammar + + return &Sequence{ + inputs: inputs, + numPromptInputs: len(inputs), + startProcessingTime: startTime, + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplers: params.samplers, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, + }, nil +} + +// inputs processes the prompt and images into a list of inputs +// by splitting the prompt on [img-] tags, tokenizing text and +// decoding images +func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { + var inputs []input + var parts []string + var matches [][]string + + // TODO(jessegross): This can sometimes trigger for matching text in the + // user's prompt. We previously tried to avoid it by only looking for images + // on image models. We don't have a clear indication now but it would be better + // to properly escape it in any case. + re := regexp.MustCompile(`\[img-(\d+)\]`) + parts = re.Split(prompt, -1) + matches = re.FindAllStringSubmatch(prompt, -1) + + for i, part := range parts { + // text - tokenize + tokens, err := s.model.(model.TextProcessor).Encode(part) + if err != nil { + return nil, err + } + + for _, t := range tokens { + inputs = append(inputs, input{token: t}) + } + + // image - decode and store + if i < len(matches) { + n, _ := strconv.Atoi(matches[i][1]) + + imageIndex := -1 + for j := range images { + if images[j].ID == n { + imageIndex = j + break + } + } + + if imageIndex < 0 { + return nil, fmt.Errorf("invalid image index: %d", n) + } + + image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) + if err != nil { + return nil, err + } + + inputs = append(inputs, input{image: image}) + } + } + + return inputs, nil +} + +type Server struct { + // is the server ready to process requests? + // protects access to model and image + ready sync.WaitGroup + + // loaded model + model model.Model + + // status for external health reporting - loading, ready to serve, etc. + status ServerStatus + + // current progress on loading the model + progress float32 + + // number of simultaneous requests to handle + parallel int + + // maximum number of elements in a batch (per sequence) + // TODO (jmorganca): make this n_batch + batchSize int + + // protects access to everything below this line + // this is context state needed for decoding + mu sync.Mutex + + // indicates that data is ready for processing + cond *sync.Cond + + // the list of simultaneous sequences being evaluated + seqs []*Sequence + + // seqs can have a maximum of parallel entries, which + // is enfoced by seqSem + seqsSem *semaphore.Weighted + + // KV cache + cache *InputCache + + // next sequence for prompt processing to avoid starvation + nextSeq int +} + +func (s *Server) allNil() bool { + for _, item := range s.seqs { + if item != nil { + return false + } + } + return true +} + +func flushPending(seq *Sequence) bool { + joined := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = []string{} + + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(joined) { + joined = joined[:len(joined)-1] + } + + if len(joined) == 0 { + return true + } + + select { + case seq.responses <- joined: + return true + case <-seq.quit: + return false + } +} + +func (s *Server) removeSequence(seqIndex int, reason string) { + seq := s.seqs[seqIndex] + + flushPending(seq) + seq.doneReason = reason + close(seq.responses) + close(seq.embedding) + seq.cache.InUse = false + s.seqs[seqIndex] = nil + s.seqsSem.Release(1) +} + +func (s *Server) run(ctx context.Context) { + s.ready.Wait() + + for { + select { + case <-ctx.Done(): + return + default: + err := s.processBatch() + if err != nil { + panic(err) + } + } + } +} + +func (s *Server) processBatch() error { + s.mu.Lock() + for s.allNil() { + s.cond.Wait() // Wait until an item is added + } + defer s.mu.Unlock() + + var options model.Options + imgSeq := -1 + + seqIdx := s.nextSeq - 1 + for range s.seqs { + seqIdx = (seqIdx + 1) % len(s.seqs) + seq := s.seqs[seqIdx] + + if seq == nil { + continue + } + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.removeSequence(seqIdx, "limit") + continue + } + + if !s.cache.enabled { + seq.inputs = append(seq.cache.Inputs, seq.inputs...) + seq.cache.Inputs = []input{} + } + + for i, input := range seq.inputs { + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { + if len(seq.pendingInputs) == 0 { + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } else { + break + } + } + + if i >= s.batchSize { + break + } + + // TODO(jessegross): Image inputs need to be rethought - it's + // it doesn't work well for different types of models or multiple sequences + if input.image != nil { + if len(seq.pendingInputs) != len(options.Images) { + break + } + + if imgSeq != seqIdx && imgSeq != -1 { + s.nextSeq = seqIdx + break + } + + imgSeq = seqIdx + options.Images = append(options.Images, input.image) + seq.pendingInputs = append(seq.pendingInputs, input) + continue + } + + options.Inputs = append(options.Inputs, input.token) + options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) + options.Sequences = append(options.Sequences, seq.cache.Id) + + seq.iBatch = len(options.Outputs) + if i+1 == len(seq.inputs) { + options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) + } + seq.pendingInputs = append(seq.pendingInputs, input) + } + + seq.inputs = seq.inputs[len(seq.pendingInputs):] + } + + if len(options.Inputs) == 0 { + return nil + } + + ctx := s.model.Backend().NewContext() + defer ctx.Close() + + modelOutput, err := model.Forward(ctx, s.model, options) + if err != nil { + return fmt.Errorf("failed to decode batch: %w", err) + } + + f32s := modelOutput.Floats() + + // TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s + logits := make([]float64, len(f32s)) + for i, f32 := range f32s { + logits[i] = float64(f32) + } + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // After calling Forward, pending inputs are now in the cache + if len(seq.pendingInputs) > 0 { + seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) + seq.pendingInputs = []input{} + } + + // don't sample prompt processing + if len(seq.inputs) != 0 { + if !s.cache.enabled { + return errors.New("caching disabled but unable to fit entire input in a batch") + } + continue + } + + seq.numPredicted++ + if seq.numPredicted == 1 { + seq.startGenerationTime = time.Now() + } + + // if done processing the prompt, generate an embedding and return + if seq.embeddingOnly { + // TODO(jessegross): Embedding support + s.removeSequence(i, "") + continue + } + + // sample a token + vocabSize := len(f32s) / len(options.Outputs) + tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...) + if err != nil { + return err + } + + // TODO(jessegross): Sampler will output a single int32 in the future + token := int32(tokens[0]) + + // if it's an end of sequence token, break + if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { + // TODO (jmorganca): we should send this back + // as it's important for the /api/generate context + // seq.responses <- piece + + s.removeSequence(i, "stop") + continue + } + + piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) + if err != nil { + return err + } + + seq.inputs = []input{{token: token}} + + seq.pendingResponses = append(seq.pendingResponses, piece) + sequence := strings.Join(seq.pendingResponses, "") + + if ok, stop := common.FindStop(sequence, seq.stop); ok { + slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + + var tokenTruncated bool + origLen := len(seq.pendingResponses) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) + newLen := len(seq.pendingResponses) + + // Update the cache based on the tokens that will be returned: + // - We have 1 token more than is currently in the cache because + // the last one generated wasn't submitted to Decode + // - Remove any stop sequences that we stripped out + // - If truncateStop removed a portion of a token, drop that + // - As defense-in-depth, if truncatedToken didn't find a stop token + // remove the extra one that we added to the cache len + tokenLen := len(seq.cache.Inputs) + 1 + tokenLen -= origLen - newLen + if tokenTruncated || origLen == newLen { + tokenLen-- + } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] + + s.removeSequence(i, "stop") + continue + } + + if common.ContainsStopSuffix(sequence, seq.stop) { + continue + } + + if common.IncompleteUnicode(sequence) { + continue + } + + if !flushPending(seq) { + s.removeSequence(i, "connection") + } + } + + return nil +} + +// TODO (jmorganca): use structs from the api package to avoid duplication +// this way the api acts as a proxy instead of using a different api for the +// runner +type Options struct { + api.Runner + + NumKeep int `json:"n_keep"` + Seed int `json:"seed"` + NumPredict int `json:"n_predict"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TypicalP float32 `json:"typical_p"` + RepeatLastN int `json:"repeat_last_n"` + Temperature float32 `json:"temperature"` + RepeatPenalty float32 `json:"repeat_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + FrequencyPenalty float32 `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau float32 `json:"mirostat_tau"` + MirostatEta float32 `json:"mirostat_eta"` + Stop []string `json:"stop"` +} + +type ImageData struct { + Data []byte `json:"data"` + ID int `json:"id"` + AspectRatioID int `json:"aspect_ratio_id"` +} + +type CompletionRequest struct { + Prompt string `json:"prompt"` + Images []ImageData `json:"image_data"` + Grammar string `json:"grammar"` + CachePrompt bool `json:"cache_prompt"` + + Options +} + +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} + +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` +} + +func getSamplers(_ CompletionRequest) []sample.Sampler { + // TODO(jessegross): Waiting for sampling code + + /*samplingParams.TopK = req.TopK + samplingParams.TopP = req.TopP + samplingParams.MinP = req.MinP + samplingParams.TypicalP = req.TypicalP + samplingParams.Temp = req.Temperature + samplingParams.RepeatLastN = req.RepeatLastN + samplingParams.PenaltyRepeat = req.RepeatPenalty + samplingParams.PenaltyFreq = req.FrequencyPenalty + samplingParams.PenaltyPresent = req.PresencePenalty + samplingParams.Mirostat = req.Mirostat + samplingParams.MirostatTau = req.MirostatTau + samplingParams.MirostatEta = req.MirostatEta + samplingParams.Seed = uint32(req.Seed) + samplingParams.Grammar = req.Grammar*/ + + return []sample.Sampler{sample.Greedy()} +} + +func (s *Server) completion(w http.ResponseWriter, r *http.Request) { + var req CompletionRequest + req.Options = Options(api.DefaultOptions()) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Set the headers to indicate streaming + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Transfer-Encoding", "chunked") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ + numPredict: req.NumPredict, + stop: req.Stop, + numKeep: int32(req.NumKeep), + samplers: getSamplers(req), + embedding: false, + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting completion request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + for { + select { + case <-r.Context().Done(): + close(seq.quit) + return + case content, ok := <-seq.responses: + if ok { + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Content: content, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + + flusher.Flush() + } else { + // Send the final response + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Stop: true, + StoppedLimit: seq.doneReason == "limit", + Timings: Timings{ + PromptN: seq.numPromptInputs, + PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), + PredictedN: seq.numPredicted, + PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), + }, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) + } + + return + } + } + } +} + +type EmbeddingRequest struct { + Content string `json:"content"` + CachePrompt bool `json:"cache_prompt"` +} + +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` +} + +func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { + var req EmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + slog.Debug("embedding request", "content", req.Content) + + seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting embeddings request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + embedding := <-seq.embedding + + if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + Embedding: embedding, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type HealthResponse struct { + Status string `json:"status"` + Progress float32 `json:"progress"` +} + +type ServerStatus int + +const ( + ServerStatusReady ServerStatus = iota + ServerStatusLoadingModel + ServerStatusError +) + +func (s ServerStatus) ToString() string { + switch s { + case ServerStatusReady: + return "ok" + case ServerStatusLoadingModel: + return "loading model" + default: + return "server error" + } +} + +func (s *Server) health(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(&HealthResponse{ + Status: s.status.ToString(), + Progress: s.progress, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type multiLPath []string + +func (m *multiLPath) Set(value string) error { + *m = append(*m, value) + return nil +} + +func (m *multiLPath) String() string { + return strings.Join(*m, ", ") +} + +func (s *Server) loadModel( + mpath string, + lpath multiLPath, + parallel int, + kvCacheType string, + kvSize int, + multiUserCache bool, +) { + var err error + s.model, err = model.New(mpath) + if err != nil { + panic(err) + } + + // TODO(jessegross): LoRA loading + if lpath.String() != "" { + panic("loras are not yet implemented") + } + + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) + if err != nil { + panic(err) + } + + if !s.cache.enabled && parallel > 1 { + parallel = 1 + slog.Warn("model does not support caching, disabling parallel processing") + } + + s.parallel = parallel + s.seqs = make([]*Sequence, s.parallel) + s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + + s.status = ServerStatusReady + s.ready.Done() +} + +func Execute(args []string) error { + fs := flag.NewFlagSet("runner", flag.ExitOnError) + mpath := fs.String("model", "", "Path to model binary file") + parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") + batchSize := fs.Int("batch-size", 512, "Batch size") + _ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") + _ = fs.Int("main-gpu", 0, "Main GPU") + _ = fs.Bool("flash-attn", false, "Enable flash attention") + kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") + kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") + port := fs.Int("port", 8080, "Port to expose the server on") + _ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") + verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") + _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") + _ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing") + _ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") + multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") + + var lpaths multiLPath + fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)") + + fs.Usage = func() { + fmt.Fprintf(fs.Output(), "Runner usage\n") + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + level := slog.LevelInfo + if *verbose { + level = slog.LevelDebug + } + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + AddSource: true, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.SourceKey { + source := attr.Value.Any().(*slog.Source) + source.File = filepath.Base(source.File) + } + return attr + }, + }) + slog.SetDefault(slog.New(handler)) + slog.Info("starting ollama engine") + // TODO(jessegross): Some system info would be useful + + server := &Server{ + batchSize: *batchSize, + status: ServerStatusLoadingModel, + } + + // TODO(jessegross): Parameters that need to be implemented: + // n-gpu-layers + // main-gpu + // flash-attn + // threads + // no-mmap + // mlock + // tensor-split + + /*var tensorSplitFloats []float32 + if *tensorSplit != "" { + stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1) + + tensorSplitFloats = make([]float32, 0, len(stringFloats)) + for _, s := range stringFloats { + f, _ := strconv.ParseFloat(s, 32) + tensorSplitFloats = append(tensorSplitFloats, float32(f)) + } + }*/ + + server.ready.Add(1) + go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) + + server.cond = sync.NewCond(&server.mu) + + ctx, cancel := context.WithCancel(context.Background()) + go server.run(ctx) + + addr := "127.0.0.1:" + strconv.Itoa(*port) + listener, err := net.Listen("tcp", addr) + if err != nil { + fmt.Println("Listen error:", err) + cancel() + return err + } + defer listener.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/embedding", server.embeddings) + mux.HandleFunc("/completion", server.completion) + mux.HandleFunc("/health", server.health) + + httpServer := http.Server{ + Handler: mux, + } + + log.Println("Server listening on", addr) + if err := httpServer.Serve(listener); err != nil { + log.Fatal("server error:", err) + return err + } + + cancel() + return nil +} diff --git a/runner/runner.go b/runner/runner.go new file mode 100644 index 000000000..500fdd72e --- /dev/null +++ b/runner/runner.go @@ -0,0 +1,24 @@ +package runner + +import ( + "github.com/ollama/ollama/runner/llamarunner" + "github.com/ollama/ollama/runner/ollamarunner" +) + +func Execute(args []string) error { + if args[0] == "runner" { + args = args[1:] + } + + var newRunner bool + if args[0] == "--ollama-engine" { + args = args[1:] + newRunner = true + } + + if newRunner { + return ollamarunner.Execute(args) + } else { + return llamarunner.Execute(args) + } +} diff --git a/server/prompt.go b/server/prompt.go index 610891729..233dffd69 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/template" @@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var imgData llm.ImageData if isMllama { - data, opts, err := mllama.Preprocess(bytes.NewReader(i)) - if err != nil { - return "", nil, err - } + if envconfig.NewEngine() { + imgData = llm.ImageData{ + ID: len(images), + Data: i, + } + } else { + data, opts, err := mllama.Preprocess(bytes.NewReader(i)) + if err != nil { + return "", nil, err + } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - return "", nil, err - } + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } - ar, ok := opts["aspectRatioIndex"].(int) - if !ok { - return "", nil, fmt.Errorf("missing aspect ratio for image") - } + ar, ok := opts["aspectRatioIndex"].(int) + if !ok { + return "", nil, fmt.Errorf("missing aspect ratio for image") + } - imgData = llm.ImageData{ - ID: len(images), - Data: buf.Bytes(), - AspectRatioID: ar, + imgData = llm.ImageData{ + ID: len(images), + Data: buf.Bytes(), + AspectRatioID: ar, + } } imgPrompt = "<|image|>" } else { diff --git a/server/routes.go b/server/routes.go index 779d3205d..95485cb81 100644 --- a/server/routes.go +++ b/server/routes.go @@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama { + if isMllama && !envconfig.NewEngine() { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})