mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
100 lines
2.6 KiB
Go
100 lines
2.6 KiB
Go
package llama4
|
|
|
|
import (
|
|
"bytes"
|
|
"image"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type Model struct {
|
|
model.Base
|
|
model.BytePairEncoding
|
|
|
|
*VisionModel `gguf:"v,vision"`
|
|
*Projector `gguf:"mm"`
|
|
*TextModel
|
|
}
|
|
|
|
type Projector struct {
|
|
Linear1 *nn.Linear `gguf:"linear_1"`
|
|
}
|
|
|
|
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
|
return p.Linear1.Forward(ctx, visionOutputs)
|
|
}
|
|
|
|
func New(c fs.Config) (model.Model, error) {
|
|
m := Model{
|
|
BytePairEncoding: model.NewBytePairEncoding(
|
|
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
|
&model.Vocabulary{
|
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
Types: c.Uints("tokenizer.ggml.token_type"),
|
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
|
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
|
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
},
|
|
),
|
|
VisionModel: newVisionModel(c),
|
|
TextModel: newTextModel(c),
|
|
}
|
|
|
|
m.Cache = kvcache.NewWrapperCache(
|
|
// TODO: pretend this is chunked attention for now
|
|
kvcache.NewSWACache(8192, m.Shift),
|
|
kvcache.NewCausalCache(m.Shift),
|
|
)
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
|
if len(m.VisionModel.Layers) < 1 {
|
|
return nil, model.ErrNoVisionModel
|
|
}
|
|
|
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f32s, aspectRatio, err := m.ProcessImage(ctx, img)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
|
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
|
return m.Projector.Forward(ctx, visionOutputs), nil
|
|
}
|
|
|
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
|
}
|
|
|
|
func init() {
|
|
model.Register("llama4", New)
|
|
}
|