mirror of
https://github.com/ollama/ollama.git
synced 2025-05-17 15:04:26 +02:00
* get eos_token_id from generation_config.json * refactor * include both ids and strings in trace * comments * remove special case for gemma3 special vocab (#10743)
132 lines
3.6 KiB
Go
132 lines
3.6 KiB
Go
package mllama
|
|
|
|
import (
|
|
"bytes"
|
|
"image"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type Model struct {
|
|
model.Base
|
|
model.BytePairEncoding
|
|
|
|
*VisionModel `gguf:"v,vision"`
|
|
*TextModel
|
|
|
|
Projector *nn.Linear `gguf:"mm.0"`
|
|
|
|
ImageProcessor
|
|
}
|
|
|
|
const (
|
|
crossAttentionLayer = iota
|
|
selfAttentionLayer
|
|
)
|
|
|
|
func New(c fs.Config) (model.Model, error) {
|
|
m := Model{
|
|
BytePairEncoding: model.NewBytePairEncoding(
|
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
&model.Vocabulary{
|
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
EOS: append(
|
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
|
),
|
|
},
|
|
),
|
|
ImageProcessor: newImageProcessor(c),
|
|
VisionModel: newVisionModel(c),
|
|
TextModel: newTextModel(c),
|
|
}
|
|
|
|
encoderCache := kvcache.NewEncoderCache()
|
|
encoderCache.SetConfig(ml.CacheConfig{})
|
|
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
|
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
|
|
return nil, model.ErrNoVisionModel
|
|
}
|
|
|
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ratio.numTiles() < m.maxNumTiles {
|
|
// Pad tiles to maxNumTiles
|
|
f32s = slices.Grow(f32s, m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles)
|
|
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
|
|
}
|
|
|
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
|
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
|
projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates)
|
|
|
|
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
|
}
|
|
|
|
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|
for i := range inputs {
|
|
if inputs[i].Multimodal != nil {
|
|
inputs[i].Token = 128256 // <|image|>
|
|
}
|
|
}
|
|
|
|
return inputs, nil
|
|
}
|
|
|
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|
var crossAttentionStates ml.Tensor
|
|
if len(batch.Multimodal) > 0 {
|
|
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
|
|
}
|
|
|
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// TODO: attention mask, cross attention mask
|
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
|
}
|
|
|
|
func init() {
|
|
model.Register("mllama", New)
|
|
}
|