input: Rename Options to Batch

Options is no longer very descriptive of this struct.
This commit is contained in:
Jesse Gross 2025-03-19 14:28:15 -07:00 committed by Jesse Gross
parent ffbfe833da
commit 0c220935bd
14 changed files with 73 additions and 60 deletions

View file

@ -33,11 +33,24 @@ type MultimodalIndex struct {
Multimodal any
}
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs []int32
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
Positions []int32
Sequences []int
Outputs []int32
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
}

View file

@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Options) (ml.Tensor, error)
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Backend() ml.Backend
Config() config
@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(opts.Positions) < 1 {
if len(batch.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, opts)
err := cache.StartForward(ctx, batch)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, opts)
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, err
}

View file

@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
panic("unimplemented")
}

View file

@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}

View file

@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil
}
func init() {

View file

@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings
var except []int
for _, image := range opts.Multimodal {
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

View file

@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}

View file

@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(opts.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
}
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}