diff --git a/kvcache/cache.go b/kvcache/cache.go index 07015b9e0..405c79733 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -56,8 +56,9 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding - // entry in positions and seqs. - StartForward(ctx ml.Context, batch input.Batch) error + // entry in positions and seqs. reserve is to preallocate memory + // without actually storing data in the cache. + StartForward(ctx ml.Context, batch input.Batch, reserve bool) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 4fc18d88f..466722845 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -146,51 +146,60 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - c.updateSlidingWindow() + if !reserve { + c.updateSlidingWindow() + + var err error + c.curLoc, err = c.findStartLoc() + if errors.Is(err, ErrKvCacheFull) { + c.defrag() + c.curLoc, err = c.findStartLoc() + } + if err != nil { + return err + } + + c.curCellRange = newRange() + for i, pos := range batch.Positions { + seq := batch.Sequences[i] + + c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} + + seqRange, ok := c.cellRanges[seq] + if !ok { + seqRange = newRange() + } + + if c.curLoc+i > seqRange.max { + seqRange.max = c.curLoc + i + } + if seqRange.max > c.curCellRange.max { + c.curCellRange.max = seqRange.max + } + + if c.curLoc+i < seqRange.min { + seqRange.min = c.curLoc + i + } + if seqRange.min < c.curCellRange.min { + c.curCellRange.min = seqRange.min + } + c.cellRanges[seq] = seqRange + } + } else { + // If we are reserving memory, don't update any of the cache metadata but set the size + // to the worst case. + c.curLoc = 0 + c.curCellRange.min = 0 + c.curCellRange.max = len(c.cells) - 1 + } var err error - c.curLoc, err = c.findStartLoc() - if errors.Is(err, ErrKvCacheFull) { - c.defrag() - c.curLoc, err = c.findStartLoc() - } - if err != nil { - return err - } - - c.curCellRange = newRange() - for i, pos := range batch.Positions { - seq := batch.Sequences[i] - - c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} - - seqRange, ok := c.cellRanges[seq] - if !ok { - seqRange = newRange() - } - - if c.curLoc+i > seqRange.max { - seqRange.max = c.curLoc + i - } - if seqRange.max > c.curCellRange.max { - c.curCellRange.max = seqRange.max - } - - if c.curLoc+i < seqRange.min { - seqRange.min = c.curLoc + i - } - if seqRange.min < c.curCellRange.min { - c.curCellRange.min = seqRange.min - } - c.cellRanges[seq] = seqRange - } - c.curMask, err = c.buildMask(ctx) return err diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bd63214cb..78f600905 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -281,7 +281,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}) + err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) if err != nil { panic(err) } @@ -315,7 +315,7 @@ func TestCanResume(t *testing.T) { err := cache.StartForward(context, input.Batch{ Positions: []int32{0, 1, 2, 3}, Sequences: []int{0, 0, 0, 0}, - }) + }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } @@ -342,7 +342,7 @@ func TestCanResume(t *testing.T) { err = cache.StartForward(context, input.Batch{ Positions: []int32{4, 5}, Sequences: []int{0, 0}, - }) + }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } @@ -440,6 +440,8 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} +func (c *testContext) Reserve() error { return nil } + func (c *testContext) MaxGraphNodes() int { return 10 } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 03d650a3f..0f269c3ee 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -27,6 +27,11 @@ type EncoderCache struct { // anything will be stored) curPos int32 + // curReserve indicates that this forward pass is only for + // memory reservation and we should not update our metadata + // based on it. + curReserve bool + // ** cache metadata ** // was something stored in the cache? @@ -83,12 +88,14 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { // We work with the most recent image if len(batch.Multimodal) > 0 { c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] } + c.curReserve = reserve + return nil } @@ -105,8 +112,10 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { } func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { - c.encoderPos = c.curPos - c.encoderCached = true + if !c.curReserve { + c.encoderPos = c.curPos + c.encoderCached = true + } if c.config.PermutedV { value = value.Permute(ctx, 1, 2, 0, 3) diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 926bc2d41..7533d959e 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -41,9 +41,9 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, batch) + err := cache.StartForward(ctx, batch, reserve) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { diff --git a/ml/backend.go b/ml/backend.go index fffc04a48..b2a83cfd5 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -97,6 +97,13 @@ type Context interface { Forward(...Tensor) Context Compute(...Tensor) + + // Reserve is analogous to Compute but rather than executing a + // graph, simply preallocates memory. Typically called with a + // worst case graph to ensure all resources are available for + // for future inference. + Reserve() error + MaxGraphNodes() int Close() diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 0aafd60b3..24bdd903d 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,6 +10,7 @@ import "C" import ( "context" + "errors" "fmt" "io" "log/slog" @@ -42,8 +43,12 @@ func devices() []*C.struct_ggml_backend_device { } type Backend struct { - meta *fsggml.GGML - sched *C.struct_ggml_backend_sched + meta *fsggml.GGML + + sched *C.struct_ggml_backend_sched + schedBackends []*C.struct_ggml_backend + schedBufts []*C.struct_ggml_backend_buffer_type + tensors map[string]*C.struct_ggml_tensor // input is the backend used for inputs @@ -389,8 +394,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, schedBackends = append(schedBackends, b) schedBufts = append(schedBufts, bt) - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) - if C.ggml_backend_is_cpu(b) { // set number of threads for cpu backend C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads))) @@ -409,7 +412,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, C.size_t(maxGraphNodes), C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), ), - input: deviceBufferTypes[input.d], + schedBackends: schedBackends, + schedBufts: schedBufts, + input: deviceBufferTypes[input.d], layers: func() map[int]*C.struct_ggml_backend_buffer_type { m := make(map[int]*C.struct_ggml_backend_buffer_type) for i, layer := range layers { @@ -534,6 +539,24 @@ func (c Context) Compute(tensors ...ml.Tensor) { } } +func (c Context) Reserve() error { + if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { + C.ggml_backend_sched_reset(c.b.sched) + return errors.New("failed to reserve graph") + } + + slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) + for i := range c.b.schedBackends { + size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), + "size", format.HumanBytes2(uint64(size))) + } + + C.ggml_backend_sched_reset(c.b.sched) + + return nil +} + func (c Context) MaxGraphNodes() int { return c.maxGraphNodes } diff --git a/model/model.go b/model/model.go index bc8944d22..ab96c4c70 100644 --- a/model/model.go +++ b/model/model.go @@ -299,7 +299,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, batch) + err := cache.StartForward(ctx, batch, false) if err != nil { return nil, err } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 543b4b2fa..062b654cf 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -448,7 +448,7 @@ func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} func (m *mockCache) Close() {} -func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } +func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil } func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) SetConfig(ml.CacheConfig) {} func (m *mockCache) CanResume(seq int, pos int32) bool { return true } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 7b7e09402..fee052805 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -728,6 +728,51 @@ func (m *multiLPath) String() string { return strings.Join(*m, ", ") } +func (s *Server) reserveWorstCaseGraph() error { + ctx := s.model.Backend().NewContext() + defer ctx.Close() + + var batch input.Batch + + inputs := make([]int32, s.batchSize) + batch.Positions = make([]int32, len(inputs)) + batch.Sequences = make([]int, len(inputs)) + for i := range inputs { + batch.Positions[i] = int32(i) + } + + batch.Outputs = make([]int32, s.parallel) + for i := range batch.Outputs { + batch.Outputs[i] = int32(i) + } + + var err error + batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + if err != nil { + return err + } + + cache := s.model.Config().Cache + if cache != nil { + err := cache.StartForward(ctx, batch, true) + if err != nil { + return err + } + } + + t, err := s.model.Forward(ctx, batch) + if err != nil { + return err + } + + err = ctx.Forward(t).Reserve() + if err != nil { + return err + } + + return nil +} + func (s *Server) loadModel( ctx context.Context, mpath string, @@ -765,6 +810,11 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + err = s.reserveWorstCaseGraph() + if err != nil { + panic(err) + } + s.status = llm.ServerStatusReady s.ready.Done() }