model: Update encoder cache to use multimodal input processing handler

The encoder cache needs to know the position of images in the input
stream so that it knows when to delete them. Previously images didn't
have a position, so we implied one by breaking batches before an
image and then assuming the image was in the first position. However,
multimodal objects are now given explicit positions in the input
stream, so we can use that instead.

Breaking batches was also a way to simulate a cross attention mask
for mllama. However, given that it only supports a single sequence
and a single image, this mask doesn't serve any real purpose.
Removing the batch break does not appear to affect the quality of
the output.

Most of this is simply moving the input data structures to a new
package to avoid import cycles.
This commit is contained in:
Jesse Gross 2025-03-08 15:45:31 -08:00 committed by Jesse Gross
parent 4614fafae0
commit a1cda80bcb
13 changed files with 157 additions and 160 deletions

View file

@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type InputCache struct {
@ -79,7 +80,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []model.Input
Inputs []input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@ -88,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]model.Input, longest)
oldestSlot.Inputs = make([]input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
return oldestSlot, longest, nil
}
func countCommonPrefix(a []model.Input, b []model.Input) int32 {
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
var count int32
for i := range a {

View file

@ -5,7 +5,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
func TestCountCommon(t *testing.T) {
@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []model.Input
t2 []model.Input
t1 []input.Input
t2 []input.Input
expected int32
}{
{
name: "Equal",
t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []model.Input{{Token: 1}},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []model.Input{},
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []model.Input{},
t2: []model.Input{},
t1: []input.Input{},
t2: []input.Input{},
expected: 0,
},
}
@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []model.Input
prompt []input.Input
longest expected
best expected
}{
@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []model.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []model.Input{{Token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []model.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []model.Input{{Token: 1}, {Token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []model.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []model.Input{{Token: 2}},
prompt: []input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []model.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []model.Input{{Token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []model.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []model.Input{{Token: 2}, {Token: 3}},
prompt: []input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []model.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []model.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []model.Input{{Token: 1}, {Token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},

View file

@ -26,6 +26,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@ -41,10 +42,10 @@ type Sequence struct {
iBatch int
// prompt inputs left to evaluate
inputs []model.Input
inputs []input.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []model.Input
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
var inputs []model.Input
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
var inputs []input.Input
var parts []string
var matches [][]string
@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
}
for _, t := range tokens {
inputs = append(inputs, model.Input{Token: t})
inputs = append(inputs, input.Input{Token: t})
}
// image - decode and store
@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
@ -250,9 +251,6 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
@ -329,29 +327,25 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var options model.Options
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
var options input.Options
for i, seq := range s.seqs {
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, "limit")
s.removeSequence(i, "limit")
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []model.Input{}
seq.cache.Inputs = []input.Input{}
}
for i, input := range seq.inputs {
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
@ -363,33 +357,23 @@ func (s *Server) processBatch() error {
}
}
if i >= s.batchSize {
if j >= s.batchSize {
break
}
// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
// to the encoder cache.
//
// Break the batch when switching from text to images so that images are always at the beginning.
if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
s.nextSeq = seqIdx
break
}
options.Inputs = append(options.Inputs, input.Token)
if input.Multimodal != nil {
options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
}
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs)
if i+1 == len(seq.inputs) {
if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, input)
seq.pendingInputs = append(seq.pendingInputs, inp)
}
seq.inputs = seq.inputs[len(seq.pendingInputs):]
@ -417,7 +401,7 @@ func (s *Server) processBatch() error {
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []model.Input{}
seq.pendingInputs = []input.Input{}
}
// don't sample prompt processing
@ -464,7 +448,7 @@ func (s *Server) processBatch() error {
return err
}
seq.inputs = []model.Input{{Token: token}}
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")