Allow models to force a new batch

This is useful for a few things:
 - Work around bugs, such as having 2 images in one batch
 - Keep the image in a single batch for fully connected attention
 - Improve performance by not evaluating embeddings multiple times
This commit is contained in:
Jesse Gross 2025-03-10 20:03:29 -07:00 committed by Michael Yang
parent a8e83a7654
commit 06007c0a18
4 changed files with 10 additions and 14 deletions

View file

@ -15,6 +15,12 @@ type Input struct {
// stored in Multimodal, used for caching and comparing // stored in Multimodal, used for caching and comparing
// equality. // equality.
MultimodalHash uint64 MultimodalHash uint64
// BatchBreak forces a new batch to be started with this
// input. For example, this can be used to align images
// with batches. Note that batches may be divided in additional
// locations as well.
BatchBreak bool
} }
// MultimodalIndex is a multimodal element (such as an image) // MultimodalIndex is a multimodal element (such as an image)

View file

@ -112,8 +112,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
result = append(result, inp) result = append(result, inp)
} else { } else {
imageInputs := []input.Input{ imageInputs := []input.Input{
{Token: 108}, // "\n\n" {Token: 108}, // "\n\n"
{Token: 255999}, // "<start_of_image>"" {Token: 255999, BatchBreak: true}, // "<start_of_image>""
} }
result = append(result, imageInputs...) result = append(result, imageInputs...)

View file

@ -363,7 +363,7 @@ func (s *Server) processBatch() error {
} }
} }
if j >= s.batchSize { if j >= s.batchSize || (inp.BatchBreak && len(seq.pendingInputs) != 0) {
break break
} }

View file

@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
var system []api.Message var system []api.Message
isMllama := checkMllamaModelFamily(m) isMllama := checkMllamaModelFamily(m)
isGemma3 := checkGemma3ModelFamily(m)
var imageNumTokens int var imageNumTokens int
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n; i >= 0; i-- { for i := n; i >= 0; i-- {
if (isMllama || isGemma3) && len(msgs[i].Images) > 1 { if isMllama && len(msgs[i].Images) > 1 {
return "", nil, errTooManyImages return "", nil, errTooManyImages
} }
@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool {
} }
return false return false
} }
func checkGemma3ModelFamily(m *Model) bool {
for _, arch := range m.Config.ModelFamilies {
if arch == "gemma3" {
return true
}
}
return false
}