ggml: more accurate estimates for head count array case

Also standardized the approach by always treatting `HeadCount()` and
`HeadCountKV()` as arrays by filling them with the same value when
they're a scalar in the original GGUF
This commit is contained in:
Devon Rifkin 2025-04-10 16:28:34 -07:00
parent 0188c74c41
commit 7c94471d38
2 changed files with 208 additions and 62 deletions

View file

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"log/slog"
"reflect"
"slices"
"strings"
@ -52,32 +53,80 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCount() uint64 {
return uint64(kv.UintOrFirstArrayValue("attention.head_count"))
func (kv KV) HeadCounts() []uint64 {
return kv.UintOrArrayAsArray("attention.head_count", kv.BlockCount(), 1)
}
func (kv KV) HeadCountKV() uint64 {
return uint64(kv.UintOrFirstArrayValue("attention.head_count_kv", 1))
func (kv KV) HeadCountKVs() []uint64 {
return kv.UintOrArrayAsArray("attention.head_count_kv", kv.BlockCount(), 1)
}
func (kv KV) EmbeddingHeadCount() uint64 {
if heads := kv.HeadCount(); heads > 0 {
return kv.EmbeddingLength() / heads
func (kv KV) EmbeddingHeadCount() []uint64 {
headCount := kv.HeadCounts()
embeddingHeadCount := make([]uint64, len(headCount))
for i, heads := range headCount {
if heads == 0 {
embeddingHeadCount[i] = 0
} else {
embeddingHeadCount[i] = kv.EmbeddingLength() / heads
}
}
return 0
return embeddingHeadCount
}
func (kv KV) EmbeddingHeadCountK() uint64 {
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
func (kv KV) FillArrayOrDefault(key string, defaultValue []uint64) []uint64 {
length := len(defaultValue)
if v, ok := keyValueUntyped(kv, key); ok {
switch v := v.(type) {
case uint32:
return FillArray(uint64(v), length)
case uint64:
return FillArray(v, length)
case int32:
return FillArray(uint64(v), length)
default:
slog.Warn("unsupported type", "key", key, "type", reflect.TypeOf(v))
}
}
return defaultValue
}
func (kv KV) EmbeddingHeadCountV() uint64 {
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
func (kv KV) EmbeddingHeadCountK() []uint64 {
return kv.FillArrayOrDefault("attention.key_length", kv.EmbeddingHeadCount())
}
func (kv KV) GQA() uint64 {
return kv.HeadCount() / kv.HeadCountKV()
func (kv KV) EmbeddingHeadCountV() []uint64 {
return kv.FillArrayOrDefault("attention.value_length", kv.EmbeddingHeadCount())
}
func (kv KV) GQAMax() uint64 {
heads := kv.HeadCounts()
headsKV := kv.HeadCountKVs()
if len(heads) != len(headsKV) {
slog.Warn("head count and head count kv are not the same length")
return 0
}
if len(heads) == 0 {
slog.Warn("head count is empty")
return 0
}
maxGQA := uint64(0)
for i := range heads {
head := heads[i]
headKV := headsKV[i]
if head == 0 || headKV == 0 {
return 0
}
gqa := head / headKV
if gqa > maxGQA {
maxGQA = gqa
}
}
return maxGQA
}
func (kv KV) ContextLength() uint64 {
@ -104,20 +153,39 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) UintOrFirstArrayValue(key string, defaultValue ...uint32) uint32 {
func (kv KV) UintOrArrayAsArray(key string, n uint64, defaultSingleValue ...uint64) []uint64 {
var singleValue *uint64
if v, ok := keyValueUntyped(kv, key); ok {
if a, ok := v.(*array); ok {
signed := a.values[0].(int32)
if signed >= 0 {
return uint32(signed)
switch v := v.(type) {
case *array:
switch v.values[0].(type) {
case int32, uint32, uint64:
values, ok := AsUint64Array(v.values)
if ok {
return values
}
default:
slog.Warn("unexpected array value type", "key", key, "type", reflect.TypeOf(v))
}
// TODO(drifkin): indicate unexpected data somehow?
return defaultValue[0]
} else if v, ok := v.(uint32); ok {
return v
case uint32:
val := uint64(v)
singleValue = &val
case int32:
val := uint64(v)
singleValue = &val
}
}
return defaultValue[0]
if singleValue == nil {
slog.Warn("falling back to default")
singleValue = &defaultSingleValue[0]
}
values := make([]uint64, n)
for i := range values {
values[i] = *singleValue
}
return values
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
@ -442,12 +510,22 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
heads := f.KV().HeadCounts()
headsKV := f.KV().HeadCountKVs()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
embeddingHeads := f.KV().EmbeddingHeadCount()
maxEmbeddingHeads, ok := MaxValue(embeddingHeads)
if !ok {
maxEmbeddingHeads = 1
slog.Warn("failed to get max embedding heads")
}
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
maxEmbeddingHeadsK, ok := MaxValue(embeddingHeadsK)
if !ok {
maxEmbeddingHeadsK = 1
slog.Warn("failed to get max embedding headsK")
}
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
layers := f.Tensors().GroupLayers()
@ -455,19 +533,30 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
kv[i] = uint64(float64(context*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
}
maxHeads, ok := MaxValue(heads)
if !ok {
maxHeads = 1
slog.Warn("failed to get max heads")
}
maxHeadsKV, ok := MaxValue(headsKV)
if !ok {
maxHeadsKV = 1
slog.Warn("failed to get max headsKV")
}
switch f.KV().Architecture() {
case "llama":
fullOffload = max(
4*batch*(1+4*embedding+context*(1+heads)),
4*batch*(1+4*embedding+context*(1+maxHeads)),
4*batch*(embedding+vocab),
)
partialOffload = 4 * batch * embedding
partialOffload += max(
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*maxHeads+maxEmbeddingHeads*maxHeadsKV),
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
@ -475,16 +564,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// mixtral 8x22b
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
partialOffload = max(
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+maxHeadsKV+embedding+context+maxEmbeddingHeads*maxHeadsKV),
4*(context*batch*maxHeads+context*maxEmbeddingHeads*maxHeadsKV+batch*1024+maxEmbeddingHeads*maxHeadsKV*batch),
)
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
// mixtral 8x7b
ffnGateWeight1 := ffnGateWeight.Shape[1]
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+maxHeads) + 2*maxHeadsKV + ffnGateWeight1)
partialOffload = max(
4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
4*batch*(3+maxEmbeddingHeads*maxHeadsKV+embedding+context*(1+maxHeads)+ffnGateWeight1)+(embedding*embedding+3*embedding*maxHeadsKV*ffnGateWeight1)*9/16,
4*batch*(1+2*embedding+context*(1+maxHeads))+embedding*(6*context*maxHeadsKV/maxHeads+embedding*9/16),
)
}
case "mllama":
@ -493,7 +582,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
kv[i] = headsKV[i] * (embeddingHeadsK[i] + embeddingHeadsV[i]) *
4 * // sizeof(float32)
visionTokens *
tiles
@ -501,7 +590,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
}
fullOffload = max(
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
4*batch*(2+3*embedding+maxEmbeddingHeadsK*maxHeads+context*(1+maxHeads)),
// vocab graph
4*batch*(embedding+vocab),
)
@ -515,23 +604,23 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
partialOffload = max(
4*(batch*
(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
(2*embedding+1+context*(1+maxHeads)+maxEmbeddingHeadsK*maxHeads)+
ropeFreqsCount+
embeddingHeadsK*context*headsKV),
maxEmbeddingHeadsK*context*maxHeadsKV),
// vocab graph
4*batch*(embedding+vocab)+embedding*vocab*105/128,
)
case "gemma", "gemma2", "gemma3":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
4*batch*(2+context+context*maxHeads+2*embedding+2*maxEmbeddingHeadsK*maxHeads),
)
partialOffload = max(
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16,
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeads+context+context*maxHeads)+
4*maxEmbeddingHeadsK*context*8+
embedding*embedding*maxEmbeddingHeadsK*maxHeads*9/16,
)
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
@ -543,42 +632,42 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
}
}
}
case "command-r":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(2+4*embedding+context*(1+heads)),
4*batch*(2+4*embedding+context*(1+maxHeads)),
)
partialOffload = max(
4*batch*(embedding+vocab)+embedding*vocab*105/128,
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
4*batch*(1+2*embedding+context*(1+maxHeads))+4*embedding*context+embedding*embedding*9/16,
)
case "qwen2":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(1+2*embedding+context+context*heads),
4*batch*(1+2*embedding+context+context*maxHeads),
)
partialOffload = max(
4*batch*(embedding+vocab)+embedding*vocab*105/128,
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
4*(batch*(1+2*embedding+context*(1+maxHeads))+embedding*(1+context)),
)
case "phi2":
fullOffload = max(
4*batch*(embedding+vocab),
4*batch*(1+4*embedding+context+context*heads),
4*batch*(1+4*embedding+context+context*maxHeads),
)
partialOffload = max(
4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2+3*embedding+context+context*heads),
4*batch*(2+3*embedding+context+context*maxHeads),
)
case "stablelm":
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
fullOffload = 4 * batch * (context*(1+maxHeads) + 3*embedding + 2)
partialOffload = max(
4*batch*(vocab+2*embedding),
fullOffload,
@ -586,12 +675,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
case "deepseek2":
fullOffload = max(
4*batch*(3*embedding+vocab),
4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
4*batch*(3*embedding+2+context*(1+maxHeadsKV)+2*maxEmbeddingHeadsK*maxHeadsKV),
)
partialOffload = max(
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeadsKV+context+context*maxHeadsKV)+4*maxEmbeddingHeadsK*context*maxHeadsKV+embedding*embedding*maxEmbeddingHeadsK*maxHeadsKV*9/16,
)
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
@ -602,8 +691,8 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
context*maxHeads+
maxEmbeddingHeadsK*maxHeads+
qkvBias.Shape[0]),
)
@ -611,11 +700,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
maxEmbeddingHeadsK*maxHeads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
context*maxHeads)+
4*maxEmbeddingHeadsK*context+
4*context*maxEmbeddingHeadsK+
4*qkvBias.Shape[0],
)
}
@ -687,9 +776,15 @@ func (f GGML) SupportsFlashAttention() bool {
}
// Check head counts match and are non-zero
headCountK := f.KV().EmbeddingHeadCountK()
headCountV := f.KV().EmbeddingHeadCountV()
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
headCount := f.KV().HeadCounts()
embeddingHeadCountK := f.KV().EmbeddingHeadCountK()
embeddingHeadCountV := f.KV().EmbeddingHeadCountV()
for i := range headCount {
if embeddingHeadCountK[i] != embeddingHeadCountV[i] {
return false
}
}
return true
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
@ -703,3 +798,54 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 2 // f16 (default)
}
}
func AsUint64Array(v []any) ([]uint64, bool) {
switch v[0].(type) {
case uint32:
values := make([]uint64, len(v))
for i, v := range v {
values[i] = uint64(v.(uint32))
}
return values, true
case uint64:
values := make([]uint64, len(v))
for i, v := range v {
values[i] = v.(uint64)
}
return values, true
case int32:
values := make([]uint64, len(v))
for i, val := range v {
val := val.(int32)
if val < 0 {
slog.Warn("negative value in int32 array", "value", val)
return nil, false
}
values[i] = uint64(val)
}
return values, true
}
return nil, false
}
func MaxValue(values []uint64) (uint64, bool) {
if len(values) == 0 {
return 0, false
}
max := values[0]
for _, v := range values {
if v > max {
max = v
}
}
return max, true
}
func FillArray[T any](value T, n int) []T {
values := make([]T, n)
for i := range values {
values[i] = value
}
return values
}

View file

@ -149,7 +149,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
if graphPartialOffload == 0 {
graphPartialOffload = f.KV().GQA() * kvTotal / 6
graphPartialOffload = f.KV().GQAMax() * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload