attention: Remove unnecessary contiguous operations

Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
This commit is contained in:
Jesse Gross 2025-02-22 21:34:10 -08:00 committed by Jesse Gross
parent 96a97adf9b
commit 854a9195f3
10 changed files with 270 additions and 86 deletions

View file

@ -29,6 +29,17 @@ type Cache interface {
// cache implementation used. // cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor) Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management ** // ** cache management **
// Init sets up runtime parameters // Init sets up runtime parameters

View file

@ -22,6 +22,9 @@ type Causal struct {
Capacity int32 Capacity int32
windowSize int32 windowSize int32
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@ -75,14 +78,34 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
} }
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
c.DType = dtype c.DType = dtype
c.Capacity = capacity c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, capacity) c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange) c.cellRanges = make(map[int]cellRange)
c.backend = backend c.backend = backend
c.cacheCtx = backend.NewContext() c.cacheCtx = backend.NewContext()
} }
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() { func (c *Causal) Close() {
c.cacheCtx.Close() c.cacheCtx.Close()
} }
@ -157,36 +180,73 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity) return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
} }
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the // Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the // token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch). // position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) { func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention // TODO(jessegross): This does not do mask padding, which is required for flash attention
len := c.curCellRange.max - c.curCellRange.min + 1 // Align and pad the cache range as required by the backend
mask := make([]float32, c.curBatchSize*len) c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*length)
for i := range c.curBatchSize { for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] || if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize { c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1)) mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
} }
} }
} }
return ctx.FromFloatSlice(mask, len, c.curBatchSize) return ctx.FromFloatSlice(mask, length, c.curBatchSize)
} }
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) { func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for _, obj := range objs { for i := range c.keys {
if obj == nil { if c.keys[i] == nil {
continue continue
} }
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len) key := c.keys[i]
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
ctx.Forward(srcView.Copy(ctx, dstView)) kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
} }
} }
@ -238,8 +298,7 @@ func (c *Causal) defrag() {
pendingLen++ pendingLen++
break break
} else { } else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
} }
@ -263,8 +322,7 @@ func (c *Causal) defrag() {
} }
if pendingLen > 0 { if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen) c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
moves++ moves++
} }
@ -305,35 +363,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer] key := c.keys[c.curLayer]
value := c.values[c.curLayer] value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min, kHeadDim := key.Dim(0)
key.Dim(0), key.Stride(1), numKVHeads := key.Dim(1)
key.Dim(1), key.Stride(2), rowSize := key.Stride(2)
c.curMask.Dim(0), cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
) )
value = value.View(ctx, key.Stride(2)*c.curCellRange.min, if c.config.PermutedV {
value.Dim(0), value.Stride(1), vHeadDim := value.Dim(1)
value.Dim(1), value.Stride(2), elemSize := value.Stride(0)
c.curMask.Dim(0),
) value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask return key, value, c.curMask
} }
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) { kHeadDim := key.Dim(0)
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2))) vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
} }
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity)) c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
} }
ctx.Forward( rowSize := c.keys[c.curLayer].Stride(2)
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))), ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
) if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
} }
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) { func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@ -389,9 +485,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue continue
} }
key = key.View(ctx, key.Stride(2)*seqRange.min, kHeadDim := key.Dim(0)
key.Dim(0), key.Stride(1), numKVHeads := key.Dim(1)
key.Dim(1), key.Stride(2), rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size, size,
) )

View file

@ -1,6 +1,8 @@
package kvcache package kvcache
import ( import (
"fmt"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
@ -11,6 +13,9 @@ import (
// //
// Not currently safe for multiple sequences // Not currently safe for multiple sequences
type EncoderCache struct { type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass ** // ** current forward pass **
// the active layer for Get and Put // the active layer for Get and Put
@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
} }
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) { func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext() c.cacheCtx = backend.NewContext()
} }
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() { func (c *EncoderCache) Close() {
c.cacheCtx.Close() c.cacheCtx.Close()
} }
@ -75,6 +100,10 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos c.encoderPos = c.curPos
c.encoderCached = true c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil { if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...) c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...) c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)

View file

@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
} }
} }
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() { func (c *WrapperCache) Close() {
for _, cache := range c.caches { for _, cache := range c.caches {
cache.Close() cache.Close()

View file

@ -27,6 +27,27 @@ type Backend interface {
SystemInfo() string SystemInfo() string
} }
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
}
// BackendParams controls how the backend loads and executes models // BackendParams controls how the backend loads and executes models
type BackendParams struct { type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU // NumThreads sets the number of threads to use if running on the CPU
@ -116,6 +137,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named // operation equivalent to following code on a tensor named
// query: // query:
// //
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query) // kq := key.MulmatFullPrec(ctx, query)
// //
// kq = kq.Scale(ctx, scale) // kq = kq.Scale(ctx, scale)

View file

@ -247,6 +247,10 @@ func (b *Backend) NewContext() ml.Context {
} }
} }
func (b *Backend) CacheConfig() ml.CacheConfig {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
type Context struct { type Context struct {
b *Backend b *Backend
ctx *C.struct_ggml_context ctx *C.struct_ggml_context
@ -661,7 +665,10 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t kqMask = mask.(*Tensor).t
} }
kq := key.MulmatFullPrec(ctx, t) query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{ kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0), t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
} }

View file

@ -3,6 +3,7 @@ package nn
import ( import (
"fmt" "fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
) )
@ -11,40 +12,50 @@ import (
// //
// Parameters: // Parameters:
// - ctx: Context for tensor operations // - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads] // - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads] // - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads] // - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension // - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
// //
// Returns: // Returns:
// //
// Attention output with shape [d_v, heads, seq_len_q] // Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor { func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if query.Dim(0) != key.Dim(0) { if key != nil && value != nil {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0))) if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
} }
if mask != nil && query.Dim(1) != mask.Dim(1) { var mask ml.Tensor
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1))) if cache != nil {
key, value, mask = cache.Get(ctx)
} }
if key.Dim(1) != value.Dim(0) { // Only use the fast SDPA implementation if we have a cache, since that's what
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0))) // will do any expected backend-specific transformations for us
} if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale) return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else { } else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query) kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale) kq = kq.Scale(ctx, scale)

View file

@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState) v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)

View file

@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c), TextModel: newTextModel(c),
} }
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift)) encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil return &m, nil
} }

View file

@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
value := sa.Value.Forward(ctx, hiddenState) value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
} }
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers // This will only get called for layers in the causal cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
} }
@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps) query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor var key, value ml.Tensor
if crossAttentionStates != nil { if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles) value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
} }
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key, value, _ = cache.Get(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention) return ca.Output.Forward(ctx, attention)