input: Rename Options to Batch

Options is no longer very descriptive of this struct.
This commit is contained in:
Jesse Gross 2025-03-19 14:28:15 -07:00 committed by Jesse Gross
parent ffbfe833da
commit 0c220935bd
14 changed files with 73 additions and 60 deletions

View file

@ -52,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass. // StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding // For each token in the coming batch, there must be a corresponding
// entry in positions and seqs. // entry in positions and seqs.
StartForward(ctx ml.Context, opts input.Options) error StartForward(ctx ml.Context, batch input.Batch) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32) CopyPrefix(srcSeq, dstSeq int, len int32)

View file

@ -140,10 +140,10 @@ func (c *Causal) Close() {
} }
} }
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(opts.Positions) c.curBatchSize = len(batch.Positions)
c.curSequences = opts.Sequences c.curSequences = batch.Sequences
c.curPositions = opts.Positions c.curPositions = batch.Positions
c.opts.Except = nil c.opts.Except = nil
var err error var err error
@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
} }
c.curCellRange = newRange() c.curCellRange = newRange()
for i, pos := range opts.Positions { for i, pos := range batch.Positions {
seq := opts.Sequences[i] seq := batch.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}

View file

@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext() context := backend.NewContext()
defer context.Close() defer context.Close()
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
if err != nil { if err != nil {
panic(err) panic(err)
} }

View file

@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
} }
} }
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
// We work with the most recent image // We work with the most recent image
if len(opts.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
} }
return nil return nil

View file

@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
} }
} }
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
for i, cache := range c.caches { for i, cache := range c.caches {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- { for j := i - 1; j >= 0; j-- {
for k := range opts.Positions { for k := range batch.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
} }
} }
return err return err

View file

@ -33,11 +33,24 @@ type MultimodalIndex struct {
Multimodal any Multimodal any
} }
// Options contains the inputs for a model forward pass // Batch contains the inputs for a model forward pass
type Options struct { type Batch struct {
Inputs []int32 // Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs []int32
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex Multimodal []MultimodalIndex
Positions []int32
Sequences []int // Positions is the position for each Input, relative to its sequence. Equal
Outputs []int32 // in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
} }

View file

@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface { type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error) Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend Backend() ml.Backend
Config() config Config() config
@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice t.Kind() == reflect.Slice
} }
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) { if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
} }
if len(opts.Positions) < 1 { if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1") return nil, errors.New("batch size cannot be less than 1")
} }
cache := m.Config().Cache cache := m.Config().Cache
if cache != nil { if cache != nil {
err := cache.StartForward(ctx, opts) err := cache.StartForward(ctx, batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
t, err := m.Forward(ctx, opts) t, err := m.Forward(ctx, batch)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{} type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented") panic("unimplemented")
} }

View file

@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil return result, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil
} }
func init() { func init() {

View file

@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings // set image embeddings
var except []int var except []int
for _, image := range opts.Multimodal { for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor) visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

View file

@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual) return hiddenState.Add(ctx, residual)
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil return inputs, nil
} }
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 { if len(batch.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor) images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 { if len(images) > 0 {
crossAttentionStates = images[len(images)-1] crossAttentionStates = images[len(images)-1]
} }
} }
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions)) positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil { if err != nil {
return nil, err return nil, err
} }
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
} }
defer s.mu.Unlock() defer s.mu.Unlock()
var options input.Options var batch input.Batch
for i, seq := range s.seqs { for i, seq := range s.seqs {
if seq == nil { if seq == nil {
@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
} }
} }
options.Inputs = append(options.Inputs, inp.Token) batch.Inputs = append(batch.Inputs, inp.Token)
if inp.Multimodal != nil { if inp.Multimodal != nil {
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
} }
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id) batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(options.Outputs) seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) { if j+1 == len(seq.inputs) {
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
} }
seq.pendingInputs = append(seq.pendingInputs, inp) seq.pendingInputs = append(seq.pendingInputs, inp)
} }
@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):] seq.inputs = seq.inputs[len(seq.pendingInputs):]
} }
if len(options.Inputs) == 0 { if len(batch.Inputs) == 0 {
return nil return nil
} }
ctx := s.model.Backend().NewContext() ctx := s.model.Backend().NewContext()
defer ctx.Close() defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, options) modelOutput, err := model.Forward(ctx, s.model, batch)
if err != nil { if err != nil {
return fmt.Errorf("failed to decode batch: %w", err) return fmt.Errorf("failed to decode batch: %w", err)
} }
@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
} }
// sample a token // sample a token
vocabSize := len(logits) / len(options.Outputs) vocabSize := len(logits) / len(batch.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil { if err != nil {