mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
164 lines
5.1 KiB
Go
164 lines
5.1 KiB
Go
package mistral3
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"strings"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type TextOptions struct {
|
|
hiddenSize, numHeads, numKVHeads, headDim int
|
|
eps, ropeBase, ropeScale float32
|
|
ropeDim uint32
|
|
}
|
|
|
|
type TextModel struct {
|
|
model.Base
|
|
|
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
|
Layers []Layer `gguf:"blk"`
|
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
|
|
|
*TextOptions
|
|
}
|
|
|
|
type SelfAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
|
batchSize := hiddenState.Dim(1)
|
|
ropeType := uint32(0)
|
|
headDim := opts.headDim
|
|
if headDim == 0 {
|
|
headDim = opts.hiddenSize / opts.numHeads
|
|
}
|
|
|
|
q := sa.Query.Forward(ctx, hiddenState)
|
|
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
|
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
k := sa.Key.Forward(ctx, hiddenState)
|
|
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
|
|
|
v := sa.Value.Forward(ctx, hiddenState)
|
|
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
|
|
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
|
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
|
return sa.Output.Forward(ctx, kqv)
|
|
}
|
|
|
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
|
}
|
|
|
|
type MLP struct {
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
}
|
|
|
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
|
return mlp.Down.Forward(ctx, hiddenState)
|
|
}
|
|
|
|
type Layer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
SelfAttention *SelfAttention
|
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *MLP
|
|
}
|
|
|
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
|
|
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
|
|
|
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
|
// we need logits for.
|
|
if outputs != nil {
|
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
|
residual = residual.Rows(ctx, outputs)
|
|
}
|
|
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
residual = hiddenState
|
|
|
|
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
|
return hiddenState.Add(ctx, residual)
|
|
}
|
|
|
|
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
|
|
|
// image embeddings
|
|
for _, image := range batch.Multimodal {
|
|
row := image.Multimodal.(*imageRow)
|
|
row.parent.dataOnce.Do(func() {
|
|
// use a new, throwaway context so the image tensor is not added to the graph
|
|
temp := m.Backend().NewContext()
|
|
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
|
|
row.parent.data = row.parent.tensor.Floats()
|
|
temp.Close()
|
|
})
|
|
|
|
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
|
|
}
|
|
|
|
for i, layer := range m.Layers {
|
|
cache.SetLayer(i)
|
|
|
|
var lastLayerOutputs ml.Tensor
|
|
if i == len(m.Layers)-1 {
|
|
lastLayerOutputs = outputs
|
|
}
|
|
|
|
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
|
}
|
|
|
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
return m.Output.Forward(ctx, hiddenState)
|
|
}
|
|
|
|
func NewTextModel(c fs.Config) (*TextModel, error) {
|
|
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
|
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
|
}
|
|
|
|
textModel := &TextModel{
|
|
Layers: make([]Layer, c.Uint("block_count")),
|
|
TextOptions: &TextOptions{
|
|
hiddenSize: int(c.Uint("embedding_length")),
|
|
numHeads: int(c.Uint("attention.head_count")),
|
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
|
headDim: int(c.Uint("attention.key_length")),
|
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
|
ropeBase: c.Float("rope.freq_base"),
|
|
ropeScale: c.Float("rope.freq_scale", 1),
|
|
ropeDim: c.Uint("rope.dimension_count"),
|
|
},
|
|
}
|
|
|
|
return textModel, nil
|
|
}
|