From ed443a03930a10bec6182c55091f0880baa1e620 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Tue, 17 Dec 2024 19:59:41 -0800 Subject: [PATCH] Runner for Ollama engine This provides integration with the new Ollama engine (5824541 next ollama runner (#7913)) and the rest of the Ollama infrastructure such as the runner and Ollama server. In addition, it also builds out the KV cache infrastructure to support requirements of how Ollama runs models such as: - Parallel processing - Memory management for defragmentation and shifting - Multi-modal modals Both old and new engines continue to be supported. By default, only the old engine is used. To enable the new engine: Start the server with the OLLAMA_NEW_ENGINE environment variable set: OLLAMA_NEW_ENGINE=1 ./ollama serve Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M: ./ollama run jessegross/llama3.1 --- cache/cache.go | 63 -- cmd/cmd.go | 7 +- cmd/runner/main.go | 2 +- envconfig/config.go | 3 + kvcache/cache.go | 54 + kvcache/causal.go | 455 +++++++++ kvcache/causal_test.go | 506 ++++++++++ kvcache/encoder.go | 97 ++ kvcache/wrapper.go | 93 ++ llm/server.go | 3 + ml/backend.go | 26 +- ml/backend/ggml/ggml.go | 15 +- model/model.go | 128 +-- model/model_test.go | 8 +- model/models/llama/model.go | 30 +- model/models/mllama/model.go | 22 +- model/models/mllama/model_text.go | 62 +- {llama/runner => runner}/README.md | 0 {llama/runner => runner/common}/stop.go | 10 +- {llama/runner => runner/common}/stop_test.go | 6 +- {llama/runner => runner/llamarunner}/cache.go | 2 +- .../llamarunner}/cache_test.go | 2 +- {llama/runner => runner/llamarunner}/image.go | 2 +- .../llamarunner}/image_test.go | 2 +- .../runner => runner/llamarunner}/runner.go | 14 +- runner/ollamarunner/cache.go | 280 ++++++ runner/ollamarunner/cache_test.go | 291 ++++++ runner/ollamarunner/runner.go | 945 ++++++++++++++++++ runner/runner.go | 24 + server/prompt.go | 42 +- server/routes.go | 2 +- 31 files changed, 2952 insertions(+), 244 deletions(-) delete mode 100644 cache/cache.go create mode 100644 kvcache/cache.go create mode 100644 kvcache/causal.go create mode 100644 kvcache/causal_test.go create mode 100644 kvcache/encoder.go create mode 100644 kvcache/wrapper.go rename {llama/runner => runner}/README.md (100%) rename {llama/runner => runner/common}/stop.go (87%) rename {llama/runner => runner/common}/stop_test.go (96%) rename {llama/runner => runner/llamarunner}/cache.go (99%) rename {llama/runner => runner/llamarunner}/cache_test.go (99%) rename {llama/runner => runner/llamarunner}/image.go (99%) rename {llama/runner => runner/llamarunner}/image_test.go (99%) rename {llama/runner => runner/llamarunner}/runner.go (98%) create mode 100644 runner/ollamarunner/cache.go create mode 100644 runner/ollamarunner/cache_test.go create mode 100644 runner/ollamarunner/runner.go create mode 100644 runner/runner.go diff --git a/cache/cache.go b/cache/cache.go deleted file mode 100644 index 572b886ed..000000000 --- a/cache/cache.go +++ /dev/null @@ -1,63 +0,0 @@ -package cache - -import ( - "github.com/ollama/ollama/ml" -) - -type Options struct { - Position int -} - -type Cache interface { - Sub(i int) Cache - Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) -} - -type Simple struct { - DType ml.DType - Capacity int - - keys, values []ml.Tensor -} - -func (c *Simple) Sub(i int) Cache { - if i >= len(c.keys) { - c.keys = append(c.keys, make([]ml.Tensor, i-len(c.keys)+1)...) - c.values = append(c.values, make([]ml.Tensor, i-len(c.values)+1)...) - } - - return &Simple{ - keys: c.keys[i : i+1], - values: c.values[i : i+1], - Capacity: c.Capacity, - DType: c.DType, - } -} - -func (c *Simple) Put(ctx ml.Context, key, value ml.Tensor, opts Options) (ml.Tensor, ml.Tensor) { - if c.keys[0] == nil || c.values[0] == nil { - c.keys[0] = ctx.Zeros(c.DType, key.Dim(0)*key.Dim(1)*c.Capacity) - c.values[0] = ctx.Zeros(c.DType, value.Dim(0)*value.Dim(1)*c.Capacity) - } - - ctx.Forward(key.Copy(ctx, c.keys[0].View(ctx, key.Stride(2)*opts.Position, key.Dim(0)*key.Dim(1)*key.Dim(2)))) - ctx.Forward(value.Copy(ctx, c.values[0].View(ctx, value.Stride(2)*opts.Position, value.Dim(0)*value.Dim(1)*value.Dim(2)))) - - n := min(c.Capacity, key.Dim(2)+opts.Position) - - key = c.keys[0].View(ctx, 0, - key.Dim(0), key.Stride(1), - key.Dim(1), key.Stride(2), - n, - ) - - value = c.values[0].View(ctx, 0, - value.Dim(0), value.Stride(1), - value.Dim(1), value.Stride(2), - n, - ) - - // TODO shift context if necessary - - return key, value -} diff --git a/cmd/cmd.go b/cmd/cmd.go index 17c607171..80ece4c60 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -35,9 +35,9 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/format" "github.com/ollama/ollama/llama" - "github.com/ollama/ollama/llama/runner" "github.com/ollama/ollama/parser" "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" "github.com/ollama/ollama/version" @@ -338,7 +338,10 @@ func RunHandler(cmd *cobra.Command, args []string) error { return err } - opts.MultiModal = len(info.ProjectorInfo) != 0 + // TODO(jessegross): We should either find another way to know if this is + // a vision model or remove the logic. Also consider that other modalities will + // need different behavior anyways. + opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine() opts.ParentModel = info.Details.ParentModel if interactive { diff --git a/cmd/runner/main.go b/cmd/runner/main.go index 34b0e9d21..fbfafc7ff 100644 --- a/cmd/runner/main.go +++ b/cmd/runner/main.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/ollama/ollama/llama/runner" + "github.com/ollama/ollama/runner" ) func main() { diff --git a/envconfig/config.go b/envconfig/config.go index 0ca3b64cd..fbd881ba7 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -165,6 +165,8 @@ var ( IntelGPU = Bool("OLLAMA_INTEL_GPU") // MultiUserCache optimizes prompt caching for multi-user scenarios MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") + // Enable the new Ollama engine + NewEngine = Bool("OLLAMA_NEW_ENGINE") ) func String(s string) func() string { @@ -250,6 +252,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, + "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/kvcache/cache.go b/kvcache/cache.go new file mode 100644 index 000000000..5d8b2f9b5 --- /dev/null +++ b/kvcache/cache.go @@ -0,0 +1,54 @@ +package kvcache + +import ( + "errors" + + "github.com/ollama/ollama/ml" +) + +var ( + ErrKvCacheFull = errors.New("could not find a kv cache slot") + ErrNotSupported = errors.New("model does not support operation") +) + +type Cache interface { + // ** used by model implementations ** + + // SetLayer sets the active layer of the cache + SetLayer(layer int) + + // Get returns the history of key and value tensors plus a mask + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) + + // Put stores a batch of key and value in the cache + // + // The shape of the tensors is documented in the specific + // cache implementation used. + Put(ctx ml.Context, key, value ml.Tensor) + + // ** cache management ** + + // Init sets up runtime parameters + Init(backend ml.Backend, dtype ml.DType, capacity int32) + + // Close closes the cache and frees resources associated with it + Close() + + // StartForward is called before the start of the model's forward pass. + // For each token in the coming batch, there must be a corresponding + // entry in positions and seqs. + StartForward(ctx ml.Context, positions []int32, seqs []int) error + + // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq + CopyPrefix(srcSeq, dstSeq int, len int32) + + // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set + // endIndex to math.MaxInt32 to remove everything starting at beginIndex. + // + // If an error occurs, the entire context for the sequence should be + // removed by calling Remove(seq, 0, math.MaxInt32) + Remove(seq int, beginIndex, endIndex int32) error +} diff --git a/kvcache/causal.go b/kvcache/causal.go new file mode 100644 index 000000000..5d46f8d4c --- /dev/null +++ b/kvcache/causal.go @@ -0,0 +1,455 @@ +package kvcache + +import ( + "errors" + "fmt" + "log/slog" + "math" + "slices" + + "github.com/ollama/ollama/ml" +) + +type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) + +// Causal cache stores K and V tensors according to their position in the +// sequence. Returns the history and a mask for attending to past tokens +// +// The tensors are of shape embed dim, kv heads, batch size +// The mask is of shape history size, batch size +type Causal struct { + DType ml.DType + Capacity int32 + windowSize int32 + + // ** current forward pass ** + + // the active layer for Get and Put + curLayer int + + // starting location for data storage for this batch + curLoc int + + // size of the current batch + curBatchSize int + + // mask of the cache as used by this batch + curMask ml.Tensor + + // locations in the cache that are needed for this batch + curCellRange cellRange + + // ** cache metadata ** + + // for each possible location in the cache, stores the position and set of sequences + // that reference the data there + cells []cacheCell + + // maps from sequence to the range of locations where it is stored in the cache + cellRanges map[int]cellRange + + // ** cache data storage ** + + shiftFn shiftFn + backend ml.Backend + cacheCtx ml.Context + keys, values []ml.Tensor +} + +type cacheCell struct { + pos int32 + sequences []int +} + +type cellRange struct { + min int + max int +} + +func NewCausalCache(shift shiftFn) *Causal { + return &Causal{windowSize: math.MaxInt32, shiftFn: shift} +} + +func NewSWACache(windowSize int32, shift shiftFn) *Causal { + return &Causal{windowSize: windowSize, shiftFn: shift} +} + +func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + c.DType = dtype + c.Capacity = capacity + c.cells = make([]cacheCell, capacity) + c.cellRanges = make(map[int]cellRange) + c.backend = backend + c.cacheCtx = backend.NewContext() +} + +func (c *Causal) Close() { + c.cacheCtx.Close() +} + +func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + c.curBatchSize = len(positions) + + var err error + c.curLoc, err = c.findStartLoc() + if errors.Is(err, ErrKvCacheFull) { + c.defrag() + c.curLoc, err = c.findStartLoc() + } + if err != nil { + return err + } + + c.curCellRange = newRange() + for i, pos := range positions { + seq := seqs[i] + + c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} + + seqRange, ok := c.cellRanges[seq] + if !ok { + seqRange = newRange() + } + + if c.curLoc+i > seqRange.max { + seqRange.max = c.curLoc + i + } + if seqRange.max > c.curCellRange.max { + c.curCellRange.max = seqRange.max + } + + if c.curLoc+i < seqRange.min { + seqRange.min = c.curLoc + i + } + if seqRange.min < c.curCellRange.min { + c.curCellRange.min = seqRange.min + } + c.cellRanges[seq] = seqRange + } + + c.curMask, err = c.buildMask(ctx, positions, seqs) + + return err +} + +func newRange() cellRange { + return cellRange{ + min: math.MaxInt, + max: 0, + } +} + +// Find the first contiguous block of at least curBatchSize +func (c *Causal) findStartLoc() (int, error) { + var start, count int + for i := range c.cells { + if len(c.cells[i].sequences) == 0 { + count++ + if count >= c.curBatchSize { + return start, nil + } + } else { + start = i + 1 + count = 0 + } + } + + return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) +} + +// Builds a mask of history x batch indicating whether for each token in the batch the +// token in the history should apply. This is based on both the sequence and causality (the +// position of the history is not ahead of the token in the batch). +func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { + // TODO(jessegross): This does not do padding, which is required for flash attention + len := c.curCellRange.max - c.curCellRange.min + 1 + mask := make([]float32, c.curBatchSize*len) + + for i := range c.curBatchSize { + for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { + if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || + c.cells[j].pos < positions[i]-c.windowSize { + mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) + } + } + } + + return ctx.FromFloatSlice(mask, len, c.curBatchSize) +} + +func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { + for _, obj := range objs { + if obj == nil { + continue + } + + srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) + dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len) + + ctx.Forward(srcView.Copy(ctx, dstView)) + } +} + +func (c *Causal) defrag() { + slog.Debug("defragmenting kv cache") + + // Defrag strategy: + // - Search for empty holes at the beginning of the cache, + // filling them with active data starting at the end + // - If there are contiguous elements that need to be moved, + // combine them into a single operation by holding new moves + // until we see that the next one is non-contiguous + // - Fill up the context with the maximum number of operations it + // can hold then compute that and continue with a new context + // + // We could try to optimize placement by grouping blocks from + // the same sequences together but most likely the next forward + // pass will disrupt this anyways, so the real world benefit + // seems limited as this time. + + ctx := c.backend.NewContext() + + // For every move, 6 tensors are required per layer (2 views and a + // copy for each of k and v). + layers := 0 + for _, key := range c.keys { + if key == nil { + continue + } + layers++ + } + + maxMoves := ctx.MaxTensors() / (6 * layers) + moves := 0 + + var pendingSrc, pendingDst, pendingLen int + src := len(c.cells) - 1 + + for dst := 0; dst < src; dst++ { + if len(c.cells[dst].sequences) == 0 { + for ; src > dst; src-- { + if len(c.cells[src].sequences) != 0 { + c.cells[dst] = c.cells[src] + c.cells[src] = cacheCell{} + + if pendingLen > 0 { + if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen { + pendingSrc = src + pendingLen++ + break + } else { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + } + + pendingSrc = src + pendingDst = dst + pendingLen = 1 + + break + } + } + } + + if moves >= maxMoves { + ctx.Compute() + ctx.Close() + ctx = c.backend.NewContext() + + moves = 0 + } + } + + if pendingLen > 0 { + moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) + moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen) + moves++ + } + + if moves > 0 { + ctx.Compute() + } + ctx.Close() + + // Reset range metadata + for seq := range c.cellRanges { + seqRange := newRange() + + for i, cell := range c.cells { + if slices.Contains(cell.sequences, seq) { + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[seq] = seqRange + } +} + +func (c *Causal) SetLayer(layer int) { + if layer >= len(c.keys) { + c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) + c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) + } + + c.curLayer = layer +} + +func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + key := c.keys[c.curLayer] + value := c.values[c.curLayer] + + key = key.View(ctx, key.Stride(2)*c.curCellRange.min, + key.Dim(0), key.Stride(1), + key.Dim(1), key.Stride(2), + c.curMask.Dim(0), + ) + + value = value.View(ctx, key.Stride(2)*c.curCellRange.min, + value.Dim(0), value.Stride(1), + value.Dim(1), value.Stride(2), + c.curMask.Dim(0), + ) + + return key, value, c.curMask +} + +func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { + if c.curBatchSize != key.Dim(2) { + panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) + } + + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) + c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity)) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2)))) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2)))) +} + +func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { + seqRange := newRange() + + for i := range c.cells { + // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end + if slices.Contains(c.cells[i].sequences, dstSeq) { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq }) + } + + if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len { + c.cells[i].sequences = append(c.cells[i].sequences, dstSeq) + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + + c.cellRanges[dstSeq] = seqRange +} + +func (c *Causal) shift(seq int, beginIndex, offset int32) error { + if c.shiftFn == nil { + return ErrNotSupported + } + + ctx := c.backend.NewContext() + defer ctx.Close() + + seqRange := c.cellRanges[seq] + size := seqRange.max - seqRange.min + 1 + + offsets := make([]int32, size) + for i := range offsets { + cell := c.cells[seqRange.min+i] + + if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { + offsets[i] = offset + } + } + + kShift, err := ctx.FromIntSlice(offsets, len(offsets)) + if err != nil { + return err + } + + for i, key := range c.keys { + if key == nil { + continue + } + + key = key.View(ctx, key.Stride(2)*seqRange.min, + key.Dim(0), key.Stride(1), + key.Dim(1), key.Stride(2), + size, + ) + + roped, err := c.shiftFn(ctx, i, key, kShift) + if err != nil { + return err + } + + ctx.Forward(roped.Copy(ctx, key)) + } + + ctx.Compute() + + return nil +} + +func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error { + var offset int32 + if endIndex != math.MaxInt32 { + offset = beginIndex - endIndex + } + + seqRange := newRange() + + for i := range c.cells { + if slices.Contains(c.cells[i].sequences, seq) { + if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex { + c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq }) + } else { + if c.cells[i].pos >= endIndex { + if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) { + // TODO(jessegross): Need to be careful about data shared between sequences + return errors.New("shifting on cells shared by multiple sequences not yet implemented") + } + + c.cells[i].pos += offset + } + if i < seqRange.min { + seqRange.min = i + } + if i > seqRange.max { + seqRange.max = i + } + } + } + } + + if seqRange == newRange() { + delete(c.cellRanges, seq) + return nil + } + + c.cellRanges[seq] = seqRange + + if endIndex != math.MaxInt32 { + err := c.shift(seq, endIndex+offset, offset) + if err != nil { + return err + } + } + + return nil +} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go new file mode 100644 index 000000000..0b614df06 --- /dev/null +++ b/kvcache/causal_test.go @@ -0,0 +1,506 @@ +package kvcache + +import ( + "math" + "slices" + "testing" + + "github.com/ollama/ollama/ml" +) + +type testCase struct { + name string + in []float32 + inShape []int + seqs []int + pos []int32 + expected []float32 + expectedShape []int + expectedMask []float32 +} + +func TestStore(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + inShape: []int{2, 3, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234}, + expectedShape: []int{2, 3, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{115, 215, 125, 225, 135, 235}, + inShape: []int{2, 3, 1}, + seqs: []int{0}, + pos: []int32{4}, + expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235}, + expectedShape: []int{2, 3, 5}, + expectedMask: []float32{0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestSWA(t *testing.T) { + backend := &testBackend{} + cache := NewSWACache(1, nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF32, 16) + + tests := []testCase{ + { + name: "SlidingWindow", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestSequences(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(nil) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + { + name: "SecondBatch", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{2, 2}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestRemove(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.Add(ctx, shift), nil + }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 1, 1}, + pos: []int32{0, 1, 0, 1}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + err := cache.Remove(0, 1, math.MaxInt32) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveEnd", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 1}, + pos: []int32{1, 2}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) + + err = cache.Remove(0, 0, 1) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "RemoveMiddle", + in: []float32{7, 8}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{1, 2}, + expected: []float32{7, 8, 3, 4, 4}, + expectedShape: []int{1, 1, 5}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestDefrag(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.Add(ctx, shift), nil + }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + inShape: []int{1, 1, 16}, + seqs: []int{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + expectedShape: []int{1, 1, 16}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + err := cache.Remove(0, 2, 4) + if err != nil { + panic(err) + } + + err = cache.Remove(0, 13, math.MaxInt32) + if err != nil { + panic(err) + } + + tests = []testCase{ + { + name: "Defrag", + in: []float32{17, 18, 19}, + inShape: []int{1, 1, 3}, + seqs: []int{0, 0, 0}, + pos: []int32{16, 17, 18}, + expected: []float32{1, 2, 12, 13, 3, 4, 5, 6, 7, 8, 9, 10, 11, 17, 18, 19}, + expectedShape: []int{1, 1, 16}, + expectedMask: []float32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func TestCopy(t *testing.T) { + backend := &testBackend{} + cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil }) + defer cache.Close() + + cache.Init(backend, ml.DTypeF16, 16) + + tests := []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0}, + }, + } + + testCache(t, backend, cache, tests) + + cache.CopyPrefix(0, 1, 2) + + tests = []testCase{ + { + name: "Copy", + in: []float32{5, 6}, + inShape: []int{1, 1, 2}, + seqs: []int{1, 1}, + pos: []int32{3, 4}, + expected: []float32{1, 2, 3, 4, 5, 6}, + expectedShape: []int{1, 1, 6}, + expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0}, + }, + } + + testCache(t, backend, cache, tests) +} + +func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + context := backend.NewContext() + defer context.Close() + + err := cache.StartForward(context, test.pos, test.seqs) + if err != nil { + panic(err) + } + + cache.SetLayer(0) + tensor, _ := context.FromFloatSlice(test.in, test.inShape...) + cache.Put(context, tensor, tensor) + + out, _, mask := cache.Get(context) + + context.Forward(out) + context.Forward(mask) + context.Compute(out, mask) + + if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { + t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) + } + }) + } +} + +type testBackend struct{} + +func (b *testBackend) Config() ml.Config { + panic("not implemented") +} + +func (b *testBackend) Get(name string) ml.Tensor { + panic("not implemented") +} + +func (b *testBackend) NewContext() ml.Context { + return &testContext{} +} + +type testContext struct{} + +func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + total := 0 + + if len(shape) > 0 { + total = 1 + for _, s := range shape { + total *= s + } + } + + return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape} +} + +func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { + t := c.Zeros(ml.DTypeF32, shape...).(*testTensor) + + copy(t.data, s) + + return t, nil +} + +func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { + f := make([]float32, len(s)) + for i := range f { + f[i] = float32(s[i]) + } + + out, _ := c.FromFloatSlice(f, shape...) + out.(*testTensor).dtype = ml.DTypeI32 + + return out, nil +} + +func (c *testContext) Forward(ml.Tensor) {} + +func (c *testContext) Compute(...ml.Tensor) {} + +func (c *testContext) MaxTensors() int { + return 10 +} + +func (c *testContext) Close() {} + +type testTensor struct { + dtype ml.DType + elementSize int + data []float32 + shape []int +} + +func (t *testTensor) Dim(n int) int { + return t.shape[n] +} + +func (t *testTensor) Stride(n int) int { + stride := t.elementSize + for i := range n { + stride *= t.shape[i] + } + + return stride +} + +func (t *testTensor) Shape() []int { + return t.shape +} + +func (t *testTensor) DType() ml.DType { + return t.dtype +} + +func (t *testTensor) Bytes() []byte { + panic("not implemented") +} + +func (t *testTensor) Floats() []float32 { + out := make([]float32, len(t.data)) + copy(out, t.data) + return out +} + +func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor) + + for i := range out.data { + out.data[i] = t.data[i] + t2.(*testTensor).data[i] + } + + return out +} + +func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim uint32, base, scale float32) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor { + offset /= t.elementSize + + var s []int + + switch len(shape) { + case 1: + s = []int{shape[0]} + case 5: + s = []int{shape[0], shape[2], shape[4]} + default: + panic("unsupported number of dimensions") + } + + context := &testContext{} + + view := context.Zeros(t.dtype, s...).(*testTensor) + view.data = t.data[offset : offset+len(view.data)] + + return view +} + +func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + panic("not implemented") +} + +func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { + copy(t2.(*testTensor).data, t.data) + return nil +} diff --git a/kvcache/encoder.go b/kvcache/encoder.go new file mode 100644 index 000000000..8a44c194b --- /dev/null +++ b/kvcache/encoder.go @@ -0,0 +1,97 @@ +package kvcache + +import ( + "github.com/ollama/ollama/ml" +) + +// Encoder cache stores K and V tensors that are position independent +// +// The tensors can be of any shape and will be returned as they were stored +// The mask is currently always nil +// +// Not currently safe for multiple sequences +type EncoderCache struct { + // ** current forward pass ** + + // the active layer for Get and Put + curLayer int + + // if something is stored during this pass, this + // will be the position (but there is no guarantee + // anything will be stored) + curPos int32 + + // ** cache metadata ** + + // was something stored in the cache? + encoderCached bool + + // position of the cached data + encoderPos int32 + + // ** cache data storage ** + + cacheCtx ml.Context + keys, values []ml.Tensor +} + +func NewEncoderCache() *EncoderCache { + return &EncoderCache{} +} + +func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + c.cacheCtx = backend.NewContext() +} + +func (c *EncoderCache) Close() { + c.cacheCtx.Close() +} + +func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + // The image is always in the first position + c.curPos = positions[0] + + return nil +} + +func (c *EncoderCache) SetLayer(layer int) { + if layer >= len(c.keys) { + c.keys = append(c.keys, make([]ml.Tensor, layer-len(c.keys)+1)...) + c.values = append(c.values, make([]ml.Tensor, layer-len(c.values)+1)...) + } + + c.curLayer = layer +} + +func (c *EncoderCache) EncoderCached() bool { + return c.encoderCached +} + +func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.keys[c.curLayer], c.values[c.curLayer], nil +} + +func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.encoderPos = c.curPos + c.encoderCached = true + + if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { + c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) + c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) + } + + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer])) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer])) +} + +func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) { + panic("encoder cache does not support multiple sequences") +} + +func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error { + if c.encoderPos >= beginIndex && c.encoderPos < endIndex { + c.encoderCached = false + } + + return nil +} diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go new file mode 100644 index 000000000..2d4c1089a --- /dev/null +++ b/kvcache/wrapper.go @@ -0,0 +1,93 @@ +package kvcache + +import ( + "math" + + "github.com/ollama/ollama/ml" +) + +// Wrapper cache is a container for multiple types of caches, +// such as for the encoding and decoding portions of a model. +type WrapperCache struct { + // caches we are wrapping + caches []Cache + + // cache to be used for this layer + curType int +} + +func NewWrapperCache(caches ...Cache) *WrapperCache { + return &WrapperCache{ + caches: caches, + } +} + +func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { + for _, cache := range c.caches { + cache.Init(backend, dtype, capacity) + } +} + +func (c *WrapperCache) Close() { + for _, cache := range c.caches { + cache.Close() + } +} + +func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { + for i, cache := range c.caches { + err := cache.StartForward(ctx, positions, seqs) + if err != nil { + // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail + for j := i - 1; j >= 0; j-- { + for k := range positions { + _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + } + } + return err + } + } + + c.curType = 0 + return nil +} + +func (c *WrapperCache) SetLayer(layer int) { + for _, cache := range c.caches { + cache.SetLayer(layer) + } +} + +func (c *WrapperCache) SetLayerType(layerType int) { + c.curType = layerType +} + +func (c *WrapperCache) UnderlyingCache() Cache { + return c.caches[c.curType] +} + +func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { + return c.caches[c.curType].Get(ctx) +} + +func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) { + c.caches[c.curType].Put(ctx, key, value) +} + +func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) { + for _, cache := range c.caches { + cache.CopyPrefix(srcSeq, dstSeq, len) + } +} + +func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error { + // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail + for _, cache := range c.caches { + err := cache.Remove(seq, beginIndex, endIndex) + if err != nil { + return err + } + } + + return nil +} diff --git a/llm/server.go b/llm/server.go index 765d7bb9b..50ba91f18 100644 --- a/llm/server.go +++ b/llm/server.go @@ -275,6 +275,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } finalParams := []string{"runner"} + if envconfig.NewEngine() { + finalParams = append(finalParams, "--ollama-engine") + } finalParams = append(finalParams, params...) finalParams = append(finalParams, "--port", strconv.Itoa(port)) diff --git a/ml/backend.go b/ml/backend.go index acfb67637..0e99ab5a8 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -50,6 +50,7 @@ type Context interface { Forward(Tensor) Compute(...Tensor) + MaxTensors() int Close() } @@ -118,7 +119,7 @@ type DumpOptions struct { Precision int } -func Dump(t Tensor, opts ...DumpOptions) string { +func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { if len(opts) < 1 { opts = append(opts, DumpOptions{ Items: 3, @@ -128,11 +129,17 @@ func Dump(t Tensor, opts ...DumpOptions) string { switch t.DType() { case DTypeF32: - return dump[[]float32](t, opts[0].Items, func(f float32) string { + return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { + return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) + }) + case DTypeF16: + f32 := ctx.Zeros(DTypeF32, t.Shape()...) + f32 = t.Copy(ctx, f32) + return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) }) case DTypeI32: - return dump[[]int32](t, opts[0].Items, func(i int32) string { + return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string { return strconv.FormatInt(int64(i), 10) }) default: @@ -140,10 +147,10 @@ func Dump(t Tensor, opts ...DumpOptions) string { } } -func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { - bts := t.Bytes() - if bts == nil { - return "" +func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string { + if t.Bytes() == nil { + ctx.Forward(t) + ctx.Compute(t) } s := make(S, mul(t.Shape()...)) @@ -191,7 +198,8 @@ func dump[S ~[]E, E number](t Tensor, items int, fn func(E) string) string { type DType int const ( - DTypeF32 DType = iota + DTypeOther DType = iota + DTypeF32 + DTypeF16 DTypeI32 - DTypeOther ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 609570672..6a727a60c 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -258,6 +258,10 @@ func (c *Context) Compute(tensors ...ml.Tensor) { } } +func (c *Context) MaxTensors() int { + return c.nodes +} + func shapeToGGML(shape []int) *C.int64_t { sh := make([]C.int64_t, len(shape)) for i, s := range shape { @@ -282,6 +286,8 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { switch dtype { case ml.DTypeF32: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape)) + case ml.DTypeF16: + t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape)) case ml.DTypeI32: t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape)) default: @@ -389,6 +395,8 @@ func (t *Tensor) DType() ml.DType { switch t.t._type { case C.GGML_TYPE_F32: return ml.DTypeF32 + case C.GGML_TYPE_F16: + return ml.DTypeF16 case C.GGML_TYPE_I32: return ml.DTypeI32 default: @@ -580,9 +588,14 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi ropeFactors = &Tensor{} } + dequant := t.t + if C.ggml_is_quantized(t.t._type) { + dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32) + } + return &Tensor{ t: C.ggml_rope_ext( - ctx.(*Context).ctx, t.t, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, + ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t, C.int(ropeDim), 131072, // YaRN n_ctx_train ropeTypeNorm, // ROPE_TYPE_NORM diff --git a/model/model.go b/model/model.go index 4a86d7d60..8a8c9b297 100644 --- a/model/model.go +++ b/model/model.go @@ -1,6 +1,7 @@ package model import ( + "errors" "fmt" "image" _ "image/jpeg" @@ -15,102 +16,42 @@ import ( _ "golang.org/x/image/tiff" _ "golang.org/x/image/webp" - "github.com/ollama/ollama/cache" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" ) -type Cache struct { - cache.Cache - cache.Options -} - -func (c Cache) Sub(i int) Cache { - if c.Cache != nil { - return Cache{ - Cache: c.Cache.Sub(i), - Options: c.Options, - } - } - - return c -} - -func (c Cache) Put(ctx ml.Context, key, value ml.Tensor, opts cache.Options) (ml.Tensor, ml.Tensor) { - if c.Cache != nil { - return c.Cache.Put(ctx, key, value, opts) - } - - return key, value -} - type Options struct { - inputs []int32 - - Offset int + Inputs []int32 + Positions []int32 + Sequences []int + Outputs []int32 Images []image.Image - - Cache } -func (opts Options) Inputs() []int32 { - return opts.inputs[opts.Offset:] -} - -func (opts Options) Positions() []int32 { - positions := make([]int32, len(opts.inputs)-opts.Offset) - for i := range positions { - positions[i] = int32(opts.Offset + i) - } - - return positions -} - -type OptionsFunc func(Model, *Options) - -func WithInputIDs(ids []int32) OptionsFunc { - return func(m Model, opts *Options) { - opts.inputs = ids - } -} - -func WithOffset(offset int) OptionsFunc { - return func(m Model, opts *Options) { - opts.Offset = offset - opts.Cache.Position = offset - } -} - -func WithImage(img image.Image) OptionsFunc { - return func(m Model, opts *Options) { - opts.Images = append(opts.Images, img) - } -} - -func WithCache(c cache.Cache) OptionsFunc { - return func(m Model, opts *Options) { - opts.Cache = Cache{ - Cache: c, - Options: cache.Options{ - Position: opts.Offset, - }, - } - } +type config struct { + Cache kvcache.Cache } type Base struct { b ml.Backend + config } func (m *Base) Backend() ml.Backend { return m.b } +func (m *Base) Config() config { + return m.config +} + type Model interface { Forward(ml.Context, Options) (ml.Tensor, error) Backend() ml.Backend + Config() config } var models = make(map[string]func(ml.Config) (Model, error)) @@ -146,12 +87,14 @@ func New(s string) (Model, error) { return nil, err } + base := Base{b: b, config: m.Config()} + v := reflect.ValueOf(m) - v.Elem().Set(populateFields(b, v.Elem())) + v.Elem().Set(populateFields(base, v.Elem())) return m, nil } -func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { +func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value { t := v.Type() if t.Kind() == reflect.Struct { @@ -170,7 +113,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { } if tt == reflect.TypeOf((*Base)(nil)).Elem() { - vv.Set(reflect.ValueOf(Base{b: b})) + vv.Set(reflect.ValueOf(base)) } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() { var fn func([]Tag) [][]string fn = func(tags []Tag) (values [][]string) { @@ -196,21 +139,21 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { names := fn(tagsCopy) for _, name := range names { - if tensor := b.Get(strings.Join(name, ".")); tensor != nil { + if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil { slog.Debug("found tensor", "", tensor) vv.Set(reflect.ValueOf(tensor)) break } } } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface { - setPointer(b, vv, tagsCopy) + setPointer(base, vv, tagsCopy) } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array { for i := range vv.Len() { vvv := vv.Index(i) if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface { - setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) + setPointer(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})) } else { - vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) + vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...)) } } } @@ -228,7 +171,7 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value { return v } -func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { +func setPointer(base Base, v reflect.Value, tags []Tag) { vv := v if v.Kind() == reflect.Interface { if v.IsNil() { @@ -243,7 +186,7 @@ func setPointer(b ml.Backend, v reflect.Value, tags []Tag) { vv = reflect.New(v.Type().Elem()).Elem() } - if f := populateFields(b, vv, tags...); f.CanAddr() { + if f := populateFields(base, vv, tags...); f.CanAddr() { v.Set(f.Addr()) } } @@ -277,18 +220,27 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(m Model, optsFuncs ...OptionsFunc) (ml.Tensor, error) { - var opts Options - for _, optsFunc := range optsFuncs { - optsFunc(m, &opts) +func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { + if len(opts.Positions) != len(opts.Sequences) { + return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) + } + + if len(opts.Positions) < 1 { + return nil, errors.New("batch size cannot be less than 1") + } + + cache := m.Config().Cache + if cache != nil { + err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + if err != nil { + return nil, err + } } - ctx := m.Backend().NewContext() t, err := m.Forward(ctx, opts) if err != nil { return nil, err } - defer ctx.Close() ctx.Forward(t) ctx.Compute(t) diff --git a/model/model_test.go b/model/model_test.go index 2ba12acde..02b8aa3c2 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -78,7 +78,7 @@ func TestPopulateFields(t *testing.T) { var m fakeModel v := reflect.ValueOf(&m) - v.Elem().Set(populateFields(&fakeBackend{ + v.Elem().Set(populateFields(Base{b: &fakeBackend{ names: []string{ "input.weight", "blk.0.attn_q.weight", @@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) { "output_norm.weight", "output.weight", }, - }, v.Elem())) + }}, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, @@ -121,11 +121,11 @@ func TestPopulateFieldsAlternateName(t *testing.T) { m := fakeModel{} v := reflect.ValueOf(&m) - v.Elem().Set(populateFields(&fakeBackend{ + v.Elem().Set(populateFields(Base{b: &fakeBackend{ names: []string{ "input.weight", }, - }, v.Elem())) + }}, v.Elem())) if diff := cmp.Diff(fakeModel{ Input: &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}}, diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 6efcc9bb7..b2c5c2c7b 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -3,6 +3,7 @@ package llama import ( "math" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" @@ -28,7 +29,7 @@ type Model struct { } func New(c ml.Config) (model.Model, error) { - return &Model{ + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ @@ -49,7 +50,11 @@ func New(c ml.Config) (model.Model, error) { ropeScale: c.Float("rope.freq_scale", 1), ropeDim: c.Uint("rope.dimension_count"), }, - }, nil + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + + return &m, nil } type SelfAttention struct { @@ -59,7 +64,7 @@ type SelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -74,7 +79,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - k, v = cache.Put(ctx, k, v, cache.Options) + cache.Put(ctx, k, v) + k, v, mask := cache.Get(ctx) q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -82,6 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten kq := k.MulmatFullPrec(ctx, q) kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) + kq = kq.Add(ctx, mask) kq = kq.Softmax(ctx) kqv := v.Mulmat(ctx, kq) @@ -91,6 +98,10 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten return sa.Output.Forward(ctx, kqv) } +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, m.Options.RopeFactors, m.Options.ropeDim, m.Options.ropeBase, m.Options.ropeScale), nil +} + type MLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` @@ -109,7 +120,7 @@ type Layer struct { MLP *MLP } -func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache model.Cache, opts *Options) ml.Tensor { +func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { residual := hiddenState hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -123,12 +134,12 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cach } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { - inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) + inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) + positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } @@ -136,13 +147,14 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) for i, layer := range m.Layers { - hiddenState = layer.Forward(ctx, hiddenState, positions, opts.Cache.Sub(i), m.Options) + m.Cache.SetLayer(i) + hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options) } hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.Output.Forward(ctx, hiddenState) - outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) + outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index e5b275b0b..a1460d940 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -1,6 +1,7 @@ package mllama import ( + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" @@ -18,8 +19,13 @@ type Model struct { ImageProcessor } +const ( + crossAttentionLayer = iota + selfAttentionLayer +) + func New(c ml.Config) (model.Model, error) { - return &Model{ + m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ @@ -33,7 +39,11 @@ func New(c ml.Config) (model.Model, error) { ImageProcessor: newImageProcessor(c), VisionModel: newVisionModel(c), TextModel: newTextModel(c), - }, nil + } + + m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) + + return &m, nil } func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { @@ -73,20 +83,20 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { crossAttentionStates = m.Projector.Forward(ctx, crossAttentionStates) } - inputs, err := ctx.FromIntSlice(opts.Inputs(), len(opts.Inputs())) + inputs, err := ctx.FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err } - positions, err := ctx.FromIntSlice(opts.Positions(), len(opts.Positions())) + positions, err := ctx.FromIntSlice(opts.Positions, len(opts.Positions)) if err != nil { return nil, err } // TODO: attention mask, cross attention mask - hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, opts.Cache) + hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)) - outputs, err := ctx.FromIntSlice([]int32{int32(len(opts.Positions())) - 1}, 1) + outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index 2b05a60ea..1e48086a3 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -4,9 +4,9 @@ import ( "math" "slices" + "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" - "github.com/ollama/ollama/model" ) type TextSelfAttention struct { @@ -16,7 +16,7 @@ type TextSelfAttention struct { Output *nn.Linear `gguf:"attn_output"` } -func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -31,7 +31,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) - key, value = cache.Put(ctx, key, value, cache.Options) + cache.Put(ctx, key, value) + key, value, mask := cache.Get(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -39,11 +40,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas scores := key.MulmatFullPrec(ctx, query) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim))) - - if mask != nil { - scores = scores.Add(ctx, mask) - } - + scores = scores.Add(ctx, mask) scores = scores.Softmax(ctx) attention := value.Mulmat(ctx, scores) @@ -53,6 +50,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas return sa.Output.Forward(ctx, attention) } +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + // This will only get called for layers in the cache, which are just the self attention layers + return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil +} + type TextMLP struct { Up *nn.Linear `gguf:"ffn_up"` Down *nn.Linear `gguf:"ffn_down"` @@ -72,7 +74,7 @@ type TextSelfAttentionDecoderLayer struct { MLP *TextMLP } -func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -94,23 +96,29 @@ type TextCrossAttention struct { Output *nn.Linear `gguf:"cross_attn_o_proj"` } -func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads - numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) query := ca.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = ca.QueryNorm.Forward(ctx, query, opts.eps) - key := ca.Key.Forward(ctx, crossAttentionStates) - key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) - key = ca.KeyNorm.Forward(ctx, key, opts.eps) + var key, value ml.Tensor + if crossAttentionStates != nil { + numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) - value := ca.Value.Forward(ctx, crossAttentionStates) - value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + key = ca.Key.Forward(ctx, crossAttentionStates) + key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + key = ca.KeyNorm.Forward(ctx, key, opts.eps) - // TODO cache key, value + value = ca.Value.Forward(ctx, crossAttentionStates) + value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) + + cache.Put(ctx, key, value) + } else { + key, value, _ = cache.Get(ctx) + } query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) @@ -137,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct { MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` } -func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { residual := hiddenState hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) @@ -153,17 +161,25 @@ func (d TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, } type TextDecoderLayer interface { - Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor + Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor } type TextDecoder struct { Layers []TextDecoderLayer } -func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache, opts *TextModelOptions) ml.Tensor { +func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { for i, layer := range d.Layers { - if !slices.Contains(opts.crossAttentionLayers, uint32(i)) || crossAttentionStates != nil { - hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache.Sub(i), opts) + layerType := selfAttentionLayer + if slices.Contains(opts.crossAttentionLayers, uint32(i)) { + layerType = crossAttentionLayer + } + + cache.SetLayer(i) + cache.SetLayerType(layerType) + + if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { + hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts) } } @@ -189,7 +205,7 @@ type TextModel struct { *TextModelOptions } -func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache model.Cache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) diff --git a/llama/runner/README.md b/runner/README.md similarity index 100% rename from llama/runner/README.md rename to runner/README.md diff --git a/llama/runner/stop.go b/runner/common/stop.go similarity index 87% rename from llama/runner/stop.go rename to runner/common/stop.go index 8dcb08d33..3f27a286e 100644 --- a/llama/runner/stop.go +++ b/runner/common/stop.go @@ -1,10 +1,10 @@ -package runner +package common import ( "strings" ) -func findStop(sequence string, stops []string) (bool, string) { +func FindStop(sequence string, stops []string) (bool, string) { for _, stop := range stops { if strings.Contains(sequence, stop) { return true, stop @@ -14,7 +14,7 @@ func findStop(sequence string, stops []string) (bool, string) { return false, "" } -func containsStopSuffix(sequence string, stops []string) bool { +func ContainsStopSuffix(sequence string, stops []string) bool { for _, stop := range stops { for i := 1; i <= len(stop); i++ { if strings.HasSuffix(sequence, stop[:i]) { @@ -29,7 +29,7 @@ func containsStopSuffix(sequence string, stops []string) bool { // truncateStop removes the provided stop string from pieces, // returning the partial pieces with stop removed, including truncating // the last piece if required (and signalling if this was the case) -func truncateStop(pieces []string, stop string) ([]string, bool) { +func TruncateStop(pieces []string, stop string) ([]string, bool) { joined := strings.Join(pieces, "") index := strings.Index(joined, stop) @@ -65,7 +65,7 @@ func truncateStop(pieces []string, stop string) ([]string, bool) { return result, tokenTruncated } -func incompleteUnicode(token string) bool { +func IncompleteUnicode(token string) bool { incomplete := false // check if there is incomplete UTF-8 character at the end diff --git a/llama/runner/stop_test.go b/runner/common/stop_test.go similarity index 96% rename from llama/runner/stop_test.go rename to runner/common/stop_test.go index 31dc161f3..8df267eb4 100644 --- a/llama/runner/stop_test.go +++ b/runner/common/stop_test.go @@ -1,4 +1,4 @@ -package runner +package common import ( "reflect" @@ -52,7 +52,7 @@ func TestTruncateStop(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, resultTrunc := truncateStop(tt.pieces, tt.stop) + result, resultTrunc := TruncateStop(tt.pieces, tt.stop) if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc { t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc) } @@ -120,7 +120,7 @@ func TestIncompleteUnicode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := incompleteUnicode(tt.input) + result := IncompleteUnicode(tt.input) if result != tt.expected { t.Errorf("incompleteUnicode(%s): have %v; want %v", tt.input, result, tt.expected) } diff --git a/llama/runner/cache.go b/runner/llamarunner/cache.go similarity index 99% rename from llama/runner/cache.go rename to runner/llamarunner/cache.go index e8a2d2994..d29e94b6b 100644 --- a/llama/runner/cache.go +++ b/runner/llamarunner/cache.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "errors" diff --git a/llama/runner/cache_test.go b/runner/llamarunner/cache_test.go similarity index 99% rename from llama/runner/cache_test.go rename to runner/llamarunner/cache_test.go index 9c838ed33..c0656dd97 100644 --- a/llama/runner/cache_test.go +++ b/runner/llamarunner/cache_test.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "testing" diff --git a/llama/runner/image.go b/runner/llamarunner/image.go similarity index 99% rename from llama/runner/image.go rename to runner/llamarunner/image.go index c1932443c..e7e30a4d8 100644 --- a/llama/runner/image.go +++ b/runner/llamarunner/image.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "errors" diff --git a/llama/runner/image_test.go b/runner/llamarunner/image_test.go similarity index 99% rename from llama/runner/image_test.go rename to runner/llamarunner/image_test.go index d5c3bc1e2..2e1efaec8 100644 --- a/llama/runner/image_test.go +++ b/runner/llamarunner/image_test.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "reflect" diff --git a/llama/runner/runner.go b/runner/llamarunner/runner.go similarity index 98% rename from llama/runner/runner.go rename to runner/llamarunner/runner.go index 60ae88dac..93d6bfabe 100644 --- a/llama/runner/runner.go +++ b/runner/llamarunner/runner.go @@ -1,4 +1,4 @@ -package runner +package llamarunner import ( "context" @@ -24,6 +24,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/runner/common" ) // input is an element of the prompt to process, either @@ -498,12 +499,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "") - if ok, stop := findStop(sequence, seq.stop); ok { + if ok, stop := common.FindStop(sequence, seq.stop); ok { slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) var tokenTruncated bool origLen := len(seq.pendingResponses) - seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) newLen := len(seq.pendingResponses) // Update the cache based on the tokens that will be returned: @@ -524,11 +525,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) continue } - if containsStopSuffix(sequence, seq.stop) { + if common.ContainsStopSuffix(sequence, seq.stop) { continue } - if incompleteUnicode(sequence) { + if common.IncompleteUnicode(sequence) { continue } @@ -885,9 +886,6 @@ func (s *Server) loadModel( } func Execute(args []string) error { - if args[0] == "runner" { - args = args[1:] - } fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") ppath := fs.String("mmproj", "", "Path to projector binary file") diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go new file mode 100644 index 000000000..e1fa98b1a --- /dev/null +++ b/runner/ollamarunner/cache.go @@ -0,0 +1,280 @@ +package ollamarunner + +import ( + "errors" + "fmt" + "log/slog" + "math" + "reflect" + "time" + + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" +) + +type InputCache struct { + // context window size (per slot) + numCtx int32 + + // does the cache store data or do we need to always send the full input? + // note that when enabled is false the underlying cache may either be nil + // or a non-nil dummy that doesn't actually store anything + enabled bool + + // individual KV caches + slots []InputCacheSlot + + // optimize cache eviction for multiple users + multiUserCache bool + + cache kvcache.Cache +} + +func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) { + if kvSize/int32(numSlots) < 1 { + return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots) + } + + slots := make([]InputCacheSlot, numSlots) + + for i := range slots { + slots[i] = InputCacheSlot{ + Id: i, + Inputs: make([]input, 0), + } + } + + cache := model.Config().Cache + if cache != nil { + cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize) + } + + return &InputCache{ + numCtx: kvSize / int32(numSlots), + enabled: cache != nil, + slots: slots, + multiUserCache: multiUserCache, + cache: cache, + }, nil +} + +func kvCacheTypeFromStr(s string) ml.DType { + switch s { + case "q8_0": + panic("kv cache quantization not yet implemented") + case "q4_0": + panic("kv cache quantization not yet implemented") + default: + return ml.DTypeF16 + } +} + +func (c *InputCache) Close() { + c.cache.Close() +} + +// Locking: Operations on InputCacheSlot (including finding one +// through LoadCacheSlot) require a lock to be be held that serializes +// these operations with each other and processBatch + +type InputCacheSlot struct { + // Index in the KV cache + Id int + + // Inputs that are stored in the KV cache + Inputs []input + + // is this cache actively being processed as part of a sequence? + InUse bool + + // last time this cache was used (as of start of processing) + lastUsed time.Time +} + +func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) { + var slot *InputCacheSlot + var numPast int32 + var err error + + // In single-user scenarios, the longest cache slot works fine for getting good input + // cache hit rates and it keeps the footprint of the cache small, which improves throughput. + // For multiple users, the "best" cache slot produces better input cache hit rates + // at the cost of worse performance when we miss the input cache. + if !c.multiUserCache { + slot, numPast, err = c.findLongestCacheSlot(prompt) + } else { + slot, numPast, err = c.findBestCacheSlot(prompt) + } + if err != nil { + return nil, nil, err + } + + if !cachePrompt { + numPast = 0 + } + + slot.InUse = true + slot.lastUsed = time.Now() + + if numPast == int32(len(prompt)) { + // Leave one input to sample so we can get a response + numPast-- + } + + if c.cache != nil { + err = c.cache.Remove(slot.Id, numPast, math.MaxInt32) + if err != nil { + // Some models don't support partial erasure + err = c.cache.Remove(slot.Id, 0, math.MaxInt32) + if err != nil { + return nil, nil, err + } + numPast = 0 + } + } + + slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt), + "used", numPast, "remaining", int32(len(prompt))-numPast) + + prompt = prompt[numPast:] + slot.Inputs = slot.Inputs[:numPast] + + return slot, prompt, nil +} + +func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { + longest := int32(-1) + var longestSlot *InputCacheSlot + + for i, s := range c.slots { + if s.InUse { + continue + } + + count := countCommonPrefix(s.Inputs, prompt) + if count > longest { + longest = count + longestSlot = &c.slots[i] + } + } + + if longestSlot == nil { + return nil, 0, errors.New("no available cache slots") + } + + return longestSlot, longest, nil +} + +func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) { + oldest := time.Now() + var oldestSlot *InputCacheSlot + + longest := int32(-1) + var longestSlot *InputCacheSlot + + for i, s := range c.slots { + count := countCommonPrefix(s.Inputs, prompt) + if count > longest { + longest = count + longestSlot = &c.slots[i] + } + + if s.lastUsed.Compare(oldest) < 0 && !s.InUse { + oldest = s.lastUsed + oldestSlot = &c.slots[i] + } + } + + if longest == int32(len(longestSlot.Inputs)) && !longestSlot.InUse { + return longestSlot, longest, nil + } + + if oldestSlot.InUse { + return nil, 0, errors.New("no available cache slots") + } + + if len(oldestSlot.Inputs) != 0 { + slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs), + "used", oldestSlot.lastUsed) + } + + if longest > 0 && longestSlot != oldestSlot { + slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", + len(longestSlot.Inputs)) + oldestSlot.Inputs = make([]input, longest) + copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) + if c.cache != nil { + c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) + } + } + + return oldestSlot, longest, nil +} + +func countCommonPrefix(a []input, b []input) int32 { + var count int32 + + for i := range a { + if i >= len(b) { + break + } + + if !reflect.DeepEqual(a[i], b[i]) { + break + } + + count++ + } + + return count +} + +func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { + targetFree := (c.numCtx - numKeep) / 2 + targetFree = max(targetFree, 1) + + currentFree := c.numCtx - inputLen + discard := targetFree - currentFree + + if discard < 0 { + discard = 0 + } + + return discard +} + +// Frees up space in the KV cache by deleting the oldest half of history and shifting +// the newest half into that space (saving numKeep inputs at the beginning). +// +// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx) +func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { + if numKeep >= c.numCtx { + return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx) + } + + inputLen := int32(len(slot.Inputs)) + discard := c.ShiftDiscard(inputLen, numKeep) + + if discard <= 0 { + return nil + } + + slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs), + "keep", numKeep, "discard", discard) + + // TODO (jessegross): KV cache removal can fail for certain types of models + if c.cache != nil { + err := c.cache.Remove(slot.Id, numKeep, numKeep+discard) + if err != nil { + return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err) + } + } + + for i := numKeep + discard; i < inputLen; i++ { + slot.Inputs[i-discard] = slot.Inputs[i] + } + slot.Inputs = slot.Inputs[:inputLen-discard] + + return nil +} diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go new file mode 100644 index 000000000..99e67b4fa --- /dev/null +++ b/runner/ollamarunner/cache_test.go @@ -0,0 +1,291 @@ +package ollamarunner + +import ( + "image" + "testing" + "time" +) + +func TestCountCommon(t *testing.T) { + imgA := image.NewRGBA(image.Rect(0, 0, 100, 100)) + imgB := image.NewRGBA(image.Rect(0, 0, 50, 50)) + imgC := image.NewRGBA(image.Rect(50, 50, 100, 100)) + + tests := []struct { + name string + t1 []input + t2 []input + expected int32 + }{ + { + name: "Equal", + t1: []input{{token: 1}, {token: 2}, {token: 3}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 3, + }, + { + name: "Prefix", + t1: []input{{token: 1}}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 1, + }, + { + name: "Image Prefix", + t1: []input{{image: imgA}}, + t2: []input{{image: imgA}, {image: imgB}, {image: imgC}}, + expected: 1, + }, + { + name: "Mixed", + t1: []input{{token: 1}, {image: imgA}}, + t2: []input{{token: 1}, {image: imgA}, {token: 5}}, + expected: 2, + }, + { + name: "Empty", + t1: []input{}, + t2: []input{{token: 1}, {token: 2}, {token: 3}}, + expected: 0, + }, + { + name: "Both Empty", + t1: []input{}, + t2: []input{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := countCommonPrefix(tt.t1, tt.t2) + if result != tt.expected { + t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected) + } + }) + } +} + +func TestFindCacheSlot(t *testing.T) { + type expected struct { + result int + len int32 + } + + tests := []struct { + name string + cache InputCache + prompt []input + longest expected + best expected + }{ + { + name: "Empty", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 0, len: 0}, + }, + { + name: "Extend", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 2}, + best: expected{result: 1, len: 2}, + }, + { + name: "New", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }}, + prompt: []input{{token: 2}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "Fork", + cache: InputCache{ + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{}, + InUse: false, + lastUsed: time.Time{}, + }, + }, + }, + prompt: []input{{token: 1}}, + longest: expected{result: 0, len: 1}, + best: expected{result: 1, len: 1}, + }, + { + name: "Evict", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 2}, {token: 3}}, + longest: expected{result: 0, len: 0}, + best: expected{result: 1, len: 0}, + }, + { + name: "In use", + cache: InputCache{slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input{{token: 1}, {token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input{{token: 1}}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }}, + prompt: []input{{token: 1}, {token: 2}}, + longest: expected{result: 1, len: 1}, + best: expected{result: 1, len: 2}, + }, + } + + for _, tt := range tests { + t.Run("Longest-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findLongestCacheSlot: err %v", err) + } else if result.Id != tt.longest.result || resultLen != tt.longest.len { + t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.longest.result, resultLen, tt.longest.len) + } + }) + } + + for _, tt := range tests { + t.Run("Best-"+tt.name, func(t *testing.T) { + result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt) + if err != nil { + t.Errorf("findBestCacheSlot: err %v", err) + } else if result.Id != tt.best.result || resultLen != tt.best.len { + t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v", + result.Id, tt.best.result, resultLen, tt.best.len) + } + }) + } +} + +func TestShiftDiscard(t *testing.T) { + tests := []struct { + name string + numCtx int32 + numKeep int32 + inputLen int32 + expected int32 + }{ + { + name: "Shift", + numCtx: 2048, + numKeep: 5, + inputLen: 2048, + expected: 1021, + }, + { + name: "Max Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 2048, + expected: 1, + }, + { + name: "No Keep", + numCtx: 2048, + numKeep: 0, + inputLen: 2048, + expected: 1024, + }, + { + name: "Truncate", + numCtx: 2048, + numKeep: 5, + inputLen: 5000, + expected: 3973, + }, + { + name: "Truncate Keep", + numCtx: 2048, + numKeep: 2047, + inputLen: 5000, + expected: 2953, + }, + { + name: "No Op", + numCtx: 2048, + numKeep: 5, + inputLen: 512, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := InputCache{numCtx: tt.numCtx} + result := c.ShiftDiscard(tt.inputLen, tt.numKeep) + if result != tt.expected { + t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go new file mode 100644 index 000000000..d5a3b3407 --- /dev/null +++ b/runner/ollamarunner/runner.go @@ -0,0 +1,945 @@ +package ollamarunner + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "image" + "log" + "log/slog" + "net" + "net/http" + "os" + "path/filepath" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" + + "golang.org/x/sync/semaphore" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/runner/common" + "github.com/ollama/ollama/sample" + + _ "github.com/ollama/ollama/model/models" +) + +// input is an element of the prompt to process, either a token or an image +type input struct { + token int32 + + image image.Image +} + +type Sequence struct { + // batch index + iBatch int + + // prompt inputs left to evaluate + inputs []input + + // inputs that have been added to a batch but not yet submitted to Forward + pendingInputs []input + + // tokens that have been generated but not returned yet (e.g. for stop sequences) + pendingResponses []string + + // input cache being used by this sequence + cache *InputCacheSlot + + // channel to send responses over + responses chan string + + // channel to stop decoding (such as if the remote connection is closed) + quit chan bool + + // number of tokens to predict + numPredict int + + // set of samplers to run on generated logits + samplers []sample.Sampler + + // channel to send back the embedding if embedding only + embedding chan []float32 + + // stop sequences + stop []string + + // number of inputs to keep at the beginning when shifting context window + numKeep int32 + + // true if an embedding are to be returned instead of text generation + embeddingOnly bool + + doneReason string + + // Metrics + startProcessingTime time.Time + startGenerationTime time.Time + numPredicted int + numPromptInputs int +} + +type NewSequenceParams struct { + numPredict int + stop []string + numKeep int32 + samplers []sample.Sampler + embedding bool +} + +func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) { + s.ready.Wait() + + startTime := time.Now() + + inputs, err := s.inputs(prompt, images) + if err != nil { + return nil, fmt.Errorf("failed to process inputs: %w", err) + } else if len(inputs) == 0 { + return nil, errors.New("no input provided") + } + + if params.numKeep < 0 { + params.numKeep = int32(len(inputs)) + } + + // Ensure that at least 1 input can be discarded during shift + params.numKeep = min(params.numKeep, s.cache.numCtx-1) + + if int32(len(inputs)) > s.cache.numCtx { + discard := int32(len(inputs)) - s.cache.numCtx + newInputs := inputs[:params.numKeep] + newInputs = append(newInputs, inputs[params.numKeep+discard:]...) + + slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs)) + inputs = newInputs + } + + // TODO(jessegross): Ingest cached history for grammar + + return &Sequence{ + inputs: inputs, + numPromptInputs: len(inputs), + startProcessingTime: startTime, + numPredict: params.numPredict, + pendingResponses: make([]string, 0), + responses: make(chan string, 100), + quit: make(chan bool, 1), + embedding: make(chan []float32, 1), + samplers: params.samplers, + embeddingOnly: params.embedding, + stop: params.stop, + numKeep: params.numKeep, + }, nil +} + +// inputs processes the prompt and images into a list of inputs +// by splitting the prompt on [img-] tags, tokenizing text and +// decoding images +func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { + var inputs []input + var parts []string + var matches [][]string + + // TODO(jessegross): This can sometimes trigger for matching text in the + // user's prompt. We previously tried to avoid it by only looking for images + // on image models. We don't have a clear indication now but it would be better + // to properly escape it in any case. + re := regexp.MustCompile(`\[img-(\d+)\]`) + parts = re.Split(prompt, -1) + matches = re.FindAllStringSubmatch(prompt, -1) + + for i, part := range parts { + // text - tokenize + tokens, err := s.model.(model.TextProcessor).Encode(part) + if err != nil { + return nil, err + } + + for _, t := range tokens { + inputs = append(inputs, input{token: t}) + } + + // image - decode and store + if i < len(matches) { + n, _ := strconv.Atoi(matches[i][1]) + + imageIndex := -1 + for j := range images { + if images[j].ID == n { + imageIndex = j + break + } + } + + if imageIndex < 0 { + return nil, fmt.Errorf("invalid image index: %d", n) + } + + image, _, err := image.Decode(bytes.NewReader(images[imageIndex].Data)) + if err != nil { + return nil, err + } + + inputs = append(inputs, input{image: image}) + } + } + + return inputs, nil +} + +type Server struct { + // is the server ready to process requests? + // protects access to model and image + ready sync.WaitGroup + + // loaded model + model model.Model + + // status for external health reporting - loading, ready to serve, etc. + status ServerStatus + + // current progress on loading the model + progress float32 + + // number of simultaneous requests to handle + parallel int + + // maximum number of elements in a batch (per sequence) + // TODO (jmorganca): make this n_batch + batchSize int + + // protects access to everything below this line + // this is context state needed for decoding + mu sync.Mutex + + // indicates that data is ready for processing + cond *sync.Cond + + // the list of simultaneous sequences being evaluated + seqs []*Sequence + + // seqs can have a maximum of parallel entries, which + // is enfoced by seqSem + seqsSem *semaphore.Weighted + + // KV cache + cache *InputCache + + // next sequence for prompt processing to avoid starvation + nextSeq int +} + +func (s *Server) allNil() bool { + for _, item := range s.seqs { + if item != nil { + return false + } + } + return true +} + +func flushPending(seq *Sequence) bool { + joined := strings.Join(seq.pendingResponses, "") + seq.pendingResponses = []string{} + + // Check if there are any partial UTF-8 characters remaining. + // We already check and queue as we are generating but some may + // still make it here: + // - Sequence is ending, e.g. generation limit has been hit + // - Invalid characters in the middle of a string + // This is a stricter check to ensure we never output invalid Unicode. + for !utf8.ValidString(joined) { + joined = joined[:len(joined)-1] + } + + if len(joined) == 0 { + return true + } + + select { + case seq.responses <- joined: + return true + case <-seq.quit: + return false + } +} + +func (s *Server) removeSequence(seqIndex int, reason string) { + seq := s.seqs[seqIndex] + + flushPending(seq) + seq.doneReason = reason + close(seq.responses) + close(seq.embedding) + seq.cache.InUse = false + s.seqs[seqIndex] = nil + s.seqsSem.Release(1) +} + +func (s *Server) run(ctx context.Context) { + s.ready.Wait() + + for { + select { + case <-ctx.Done(): + return + default: + err := s.processBatch() + if err != nil { + panic(err) + } + } + } +} + +func (s *Server) processBatch() error { + s.mu.Lock() + for s.allNil() { + s.cond.Wait() // Wait until an item is added + } + defer s.mu.Unlock() + + var options model.Options + imgSeq := -1 + + seqIdx := s.nextSeq - 1 + for range s.seqs { + seqIdx = (seqIdx + 1) % len(s.seqs) + seq := s.seqs[seqIdx] + + if seq == nil { + continue + } + + // if past the num predict limit + if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { + s.removeSequence(seqIdx, "limit") + continue + } + + if !s.cache.enabled { + seq.inputs = append(seq.cache.Inputs, seq.inputs...) + seq.cache.Inputs = []input{} + } + + for i, input := range seq.inputs { + if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { + if len(seq.pendingInputs) == 0 { + err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) + if err != nil { + return err + } + } else { + break + } + } + + if i >= s.batchSize { + break + } + + // TODO(jessegross): Image inputs need to be rethought - it's + // it doesn't work well for different types of models or multiple sequences + if input.image != nil { + if len(seq.pendingInputs) != len(options.Images) { + break + } + + if imgSeq != seqIdx && imgSeq != -1 { + s.nextSeq = seqIdx + break + } + + imgSeq = seqIdx + options.Images = append(options.Images, input.image) + seq.pendingInputs = append(seq.pendingInputs, input) + continue + } + + options.Inputs = append(options.Inputs, input.token) + options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) + options.Sequences = append(options.Sequences, seq.cache.Id) + + seq.iBatch = len(options.Outputs) + if i+1 == len(seq.inputs) { + options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) + } + seq.pendingInputs = append(seq.pendingInputs, input) + } + + seq.inputs = seq.inputs[len(seq.pendingInputs):] + } + + if len(options.Inputs) == 0 { + return nil + } + + ctx := s.model.Backend().NewContext() + defer ctx.Close() + + modelOutput, err := model.Forward(ctx, s.model, options) + if err != nil { + return fmt.Errorf("failed to decode batch: %w", err) + } + + f32s := modelOutput.Floats() + + // TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s + logits := make([]float64, len(f32s)) + for i, f32 := range f32s { + logits[i] = float64(f32) + } + + for i, seq := range s.seqs { + if seq == nil { + continue + } + + // After calling Forward, pending inputs are now in the cache + if len(seq.pendingInputs) > 0 { + seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) + seq.pendingInputs = []input{} + } + + // don't sample prompt processing + if len(seq.inputs) != 0 { + if !s.cache.enabled { + return errors.New("caching disabled but unable to fit entire input in a batch") + } + continue + } + + seq.numPredicted++ + if seq.numPredicted == 1 { + seq.startGenerationTime = time.Now() + } + + // if done processing the prompt, generate an embedding and return + if seq.embeddingOnly { + // TODO(jessegross): Embedding support + s.removeSequence(i, "") + continue + } + + // sample a token + vocabSize := len(f32s) / len(options.Outputs) + tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...) + if err != nil { + return err + } + + // TODO(jessegross): Sampler will output a single int32 in the future + token := int32(tokens[0]) + + // if it's an end of sequence token, break + if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { + // TODO (jmorganca): we should send this back + // as it's important for the /api/generate context + // seq.responses <- piece + + s.removeSequence(i, "stop") + continue + } + + piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) + if err != nil { + return err + } + + seq.inputs = []input{{token: token}} + + seq.pendingResponses = append(seq.pendingResponses, piece) + sequence := strings.Join(seq.pendingResponses, "") + + if ok, stop := common.FindStop(sequence, seq.stop); ok { + slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop) + + var tokenTruncated bool + origLen := len(seq.pendingResponses) + seq.pendingResponses, tokenTruncated = common.TruncateStop(seq.pendingResponses, stop) + newLen := len(seq.pendingResponses) + + // Update the cache based on the tokens that will be returned: + // - We have 1 token more than is currently in the cache because + // the last one generated wasn't submitted to Decode + // - Remove any stop sequences that we stripped out + // - If truncateStop removed a portion of a token, drop that + // - As defense-in-depth, if truncatedToken didn't find a stop token + // remove the extra one that we added to the cache len + tokenLen := len(seq.cache.Inputs) + 1 + tokenLen -= origLen - newLen + if tokenTruncated || origLen == newLen { + tokenLen-- + } + seq.cache.Inputs = seq.cache.Inputs[:tokenLen] + + s.removeSequence(i, "stop") + continue + } + + if common.ContainsStopSuffix(sequence, seq.stop) { + continue + } + + if common.IncompleteUnicode(sequence) { + continue + } + + if !flushPending(seq) { + s.removeSequence(i, "connection") + } + } + + return nil +} + +// TODO (jmorganca): use structs from the api package to avoid duplication +// this way the api acts as a proxy instead of using a different api for the +// runner +type Options struct { + api.Runner + + NumKeep int `json:"n_keep"` + Seed int `json:"seed"` + NumPredict int `json:"n_predict"` + TopK int `json:"top_k"` + TopP float32 `json:"top_p"` + MinP float32 `json:"min_p"` + TypicalP float32 `json:"typical_p"` + RepeatLastN int `json:"repeat_last_n"` + Temperature float32 `json:"temperature"` + RepeatPenalty float32 `json:"repeat_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + FrequencyPenalty float32 `json:"frequency_penalty"` + Mirostat int `json:"mirostat"` + MirostatTau float32 `json:"mirostat_tau"` + MirostatEta float32 `json:"mirostat_eta"` + Stop []string `json:"stop"` +} + +type ImageData struct { + Data []byte `json:"data"` + ID int `json:"id"` + AspectRatioID int `json:"aspect_ratio_id"` +} + +type CompletionRequest struct { + Prompt string `json:"prompt"` + Images []ImageData `json:"image_data"` + Grammar string `json:"grammar"` + CachePrompt bool `json:"cache_prompt"` + + Options +} + +type Timings struct { + PredictedN int `json:"predicted_n"` + PredictedMS float64 `json:"predicted_ms"` + PromptN int `json:"prompt_n"` + PromptMS float64 `json:"prompt_ms"` +} + +type CompletionResponse struct { + Content string `json:"content"` + Stop bool `json:"stop"` + + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + StoppedLimit bool `json:"stopped_limit,omitempty"` + PredictedN int `json:"predicted_n,omitempty"` + PredictedMS float64 `json:"predicted_ms,omitempty"` + PromptN int `json:"prompt_n,omitempty"` + PromptMS float64 `json:"prompt_ms,omitempty"` + + Timings Timings `json:"timings"` +} + +func getSamplers(_ CompletionRequest) []sample.Sampler { + // TODO(jessegross): Waiting for sampling code + + /*samplingParams.TopK = req.TopK + samplingParams.TopP = req.TopP + samplingParams.MinP = req.MinP + samplingParams.TypicalP = req.TypicalP + samplingParams.Temp = req.Temperature + samplingParams.RepeatLastN = req.RepeatLastN + samplingParams.PenaltyRepeat = req.RepeatPenalty + samplingParams.PenaltyFreq = req.FrequencyPenalty + samplingParams.PenaltyPresent = req.PresencePenalty + samplingParams.Mirostat = req.Mirostat + samplingParams.MirostatTau = req.MirostatTau + samplingParams.MirostatEta = req.MirostatEta + samplingParams.Seed = uint32(req.Seed) + samplingParams.Grammar = req.Grammar*/ + + return []sample.Sampler{sample.Greedy()} +} + +func (s *Server) completion(w http.ResponseWriter, r *http.Request) { + var req CompletionRequest + req.Options = Options(api.DefaultOptions()) + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Set the headers to indicate streaming + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Transfer-Encoding", "chunked") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ + numPredict: req.NumPredict, + stop: req.Stop, + numKeep: int32(req.NumKeep), + samplers: getSamplers(req), + embedding: false, + }) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting completion request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + for { + select { + case <-r.Context().Done(): + close(seq.quit) + return + case content, ok := <-seq.responses: + if ok { + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Content: content, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + close(seq.quit) + return + } + + flusher.Flush() + } else { + // Send the final response + if err := json.NewEncoder(w).Encode(&CompletionResponse{ + Stop: true, + StoppedLimit: seq.doneReason == "limit", + Timings: Timings{ + PromptN: seq.numPromptInputs, + PromptMS: float64(seq.startGenerationTime.Sub(seq.startProcessingTime).Milliseconds()), + PredictedN: seq.numPredicted, + PredictedMS: float64(time.Since(seq.startGenerationTime).Milliseconds()), + }, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError) + } + + return + } + } + } +} + +type EmbeddingRequest struct { + Content string `json:"content"` + CachePrompt bool `json:"cache_prompt"` +} + +type EmbeddingResponse struct { + Embedding []float32 `json:"embedding"` +} + +func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) { + var req EmbeddingRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + + slog.Debug("embedding request", "content", req.Content) + + seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true}) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError) + return + } + + // Ensure there is a place to put the sequence, released when removed from s.seqs + if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { + if errors.Is(err, context.Canceled) { + slog.Info("aborting embeddings request due to client closing the connection") + } else { + slog.Error("Failed to acquire semaphore", "error", err) + } + return + } + + s.mu.Lock() + found := false + for i, sq := range s.seqs { + if sq == nil { + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt) + if err != nil { + s.mu.Unlock() + http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError) + return + } + s.seqs[i] = seq + s.cond.Signal() + found = true + break + } + } + s.mu.Unlock() + + if !found { + http.Error(w, "could not find an available sequence", http.StatusInternalServerError) + return + } + + embedding := <-seq.embedding + + if err := json.NewEncoder(w).Encode(&EmbeddingResponse{ + Embedding: embedding, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type HealthResponse struct { + Status string `json:"status"` + Progress float32 `json:"progress"` +} + +type ServerStatus int + +const ( + ServerStatusReady ServerStatus = iota + ServerStatusLoadingModel + ServerStatusError +) + +func (s ServerStatus) ToString() string { + switch s { + case ServerStatusReady: + return "ok" + case ServerStatusLoadingModel: + return "loading model" + default: + return "server error" + } +} + +func (s *Server) health(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(&HealthResponse{ + Status: s.status.ToString(), + Progress: s.progress, + }); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } +} + +type multiLPath []string + +func (m *multiLPath) Set(value string) error { + *m = append(*m, value) + return nil +} + +func (m *multiLPath) String() string { + return strings.Join(*m, ", ") +} + +func (s *Server) loadModel( + mpath string, + lpath multiLPath, + parallel int, + kvCacheType string, + kvSize int, + multiUserCache bool, +) { + var err error + s.model, err = model.New(mpath) + if err != nil { + panic(err) + } + + // TODO(jessegross): LoRA loading + if lpath.String() != "" { + panic("loras are not yet implemented") + } + + s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache) + if err != nil { + panic(err) + } + + if !s.cache.enabled && parallel > 1 { + parallel = 1 + slog.Warn("model does not support caching, disabling parallel processing") + } + + s.parallel = parallel + s.seqs = make([]*Sequence, s.parallel) + s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + + s.status = ServerStatusReady + s.ready.Done() +} + +func Execute(args []string) error { + fs := flag.NewFlagSet("runner", flag.ExitOnError) + mpath := fs.String("model", "", "Path to model binary file") + parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") + batchSize := fs.Int("batch-size", 512, "Batch size") + _ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") + _ = fs.Int("main-gpu", 0, "Main GPU") + _ = fs.Bool("flash-attn", false, "Enable flash attention") + kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") + kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") + port := fs.Int("port", 8080, "Port to expose the server on") + _ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") + verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") + _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") + _ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing") + _ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") + multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") + + var lpaths multiLPath + fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)") + + fs.Usage = func() { + fmt.Fprintf(fs.Output(), "Runner usage\n") + fs.PrintDefaults() + } + if err := fs.Parse(args); err != nil { + return err + } + level := slog.LevelInfo + if *verbose { + level = slog.LevelDebug + } + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + AddSource: true, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + if attr.Key == slog.SourceKey { + source := attr.Value.Any().(*slog.Source) + source.File = filepath.Base(source.File) + } + return attr + }, + }) + slog.SetDefault(slog.New(handler)) + slog.Info("starting ollama engine") + // TODO(jessegross): Some system info would be useful + + server := &Server{ + batchSize: *batchSize, + status: ServerStatusLoadingModel, + } + + // TODO(jessegross): Parameters that need to be implemented: + // n-gpu-layers + // main-gpu + // flash-attn + // threads + // no-mmap + // mlock + // tensor-split + + /*var tensorSplitFloats []float32 + if *tensorSplit != "" { + stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1) + + tensorSplitFloats = make([]float32, 0, len(stringFloats)) + for _, s := range stringFloats { + f, _ := strconv.ParseFloat(s, 32) + tensorSplitFloats = append(tensorSplitFloats, float32(f)) + } + }*/ + + server.ready.Add(1) + go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) + + server.cond = sync.NewCond(&server.mu) + + ctx, cancel := context.WithCancel(context.Background()) + go server.run(ctx) + + addr := "127.0.0.1:" + strconv.Itoa(*port) + listener, err := net.Listen("tcp", addr) + if err != nil { + fmt.Println("Listen error:", err) + cancel() + return err + } + defer listener.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/embedding", server.embeddings) + mux.HandleFunc("/completion", server.completion) + mux.HandleFunc("/health", server.health) + + httpServer := http.Server{ + Handler: mux, + } + + log.Println("Server listening on", addr) + if err := httpServer.Serve(listener); err != nil { + log.Fatal("server error:", err) + return err + } + + cancel() + return nil +} diff --git a/runner/runner.go b/runner/runner.go new file mode 100644 index 000000000..500fdd72e --- /dev/null +++ b/runner/runner.go @@ -0,0 +1,24 @@ +package runner + +import ( + "github.com/ollama/ollama/runner/llamarunner" + "github.com/ollama/ollama/runner/ollamarunner" +) + +func Execute(args []string) error { + if args[0] == "runner" { + args = args[1:] + } + + var newRunner bool + if args[0] == "--ollama-engine" { + args = args[1:] + newRunner = true + } + + if newRunner { + return ollamarunner.Execute(args) + } else { + return llamarunner.Execute(args) + } +} diff --git a/server/prompt.go b/server/prompt.go index 610891729..233dffd69 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/llm" "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/template" @@ -92,26 +93,33 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. var imgData llm.ImageData if isMllama { - data, opts, err := mllama.Preprocess(bytes.NewReader(i)) - if err != nil { - return "", nil, err - } + if envconfig.NewEngine() { + imgData = llm.ImageData{ + ID: len(images), + Data: i, + } + } else { + data, opts, err := mllama.Preprocess(bytes.NewReader(i)) + if err != nil { + return "", nil, err + } - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - return "", nil, err - } + buf := new(bytes.Buffer) + err = binary.Write(buf, binary.LittleEndian, data) + if err != nil { + return "", nil, err + } - ar, ok := opts["aspectRatioIndex"].(int) - if !ok { - return "", nil, fmt.Errorf("missing aspect ratio for image") - } + ar, ok := opts["aspectRatioIndex"].(int) + if !ok { + return "", nil, fmt.Errorf("missing aspect ratio for image") + } - imgData = llm.ImageData{ - ID: len(images), - Data: buf.Bytes(), - AspectRatioID: ar, + imgData = llm.ImageData{ + ID: len(images), + Data: buf.Bytes(), + AspectRatioID: ar, + } } imgPrompt = "<|image|>" } else { diff --git a/server/routes.go b/server/routes.go index 779d3205d..95485cb81 100644 --- a/server/routes.go +++ b/server/routes.go @@ -203,7 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama { + if isMllama && !envconfig.NewEngine() { data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})