From dbb149e6f78673cc1c84e6527321c740d8d36a9a Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 3 Apr 2025 12:50:20 -0700 Subject: [PATCH] ollamarunner: Preallocate worst case graph at startup Currently, the KV cache and graph are lazily allocated as needed. The cache is fully allocated on first use of the corresponding layer whereas the graph grows with the size of the context. This can be an issue if another application allocates more VRAM after we do our calculations - Ollama will crash in the middle of inference. If we instead allocate the maximum needed memory at startup of the runner, we will either succeed or fail at that point rather than at some surprising time in the future. Currently, this only generates a worst case batch for text, which means that vision models may get a partial allocation and continue to lazily allocate the rest. --- kvcache/cache.go | 5 +- kvcache/causal.go | 85 +++++++++++++++++-------------- kvcache/causal_test.go | 8 +-- kvcache/encoder.go | 15 ++++-- kvcache/wrapper.go | 4 +- ml/backend.go | 7 +++ ml/backend/ggml/ggml.go | 33 ++++++++++-- model/model.go | 2 +- runner/ollamarunner/cache_test.go | 2 +- runner/ollamarunner/runner.go | 50 ++++++++++++++++++ 10 files changed, 156 insertions(+), 55 deletions(-) diff --git a/kvcache/cache.go b/kvcache/cache.go index 07015b9e0..405c79733 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -56,8 +56,9 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding - // entry in positions and seqs. - StartForward(ctx ml.Context, batch input.Batch) error + // entry in positions and seqs. reserve is to preallocate memory + // without actually storing data in the cache. + StartForward(ctx ml.Context, batch input.Batch, reserve bool) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 4fc18d88f..466722845 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -146,51 +146,60 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { c.curBatchSize = len(batch.Positions) c.curSequences = batch.Sequences c.curPositions = batch.Positions c.opts.Except = nil - c.updateSlidingWindow() + if !reserve { + c.updateSlidingWindow() + + var err error + c.curLoc, err = c.findStartLoc() + if errors.Is(err, ErrKvCacheFull) { + c.defrag() + c.curLoc, err = c.findStartLoc() + } + if err != nil { + return err + } + + c.curCellRange = newRange() + for i, pos := range batch.Positions { + seq := batch.Sequences[i] + + c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} + + seqRange, ok := c.cellRanges[seq] + if !ok { + seqRange = newRange() + } + + if c.curLoc+i > seqRange.max { + seqRange.max = c.curLoc + i + } + if seqRange.max > c.curCellRange.max { + c.curCellRange.max = seqRange.max + } + + if c.curLoc+i < seqRange.min { + seqRange.min = c.curLoc + i + } + if seqRange.min < c.curCellRange.min { + c.curCellRange.min = seqRange.min + } + c.cellRanges[seq] = seqRange + } + } else { + // If we are reserving memory, don't update any of the cache metadata but set the size + // to the worst case. + c.curLoc = 0 + c.curCellRange.min = 0 + c.curCellRange.max = len(c.cells) - 1 + } var err error - c.curLoc, err = c.findStartLoc() - if errors.Is(err, ErrKvCacheFull) { - c.defrag() - c.curLoc, err = c.findStartLoc() - } - if err != nil { - return err - } - - c.curCellRange = newRange() - for i, pos := range batch.Positions { - seq := batch.Sequences[i] - - c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} - - seqRange, ok := c.cellRanges[seq] - if !ok { - seqRange = newRange() - } - - if c.curLoc+i > seqRange.max { - seqRange.max = c.curLoc + i - } - if seqRange.max > c.curCellRange.max { - c.curCellRange.max = seqRange.max - } - - if c.curLoc+i < seqRange.min { - seqRange.min = c.curLoc + i - } - if seqRange.min < c.curCellRange.min { - c.curCellRange.min = seqRange.min - } - c.cellRanges[seq] = seqRange - } - c.curMask, err = c.buildMask(ctx) return err diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index bd63214cb..78f600905 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -281,7 +281,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}) + err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false) if err != nil { panic(err) } @@ -315,7 +315,7 @@ func TestCanResume(t *testing.T) { err := cache.StartForward(context, input.Batch{ Positions: []int32{0, 1, 2, 3}, Sequences: []int{0, 0, 0, 0}, - }) + }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } @@ -342,7 +342,7 @@ func TestCanResume(t *testing.T) { err = cache.StartForward(context, input.Batch{ Positions: []int32{4, 5}, Sequences: []int{0, 0}, - }) + }, false) if err != nil { t.Fatalf("StartForward failed: %v", err) } @@ -440,6 +440,8 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c } func (c *testContext) Compute(...ml.Tensor) {} +func (c *testContext) Reserve() error { return nil } + func (c *testContext) MaxGraphNodes() int { return 10 } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 03d650a3f..0f269c3ee 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -27,6 +27,11 @@ type EncoderCache struct { // anything will be stored) curPos int32 + // curReserve indicates that this forward pass is only for + // memory reservation and we should not update our metadata + // based on it. + curReserve bool + // ** cache metadata ** // was something stored in the cache? @@ -83,12 +88,14 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { // We work with the most recent image if len(batch.Multimodal) > 0 { c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] } + c.curReserve = reserve + return nil } @@ -105,8 +112,10 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { } func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) { - c.encoderPos = c.curPos - c.encoderCached = true + if !c.curReserve { + c.encoderPos = c.curPos + c.encoderCached = true + } if c.config.PermutedV { value = value.Permute(ctx, 1, 2, 0, 3) diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 926bc2d41..7533d959e 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -41,9 +41,9 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error { +func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, batch) + err := cache.StartForward(ctx, batch, reserve) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { diff --git a/ml/backend.go b/ml/backend.go index fffc04a48..b2a83cfd5 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -97,6 +97,13 @@ type Context interface { Forward(...Tensor) Context Compute(...Tensor) + + // Reserve is analogous to Compute but rather than executing a + // graph, simply preallocates memory. Typically called with a + // worst case graph to ensure all resources are available for + // for future inference. + Reserve() error + MaxGraphNodes() int Close() diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 0aafd60b3..24bdd903d 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,6 +10,7 @@ import "C" import ( "context" + "errors" "fmt" "io" "log/slog" @@ -42,8 +43,12 @@ func devices() []*C.struct_ggml_backend_device { } type Backend struct { - meta *fsggml.GGML - sched *C.struct_ggml_backend_sched + meta *fsggml.GGML + + sched *C.struct_ggml_backend_sched + schedBackends []*C.struct_ggml_backend + schedBufts []*C.struct_ggml_backend_buffer_type + tensors map[string]*C.struct_ggml_tensor // input is the backend used for inputs @@ -389,8 +394,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, schedBackends = append(schedBackends, b) schedBufts = append(schedBufts, bt) - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt))) - if C.ggml_backend_is_cpu(b) { // set number of threads for cpu backend C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads))) @@ -409,7 +412,9 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, C.size_t(maxGraphNodes), C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)), ), - input: deviceBufferTypes[input.d], + schedBackends: schedBackends, + schedBufts: schedBufts, + input: deviceBufferTypes[input.d], layers: func() map[int]*C.struct_ggml_backend_buffer_type { m := make(map[int]*C.struct_ggml_backend_buffer_type) for i, layer := range layers { @@ -534,6 +539,24 @@ func (c Context) Compute(tensors ...ml.Tensor) { } } +func (c Context) Reserve() error { + if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) { + C.ggml_backend_sched_reset(c.b.sched) + return errors.New("failed to reserve graph") + } + + slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched)) + for i := range c.b.schedBackends { + size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i]) + slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), + "size", format.HumanBytes2(uint64(size))) + } + + C.ggml_backend_sched_reset(c.b.sched) + + return nil +} + func (c Context) MaxGraphNodes() int { return c.maxGraphNodes } diff --git a/model/model.go b/model/model.go index bc8944d22..ab96c4c70 100644 --- a/model/model.go +++ b/model/model.go @@ -299,7 +299,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, batch) + err := cache.StartForward(ctx, batch, false) if err != nil { return nil, err } diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 543b4b2fa..062b654cf 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -448,7 +448,7 @@ func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {} func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {} func (m *mockCache) Close() {} -func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil } +func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil } func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {} func (m *mockCache) SetConfig(ml.CacheConfig) {} func (m *mockCache) CanResume(seq int, pos int32) bool { return true } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 7b7e09402..fee052805 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -728,6 +728,51 @@ func (m *multiLPath) String() string { return strings.Join(*m, ", ") } +func (s *Server) reserveWorstCaseGraph() error { + ctx := s.model.Backend().NewContext() + defer ctx.Close() + + var batch input.Batch + + inputs := make([]int32, s.batchSize) + batch.Positions = make([]int32, len(inputs)) + batch.Sequences = make([]int, len(inputs)) + for i := range inputs { + batch.Positions[i] = int32(i) + } + + batch.Outputs = make([]int32, s.parallel) + for i := range batch.Outputs { + batch.Outputs[i] = int32(i) + } + + var err error + batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + if err != nil { + return err + } + + cache := s.model.Config().Cache + if cache != nil { + err := cache.StartForward(ctx, batch, true) + if err != nil { + return err + } + } + + t, err := s.model.Forward(ctx, batch) + if err != nil { + return err + } + + err = ctx.Forward(t).Reserve() + if err != nil { + return err + } + + return nil +} + func (s *Server) loadModel( ctx context.Context, mpath string, @@ -765,6 +810,11 @@ func (s *Server) loadModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + err = s.reserveWorstCaseGraph() + if err != nil { + panic(err) + } + s.status = llm.ServerStatusReady s.ready.Done() }