mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
This provides integration with the new Ollama engine
(5824541
next ollama runner (#7913)) and the rest of the Ollama
infrastructure such as the runner and Ollama server.
In addition, it also builds out the KV cache infrastructure to
support requirements of how Ollama runs models such as:
- Parallel processing
- Memory management for defragmentation and shifting
- Multi-modal modals
Both old and new engines continue to be supported. By default, only
the old engine is used. To enable the new engine:
Start the server with the OLLAMA_NEW_ENGINE environment variable set:
OLLAMA_NEW_ENGINE=1 ./ollama serve
Start a model that is supported by the Ollama engine. This one is Llama 3.1 8b Q4_K_M:
./ollama run jessegross/llama3.1
241 lines
8.7 KiB
Go
241 lines
8.7 KiB
Go
package mllama
|
|
|
|
import (
|
|
"math"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
)
|
|
|
|
type TextSelfAttention struct {
|
|
Query *nn.Linear `gguf:"attn_q"`
|
|
Key *nn.Linear `gguf:"attn_k"`
|
|
Value *nn.Linear `gguf:"attn_v"`
|
|
Output *nn.Linear `gguf:"attn_output"`
|
|
}
|
|
|
|
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
batchSize := hiddenState.Dim(1)
|
|
headDim := opts.hiddenSize / opts.numHeads
|
|
|
|
query := sa.Query.Forward(ctx, hiddenState)
|
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
|
query = query.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
|
key := sa.Key.Forward(ctx, hiddenState)
|
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
key = key.RoPE(ctx, positions, opts.RopeFactors, opts.ropeDim, opts.ropeBase, opts.ropeScale)
|
|
|
|
value := sa.Value.Forward(ctx, hiddenState)
|
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
|
|
|
cache.Put(ctx, key, value)
|
|
key, value, mask := cache.Get(ctx)
|
|
|
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
scores := key.MulmatFullPrec(ctx, query)
|
|
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
scores = scores.Add(ctx, mask)
|
|
scores = scores.Softmax(ctx)
|
|
|
|
attention := value.Mulmat(ctx, scores)
|
|
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
|
|
return sa.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
|
// This will only get called for layers in the cache, which are just the self attention layers
|
|
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
|
}
|
|
|
|
type TextMLP struct {
|
|
Up *nn.Linear `gguf:"ffn_up"`
|
|
Down *nn.Linear `gguf:"ffn_down"`
|
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
|
}
|
|
|
|
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
|
return mlp.Down.Forward(ctx, hiddenState)
|
|
}
|
|
|
|
type TextSelfAttentionDecoderLayer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
SelfAttention *TextSelfAttention
|
|
|
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *TextMLP
|
|
}
|
|
|
|
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
residual = hiddenState
|
|
|
|
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
|
|
return hiddenState.Add(ctx, residual)
|
|
}
|
|
|
|
type TextCrossAttention struct {
|
|
QueryNorm *nn.RMSNorm `gguf:"cross_attn_q_norm"`
|
|
Query *nn.Linear `gguf:"cross_attn_q_proj"`
|
|
KeyNorm *nn.RMSNorm `gguf:"cross_attn_k_norm"`
|
|
Key *nn.Linear `gguf:"cross_attn_k_proj"`
|
|
Value *nn.Linear `gguf:"cross_attn_v_proj"`
|
|
Output *nn.Linear `gguf:"cross_attn_o_proj"`
|
|
}
|
|
|
|
func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentionStates ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
batchSize := hiddenState.Dim(1)
|
|
headDim := opts.hiddenSize / opts.numHeads
|
|
|
|
query := ca.Query.Forward(ctx, hiddenState)
|
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
|
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
|
|
|
var key, value ml.Tensor
|
|
if crossAttentionStates != nil {
|
|
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
|
|
|
key = ca.Key.Forward(ctx, crossAttentionStates)
|
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
|
key = ca.KeyNorm.Forward(ctx, key, opts.eps)
|
|
|
|
value = ca.Value.Forward(ctx, crossAttentionStates)
|
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
|
|
|
cache.Put(ctx, key, value)
|
|
} else {
|
|
key, value, _ = cache.Get(ctx)
|
|
}
|
|
|
|
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
|
|
|
scores := key.Mulmat(ctx, query)
|
|
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
|
|
scores = scores.Softmax(ctx)
|
|
|
|
attention := value.Mulmat(ctx, scores)
|
|
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
|
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
|
|
|
return ca.Output.Forward(ctx, attention)
|
|
}
|
|
|
|
type TextCrossAttentionDecoderLayer struct {
|
|
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
|
CrossAttention *TextCrossAttention
|
|
AttentionGate ml.Tensor `gguf:"cross_attn_attn_gate"`
|
|
|
|
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
|
MLP *TextMLP
|
|
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
|
|
}
|
|
|
|
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
residual := hiddenState
|
|
|
|
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.CrossAttention.Forward(ctx, hiddenState, crossAttentionStates, cache, opts)
|
|
hiddenState = hiddenState.Mul(ctx, d.AttentionGate.Tanh(ctx))
|
|
hiddenState = hiddenState.Add(ctx, residual)
|
|
residual = hiddenState
|
|
|
|
hiddenState = d.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
|
hiddenState = d.MLP.Forward(ctx, hiddenState, opts)
|
|
hiddenState = hiddenState.Mul(ctx, d.MLPGate.Tanh(ctx))
|
|
return hiddenState.Add(ctx, residual)
|
|
}
|
|
|
|
type TextDecoderLayer interface {
|
|
Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
|
|
}
|
|
|
|
type TextDecoder struct {
|
|
Layers []TextDecoderLayer
|
|
}
|
|
|
|
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
|
|
for i, layer := range d.Layers {
|
|
layerType := selfAttentionLayer
|
|
if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
|
|
layerType = crossAttentionLayer
|
|
}
|
|
|
|
cache.SetLayer(i)
|
|
cache.SetLayerType(layerType)
|
|
|
|
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
|
|
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
|
|
}
|
|
}
|
|
|
|
return hiddenState
|
|
}
|
|
|
|
type TextModelOptions struct {
|
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
|
|
|
hiddenSize, numHeads, numKVHeads int
|
|
eps, ropeBase, ropeScale float32
|
|
ropeDim uint32
|
|
|
|
crossAttentionLayers []uint32
|
|
}
|
|
|
|
type TextModel struct {
|
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
|
Transformer *TextDecoder `gguf:"blk"`
|
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
|
Output *nn.Linear `gguf:"output"`
|
|
|
|
*TextModelOptions
|
|
}
|
|
|
|
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
|
|
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
|
|
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
|
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
|
return m.Output.Forward(ctx, hiddenState)
|
|
}
|
|
|
|
func newTextModel(c ml.Config) *TextModel {
|
|
var decoderLayers []TextDecoderLayer
|
|
for i := range c.Uint("block_count") {
|
|
var textDecoderLayer TextDecoderLayer
|
|
if slices.Contains(c.Uints("attention.cross_attention_layers"), i) {
|
|
textDecoderLayer = &TextCrossAttentionDecoderLayer{}
|
|
} else {
|
|
textDecoderLayer = &TextSelfAttentionDecoderLayer{}
|
|
}
|
|
|
|
decoderLayers = append(decoderLayers, textDecoderLayer)
|
|
}
|
|
|
|
return &TextModel{
|
|
Transformer: &TextDecoder{Layers: decoderLayers},
|
|
TextModelOptions: &TextModelOptions{
|
|
hiddenSize: int(c.Uint("embedding_length")),
|
|
numHeads: int(c.Uint("attention.head_count")),
|
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
|
ropeBase: c.Float("rope.freq_base"),
|
|
ropeScale: c.Float("rope.freq_scale", 1),
|
|
ropeDim: c.Uint("rope.dimension_count"),
|
|
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
|
|
},
|
|
}
|
|
}
|