ollamarunner: Multi-modal worst case graph

We currently preallocate compute graph memory for the worst case
batch of text tokens. This adds support for doing the same for
images.

Note that image models are more complicated than text models in
how they process their inputs so there may be cases where this
approach isn't completely generic for all models. It covers all
currently supported models though.
This commit is contained in:
Jesse Gross 2025-04-07 13:59:11 -07:00
parent d00fac92e4
commit 086d683f9c
2 changed files with 93 additions and 14 deletions

View file

@ -52,12 +52,12 @@ func (m *multimodalStore) addMultimodal(embedding []input.Multimodal) {
// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) {
func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
out := make([]input.Multimodal, len(in))
for i := range out {
if in[i].Tensor != nil {
var err error
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor)
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
if err != nil {
return nil, err
}
@ -69,7 +69,7 @@ func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in [
return out, nil
}
func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) {
func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
entry := m.m[in]
if entry.data == nil {
@ -87,19 +87,32 @@ func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Te
return nil, nil
}
computeCtx.Forward(tensors...).Compute(tensors...)
computeCtx.Forward(tensors...)
entry.data = make([][]float32, len(entry.mm))
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
if !reserve {
computeCtx.Compute(tensors...)
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
}
}
} else {
err := computeCtx.Reserve()
if err != nil {
return nil, err
}
}
}
for i, t := range entry.mm {
if in == t.Tensor {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
} else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
}
}
}

View file

@ -1,12 +1,14 @@
package ollamarunner
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"hash/maphash"
"image"
"log"
"log/slog"
"net"
@ -21,6 +23,7 @@ import (
"time"
"unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
@ -443,7 +446,7 @@ func (s *Server) processBatch() error {
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal)
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return err
}
@ -731,12 +734,76 @@ func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
mmStore := newMultimodalStore()
// Multimodal strategy:
// - Encode a 2048x2048 image. This assumes that a single image of this
// size is sufficient to trigger the worst case. This is currently true
// because for existing models, only a single image fits in a batch.
// - Add the embedding to a full batch of tokens - this is necessary because
// the model may be looking for non-image data, such as <image> tags.
// - Run PostTokenize to execute any transformations between generated
// embeddings and what the forward pass expects.
// - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close()
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
var buf bytes.Buffer
bmp.Encode(&buf, img)
inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes())
if err != nil {
// The model isn't really multimodal for this situation - just make a text batch.
goto formBatch
}
mmStore.addMultimodal(inputs[0].Multimodal)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return err
}
for i, inp := range inputs {
minBatch := 1 + inp.SameBatch
if minBatch > s.batchSize {
inputs = inputs[i:min(i+minBatch, len(inputs))]
break
} else if i+minBatch > s.batchSize {
inputs = inputs[:i]
break
}
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
copy(newInputs, inputs)
inputs = newInputs
}
}
formBatch:
var batch input.Batch
inputs := make([]int32, s.batchSize)
batchInputs := make([]int32, len(inputs))
batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs))
for i := range inputs {
for i, inp := range inputs {
batchInputs[i] = inp.Token
if inp.Multimodal != nil {
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
}
batch.Positions[i] = int32(i)
}
@ -745,8 +812,7 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i)
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
if err != nil {
return err
}