diff --git a/runner/ollamarunner/multimodal.go b/runner/ollamarunner/multimodal.go index bb381c5e0..004a5f7ce 100644 --- a/runner/ollamarunner/multimodal.go +++ b/runner/ollamarunner/multimodal.go @@ -52,12 +52,12 @@ func (m *multimodalStore) addMultimodal(embedding []input.Multimodal) { // getMultimodal takes a source set of tensors (which may contain a whole or // parts of one or more images) and returns the equivalent that can be used in // the current context -func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) { +func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) { out := make([]input.Multimodal, len(in)) for i := range out { if in[i].Tensor != nil { var err error - out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor) + out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in [ return out, nil } -func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) { +func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) { entry := m.m[in] if entry.data == nil { @@ -87,19 +87,32 @@ func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Te return nil, nil } - computeCtx.Forward(tensors...).Compute(tensors...) - + computeCtx.Forward(tensors...) entry.data = make([][]float32, len(entry.mm)) - for i, t := range entry.mm { - if t.Tensor != nil { - entry.data[i] = t.Tensor.Floats() + + if !reserve { + computeCtx.Compute(tensors...) + + for i, t := range entry.mm { + if t.Tensor != nil { + entry.data[i] = t.Tensor.Floats() + } + } + } else { + err := computeCtx.Reserve() + if err != nil { + return nil, err } } } for i, t := range entry.mm { if in == t.Tensor { - return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + if !reserve { + return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...) + } else { + return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil + } } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 916313532..3fee73cd7 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1,12 +1,14 @@ package ollamarunner import ( + "bytes" "context" "encoding/json" "errors" "flag" "fmt" "hash/maphash" + "image" "log" "log/slog" "net" @@ -21,6 +23,7 @@ import ( "time" "unicode/utf8" + "golang.org/x/image/bmp" "golang.org/x/sync/semaphore" "github.com/ollama/ollama/api" @@ -443,7 +446,7 @@ func (s *Server) processBatch() error { batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal) + mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false) if err != nil { return err } @@ -731,12 +734,76 @@ func (s *Server) reserveWorstCaseGraph() error { ctx := s.model.Backend().NewContext() defer ctx.Close() + var err error + inputs := make([]input.Input, s.batchSize) + mmStore := newMultimodalStore() + + // Multimodal strategy: + // - Encode a 2048x2048 image. This assumes that a single image of this + // size is sufficient to trigger the worst case. This is currently true + // because for existing models, only a single image fits in a batch. + // - Add the embedding to a full batch of tokens - this is necessary because + // the model may be looking for non-image data, such as tags. + // - Run PostTokenize to execute any transformations between generated + // embeddings and what the forward pass expects. + // - The result may now be larger than a batch (images may not fit in a + // single batch), so trim based on what will fit and must be grouped together. + // - Fill out the rest of the space with text tokens. + if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok { + mmCtx := s.model.Backend().NewContext() + defer mmCtx.Close() + + img := image.NewGray(image.Rect(0, 0, 2048, 2048)) + var buf bytes.Buffer + bmp.Encode(&buf, img) + + inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()) + if err != nil { + // The model isn't really multimodal for this situation - just make a text batch. + goto formBatch + } + + mmStore.addMultimodal(inputs[0].Multimodal) + + inputs, err = multimodalProcessor.PostTokenize(inputs) + if err != nil { + return err + } + + for i, inp := range inputs { + minBatch := 1 + inp.SameBatch + if minBatch > s.batchSize { + inputs = inputs[i:min(i+minBatch, len(inputs))] + break + } else if i+minBatch > s.batchSize { + inputs = inputs[:i] + break + } + } + + if len(inputs) < s.batchSize { + newInputs := make([]input.Input, s.batchSize) + copy(newInputs, inputs) + inputs = newInputs + } + } + +formBatch: var batch input.Batch - inputs := make([]int32, s.batchSize) + batchInputs := make([]int32, len(inputs)) batch.Positions = make([]int32, len(inputs)) batch.Sequences = make([]int, len(inputs)) - for i := range inputs { + for i, inp := range inputs { + batchInputs[i] = inp.Token + if inp.Multimodal != nil { + mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true) + if err != nil { + return err + } + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm}) + } + batch.Positions[i] = int32(i) } @@ -745,8 +812,7 @@ func (s *Server) reserveWorstCaseGraph() error { batch.Outputs[i] = int32(i) } - var err error - batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs)) if err != nil { return err }