From f0c66e6dea7d79c0f6106540d20cea37f93bd97f Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 3 Apr 2025 15:18:29 -0700 Subject: [PATCH] llama4 --- convert/convert.go | 2 + convert/convert_llama.go | 13 +- convert/convert_llama4.go | 167 ++++++++++++++++++ convert/reader.go | 16 +- convert/reader_safetensors.go | 15 ++ convert/reader_torch.go | 11 ++ fs/ggml/ggml.go | 1 + ml/backend.go | 4 + ml/backend/ggml/ggml.go | 39 ++++- model/models/llama4/model.go | 100 +++++++++++ model/models/llama4/model_text.go | 223 ++++++++++++++++++++++++ model/models/llama4/model_vision.go | 256 ++++++++++++++++++++++++++++ model/models/models.go | 1 + 13 files changed, 833 insertions(+), 15 deletions(-) create mode 100644 convert/convert_llama4.go create mode 100644 model/models/llama4/model.go create mode 100644 model/models/llama4/model_text.go create mode 100644 model/models/llama4/model_vision.go diff --git a/convert/convert.go b/convert/convert.go index f6c7f3a55..ffcc2b8ab 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -173,6 +173,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error { switch p.Architectures[0] { case "LlamaForCausalLM": conv = &llamaModel{} + case "Llama4ForConditionalGeneration": + conv = &llama4Model{} case "Mistral3ForConditionalGeneration": conv = &mistral3Model{} case "MixtralForCausalLM": diff --git a/convert/convert_llama.go b/convert/convert_llama.go index 679d062ea..0caaa1949 100644 --- a/convert/convert_llama.go +++ b/convert/convert_llama.go @@ -42,6 +42,8 @@ type llamaModel struct { LayerNormEpsilon float32 `json:"layer_norm_epsilon"` NormEpsilon float32 `json:"norm_epsilon"` HeadDim uint32 `json:"head_dim"` + + skipRepack bool } var _ ModelConverter = (*llamaModel)(nil) @@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV { kv["llama.rope.dimension_count"] = p.HiddenSize / headCount } + if p.HeadDim > 0 { + kv["llama.attention.head_dim"] = p.HeadDim + } + if p.RopeTheta > 0 { kv["llama.rope.freq_base"] = p.RopeTheta } @@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor { } for _, t := range ts { - if strings.HasSuffix(t.Name(), "attn_q.weight") || - strings.HasSuffix(t.Name(), "attn_k.weight") { - t.SetRepacker(p.repack) + if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") { + if !p.skipRepack { + t.SetRepacker(p.repack) + } } out = append(out, ggml.Tensor{ diff --git a/convert/convert_llama4.go b/convert/convert_llama4.go new file mode 100644 index 000000000..14463a179 --- /dev/null +++ b/convert/convert_llama4.go @@ -0,0 +1,167 @@ +package convert + +import ( + "slices" + "strings" + + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + + "github.com/ollama/ollama/fs/ggml" +) + +type llama4Model struct { + ModelParameters + TextModel struct { + llamaModel + NumExpertsPerToken uint32 `json:"num_experts_per_tok"` + NumLocalExperts uint32 `json:"num_local_experts"` + InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"` + UseQKNorm bool `json:"use_qk_norm"` + IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"` + } `json:"text_config"` + VisionModel struct { + NumHiddenLayers uint32 `json:"num_hidden_layers"` + HiddenSize uint32 `json:"hidden_size"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + ImageSize uint32 `json:"image_size"` + PatchSize uint32 `json:"patch_size"` + RopeTheta float32 `json:"rope_theta"` + NormEpsilon float32 `json:"norm_eps"` + PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"` + } `json:"vision_config"` +} + +// KV implements ModelConverter. +func (p *llama4Model) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "llama4" + + for k, v := range p.TextModel.KV(t) { + if strings.HasPrefix(k, "llama.") { + kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v + } + } + + kv["llama4.intermediate_size"] = p.TextModel.IntermediateSizeMLP + kv["llama4.intermediate_size_moe"] = p.TextModel.IntermediateSize + + kv["llama4.expert_count"] = p.TextModel.NumLocalExperts + kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken + kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep + kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm + + kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers + kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize + kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize + kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads + kv["llama4.vision.image_size"] = p.VisionModel.ImageSize + kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize + kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta + kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon + kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio + return kv +} + +// Replacements implements ModelConverter. +func (p *llama4Model) Replacements() []string { + return append( + p.TextModel.Replacements(), + "language_model.", "", + "vision_model", "v", + "multi_modal_projector", "mm", + "feed_forward.down_proj", "ffn_down", + "feed_forward.up_proj", "ffn_up", + "feed_forward.gate_proj", "ffn_gate", + "feed_forward.", "ffn_", + "shared_expert.down_proj", "down_shexp", + "shared_expert.gate_proj", "gate_shexp", + "shared_expert.up_proj", "up_shexp", + "experts.down_proj", "down_exps.weight", + "experts.gate_up_proj", "gate_up_exps.weight", + "router", "gate_inp", + "patch_embedding.linear", "patch_embedding", + ) +} + +// Tensors implements ModelConverter. +func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor { + var out []ggml.Tensor + + var textTensors []Tensor + for _, t := range ts { + if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") { + out = append(out, ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } else if strings.Contains(t.Name(), "ffn_gate_up_exps") { + // gate and up projectors are fused + // dims[1], dims[2] must be swapped + // [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size] + halfDim := int(t.Shape()[2]) / 2 + + newShape := slices.Clone(t.Shape()) + newShape[1], newShape[2] = newShape[2]/2, newShape[1] + for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} { + // clone tensor since we need separate repackers + tt := t.Clone() + tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim))) + out = append(out, ggml.Tensor{ + Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name), + Kind: tt.Kind(), + Shape: newShape, + WriterTo: tt, + }) + } + } else if strings.Contains(t.Name(), "ffn_down_exps") { + // dims[1], dims[2] must be swapped + // [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size] + t.SetRepacker(p.repack()) + newShape := slices.Clone(t.Shape()) + newShape[1], newShape[2] = newShape[2], newShape[1] + out = append(out, ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: newShape, + WriterTo: t, + }) + } else { + textTensors = append(textTensors, t) + } + } + + p.TextModel.skipRepack = true + out = append(out, p.TextModel.Tensors(textTensors)...) + return out +} + +func (p *llama4Model) repack(slice ...tensor.Slice) Repacker { + return func(name string, data []float32, shape []uint64) ([]float32, error) { + dims := make([]int, len(shape)) + for i, dim := range shape { + dims[i] = int(dim) + } + + var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + t, err := t.Slice(slice...) + if err != nil { + return nil, err + } + + if err := t.T(0, 2, 1); err != nil { + return nil, err + } + + t = tensor.Materialize(t) + // flatten tensor so it can be return as a vector + if err := t.Reshape(t.Shape().TotalSize()); err != nil { + return nil, err + } + + return native.VectorF32(t.(*tensor.Dense)) + } +} diff --git a/convert/reader.go b/convert/reader.go index 904b13a42..ab81d5c0b 100644 --- a/convert/reader.go +++ b/convert/reader.go @@ -11,14 +11,15 @@ type Tensor interface { Name() string Shape() []uint64 Kind() uint32 - SetRepacker(repacker) + SetRepacker(Repacker) WriteTo(io.Writer) (int64, error) + Clone() Tensor } type tensorBase struct { - name string - shape []uint64 - repacker + name string + shape []uint64 + repacker Repacker } func (t tensorBase) Name() string { @@ -36,7 +37,8 @@ const ( func (t tensorBase) Kind() uint32 { if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") || - t.name == "token_types.weight" { + t.name == "token_types.weight" || + t.name == "v.positional_embedding_vlm" { // these tensors are always F32 return 0 } @@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 { } } -func (t *tensorBase) SetRepacker(fn repacker) { +func (t *tensorBase) SetRepacker(fn Repacker) { t.repacker = fn } -type repacker func(string, []float32, []uint64) ([]float32, error) +type Repacker func(string, []float32, []uint64) ([]float32, error) func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) { patterns := []struct { diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go index b21d219c2..f58585321 100644 --- a/convert/reader_safetensors.go +++ b/convert/reader_safetensors.go @@ -94,6 +94,21 @@ type safetensor struct { *tensorBase } +func (st safetensor) Clone() Tensor { + return &safetensor{ + fs: st.fs, + path: st.path, + dtype: st.dtype, + offset: st.offset, + size: st.size, + tensorBase: &tensorBase{ + name: st.name, + repacker: st.repacker, + shape: slices.Clone(st.shape), + }, + } +} + func (st safetensor) WriteTo(w io.Writer) (int64, error) { f, err := st.fs.Open(st.path) if err != nil { diff --git a/convert/reader_torch.go b/convert/reader_torch.go index 1b3e1c9f1..7f6d6c872 100644 --- a/convert/reader_torch.go +++ b/convert/reader_torch.go @@ -43,6 +43,17 @@ type torch struct { *tensorBase } +func (t torch) Clone() Tensor { + return torch{ + storage: t.storage, + tensorBase: &tensorBase{ + name: t.name, + shape: t.shape, + repacker: t.repacker, + }, + } +} + func (pt torch) WriteTo(w io.Writer) (int64, error) { return 0, nil } diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 646dc75df..947295e36 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -124,6 +124,7 @@ func (kv KV) OllamaEngineRequired() bool { return slices.Contains([]string{ "gemma3", "mistral3", + "llama4", }, kv.Architecture()) } diff --git a/ml/backend.go b/ml/backend.go index 70c2fd8e2..0cd33bd8a 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -133,6 +133,7 @@ type Tensor interface { Mul(ctx Context, t2 Tensor) Tensor Mulmat(ctx Context, t2 Tensor) Tensor MulmatFullPrec(ctx Context, t2 Tensor) Tensor + MulmatID(ctx Context, t2, ids Tensor) Tensor Softmax(ctx Context) Tensor LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor @@ -150,6 +151,7 @@ type Tensor interface { Tanh(ctx Context) Tensor GELU(ctx Context) Tensor SILU(ctx Context) Tensor + Sigmoid(ctx Context) Tensor Reshape(ctx Context, shape ...int) Tensor View(ctx Context, offset int, shape ...int) Tensor @@ -168,6 +170,8 @@ type Tensor interface { Rows(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor Duplicate(ctx Context) Tensor + + TopK(ctx Context, k int) Tensor } // ScaledDotProductAttention implements a fused attention diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index c486b7477..177ac6fd0 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t), + } +} + func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { - tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) - if b != nil { - tt = tt.Add(ctx, b) + tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps)) + if w != nil { + tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t) + if b != nil { + tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t) + } } - return tt + return &Tensor{b: t.b, t: tt} } func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { - return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) + tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps)) + if w != nil { + tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t) + } + + return &Tensor{b: t.b, t: tt} } func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { @@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor { } } +func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t), + } +} + func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor { if len(shape) != 4 { panic("expected 4 dimensions") @@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor { t: C.ggml_dup(ctx.(*Context).ctx, t.t), } } + +func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor { + return &Tensor{ + b: t.b, + t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)), + } +} diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go new file mode 100644 index 000000000..bbc64e74b --- /dev/null +++ b/model/models/llama4/model.go @@ -0,0 +1,100 @@ +package llama4 + +import ( + "bytes" + "image" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.BytePairEncoding + + *VisionModel `gguf:"v,vision"` + *Projector `gguf:"mm"` + *TextModel +} + +type Projector struct { + Linear1 *nn.Linear `gguf:"linear_1"` +} + +func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor { + return p.Linear1.Forward(ctx, visionOutputs) +} + +func New(c fs.Config) (model.Model, error) { + m := Model{ + BytePairEncoding: model.NewBytePairEncoding( + c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Uints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")), + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + }, + ), + VisionModel: newVisionModel(c), + TextModel: newTextModel(c), + } + + m.Cache = kvcache.NewWrapperCache( + // TODO: pretend this is chunked attention for now + kvcache.NewSWACache(8192, m.Shift), + kvcache.NewCausalCache(m.Shift), + ) + + return &m, nil +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) { + if len(m.VisionModel.Layers) < 1 { + return nil, model.ErrNoVisionModel + } + + img, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + f32s, aspectRatio, err := m.ProcessImage(ctx, img) + if err != nil { + return nil, err + } + + pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s)) + if err != nil { + return nil, err + } + + visionOutputs := m.VisionModel.Forward(ctx, pixelValues) + visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3)) + return m.Projector.Forward(ctx, visionOutputs), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions)) + if err != nil { + return nil, err + } + + outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)) + if err != nil { + return nil, err + } + + return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil +} + +func init() { + model.Register("llama4", New) +} diff --git a/model/models/llama4/model_text.go b/model/models/llama4/model_text.go new file mode 100644 index 000000000..77e54e34a --- /dev/null +++ b/model/models/llama4/model_text.go @@ -0,0 +1,223 @@ +package llama4 + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" +) + +type TextAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_factors"` +} + +func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor { + batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + if useRope { + query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale) + + if opts.useQKNorm { + query = query.RMSNorm(ctx, nil, opts.eps) + key = key.RMSNorm(ctx, nil, opts.eps) + } + } + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) + return sa.Output.Forward(ctx, attention) +} + +type TextMLP struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type TextExperts struct { + Gate ml.Tensor `gguf:"ffn_gate_exps.weight"` + Up ml.Tensor `gguf:"ffn_up_exps.weight"` + Down ml.Tensor `gguf:"ffn_down_exps.weight"` +} + +func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor { + experts := routerLogits.TopK(ctx, opts.numExpertsUsed) + scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed) + hiddenStates = hiddenStates.Mul(ctx, scores) + + upStates := e.Up.MulmatID(ctx, hiddenStates, experts) + gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts) + downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts) + + nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))) + } + + return nextStates +} + +// TextSharedExpert is TextMLP with different names +type TextSharedExpert struct { + Gate *nn.Linear `gguf:"ffn_gate_shexp"` + Up *nn.Linear `gguf:"ffn_up_shexp"` + Down *nn.Linear `gguf:"ffn_down_shexp"` +} + +func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type TextMOE struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Experts *TextExperts + SharedExpert *TextSharedExpert +} + +func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + routerLogits := moe.Router.Forward(ctx, hiddenStates) + + sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts) + routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts) + return sharedStates.Add(ctx, routedStates) +} + +type TextFeedForward interface { + Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor +} + +type TextLayer struct { + AttentionNorm *nn.LayerNorm `gguf:"attn_norm"` + Attention *TextAttention + + FFNNorm *nn.LayerNorm `gguf:"ffn_norm"` + FeedForward TextFeedForward +} + +func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor { + residual := hiddenStates + + // self attention + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, useRope, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + residual = hiddenStates + + hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts) + + return residual.Add(ctx, hiddenStates) +} + +type TextOptions struct { + hiddenSize int + numHeads, numKVHeads, headDim int + numExperts, numExpertsUsed int + ropeDim int + ropeBase, ropeScale float32 + eps float32 + interleaveLayerStep int + useQKNorm bool +} + +type TextModel struct { + Layers []TextLayer `gguf:"blk"` + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.LayerNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + *TextOptions +} + +func newTextModel(c fs.Config) *TextModel { + layers := make([]TextLayer, c.Uint("block_count")) + interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1) + for i := range layers { + if (i+1)%int(interleaveLayerStep) == 0 { + layers[i] = TextLayer{FeedForward: &TextMOE{}} + } else { + layers[i] = TextLayer{FeedForward: &TextMLP{}} + } + } + + return &TextModel{ + Layers: layers, + TextOptions: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.head_dim", 128)), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + ropeDim: int(c.Uint("rope.dimension_count")), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.freq_scale", 1), + eps: c.Float("attention.layer_norm_rms_epsilon"), + interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)), + useQKNorm: c.Bool("use_qk_norm", true), + }, + } +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor { + hiddenStates := m.TokenEmbedding.Forward(ctx, inputs) + + for i, layer := range m.Layers { + cache.SetLayer(i) + wc := cache.(*kvcache.WrapperCache) + wc.SetLayerType(1) + useChunkedAttention := (i+1)%4 != 0 + if useChunkedAttention { + wc.SetLayerType(0) + } + + var lastLayerOutputs ml.Tensor + if i == len(m.Layers)-1 { + lastLayerOutputs = outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions) + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps) + return m.Output.Forward(ctx, hiddenStates) +} + +func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil +} diff --git a/model/models/llama4/model_vision.go b/model/models/llama4/model_vision.go new file mode 100644 index 000000000..3bf9cee75 --- /dev/null +++ b/model/models/llama4/model_vision.go @@ -0,0 +1,256 @@ +package llama4 + +import ( + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type VisionAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor. +// This is equivalent to the Pytorch implmentation using half rotations: +// +// cos, sin = torch.cos(freqs), torch.sin(freqs) +// cos = cos.unsqueeze(-1) +// sin = sin.unsqueeze(-1) +// t = t.reshape(*t.shape[:-1], -1, 2) +// t_out = (t * cos) + (_rotate_half(t) * sin) +// t_out = t_out.flatten(3) +// +// Which is equivalent to the Pytorch implementation using complex numbers: +// +// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2)) +// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:] +// freqs_ci = freqs_ci.to(t_.device) +// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3) +// +// Due to the 1) the dimensional and 2) the datatype limitations of current backends, +// we need to use a different approach to achieve the same result. +func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { + width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3) + + t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3)) + + // t1 = t[..., 0::2] + t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) + t1 = t1.Reshape(ctx, width/2, height, channels, tiles) + + // t2 = t[..., 1::2] + t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx) + t2 = t2.Reshape(ctx, width/2, height, channels, tiles) + + // cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1) + cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0) + cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3)) + cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + cosOut = cosOut.Reshape(ctx, width, height, channels, tiles) + + // sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1) + sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0) + sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3)) + sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx) + sinOut = sinOut.Reshape(ctx, width, height, channels, tiles) + + return cosOut.Add(ctx, sinOut) +} + +func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { + headDim := opts.hiddenSize / opts.numHeads + + query := sa.Query.Forward(ctx, hiddenState) + key := sa.Key.Forward(ctx, hiddenState) + value := sa.Value.Forward(ctx, hiddenState) + + query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2)) + key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2)) + value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2)) + + query = applyVisionRotaryEmbedding(ctx, query, cos, sin) + key = applyVisionRotaryEmbedding(ctx, key, cos, sin) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3)) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + FC1 *nn.Linear `gguf:"fc1"` + FC2 *nn.Linear `gguf:"fc2"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { + hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx) + hiddenStates = mlp.FC2.Forward(ctx, hiddenStates) + return hiddenStates +} + +type VisionLayer struct { + InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"` + *VisionAttention + + PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"` + *VisionMLP `gguf:"mlp"` +} + +func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor { + residual := hiddenStates + + // self attention + hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + // MLP + residual = hiddenStates + hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + return hiddenStates +} + +type VisionAdapter struct { + FC1 *nn.Linear `gguf:"mlp.fc1"` + FC2 *nn.Linear `gguf:"mlp.fc2"` +} + +func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { + patches := hiddenStates.Dim(1) + patchSize := int(math.Sqrt(float64(patches))) + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2)) + + channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3) + + channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio) + hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) + hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + + channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio) + hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles) + hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + + hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles) + + hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx) + hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx) + return hiddenStates +} + +type VisionOptions struct { + hiddenSize, numHeads int + imageSize, patchSize int + + ropeTheta float32 + eps float32 + pixelShuffleRatio float32 +} + +type PatchEmbedding struct { + *nn.Linear +} + +func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor { + kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2)) + hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1) + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3)) + return p.Linear.Forward(ctx, hiddenStates) +} + +type VisionModel struct { + Layers []VisionLayer `gguf:"blk"` + + *PatchEmbedding `gguf:"patch_embedding"` + ClassEmbedding ml.Tensor `gguf:"class_embedding"` + PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"` + + LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"` + LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"` + + *VisionAdapter `gguf:"vision_adapter"` + + *VisionOptions +} + +func newVisionModel(c fs.Config) *VisionModel { + return &VisionModel{ + Layers: make([]VisionLayer, c.Uint("vision.block_count")), + VisionOptions: &VisionOptions{ + hiddenSize: int(c.Uint("vision.embedding_length")), + numHeads: int(c.Uint("vision.attention.head_count")), + imageSize: int(c.Uint("vision.image_size")), + patchSize: int(c.Uint("vision.patch_size")), + ropeTheta: float32(c.Float("vision.rope.freq_base")), + eps: c.Float("vision.layer_norm_epsilon"), + pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")), + }, + } +} + +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor { + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions) + hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1) + + hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding) + hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps) + + cos, sin := m.rotaryEmbedding(ctx) + for _, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions) + } + + hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps) + hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0) + hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions) + return hiddenStates +} + +// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function +// which in turn mimics Python's // operator. +func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T { + if b == 0 { + panic("division by zero") + } + + if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 { + return a / b + } + + return a/b - 1 +} + +func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) { + patchesPerSide := m.imageSize / m.patchSize + numPatches := patchesPerSide*patchesPerSide + 1 + + headDim := m.hiddenSize / m.numHeads + freqDim := headDim / 2 + + freqs := make([]float32, numPatches*freqDim) + for i := range numPatches - 1 { + for j := 0; j < freqDim; j += 2 { + positionX := i*freqDim/2 + j/2 + positionY := (i+numPatches)*freqDim/2 + j/2 + ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim)) + freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq) + freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq) + } + } + + ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2) + if err != nil { + panic(err) + } + + ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches) + return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx) +} diff --git a/model/models/models.go b/model/models/models.go index c5da2894b..73b4c53a5 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -4,6 +4,7 @@ import ( _ "github.com/ollama/ollama/model/models/gemma2" _ "github.com/ollama/ollama/model/models/gemma3" _ "github.com/ollama/ollama/model/models/llama" + _ "github.com/ollama/ollama/model/models/llama4" _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" )