mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
input: Rename Options to Batch
Options is no longer very descriptive of this struct.
This commit is contained in:
parent
ffbfe833da
commit
0c220935bd
14 changed files with 73 additions and 60 deletions
|
@ -52,7 +52,7 @@ type Cache interface {
|
|||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs.
|
||||
StartForward(ctx ml.Context, opts input.Options) error
|
||||
StartForward(ctx ml.Context, batch input.Batch) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
|
|
@ -140,10 +140,10 @@ func (c *Causal) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
c.curBatchSize = len(opts.Positions)
|
||||
c.curSequences = opts.Sequences
|
||||
c.curPositions = opts.Positions
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
var err error
|
||||
|
@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
|
|||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range opts.Positions {
|
||||
seq := opts.Sequences[i]
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
|
|
|
@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
|||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
|
||||
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
// We work with the most recent image
|
||||
if len(opts.Multimodal) > 0 {
|
||||
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
|
||||
if len(batch.Multimodal) > 0 {
|
||||
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
for i, cache := range c.caches {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch)
|
||||
if err != nil {
|
||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
for k := range opts.Positions {
|
||||
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
|
||||
for k := range batch.Positions {
|
||||
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
|
||||
}
|
||||
}
|
||||
return err
|
||||
|
|
|
@ -33,11 +33,24 @@ type MultimodalIndex struct {
|
|||
Multimodal any
|
||||
}
|
||||
|
||||
// Options contains the inputs for a model forward pass
|
||||
type Options struct {
|
||||
Inputs []int32
|
||||
// Batch contains the inputs for a model forward pass
|
||||
type Batch struct {
|
||||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs []int32
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
Positions []int32
|
||||
Sequences []int
|
||||
Outputs []int32
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
Positions []int32
|
||||
|
||||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
}
|
||||
|
|
|
@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
|||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
||||
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
|
@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
|
|||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
||||
if len(opts.Positions) != len(opts.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(opts.Positions) < 1 {
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, opts)
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
|
|||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
|
||||
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
|
|
|
@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
var except []int
|
||||
for _, image := range opts.Multimodal {
|
||||
for _, image := range batch.Multimodal {
|
||||
visionOutputs := image.Multimodal.(ml.Tensor)
|
||||
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
||||
|
||||
|
|
|
@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
return inputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
var crossAttentionStates ml.Tensor
|
||||
if len(opts.Multimodal) > 0 {
|
||||
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
if len(batch.Multimodal) > 0 {
|
||||
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
|
||||
if len(images) > 0 {
|
||||
crossAttentionStates = images[len(images)-1]
|
||||
}
|
||||
}
|
||||
|
||||
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
|
||||
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var options input.Options
|
||||
var batch input.Batch
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
|
@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
}
|
||||
|
||||
options.Inputs = append(options.Inputs, inp.Token)
|
||||
batch.Inputs = append(batch.Inputs, inp.Token)
|
||||
if inp.Multimodal != nil {
|
||||
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
|
||||
}
|
||||
|
||||
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
options.Sequences = append(options.Sequences, seq.cache.Id)
|
||||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(options.Outputs)
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if j+1 == len(seq.inputs) {
|
||||
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
|
||||
}
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
|
|||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
if len(options.Inputs) == 0 {
|
||||
if len(batch.Inputs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, options)
|
||||
modelOutput, err := model.Forward(ctx, s.model, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue