mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
model: Update encoder cache to use multimodal input processing handler
The encoder cache needs to know the position of images in the input stream so that it knows when to delete them. Previously images didn't have a position, so we implied one by breaking batches before an image and then assuming the image was in the first position. However, multimodal objects are now given explicit positions in the input stream, so we can use that instead. Breaking batches was also a way to simulate a cross attention mask for mllama. However, given that it only supports a single sequence and a single image, this mask doesn't serve any real purpose. Removing the batch break does not appear to affect the quality of the output. Most of this is simply moving the input data structures to a new package to avoid import cycles.
This commit is contained in:
parent
4614fafae0
commit
a1cda80bcb
13 changed files with 157 additions and 160 deletions
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type InputCache struct {
|
||||
|
@ -79,7 +80,7 @@ type InputCacheSlot struct {
|
|||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []model.Input
|
||||
Inputs []input.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
|
@ -88,7 +89,7 @@ type InputCacheSlot struct {
|
|||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
|
@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp
|
|||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
|
@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot
|
|||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
|
@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
|
|||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]model.Input, longest)
|
||||
oldestSlot.Inputs = make([]input.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
|
@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i
|
|||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []model.Input, b []model.Input) int32 {
|
||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
|
|
|
@ -5,7 +5,7 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
func TestCountCommon(t *testing.T) {
|
||||
|
@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) {
|
|||
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []model.Input
|
||||
t2 []model.Input
|
||||
t1 []input.Input
|
||||
t2 []input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []model.Input{{Token: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{{Token: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
|
||||
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
|
||||
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []model.Input{},
|
||||
t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []model.Input{},
|
||||
t2: []model.Input{},
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []model.Input
|
||||
prompt []input.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
|
@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []model.Input{{Token: 1}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
|
@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []model.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
|
@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []model.Input{{Token: 2}},
|
||||
prompt: []input.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
|
@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{},
|
||||
Inputs: []input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []model.Input{{Token: 1}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
|
@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []model.Input{{Token: 2}, {Token: 3}},
|
||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
|
@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []model.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []model.Input{{Token: 1}},
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []model.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
|
|
|
@ -26,6 +26,7 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
|
@ -41,10 +42,10 @@ type Sequence struct {
|
|||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []model.Input
|
||||
inputs []input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []model.Input
|
||||
pendingInputs []input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
|
@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) {
|
||||
var inputs []model.Input
|
||||
func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) {
|
||||
var inputs []input.Input
|
||||
var parts []string
|
||||
var matches [][]string
|
||||
|
||||
|
@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
|
|||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, model.Input{Token: t})
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
|
@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo
|
|||
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
|
||||
imageHash := s.multimodalHash.Sum64()
|
||||
|
||||
inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
|
@ -250,9 +251,6 @@ type Server struct {
|
|||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
// multimodalHash generates hashes for comparing equality
|
||||
// of non-text data
|
||||
multimodalHash maphash.Hash
|
||||
|
@ -329,29 +327,25 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options model.Options
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
var options input.Options
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, "limit")
|
||||
s.removeSequence(i, "limit")
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []model.Input{}
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
}
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
for j, inp := range seq.inputs {
|
||||
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
|
@ -363,33 +357,23 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
}
|
||||
|
||||
if i >= s.batchSize {
|
||||
if j >= s.batchSize {
|
||||
break
|
||||
}
|
||||
|
||||
// TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint
|
||||
// to the encoder cache.
|
||||
//
|
||||
// Break the batch when switching from text to images so that images are always at the beginning.
|
||||
if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 ||
|
||||
(len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) {
|
||||
s.nextSeq = seqIdx
|
||||
break
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, input.Token)
|
||||
if input.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal})
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
}
|
||||
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(options.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
if j+1 == len(seq.inputs) {
|
||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
|
@ -417,7 +401,7 @@ func (s *Server) processBatch() error {
|
|||
// After calling Forward, pending inputs are now in the cache
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []model.Input{}
|
||||
seq.pendingInputs = []input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
|
@ -464,7 +448,7 @@ func (s *Server) processBatch() error {
|
|||
return err
|
||||
}
|
||||
|
||||
seq.inputs = []model.Input{{Token: token}}
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue