diff --git a/kvcache/cache.go b/kvcache/cache.go index d35489057..aa0a20562 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -52,7 +52,7 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding // entry in positions and seqs. - StartForward(ctx ml.Context, opts input.Options) error + StartForward(ctx ml.Context, batch input.Batch) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index edf6666da..79fa24e87 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -140,10 +140,10 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { - c.curBatchSize = len(opts.Positions) - c.curSequences = opts.Sequences - c.curPositions = opts.Positions +func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error { + c.curBatchSize = len(batch.Positions) + c.curSequences = batch.Sequences + c.curPositions = batch.Positions c.opts.Except = nil var err error @@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { } c.curCellRange = newRange() - for i, pos := range opts.Positions { - seq := opts.Sequences[i] + for i, pos := range batch.Positions { + seq := batch.Sequences[i] c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 56d85ceb6..0f2385db7 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) + err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}) if err != nil { panic(err) } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 6a9df2abc..94c5d99c3 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -79,10 +79,10 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { +func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error { // We work with the most recent image - if len(opts.Multimodal) > 0 { - c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] + if len(batch.Multimodal) > 0 { + c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index] } return nil diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index aaccd1661..c85807a04 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -41,14 +41,14 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { +func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, opts) + err := cache.StartForward(ctx, batch) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { - for k := range opts.Positions { - _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) + for k := range batch.Positions { + _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32) } } return err diff --git a/model/input/input.go b/model/input/input.go index 30bdcf065..ce43efb58 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -33,11 +33,24 @@ type MultimodalIndex struct { Multimodal any } -// Options contains the inputs for a model forward pass -type Options struct { - Inputs []int32 +// Batch contains the inputs for a model forward pass +type Batch struct { + // Inputs is the input tokens, including placeholders for multimodal inputs. + Inputs []int32 + + // Multimodal is a set of multimodal embeddings previously created by + // EncodeMultimodal, along with an index into Inputs. Unused for text-only + // models or for batches without multimodal elements. Multimodal []MultimodalIndex - Positions []int32 - Sequences []int - Outputs []int32 + + // Positions is the position for each Input, relative to its sequence. Equal + // in length to Inputs. + Positions []int32 + + // Sequences is the sequence for each Input. Equal in length to Inputs. + Sequences []int + + // Outputs are the set of indicies into Inputs for which output data should + // be returned. + Outputs []int32 } diff --git a/model/model.go b/model/model.go index 53e47add9..94156ae2b 100644 --- a/model/model.go +++ b/model/model.go @@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { - Forward(ml.Context, input.Options) (ml.Tensor, error) + Forward(ml.Context, input.Batch) (ml.Tensor, error) Backend() ml.Backend Config() config @@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { - if len(opts.Positions) != len(opts.Sequences) { - return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) +func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { + if len(batch.Positions) != len(batch.Sequences) { + return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) } - if len(opts.Positions) < 1 { + if len(batch.Positions) < 1 { return nil, errors.New("batch size cannot be less than 1") } cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, opts) + err := cache.StartForward(ctx, batch) if err != nil { return nil, err } } - t, err := m.Forward(ctx, opts) + t, err := m.Forward(ctx, batch) if err != nil { return nil, err } diff --git a/model/model_test.go b/model/model_test.go index 354dd1d8b..0b1ea08e8 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) { type notTextProcessorModel struct{} -func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { +func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) { panic("unimplemented") } diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 29ffa2318..2b347d72c 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) if err != nil { return nil, err } - positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err } - outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) if err != nil { return nil, err } diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 95f89ad48..900bf31e6 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { return result, nil } -func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) if err != nil { return nil, err } - positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err } - outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) if err != nil { return nil, err } - return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil + return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 567f65a5e..7d8b6577e 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) // set image embeddings var except []int - for _, image := range opts.Multimodal { + for _, image := range batch.Multimodal { visionOutputs := image.Multimodal.(ml.Tensor) ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 87eb9b750..e5ecd29ed 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) if err != nil { return nil, err } - positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err } - outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) if err != nil { return nil, err } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 0aa11f178..6d9c608e9 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { return inputs, nil } -func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { var crossAttentionStates ml.Tensor - if len(opts.Multimodal) > 0 { - images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) + if len(batch.Multimodal) > 0 { + images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor) if len(images) > 0 { crossAttentionStates = images[len(images)-1] } } - inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) + inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) if err != nil { return nil, err } - positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err } - outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) if err != nil { return nil, err } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 9a1a549cd..91463f93f 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -348,7 +348,7 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() - var options input.Options + var batch input.Batch for i, seq := range s.seqs { if seq == nil { @@ -395,17 +395,17 @@ func (s *Server) processBatch() error { } } - options.Inputs = append(options.Inputs, inp.Token) + batch.Inputs = append(batch.Inputs, inp.Token) if inp.Multimodal != nil { - options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal}) } - options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) - options.Sequences = append(options.Sequences, seq.cache.Id) + batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) + batch.Sequences = append(batch.Sequences, seq.cache.Id) - seq.iBatch = len(options.Outputs) + seq.iBatch = len(batch.Outputs) if j+1 == len(seq.inputs) { - options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) + batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) } @@ -413,14 +413,14 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } - if len(options.Inputs) == 0 { + if len(batch.Inputs) == 0 { return nil } ctx := s.model.Backend().NewContext() defer ctx.Close() - modelOutput, err := model.Forward(ctx, s.model, options) + modelOutput, err := model.Forward(ctx, s.model, batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err) } @@ -460,7 +460,7 @@ func (s *Server) processBatch() error { } // sample a token - vocabSize := len(logits) / len(options.Outputs) + vocabSize := len(logits) / len(batch.Outputs) token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) if err != nil {