From 086d683f9ca0105591738ea587029f91a78ef881 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Mon, 7 Apr 2025 13:59:11 -0700 Subject: [PATCH] ollamarunner: Multi-modal worst case graph We currently preallocate compute graph memory for the worst case batch of text tokens. This adds support for doing the same for images. Note that image models are more complicated than text models in how they process their inputs so there may be cases where this approach isn't completely generic for all models. It covers all currently supported models though. --- runner/ollamarunner/multimodal.go | 31 +++++++++---- runner/ollamarunner/runner.go | 76 +++++++++++++++++++++++++++++-- 2 files changed, 93 insertions(+), 14 deletions(-) diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index bb381c5e0..004a5f7ce 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -52,12 +52,12 @@ func (m *multimodalStore) addMultimodal(embedding []input.Multimodal) { // getMultimodal takes a source set of tensors (which may contain a whole or // parts of one or more images) and returns the equivalent that can be used in // the current context -func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) { +func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) { out := make([]input.Multimodal, len(in)) for i := range out { if in[i].Tensor != nil { var err error - out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor) + out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in [ return out, nil } -func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) { +func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) { entry := m.m[in] if entry.data == nil { @@ -87,19 +87,32 @@ func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Te return nil, nil } - computeCtx.Forward(tensors...).Compute(tensors...) - + computeCtx.Forward(tensors...) entry.data = make([][]float32, len(entry.mm)) - for i, t := range entry.mm { - if t.Tensor != nil { - entry.data[i] = t.Tensor.Floats() + + if !reserve { + computeCtx.Compute(tensors...) + + for i, t := range entry.mm { + if t.Tensor != nil { + entry.data[i] = t.Tensor.Floats() + } + } + } else { + err := computeCtx.Reserve() + if err != nil { + return nil, err } } } for i, t := range entry.mm { if in == t.Tensor { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + if !reserve { + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + } else { + return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil + } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 916313532..3fee73cd7 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1,12 +1,14 @@ package ollamarunner import ( + "bytes" "context" "encoding/json" "errors" "flag" "fmt" "hash/maphash" + "image" "log" "log/slog" "net" @@ -21,6 +23,7 @@ import ( "time" "unicode/utf8" + "golang.org/x/image/bmp" "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" @@ -443,7 +446,7 @@ func (s *Server) processBatch() error { batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal) + mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) if err != nil { return err } @@ -731,12 +734,76 @@ func (s *Server) reserveWorstCaseGraph() error { ctx := s.model.Backend().NewContext() defer ctx.Close() + var err error + inputs := make([]input.Input, s.batchSize) + mmStore := newMultimodalStore() + + // Multimodal strategy: + // - Encode a 2048x2048 image. This assumes that a single image of this + // size is sufficient to trigger the worst case. This is currently true + // because for existing models, only a single image fits in a batch. + // - Add the embedding to a full batch of tokens - this is necessary because + // the model may be looking for non-image data, such as tags. + // - Run PostTokenize to execute any transformations between generated + // embeddings and what the forward pass expects. + // - The result may now be larger than a batch (images may not fit in a + // single batch), so trim based on what will fit and must be grouped together. + // - Fill out the rest of the space with text tokens. + if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok { + mmCtx := s.model.Backend().NewContext() + defer mmCtx.Close() + + img := image.NewGray(image.Rect(0, 0, 2048, 2048)) + var buf bytes.Buffer + bmp.Encode(&buf, img) + + inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()) + if err != nil { + // The model isn't really multimodal for this situation - just make a text batch. + goto formBatch + } + + mmStore.addMultimodal(inputs[0].Multimodal) + + inputs, err = multimodalProcessor.PostTokenize(inputs) + if err != nil { + return err + } + + for i, inp := range inputs { + minBatch := 1 + inp.SameBatch + if minBatch > s.batchSize { + inputs = inputs[i:min(i+minBatch, len(inputs))] + break + } else if i+minBatch > s.batchSize { + inputs = inputs[:i] + break + } + } + + if len(inputs) < s.batchSize { + newInputs := make([]input.Input, s.batchSize) + copy(newInputs, inputs) + inputs = newInputs + } + } + +formBatch: var batch input.Batch - inputs := make([]int32, s.batchSize) + batchInputs := make([]int32, len(inputs)) batch.Positions = make([]int32, len(inputs)) batch.Sequences = make([]int, len(inputs)) - for i := range inputs { + for i, inp := range inputs { + batchInputs[i] = inp.Token + if inp.Multimodal != nil { + mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true) + if err != nil { + return err + } + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm}) + } + batch.Positions[i] = int32(i) } @@ -745,8 +812,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Outputs[i] = int32(i) } - var err error - batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) if err != nil { return err }