mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
Mistral is a popular research lab making open source models. This updates the forward pass of llama architecture models to support both llama models and mistral models by accounting for additional metadata present in mistral models, and finding the correct dimensions for the output projection.
186 lines
6.6 KiB
Go
186 lines
6.6 KiB
Go
package mistral3
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
var batchSize int = 1
|
|
|
|
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
|
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
|
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
|
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
|
}
|
|
|
|
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
|
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
|
}
|
|
|
|
type VisionSelfAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
query := sa.Query.Forward(ctx, hiddenStates)
|
|
key := sa.Key.Forward(ctx, hiddenStates)
|
|
value := sa.Value.Forward(ctx, hiddenStates)
|
|
|
|
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
|
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
|
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
|
|
|
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
|
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
|
|
|
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
|
return sa.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type VisionMLP struct {
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
}
|
|
|
|
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
|
return mlp.Down.Forward(ctx, hiddenStates)
|
|
}
|
|
|
|
type VisionEncoderLayer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
SelfAttention *VisionSelfAttention
|
|
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *VisionMLP
|
|
}
|
|
|
|
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
|
residual := hiddenStates
|
|
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
|
|
hiddenStates = hiddenStates.Add(ctx, residual)
|
|
|
|
residual = hiddenStates
|
|
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
|
|
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
|
return hiddenStates.Add(ctx, residual)
|
|
}
|
|
|
|
type VisionModelOptions struct {
|
|
hiddenSize int
|
|
numHeads int
|
|
headDim int
|
|
intermediateSize int
|
|
imageSize int
|
|
patchSize int
|
|
numChannels int
|
|
eps float32
|
|
ropeBase float32
|
|
}
|
|
|
|
type VisionModel struct {
|
|
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
|
|
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
|
|
Layers []VisionEncoderLayer `gguf:"blk"`
|
|
|
|
*VisionModelOptions
|
|
}
|
|
|
|
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
|
|
maxPatchesPerSide := m.imageSize / m.patchSize
|
|
frequencies := m.headDim / 2
|
|
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
|
|
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
|
|
for i := range frequencies {
|
|
for j := range maxPatchesPerSide {
|
|
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
|
|
if i%2 == 0 {
|
|
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
|
|
} else {
|
|
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
|
|
}
|
|
}
|
|
}
|
|
|
|
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
|
|
h = h.Repeat(ctx, 1, maxPatchesPerSide)
|
|
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
w = w.Repeat(ctx, 2, maxPatchesPerSide)
|
|
|
|
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
|
|
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
|
|
return inverseFrequencies.Rows(ctx, positionIDs)
|
|
}
|
|
|
|
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
|
|
numPatchesW := pixelValues.Dim(0) / m.patchSize
|
|
numPatchesH := pixelValues.Dim(1) / m.patchSize
|
|
numPatches := numPatchesW * numPatchesH
|
|
|
|
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
|
|
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
|
|
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
|
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
|
|
|
|
// Prepare position IDs for 2D rope
|
|
positions := make([]int32, numPatches)
|
|
for h := range numPatchesH {
|
|
for w := range numPatchesW {
|
|
idx := h*numPatchesW + w
|
|
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
|
|
}
|
|
}
|
|
|
|
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
|
|
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
|
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
|
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
|
|
|
for _, layer := range m.Layers {
|
|
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
|
|
}
|
|
|
|
return hiddenStates
|
|
}
|
|
|
|
func newVisionModel(c fs.Config) *VisionModel {
|
|
return &VisionModel{
|
|
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
|
|
VisionModelOptions: &VisionModelOptions{
|
|
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
|
|
numHeads: int(c.Uint("vision.attention.head_count", 16)),
|
|
headDim: int(c.Uint("vision.attention.key_length", 64)),
|
|
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
|
|
imageSize: int(c.Uint("vision.image_size", 1540)),
|
|
patchSize: int(c.Uint("vision.patch_size", 14)),
|
|
numChannels: int(c.Uint("vision.num_channels", 3)),
|
|
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
|
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
|
|
},
|
|
}
|
|
}
|