model: Update encoder cache to use multimodal input processing handler

The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
This commit is contained in:
Jesse Gross 2025-03-08 15:45:31 -08:00 committed by Jesse Gross
parent 4614fafae0
commit a1cda80bcb
13 changed files with 157 additions and 160 deletions

View file

@ -4,6 +4,7 @@ import (
"errors" "errors"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
var ( var (
@ -51,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. // entry in positions and seqs.
StartForward(ctx ml.Context, positions []int32, seqs []int) error StartForward(ctx ml.Context, opts input.Options) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)

View file

@ -8,6 +8,7 @@ import (
"slices" "slices"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
@ -140,10 +141,10 @@ func (c *Causal) Close() {
} }
} }
func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(positions) c.curBatchSize = len(opts.Positions)
c.curSequences = seqs c.curSequences = opts.Sequences
c.curPositions = positions c.curPositions = opts.Positions
var err error var err error
c.curLoc, err = c.findStartLoc() c.curLoc, err = c.findStartLoc()
@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err
} }
c.curCellRange = newRange() c.curCellRange = newRange()
for i, pos := range positions { for i, pos := range opts.Positions {
seq := seqs[i] seq := opts.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}

View file

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
type testCase struct { type testCase struct {
@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, test.pos, test.seqs) err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Encoder cache stores K and V tensors that are position independent // Encoder cache stores K and V tensors that are position independent
@ -78,9 +79,11 @@ func (c *EncoderCache) Close() {
} }
} }
func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
// The image is always in the first position // We work with the most recent image
c.curPos = positions[0] if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
}
return nil return nil
} }

View file

@ -4,6 +4,7 @@ import (
"math" "math"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
) )
// Wrapper cache is a container for multiple types of caches, // Wrapper cache is a container for multiple types of caches,
@ -40,14 +41,14 @@ func (c *WrapperCache) Close() {
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, positions, seqs) err := cache.StartForward(ctx, opts)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {
for k := range positions { for k := range opts.Positions {
_ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
} }
} }
return err return err

37
model/input/input.go Normal file
View file

@ -0,0 +1,37 @@
package input
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
}

View file

@ -19,66 +19,12 @@ import (
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
_ "github.com/ollama/ollama/ml/backend" _ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/model/input"
) )
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
// equality.
MultimodalHash uint64
}
// MultimodalIndex is a multimodal element (such as an image)
// together with an index into the slice of Inputs with the
// corresponding token. Note that the index is not the same
// as the position - to find that use the index with the
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
}
type config struct {
Cache kvcache.Cache
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
}
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
Forward(ml.Context, Options) (ml.Tensor, error) Forward(ml.Context, input.Options) (ml.Tensor, error)
Backend() ml.Backend Backend() ml.Backend
Config() config Config() config
@ -112,7 +58,26 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal // This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately // that is modified to ensure that there is a unique hash value that accurately
// represents the contents. // represents the contents.
PostTokenize(ml.Context, []Input) ([]Input, error) PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
type Base struct {
b ml.Backend
config
}
type config struct {
Cache kvcache.Cache
}
// Backend returns the underlying backend that will run the model
func (m *Base) Backend() ml.Backend {
return m.b
}
func (m *Base) Config() config {
return m.config
} }
var models = make(map[string]func(ml.Config) (Model, error)) var models = make(map[string]func(ml.Config) (Model, error))
@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) { if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
} }
@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, opts.Positions, opts.Sequences) err := cache.StartForward(ctx, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/backend/ggml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
) )
func TestParseTags(t *testing.T) { func TestParseTags(t *testing.T) {
@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{} type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) { func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
panic("unimplemented") panic("unimplemented")
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type Options struct { type Options struct {
@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type Model struct { type Model struct {
@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
return m.Projector.Forward(ctx, crossAttentionStates), nil return m.Projector.Forward(ctx, crossAttentionStates), nil
} }
func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) { func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
var images []model.Input var images []input.Input
fnvHash := fnv.New64a() fnvHash := fnv.New64a()
for i := range inputs { for i := range inputs {
@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu
} }
} }
inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 }) inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 })
return inputs, nil return inputs, nil
} }
func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if opts.Multimodal != nil { if len(opts.Multimodal) > 0 {
crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor) crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
} }
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

View file

@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
) )
type InputCache struct { type InputCache struct {
@ -79,7 +80,7 @@ type InputCacheSlot struct {
Id int Id int
// Inputs that are stored in the KV cache // Inputs that are stored in the KV cache
Inputs []model.Input Inputs []input.Input
// is this cache actively being processed as part of a sequence? // is this cache actively being processed as part of a sequence?
InUse bool InUse bool
@ -88,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time lastUsed time.Time
} }
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) { func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot var slot *InputCacheSlot
var numPast int32 var numPast int32
var err error var err error
@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp
return slot, prompt, nil return slot, prompt, nil
} }
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1) longest := int32(-1)
var longestSlot *InputCacheSlot var longestSlot *InputCacheSlot
@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot
return longestSlot, longest, nil return longestSlot, longest, nil
} }
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now() oldest := time.Now()
var oldestSlot *InputCacheSlot var oldestSlot *InputCacheSlot
@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot { if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs)) len(longestSlot.Inputs))
oldestSlot.Inputs = make([]model.Input, longest) oldestSlot.Inputs = make([]input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil { if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil return oldestSlot, longest, nil
} }
func countCommonPrefix(a []model.Input, b []model.Input) int32 { func countCommonPrefix(a []input.Input, b []input.Input) int32 {
var count int32 var count int32
for i := range a { for i := range a {

View file

@ -5,7 +5,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model/input"
) )
func TestCountCommon(t *testing.T) { func TestCountCommon(t *testing.T) {
@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
t1 []model.Input t1 []input.Input
t2 []model.Input t2 []input.Input
expected int32 expected int32
}{ }{
{ {
name: "Equal", name: "Equal",
t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3, expected: 3,
}, },
{ {
name: "Prefix", name: "Prefix",
t1: []model.Input{{Token: 1}}, t1: []input.Input{{Token: 1}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Image Prefix", name: "Image Prefix",
t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1, expected: 1,
}, },
{ {
name: "Mixed", name: "Mixed",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2, expected: 2,
}, },
{ {
name: "Mixed, Same Length", name: "Mixed, Same Length",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1, expected: 1,
}, },
{ {
name: "Empty", name: "Empty",
t1: []model.Input{}, t1: []input.Input{},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0, expected: 0,
}, },
{ {
name: "Both Empty", name: "Both Empty",
t1: []model.Input{}, t1: []input.Input{},
t2: []model.Input{}, t2: []input.Input{},
expected: 0, expected: 0,
}, },
} }
@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
cache InputCache cache InputCache
prompt []model.Input prompt []input.Input
longest expected longest expected
best expected best expected
}{ }{
@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{}, Inputs: []input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{}, Inputs: []input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []model.Input{{Token: 1}}, prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0}, best: expected{result: 0, len: 0},
}, },
@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{{Token: 1}}, Inputs: []input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{{Token: 1}, {Token: 2}}, Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []model.Input{{Token: 1}, {Token: 2}}, prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2}, longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },
@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}}, Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{}, Inputs: []input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}}, }},
prompt: []model.Input{{Token: 2}}, prompt: []input.Input{{Token: 2}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{ slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}}, Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{}, Inputs: []input.Input{},
InUse: false, InUse: false,
lastUsed: time.Time{}, lastUsed: time.Time{},
}, },
}, },
}, },
prompt: []model.Input{{Token: 1}}, prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 1}, longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1}, best: expected{result: 1, len: 1},
}, },
@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{{Token: 1}}, Inputs: []input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{{Token: 1}, {Token: 2}}, Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []model.Input{{Token: 2}, {Token: 3}}, prompt: []input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0}, longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0}, best: expected{result: 1, len: 0},
}, },
@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{ cache: InputCache{slots: []InputCacheSlot{
{ {
Id: 0, Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}}, Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true, InUse: true,
lastUsed: time.Now().Add(-time.Second), lastUsed: time.Now().Add(-time.Second),
}, },
{ {
Id: 1, Id: 1,
Inputs: []model.Input{{Token: 1}}, Inputs: []input.Input{{Token: 1}},
InUse: false, InUse: false,
lastUsed: time.Now().Add(-2 * time.Second), lastUsed: time.Now().Add(-2 * time.Second),
}, },
}}, }},
prompt: []model.Input{{Token: 1}, {Token: 2}}, prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1}, longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2}, best: expected{result: 1, len: 2},
}, },

View file

@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample" "github.com/ollama/ollama/sample"
@ -41,10 +42,10 @@ type Sequence struct {
iBatch int iBatch int
// prompt inputs left to evaluate // prompt inputs left to evaluate
inputs []model.Input inputs []input.Input
// inputs that have been added to a batch but not yet submitted to Forward // inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []model.Input pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs // inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and // by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images // decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) { func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
var inputs []model.Input var inputs []input.Input
var parts []string var parts []string
var matches [][]string var matches [][]string
@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
} }
for _, t := range tokens { for _, t := range tokens {
inputs = append(inputs, model.Input{Token: t}) inputs = append(inputs, input.Input{Token: t})
} }
// image - decode and store // image - decode and store
@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
_, _ = s.multimodalHash.Write(images[imageIndex].Data) _, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64() imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true postTokenize = true
} }
} }
@ -250,9 +251,6 @@ type Server struct {
// KV cache // KV cache
cache *InputCache cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality // multimodalHash generates hashes for comparing equality
// of non-text data // of non-text data
multimodalHash maphash.Hash multimodalHash maphash.Hash
@ -329,29 +327,25 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
var options model.Options var options input.Options
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
for i, seq := range s.seqs {
if seq == nil { if seq == nil {
continue continue
} }
// if past the num predict limit // if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit") s.removeSequence(i, "limit")
continue continue
} }
if !s.cache.enabled { if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...) seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []model.Input{} seq.cache.Inputs = []input.Input{}
} }
for i, input := range seq.inputs { for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 { if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@ -363,33 +357,23 @@ func (s *Server) processBatch() error {
} }
} }
if i >= s.batchSize { if j >= s.batchSize {
break break
} }
// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint options.Inputs = append(options.Inputs, inp.Token)
// to the encoder cache. if inp.Multimodal != nil {
// options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
// Break the batch when switching from text to images so that images are always at the beginning.
if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
s.nextSeq = seqIdx
break
}
options.Inputs = append(options.Inputs, input.Token)
if input.Multimodal != nil {
options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
} }
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id) options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs) seq.iBatch = len(options.Outputs)
if i+1 == len(seq.inputs) { if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
} }
seq.pendingInputs = append(seq.pendingInputs, input) seq.pendingInputs = append(seq.pendingInputs, inp)
} }
seq.inputs = seq.inputs[len(seq.pendingInputs):] seq.inputs = seq.inputs[len(seq.pendingInputs):]
@ -417,7 +401,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache // After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 { if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []model.Input{} seq.pendingInputs = []input.Input{}
} }
// don't sample prompt processing // don't sample prompt processing
@ -464,7 +448,7 @@ func (s *Server) processBatch() error {
return err return err
} }
seq.inputs = []model.Input{{Token: token}} seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "") sequence := strings.Join(seq.pendingResponses, "")