From 0fbfcf3c9c7bfdbf4616238595eafd7eca2a916c Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 19 Mar 2025 14:36:21 -0700 Subject: [PATCH] model: Pass input tensor instead of raw data to models Rather than directly giving the input data to models, we can pass a tensor instead. In the short term, this saves some duplicated code. Longer term, we will want to overlap setting up the next batch with processing of the current one. In this case, we will only have the shape of tensor but it will not be loaded with data at the time of graph generation. By passing only a tensor to models now, we set up this possibility and prevent them from relying on data that they won't have in the future. Although the same could be done for Positions and Outputs, in some cases we either need the raw input data or don't use them at all. Therefore, for now we leave them as they are and allow models to convert them to tensors as needed. --- model/input/input.go | 4 +++- model/model.go | 8 +++++++- model/models/gemma2/model.go | 7 +------ model/models/gemma3/model.go | 7 +------ model/models/llama/model.go | 7 +------ model/models/mllama/model.go | 7 +------ runner/ollamarunner/runner.go | 11 ++++++----- 7 files changed, 20 insertions(+), 31 deletions(-) diff --git a/model/input/input.go b/model/input/input.go index ce43efb58..d66f52a0d 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -1,5 +1,7 @@ package input +import "github.com/ollama/ollama/ml" + // Input represents one token in the input stream type Input struct { // Token is a single element of text. @@ -36,7 +38,7 @@ type MultimodalIndex struct { // Batch contains the inputs for a model forward pass type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. - Inputs []int32 + Inputs ml.Tensor // Multimodal is a set of multimodal embeddings previously created by // EncodeMultimodal, along with an index into Inputs. Unused for text-only diff --git a/model/model.go b/model/model.go index 94156ae2b..ab29916ab 100644 --- a/model/model.go +++ b/model/model.go @@ -280,7 +280,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { if len(batch.Positions) != len(batch.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) } @@ -289,6 +289,12 @@ func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { return nil, errors.New("batch size cannot be less than 1") } + var err error + batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + if err != nil { + return nil, err + } + cache := m.Config().Cache if cache != nil { err := cache.StartForward(ctx, batch) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2b347d72c..67c69ee86 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -169,11 +169,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -184,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) if len(m.Layers) == gemma27BLayerCount { diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 900bf31e6..567ad1a45 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -140,11 +140,6 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index e5ecd29ed..5c173997b 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -140,11 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 6d9c608e9..988a189d4 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -144,11 +144,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } } - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -160,7 +155,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 91463f93f..443b34b05 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -348,6 +348,7 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() + var batchInputs []int32 var batch input.Batch for i, seq := range s.seqs { @@ -395,9 +396,9 @@ func (s *Server) processBatch() error { } } - batch.Inputs = append(batch.Inputs, inp.Token) + batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal}) + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) } batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) @@ -405,7 +406,7 @@ func (s *Server) processBatch() error { seq.iBatch = len(batch.Outputs) if j+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1)) + batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) } @@ -413,14 +414,14 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } - if len(batch.Inputs) == 0 { + if len(batchInputs) == 0 { return nil } ctx := s.model.Backend().NewContext() defer ctx.Close() - modelOutput, err := model.Forward(ctx, s.model, batch) + modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err) }