diff --git a/model/input/input.go b/model/input/input.go index ce43efb58..d66f52a0d 100644 --- a/model/input/input.go +++ b/model/input/input.go @@ -1,5 +1,7 @@ package input +import "github.com/ollama/ollama/ml" + // Input represents one token in the input stream type Input struct { // Token is a single element of text. @@ -36,7 +38,7 @@ type MultimodalIndex struct { // Batch contains the inputs for a model forward pass type Batch struct { // Inputs is the input tokens, including placeholders for multimodal inputs. - Inputs []int32 + Inputs ml.Tensor // Multimodal is a set of multimodal embeddings previously created by // EncodeMultimodal, along with an index into Inputs. Unused for text-only diff --git a/model/model.go b/model/model.go index 94156ae2b..ab29916ab 100644 --- a/model/model.go +++ b/model/model.go @@ -280,7 +280,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) { if len(batch.Positions) != len(batch.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences)) } @@ -289,6 +289,12 @@ func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) { return nil, errors.New("batch size cannot be less than 1") } + var err error + batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs)) + if err != nil { + return nil, err + } + cache := m.Config().Cache if cache != nil { err := cache.StartForward(ctx, batch) diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 2b347d72c..67c69ee86 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -169,11 +169,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -184,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize))) if len(m.Layers) == gemma27BLayerCount { diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 900bf31e6..567ad1a45 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -140,11 +140,6 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil } func init() { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index e5ecd29ed..5c173997b 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -140,11 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { return nil, err } - hiddenState := m.TokenEmbedding.Forward(ctx, inputs) + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) for i, layer := range m.Layers { m.Cache.SetLayer(i) diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 6d9c608e9..988a189d4 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -144,11 +144,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } } - inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs)) - if err != nil { - return nil, err - } - positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) if err != nil { return nil, err @@ -160,7 +155,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { } // TODO: attention mask, cross attention mask - return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil } func init() { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 91463f93f..443b34b05 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -348,6 +348,7 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() + var batchInputs []int32 var batch input.Batch for i, seq := range s.seqs { @@ -395,9 +396,9 @@ func (s *Server) processBatch() error { } } - batch.Inputs = append(batch.Inputs, inp.Token) + batchInputs = append(batchInputs, inp.Token) if inp.Multimodal != nil { - batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal}) + batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal}) } batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) @@ -405,7 +406,7 @@ func (s *Server) processBatch() error { seq.iBatch = len(batch.Outputs) if j+1 == len(seq.inputs) { - batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1)) + batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1)) } seq.pendingInputs = append(seq.pendingInputs, inp) } @@ -413,14 +414,14 @@ func (s *Server) processBatch() error { seq.inputs = seq.inputs[len(seq.pendingInputs):] } - if len(batch.Inputs) == 0 { + if len(batchInputs) == 0 { return nil } ctx := s.model.Backend().NewContext() defer ctx.Close() - modelOutput, err := model.Forward(ctx, s.model, batch) + modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err) }