mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
runner: clear cache when shift is not possible (#9433)
Clear KV cache when shift operation is not supported by model. Added KvCacheCanShift() check to handle models that can't perform cache shifts, falling back to full cache clear while preserving logical token history to maintain expected behavior when context window fills up.
This commit is contained in:
parent
ef27d52e79
commit
66b2539238
6 changed files with 180 additions and 14 deletions
|
@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
|
|||
C.llama_kv_cache_defrag(c.c)
|
||||
}
|
||||
|
||||
func (c *Context) KvCacheCanShift() bool {
|
||||
return bool(C.llama_kv_cache_can_shift(c.c))
|
||||
}
|
||||
|
||||
// Get the embeddings for a sequence id
|
||||
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
|
||||
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
|
||||
|
|
|
@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
|||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
|
||||
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
|
@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
inputLen := len(slot.Inputs)
|
||||
discard := c.ShiftDiscard(inputLen, numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
|
@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
|||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
var shiftFailed bool
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
if c.lc.KvCacheCanShift() {
|
||||
// For models that support shifting, attempt to shift the KV cache
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
} else {
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
|
||||
}
|
||||
} else {
|
||||
// For models that don't support shifting
|
||||
shiftFailed = true
|
||||
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
|
||||
}
|
||||
|
||||
if shiftFailed {
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Clear the entire KV cache
|
||||
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||
// Reset the slot inputs since we've cleared the cache
|
||||
slot.Inputs = []input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
|
||||
// Standard shift succeeded - update input array
|
||||
for i := numKeep + discard; i < inputLen; i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
slot.Inputs = slot.Inputs[:inputLen-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Continue processing as normal
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break
|
||||
|
|
|
@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||
return discard
|
||||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
|
@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if c.cache != nil {
|
||||
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
|
||||
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
|
||||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, -1)
|
||||
slot.Inputs = []input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package ollamarunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
|
@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementation of the Cache interface
|
||||
type mockCache struct {
|
||||
shouldFail bool
|
||||
}
|
||||
|
||||
// Implement only the methods needed for the test
|
||||
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
|
||||
if m.shouldFail {
|
||||
return fmt.Errorf("mock cache removal error")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stub implementations for other interface methods
|
||||
func (m *mockCache) SetLayer(layer int) {}
|
||||
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
|
||||
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||
func (m *mockCache) Close() {}
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||
|
||||
func TestShiftCacheSlot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
wantInputsLen int
|
||||
}{
|
||||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
wantInputsLen: 6, // After discarding 4 tokens
|
||||
},
|
||||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
wantInputsLen: 0, // Original inputs should be cleared
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &mockCache{shouldFail: tt.cacheErr}
|
||||
c := InputCache{
|
||||
numCtx: tt.numCtx,
|
||||
cache: mock,
|
||||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
err := c.ShiftCacheSlot(slot, tt.numKeep)
|
||||
|
||||
if tt.wantErr != nil {
|
||||
if err == nil {
|
||||
t.Errorf("Expected error but got nil")
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.As(err, &tt.wantErr) {
|
||||
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(slot.Inputs) != tt.wantInputsLen {
|
||||
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -407,7 +407,15 @@ func (s *Server) processBatch() error {
|
|||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
return err
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue