diff --git a/convert/convert_llama4.go b/convert/convert_llama4.go index 14463a179..9aa0382ff 100644 --- a/convert/convert_llama4.go +++ b/convert/convert_llama4.go @@ -19,6 +19,7 @@ type llama4Model struct { InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"` UseQKNorm bool `json:"use_qk_norm"` IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"` + AttentionChunkSize uint32 `json:"attention_chunk_size"` } `json:"text_config"` VisionModel struct { NumHiddenLayers uint32 `json:"num_hidden_layers"` @@ -51,6 +52,7 @@ func (p *llama4Model) KV(t *Tokenizer) ggml.KV { kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm + kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize diff --git a/kvcache/causal.go b/kvcache/causal.go index 466722845..ea07932cd 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -21,6 +21,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e type Causal struct { DType ml.DType windowSize int32 + chunkSize int32 opts CausalOptions @@ -97,6 +98,17 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal { } } +func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal { + return &Causal{ + windowSize: math.MaxInt32, + chunkSize: chunkSize, + shiftFn: shift, + ctxs: make(map[int]ml.Context), + keys: make(map[int]ml.Tensor), + values: make(map[int]ml.Tensor), + } +} + func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) { if c.config == nil { var config ml.CacheConfig @@ -300,6 +312,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { for j := c.curCellRange.min; j <= c.curCellRange.max; j++ { if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) || (enabled && c.cells[j].pos > c.curPositions[i]) || + c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize || c.cells[j].pos < c.curPositions[i]-c.windowSize { mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1)) } diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index af01bb6f1..796987088 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -86,6 +86,64 @@ func TestSWA(t *testing.T) { testCache(t, backend, cache, tests) } +func TestChunkedAttention(t *testing.T) { + cache := NewChunkedAttentionCache(2, nil) + defer cache.Close() + + var b testBackend + cache.Init(&b, ml.DTypeF16, 1, 16, 16) + + x := float32(math.Inf(-1)) + + testCache( + t, &b, cache, + []testCase{ + { + name: "FirstBatch", + in: []float32{1, 2, 3, 4}, + inShape: []int{1, 1, 4}, + seqs: []int{0, 0, 0, 0}, + pos: []int32{0, 1, 2, 3}, + expected: []float32{1, 2, 3, 4}, + expectedShape: []int{1, 1, 4}, + expectedMask: []float32{ + 0, x, x, x, + 0, 0, x, x, + x, x, 0, x, + x, x, 0, 0, + }, + }, + { + name: "SecondBatch", + in: []float32{5, 6, 7}, + inShape: []int{1, 1, 3}, + seqs: []int{0, 0, 0}, + pos: []int32{4, 5, 6}, + expected: []float32{1, 2, 3, 4, 5, 6, 7}, + expectedShape: []int{1, 1, 7}, + expectedMask: []float32{ + x, x, x, x, 0, x, x, + x, x, x, x, 0, 0, x, + x, x, x, x, x, x, 0, + }, + }, + { + name: "ThirdBatch", + in: []float32{8, 9}, + inShape: []int{1, 1, 2}, + seqs: []int{0, 0}, + pos: []int32{7, 8}, + expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, + expectedShape: []int{1, 1, 9}, + expectedMask: []float32{ + x, x, x, x, x, x, 0, 0, x, + x, x, x, x, x, x, x, x, 0, + }, + }, + }, + ) +} + func TestSequences(t *testing.T) { backend := &testBackend{} cache := NewCausalCache(nil) @@ -293,8 +351,16 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context.Forward(out, mask).Compute(out, mask) - if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) { - t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask) + if !slices.Equal(out.Floats(), test.expected) { + t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected) + } + + if !slices.Equal(out.Shape(), test.expectedShape) { + t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape) + } + + if !slices.Equal(mask.Floats(), test.expectedMask) { + t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask) } }) } diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go index d3ed45ead..8f80c1dd4 100644 --- a/model/models/llama4/model.go +++ b/model/models/llama4/model.go @@ -52,8 +52,7 @@ func New(c fs.Config) (model.Model, error) { } m.Cache = kvcache.NewWrapperCache( - // TODO: pretend this is chunked attention for now - kvcache.NewSWACache(8192, m.Shift), + kvcache.NewChunkedAttentionCache(int32(c.Uint("attention.chunk_size")), m.Shift), kvcache.NewCausalCache(m.Shift), )