model: Pass input tensor instead of raw data to models

Rather than directly giving the input data to models, we can
pass a tensor instead. In the short term, this saves some duplicated
code.

Longer term, we will want to overlap setting up the next batch with
processing of the current one. In this case, we will only have the
shape of tensor but it will not be loaded with data at the time of
graph generation. By passing only a tensor to models now, we set up
this possibility and prevent them from relying on data that they won't
have in the future.

Although the same could be done for Positions and Outputs, in some
cases we either need the raw input data or don't use them at all.
Therefore, for now we leave them as they are and allow models to
convert them to tensors as needed.
This commit is contained in:
Jesse Gross 2025-03-19 14:36:21 -07:00 committed by Jesse Gross
parent 0c220935bd
commit 0fbfcf3c9c
7 changed files with 20 additions and 31 deletions

View file

@ -348,6 +348,7 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
@ -395,9 +396,9 @@ func (s *Server) processBatch() error {
}
}
batch.Inputs = append(batch.Inputs, inp.Token)
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
}
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
@ -405,7 +406,7 @@ func (s *Server) processBatch() error {
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@ -413,14 +414,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if len(batch.Inputs) == 0 {
if len(batchInputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batch)
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}