mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
ollamarunner: Multi-modal worst case graph
We currently preallocate compute graph memory for the worst case batch of text tokens. This adds support for doing the same for images. Note that image models are more complicated than text models in how they process their inputs so there may be cases where this approach isn't completely generic for all models. It covers all currently supported models though.
This commit is contained in:
parent
d00fac92e4
commit
086d683f9c
2 changed files with 93 additions and 14 deletions
|
@ -52,12 +52,12 @@ func (m *multimodalStore) addMultimodal(embedding []input.Multimodal) {
|
|||
// getMultimodal takes a source set of tensors (which may contain a whole or
|
||||
// parts of one or more images) and returns the equivalent that can be used in
|
||||
// the current context
|
||||
func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal) ([]input.Multimodal, error) {
|
||||
func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
|
||||
out := make([]input.Multimodal, len(in))
|
||||
for i := range out {
|
||||
if in[i].Tensor != nil {
|
||||
var err error
|
||||
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor)
|
||||
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,7 +69,7 @@ func (m *multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in [
|
|||
return out, nil
|
||||
}
|
||||
|
||||
func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor) (ml.Tensor, error) {
|
||||
func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
|
||||
entry := m.m[in]
|
||||
|
||||
if entry.data == nil {
|
||||
|
@ -87,19 +87,32 @@ func (m *multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Te
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
computeCtx.Forward(tensors...).Compute(tensors...)
|
||||
|
||||
computeCtx.Forward(tensors...)
|
||||
entry.data = make([][]float32, len(entry.mm))
|
||||
for i, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
entry.data[i] = t.Tensor.Floats()
|
||||
|
||||
if !reserve {
|
||||
computeCtx.Compute(tensors...)
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if t.Tensor != nil {
|
||||
entry.data[i] = t.Tensor.Floats()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := computeCtx.Reserve()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, t := range entry.mm {
|
||||
if in == t.Tensor {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
if !reserve {
|
||||
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
|
||||
} else {
|
||||
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
package ollamarunner
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net"
|
||||
|
@ -21,6 +23,7 @@ import (
|
|||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"golang.org/x/image/bmp"
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
@ -443,7 +446,7 @@ func (s *Server) processBatch() error {
|
|||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal)
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -731,12 +734,76 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
// - Encode a 2048x2048 image. This assumes that a single image of this
|
||||
// size is sufficient to trigger the worst case. This is currently true
|
||||
// because for existing models, only a single image fits in a batch.
|
||||
// - Add the embedding to a full batch of tokens - this is necessary because
|
||||
// the model may be looking for non-image data, such as <image> tags.
|
||||
// - Run PostTokenize to execute any transformations between generated
|
||||
// embeddings and what the forward pass expects.
|
||||
// - The result may now be larger than a batch (images may not fit in a
|
||||
// single batch), so trim based on what will fit and must be grouped together.
|
||||
// - Fill out the rest of the space with text tokens.
|
||||
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
|
||||
mmCtx := s.model.Backend().NewContext()
|
||||
defer mmCtx.Close()
|
||||
|
||||
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
|
||||
var buf bytes.Buffer
|
||||
bmp.Encode(&buf, img)
|
||||
|
||||
inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes())
|
||||
if err != nil {
|
||||
// The model isn't really multimodal for this situation - just make a text batch.
|
||||
goto formBatch
|
||||
}
|
||||
|
||||
mmStore.addMultimodal(inputs[0].Multimodal)
|
||||
|
||||
inputs, err = multimodalProcessor.PostTokenize(inputs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, inp := range inputs {
|
||||
minBatch := 1 + inp.SameBatch
|
||||
if minBatch > s.batchSize {
|
||||
inputs = inputs[i:min(i+minBatch, len(inputs))]
|
||||
break
|
||||
} else if i+minBatch > s.batchSize {
|
||||
inputs = inputs[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
|
||||
formBatch:
|
||||
var batch input.Batch
|
||||
|
||||
inputs := make([]int32, s.batchSize)
|
||||
batchInputs := make([]int32, len(inputs))
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i := range inputs {
|
||||
for i, inp := range inputs {
|
||||
batchInputs[i] = inp.Token
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
|
||||
}
|
||||
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
|
@ -745,8 +812,7 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue