mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
llama4
This commit is contained in:
parent
54055a6dae
commit
f0c66e6dea
13 changed files with 833 additions and 15 deletions
|
@ -173,6 +173,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||
switch p.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
conv = &llamaModel{}
|
||||
case "Llama4ForConditionalGeneration":
|
||||
conv = &llama4Model{}
|
||||
case "Mistral3ForConditionalGeneration":
|
||||
conv = &mistral3Model{}
|
||||
case "MixtralForCausalLM":
|
||||
|
|
|
@ -42,6 +42,8 @@ type llamaModel struct {
|
|||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
|
||||
skipRepack bool
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*llamaModel)(nil)
|
||||
|
@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
|||
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
||||
}
|
||||
|
||||
if p.HeadDim > 0 {
|
||||
kv["llama.attention.head_dim"] = p.HeadDim
|
||||
}
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["llama.rope.freq_base"] = p.RopeTheta
|
||||
}
|
||||
|
@ -133,9 +139,10 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
|||
}
|
||||
|
||||
for _, t := range ts {
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
||||
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
t.SetRepacker(p.repack)
|
||||
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||
if !p.skipRepack {
|
||||
t.SetRepacker(p.repack)
|
||||
}
|
||||
}
|
||||
|
||||
out = append(out, ggml.Tensor{
|
||||
|
|
167
convert/convert_llama4.go
Normal file
167
convert/convert_llama4.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type llama4Model struct {
|
||||
ModelParameters
|
||||
TextModel struct {
|
||||
llamaModel
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
||||
UseQKNorm bool `json:"use_qk_norm"`
|
||||
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NormEpsilon float32 `json:"norm_eps"`
|
||||
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
// KV implements ModelConverter.
|
||||
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "llama4"
|
||||
|
||||
for k, v := range p.TextModel.KV(t) {
|
||||
if strings.HasPrefix(k, "llama.") {
|
||||
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
|
||||
}
|
||||
}
|
||||
|
||||
kv["llama4.intermediate_size"] = p.TextModel.IntermediateSizeMLP
|
||||
kv["llama4.intermediate_size_moe"] = p.TextModel.IntermediateSize
|
||||
|
||||
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
|
||||
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
||||
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
||||
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
||||
|
||||
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
|
||||
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
|
||||
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
|
||||
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
|
||||
return kv
|
||||
}
|
||||
|
||||
// Replacements implements ModelConverter.
|
||||
func (p *llama4Model) Replacements() []string {
|
||||
return append(
|
||||
p.TextModel.Replacements(),
|
||||
"language_model.", "",
|
||||
"vision_model", "v",
|
||||
"multi_modal_projector", "mm",
|
||||
"feed_forward.down_proj", "ffn_down",
|
||||
"feed_forward.up_proj", "ffn_up",
|
||||
"feed_forward.gate_proj", "ffn_gate",
|
||||
"feed_forward.", "ffn_",
|
||||
"shared_expert.down_proj", "down_shexp",
|
||||
"shared_expert.gate_proj", "gate_shexp",
|
||||
"shared_expert.up_proj", "up_shexp",
|
||||
"experts.down_proj", "down_exps.weight",
|
||||
"experts.gate_up_proj", "gate_up_exps.weight",
|
||||
"router", "gate_inp",
|
||||
"patch_embedding.linear", "patch_embedding",
|
||||
)
|
||||
}
|
||||
|
||||
// Tensors implements ModelConverter.
|
||||
func (p *llama4Model) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
var textTensors []Tensor
|
||||
for _, t := range ts {
|
||||
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
|
||||
// gate and up projectors are fused
|
||||
// dims[1], dims[2] must be swapped
|
||||
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
|
||||
halfDim := int(t.Shape()[2]) / 2
|
||||
|
||||
newShape := slices.Clone(t.Shape())
|
||||
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
|
||||
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
|
||||
// clone tensor since we need separate repackers
|
||||
tt := t.Clone()
|
||||
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
|
||||
Kind: tt.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: tt,
|
||||
})
|
||||
}
|
||||
} else if strings.Contains(t.Name(), "ffn_down_exps") {
|
||||
// dims[1], dims[2] must be swapped
|
||||
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
|
||||
t.SetRepacker(p.repack())
|
||||
newShape := slices.Clone(t.Shape())
|
||||
newShape[1], newShape[2] = newShape[2], newShape[1]
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: newShape,
|
||||
WriterTo: t,
|
||||
})
|
||||
} else {
|
||||
textTensors = append(textTensors, t)
|
||||
}
|
||||
}
|
||||
|
||||
p.TextModel.skipRepack = true
|
||||
out = append(out, p.TextModel.Tensors(textTensors)...)
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
|
||||
return func(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
for i, dim := range shape {
|
||||
dims[i] = int(dim)
|
||||
}
|
||||
|
||||
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||
t, err := t.Slice(slice...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := t.T(0, 2, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t = tensor.Materialize(t)
|
||||
// flatten tensor so it can be return as a vector
|
||||
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return native.VectorF32(t.(*tensor.Dense))
|
||||
}
|
||||
}
|
|
@ -11,14 +11,15 @@ type Tensor interface {
|
|||
Name() string
|
||||
Shape() []uint64
|
||||
Kind() uint32
|
||||
SetRepacker(repacker)
|
||||
SetRepacker(Repacker)
|
||||
WriteTo(io.Writer) (int64, error)
|
||||
Clone() Tensor
|
||||
}
|
||||
|
||||
type tensorBase struct {
|
||||
name string
|
||||
shape []uint64
|
||||
repacker
|
||||
name string
|
||||
shape []uint64
|
||||
repacker Repacker
|
||||
}
|
||||
|
||||
func (t tensorBase) Name() string {
|
||||
|
@ -36,7 +37,8 @@ const (
|
|||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||
t.name == "token_types.weight" {
|
||||
t.name == "token_types.weight" ||
|
||||
t.name == "v.positional_embedding_vlm" {
|
||||
// these tensors are always F32
|
||||
return 0
|
||||
}
|
||||
|
@ -51,11 +53,11 @@ func (t tensorBase) Kind() uint32 {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *tensorBase) SetRepacker(fn repacker) {
|
||||
func (t *tensorBase) SetRepacker(fn Repacker) {
|
||||
t.repacker = fn
|
||||
}
|
||||
|
||||
type repacker func(string, []float32, []uint64) ([]float32, error)
|
||||
type Repacker func(string, []float32, []uint64) ([]float32, error)
|
||||
|
||||
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||
patterns := []struct {
|
||||
|
|
|
@ -94,6 +94,21 @@ type safetensor struct {
|
|||
*tensorBase
|
||||
}
|
||||
|
||||
func (st safetensor) Clone() Tensor {
|
||||
return &safetensor{
|
||||
fs: st.fs,
|
||||
path: st.path,
|
||||
dtype: st.dtype,
|
||||
offset: st.offset,
|
||||
size: st.size,
|
||||
tensorBase: &tensorBase{
|
||||
name: st.name,
|
||||
repacker: st.repacker,
|
||||
shape: slices.Clone(st.shape),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
f, err := st.fs.Open(st.path)
|
||||
if err != nil {
|
||||
|
|
|
@ -43,6 +43,17 @@ type torch struct {
|
|||
*tensorBase
|
||||
}
|
||||
|
||||
func (t torch) Clone() Tensor {
|
||||
return torch{
|
||||
storage: t.storage,
|
||||
tensorBase: &tensorBase{
|
||||
name: t.name,
|
||||
shape: t.shape,
|
||||
repacker: t.repacker,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
|
|
@ -124,6 +124,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||
return slices.Contains([]string{
|
||||
"gemma3",
|
||||
"mistral3",
|
||||
"llama4",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
|
|
@ -133,6 +133,7 @@ type Tensor interface {
|
|||
Mul(ctx Context, t2 Tensor) Tensor
|
||||
Mulmat(ctx Context, t2 Tensor) Tensor
|
||||
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
|
||||
MulmatID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
|
@ -150,6 +151,7 @@ type Tensor interface {
|
|||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
|
@ -168,6 +170,8 @@ type Tensor interface {
|
|||
Rows(ctx Context, t2 Tensor) Tensor
|
||||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
Duplicate(ctx Context) Tensor
|
||||
|
||||
TopK(ctx Context, k int) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
|
|
|
@ -884,17 +884,32 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) MulmatID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul_mat_id(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, ids.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
|
||||
if b != nil {
|
||||
tt = C.ggml_add(ctx.(*Context).ctx, tt, b.(*Tensor).t)
|
||||
}
|
||||
}
|
||||
|
||||
return tt
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
tt := C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
tt = C.ggml_mul(ctx.(*Context).ctx, tt, w.(*Tensor).t)
|
||||
}
|
||||
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
|
@ -995,6 +1010,13 @@ func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_sigmoid_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
if len(shape) != 4 {
|
||||
panic("expected 4 dimensions")
|
||||
|
@ -1158,3 +1180,10 @@ func (t *Tensor) Duplicate(ctx ml.Context) ml.Tensor {
|
|||
t: C.ggml_dup(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
|
||||
}
|
||||
}
|
||||
|
|
100
model/models/llama4/model.go
Normal file
100
model/models/llama4/model.go
Normal file
|
@ -0,0 +1,100 @@
|
|||
package llama4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*Projector `gguf:"mm"`
|
||||
*TextModel
|
||||
}
|
||||
|
||||
type Projector struct {
|
||||
Linear1 *nn.Linear `gguf:"linear_1"`
|
||||
}
|
||||
|
||||
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
|
||||
return p.Linear1.Forward(ctx, visionOutputs)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
// TODO: pretend this is chunked attention for now
|
||||
kvcache.NewSWACache(8192, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) < 1 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f32s, aspectRatio, err := m.ProcessImage(ctx, img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
||||
return m.Projector.Forward(ctx, visionOutputs), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("llama4", New)
|
||||
}
|
223
model/models/llama4/model_text.go
Normal file
223
model/models/llama4/model_text.go
Normal file
|
@ -0,0 +1,223 @@
|
|||
package llama4
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_factors"`
|
||||
}
|
||||
|
||||
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
|
||||
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if useRope {
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.useQKNorm {
|
||||
query = query.RMSNorm(ctx, nil, opts.eps)
|
||||
key = key.RMSNorm(ctx, nil, opts.eps)
|
||||
}
|
||||
}
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
|
||||
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
|
||||
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
|
||||
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
|
||||
}
|
||||
|
||||
return nextStates
|
||||
}
|
||||
|
||||
// TextSharedExpert is TextMLP with different names
|
||||
type TextSharedExpert struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
|
||||
Up *nn.Linear `gguf:"ffn_up_shexp"`
|
||||
Down *nn.Linear `gguf:"ffn_down_shexp"`
|
||||
}
|
||||
|
||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextMOE struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Experts *TextExperts
|
||||
SharedExpert *TextSharedExpert
|
||||
}
|
||||
|
||||
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
|
||||
routerLogits := moe.Router.Forward(ctx, hiddenStates)
|
||||
|
||||
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
|
||||
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
|
||||
return sharedStates.Add(ctx, routedStates)
|
||||
}
|
||||
|
||||
type TextFeedForward interface {
|
||||
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type TextLayer struct {
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
|
||||
Attention *TextAttention
|
||||
|
||||
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"`
|
||||
FeedForward TextFeedForward
|
||||
}
|
||||
|
||||
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
|
||||
// self attention
|
||||
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, useRope, opts)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
residual = hiddenStates
|
||||
|
||||
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
return residual.Add(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextOptions struct {
|
||||
hiddenSize int
|
||||
numHeads, numKVHeads, headDim int
|
||||
numExperts, numExpertsUsed int
|
||||
ropeDim int
|
||||
ropeBase, ropeScale float32
|
||||
eps float32
|
||||
interleaveLayerStep int
|
||||
useQKNorm bool
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
OutputNorm *nn.LayerNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
layers := make([]TextLayer, c.Uint("block_count"))
|
||||
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
|
||||
for i := range layers {
|
||||
if (i+1)%int(interleaveLayerStep) == 0 {
|
||||
layers[i] = TextLayer{FeedForward: &TextMOE{}}
|
||||
} else {
|
||||
layers[i] = TextLayer{FeedForward: &TextMLP{}}
|
||||
}
|
||||
}
|
||||
|
||||
return &TextModel{
|
||||
Layers: layers,
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.head_dim", 128)),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
|
||||
useQKNorm: c.Bool("use_qk_norm", true),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(1)
|
||||
useChunkedAttention := (i+1)%4 != 0
|
||||
if useChunkedAttention {
|
||||
wc.SetLayerType(0)
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
|
||||
}
|
256
model/models/llama4/model_vision.go
Normal file
256
model/models/llama4/model_vision.go
Normal file
|
@ -0,0 +1,256 @@
|
|||
package llama4
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type VisionAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
|
||||
// This is equivalent to the Pytorch implmentation using half rotations:
|
||||
//
|
||||
// cos, sin = torch.cos(freqs), torch.sin(freqs)
|
||||
// cos = cos.unsqueeze(-1)
|
||||
// sin = sin.unsqueeze(-1)
|
||||
// t = t.reshape(*t.shape[:-1], -1, 2)
|
||||
// t_out = (t * cos) + (_rotate_half(t) * sin)
|
||||
// t_out = t_out.flatten(3)
|
||||
//
|
||||
// Which is equivalent to the Pytorch implementation using complex numbers:
|
||||
//
|
||||
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
|
||||
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
|
||||
// freqs_ci = freqs_ci.to(t_.device)
|
||||
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
|
||||
//
|
||||
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
|
||||
// we need to use a different approach to achieve the same result.
|
||||
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
|
||||
|
||||
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
|
||||
|
||||
// t1 = t[..., 0::2]
|
||||
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
||||
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
|
||||
|
||||
// t2 = t[..., 1::2]
|
||||
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
|
||||
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
|
||||
|
||||
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
|
||||
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
|
||||
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
|
||||
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
|
||||
|
||||
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
|
||||
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
|
||||
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
|
||||
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
|
||||
|
||||
return cosOut.Add(ctx, sinOut)
|
||||
}
|
||||
|
||||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
|
||||
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
|
||||
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
|
||||
|
||||
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
|
||||
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type VisionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
|
||||
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionLayer struct {
|
||||
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
|
||||
*VisionAttention
|
||||
|
||||
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
|
||||
*VisionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
|
||||
// self attention
|
||||
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
// MLP
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionAdapter struct {
|
||||
FC1 *nn.Linear `gguf:"mlp.fc1"`
|
||||
FC2 *nn.Linear `gguf:"mlp.fc2"`
|
||||
}
|
||||
|
||||
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
patches := hiddenStates.Dim(1)
|
||||
patchSize := int(math.Sqrt(float64(patches)))
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
|
||||
|
||||
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
|
||||
|
||||
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
|
||||
|
||||
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
|
||||
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type VisionOptions struct {
|
||||
hiddenSize, numHeads int
|
||||
imageSize, patchSize int
|
||||
|
||||
ropeTheta float32
|
||||
eps float32
|
||||
pixelShuffleRatio float32
|
||||
}
|
||||
|
||||
type PatchEmbedding struct {
|
||||
*nn.Linear
|
||||
}
|
||||
|
||||
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
|
||||
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
|
||||
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
|
||||
return p.Linear.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type VisionModel struct {
|
||||
Layers []VisionLayer `gguf:"blk"`
|
||||
|
||||
*PatchEmbedding `gguf:"patch_embedding"`
|
||||
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
|
||||
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
|
||||
|
||||
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
|
||||
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
|
||||
|
||||
*VisionAdapter `gguf:"vision_adapter"`
|
||||
|
||||
*VisionOptions
|
||||
}
|
||||
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
return &VisionModel{
|
||||
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
|
||||
VisionOptions: &VisionOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.attention.head_count")),
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
ropeTheta: float32(c.Float("vision.rope.freq_base")),
|
||||
eps: c.Float("vision.layer_norm_epsilon"),
|
||||
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
|
||||
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
|
||||
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
cos, sin := m.rotaryEmbedding(ctx)
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
|
||||
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0)
|
||||
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
|
||||
// which in turn mimics Python's // operator.
|
||||
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
|
||||
if b == 0 {
|
||||
panic("division by zero")
|
||||
}
|
||||
|
||||
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
|
||||
return a / b
|
||||
}
|
||||
|
||||
return a/b - 1
|
||||
}
|
||||
|
||||
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
|
||||
patchesPerSide := m.imageSize / m.patchSize
|
||||
numPatches := patchesPerSide*patchesPerSide + 1
|
||||
|
||||
headDim := m.hiddenSize / m.numHeads
|
||||
freqDim := headDim / 2
|
||||
|
||||
freqs := make([]float32, numPatches*freqDim)
|
||||
for i := range numPatches - 1 {
|
||||
for j := 0; j < freqDim; j += 2 {
|
||||
positionX := i*freqDim/2 + j/2
|
||||
positionY := (i+numPatches)*freqDim/2 + j/2
|
||||
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
|
||||
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
|
||||
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
|
||||
}
|
||||
}
|
||||
|
||||
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
|
||||
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
|
||||
}
|
|
@ -4,6 +4,7 @@ import (
|
|||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/llama"
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue