diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index de4ed4d54..b6e7d6f41 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "log/slog" + "reflect" "slices" "strings" @@ -52,32 +53,80 @@ func (kv KV) EmbeddingLength() uint64 { return uint64(kv.Uint("embedding_length")) } -func (kv KV) HeadCount() uint64 { - return uint64(kv.UintOrFirstArrayValue("attention.head_count")) +func (kv KV) HeadCounts() []uint64 { + return kv.UintOrArrayAsArray("attention.head_count", kv.BlockCount(), 1) } -func (kv KV) HeadCountKV() uint64 { - return uint64(kv.UintOrFirstArrayValue("attention.head_count_kv", 1)) +func (kv KV) HeadCountKVs() []uint64 { + return kv.UintOrArrayAsArray("attention.head_count_kv", kv.BlockCount(), 1) } -func (kv KV) EmbeddingHeadCount() uint64 { - if heads := kv.HeadCount(); heads > 0 { - return kv.EmbeddingLength() / heads +func (kv KV) EmbeddingHeadCount() []uint64 { + headCount := kv.HeadCounts() + embeddingHeadCount := make([]uint64, len(headCount)) + for i, heads := range headCount { + if heads == 0 { + embeddingHeadCount[i] = 0 + } else { + embeddingHeadCount[i] = kv.EmbeddingLength() / heads + } } - return 0 + return embeddingHeadCount } -func (kv KV) EmbeddingHeadCountK() uint64 { - return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount()))) +func (kv KV) FillArrayOrDefault(key string, defaultValue []uint64) []uint64 { + length := len(defaultValue) + if v, ok := keyValueUntyped(kv, key); ok { + switch v := v.(type) { + case uint32: + return FillArray(uint64(v), length) + case uint64: + return FillArray(v, length) + case int32: + return FillArray(uint64(v), length) + default: + slog.Warn("unsupported type", "key", key, "type", reflect.TypeOf(v)) + } + } + + return defaultValue } -func (kv KV) EmbeddingHeadCountV() uint64 { - return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount()))) +func (kv KV) EmbeddingHeadCountK() []uint64 { + return kv.FillArrayOrDefault("attention.key_length", kv.EmbeddingHeadCount()) } -func (kv KV) GQA() uint64 { - return kv.HeadCount() / kv.HeadCountKV() +func (kv KV) EmbeddingHeadCountV() []uint64 { + return kv.FillArrayOrDefault("attention.value_length", kv.EmbeddingHeadCount()) +} + +func (kv KV) GQAMax() uint64 { + heads := kv.HeadCounts() + headsKV := kv.HeadCountKVs() + if len(heads) != len(headsKV) { + slog.Warn("head count and head count kv are not the same length") + return 0 + } + if len(heads) == 0 { + slog.Warn("head count is empty") + return 0 + } + + maxGQA := uint64(0) + for i := range heads { + head := heads[i] + headKV := headsKV[i] + if head == 0 || headKV == 0 { + return 0 + } + gqa := head / headKV + if gqa > maxGQA { + maxGQA = gqa + } + } + + return maxGQA } func (kv KV) ContextLength() uint64 { @@ -104,20 +153,39 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool { return keyValue(kv, key, append(defaultValue, false)...) } -func (kv KV) UintOrFirstArrayValue(key string, defaultValue ...uint32) uint32 { +func (kv KV) UintOrArrayAsArray(key string, n uint64, defaultSingleValue ...uint64) []uint64 { + var singleValue *uint64 if v, ok := keyValueUntyped(kv, key); ok { - if a, ok := v.(*array); ok { - signed := a.values[0].(int32) - if signed >= 0 { - return uint32(signed) + switch v := v.(type) { + case *array: + switch v.values[0].(type) { + case int32, uint32, uint64: + values, ok := AsUint64Array(v.values) + if ok { + return values + } + default: + slog.Warn("unexpected array value type", "key", key, "type", reflect.TypeOf(v)) } - // TODO(drifkin): indicate unexpected data somehow? - return defaultValue[0] - } else if v, ok := v.(uint32); ok { - return v + case uint32: + val := uint64(v) + singleValue = &val + case int32: + val := uint64(v) + singleValue = &val } } - return defaultValue[0] + if singleValue == nil { + slog.Warn("falling back to default") + singleValue = &defaultSingleValue[0] + } + + values := make([]uint64, n) + for i := range values { + values[i] = *singleValue + } + + return values } func (kv KV) Strings(key string, defaultValue ...[]string) []string { @@ -442,12 +510,22 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) { func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { embedding := f.KV().EmbeddingLength() - heads := f.KV().HeadCount() - headsKV := f.KV().HeadCountKV() + heads := f.KV().HeadCounts() + headsKV := f.KV().HeadCountKVs() vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size) embeddingHeads := f.KV().EmbeddingHeadCount() + maxEmbeddingHeads, ok := MaxValue(embeddingHeads) + if !ok { + maxEmbeddingHeads = 1 + slog.Warn("failed to get max embedding heads") + } embeddingHeadsK := f.KV().EmbeddingHeadCountK() + maxEmbeddingHeadsK, ok := MaxValue(embeddingHeadsK) + if !ok { + maxEmbeddingHeadsK = 1 + slog.Warn("failed to get max embedding headsK") + } embeddingHeadsV := f.KV().EmbeddingHeadCountV() layers := f.Tensors().GroupLayers() @@ -455,19 +533,30 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri bytesPerElement := kvCacheBytesPerElement(kvCacheType) kv = make([]uint64, f.KV().BlockCount()) for i := range kv { - kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + kv[i] = uint64(float64(context*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement) + } + + maxHeads, ok := MaxValue(heads) + if !ok { + maxHeads = 1 + slog.Warn("failed to get max heads") + } + maxHeadsKV, ok := MaxValue(headsKV) + if !ok { + maxHeadsKV = 1 + slog.Warn("failed to get max headsKV") } switch f.KV().Architecture() { case "llama": fullOffload = max( - 4*batch*(1+4*embedding+context*(1+heads)), + 4*batch*(1+4*embedding+context*(1+maxHeads)), 4*batch*(embedding+vocab), ) partialOffload = 4 * batch * embedding partialOffload += max( - 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV), + 4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*maxHeads+maxEmbeddingHeads*maxHeadsKV), 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) @@ -475,16 +564,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri // mixtral 8x22b ff := uint64(f.KV()["llama.feed_forward_length"].(uint32)) partialOffload = max( - 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV), - 4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch), + 3*ffnGateExpsWeight.Size()+4*batch*(2*ff+maxHeadsKV+embedding+context+maxEmbeddingHeads*maxHeadsKV), + 4*(context*batch*maxHeads+context*maxEmbeddingHeads*maxHeadsKV+batch*1024+maxEmbeddingHeads*maxHeadsKV*batch), ) } else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok { // mixtral 8x7b ffnGateWeight1 := ffnGateWeight.Shape[1] - fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1) + fullOffload = 4 * batch * (2 + 3*embedding + context*(1+maxHeads) + 2*maxHeadsKV + ffnGateWeight1) partialOffload = max( - 4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16, - 4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16), + 4*batch*(3+maxEmbeddingHeads*maxHeadsKV+embedding+context*(1+maxHeads)+ffnGateWeight1)+(embedding*embedding+3*embedding*maxHeadsKV*ffnGateWeight1)*9/16, + 4*batch*(1+2*embedding+context*(1+maxHeads))+embedding*(6*context*maxHeadsKV/maxHeads+embedding*9/16), ) } case "mllama": @@ -493,7 +582,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers") for i := range kv { if slices.Contains(crossAttentionLayers, uint32(i)) { - kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) * + kv[i] = headsKV[i] * (embeddingHeadsK[i] + embeddingHeadsV[i]) * 4 * // sizeof(float32) visionTokens * tiles @@ -501,7 +590,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri } fullOffload = max( - 4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)), + 4*batch*(2+3*embedding+maxEmbeddingHeadsK*maxHeads+context*(1+maxHeads)), // vocab graph 4*batch*(embedding+vocab), ) @@ -515,23 +604,23 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri partialOffload = max( 4*(batch* - (2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+ + (2*embedding+1+context*(1+maxHeads)+maxEmbeddingHeadsK*maxHeads)+ ropeFreqsCount+ - embeddingHeadsK*context*headsKV), + maxEmbeddingHeadsK*context*maxHeadsKV), // vocab graph 4*batch*(embedding+vocab)+embedding*vocab*105/128, ) case "gemma", "gemma2", "gemma3": fullOffload = max( 4*batch*(embedding+vocab), - 4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads), + 4*batch*(2+context+context*maxHeads+2*embedding+2*maxEmbeddingHeadsK*maxHeads), ) partialOffload = max( 4*embedding*batch+embedding*vocab*105/128+4*vocab*batch, - 4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+ - 4*embeddingHeadsK*context*8+ - embedding*embeddingHeadsK*heads*9/16, + 4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeads+context+context*maxHeads)+ + 4*maxEmbeddingHeadsK*context*8+ + embedding*embedding*maxEmbeddingHeadsK*maxHeads*9/16, ) // Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama @@ -543,42 +632,42 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri // Every 6th layer is a global layer, which is the full context size that has already been set. The other // layers are the smaller local (sliding) layers. if (i+1)%gemma3GlobalCacheCount != 0 { - kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement) } } } case "command-r": fullOffload = max( 4*batch*(embedding+vocab), - 4*batch*(2+4*embedding+context*(1+heads)), + 4*batch*(2+4*embedding+context*(1+maxHeads)), ) partialOffload = max( 4*batch*(embedding+vocab)+embedding*vocab*105/128, - 4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16, + 4*batch*(1+2*embedding+context*(1+maxHeads))+4*embedding*context+embedding*embedding*9/16, ) case "qwen2": fullOffload = max( 4*batch*(embedding+vocab), - 4*batch*(1+2*embedding+context+context*heads), + 4*batch*(1+2*embedding+context+context*maxHeads), ) partialOffload = max( 4*batch*(embedding+vocab)+embedding*vocab*105/128, - 4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)), + 4*(batch*(1+2*embedding+context*(1+maxHeads))+embedding*(1+context)), ) case "phi2": fullOffload = max( 4*batch*(embedding+vocab), - 4*batch*(1+4*embedding+context+context*heads), + 4*batch*(1+4*embedding+context+context*maxHeads), ) partialOffload = max( 4*batch*(2*embedding+vocab)+embedding*vocab*105/128, - 4*batch*(2+3*embedding+context+context*heads), + 4*batch*(2+3*embedding+context+context*maxHeads), ) case "stablelm": - fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2) + fullOffload = 4 * batch * (context*(1+maxHeads) + 3*embedding + 2) partialOffload = max( 4*batch*(vocab+2*embedding), fullOffload, @@ -586,12 +675,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri case "deepseek2": fullOffload = max( 4*batch*(3*embedding+vocab), - 4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV), + 4*batch*(3*embedding+2+context*(1+maxHeadsKV)+2*maxEmbeddingHeadsK*maxHeadsKV), ) partialOffload = max( 4*batch*(3*embedding+vocab)+embedding*vocab*105/128, - 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16, + 4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeadsKV+context+context*maxHeadsKV)+4*maxEmbeddingHeadsK*context*maxHeadsKV+embedding*embedding*maxEmbeddingHeadsK*maxHeadsKV*9/16, ) case "chatglm": fullOffload = 4 * batch * (embedding + vocab) @@ -602,8 +691,8 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri 4*batch*(2+ 2*embedding+ context+ - context*heads+ - embeddingHeadsK*heads+ + context*maxHeads+ + maxEmbeddingHeadsK*maxHeads+ qkvBias.Shape[0]), ) @@ -611,11 +700,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri partialOffload, 4*batch*(1+ 2*embedding+ - embeddingHeadsK*heads+ + maxEmbeddingHeadsK*maxHeads+ context+ - context*heads)+ - 4*embeddingHeadsK*context+ - 4*context*embeddingHeadsK+ + context*maxHeads)+ + 4*maxEmbeddingHeadsK*context+ + 4*context*maxEmbeddingHeadsK+ 4*qkvBias.Shape[0], ) } @@ -687,9 +776,15 @@ func (f GGML) SupportsFlashAttention() bool { } // Check head counts match and are non-zero - headCountK := f.KV().EmbeddingHeadCountK() - headCountV := f.KV().EmbeddingHeadCountV() - return headCountK != 0 && headCountV != 0 && headCountK == headCountV + headCount := f.KV().HeadCounts() + embeddingHeadCountK := f.KV().EmbeddingHeadCountK() + embeddingHeadCountV := f.KV().EmbeddingHeadCountV() + for i := range headCount { + if embeddingHeadCountK[i] != embeddingHeadCountV[i] { + return false + } + } + return true } // kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type @@ -703,3 +798,54 @@ func kvCacheBytesPerElement(cacheType string) float64 { return 2 // f16 (default) } } + +func AsUint64Array(v []any) ([]uint64, bool) { + switch v[0].(type) { + case uint32: + values := make([]uint64, len(v)) + for i, v := range v { + values[i] = uint64(v.(uint32)) + } + return values, true + case uint64: + values := make([]uint64, len(v)) + for i, v := range v { + values[i] = v.(uint64) + } + return values, true + case int32: + values := make([]uint64, len(v)) + for i, val := range v { + val := val.(int32) + if val < 0 { + slog.Warn("negative value in int32 array", "value", val) + return nil, false + } + values[i] = uint64(val) + } + return values, true + } + return nil, false +} + +func MaxValue(values []uint64) (uint64, bool) { + if len(values) == 0 { + return 0, false + } + + max := values[0] + for _, v := range values { + if v > max { + max = v + } + } + return max, true +} + +func FillArray[T any](value T, n int) []T { + values := make([]T, n) + for i := range values { + values[i] = value + } + return values +} diff --git a/llm/memory.go b/llm/memory.go index 85a0fabd3..e7db6eca1 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -149,7 +149,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } if graphPartialOffload == 0 { - graphPartialOffload = f.KV().GQA() * kvTotal / 6 + graphPartialOffload = f.KV().GQAMax() * kvTotal / 6 } if graphFullOffload == 0 { graphFullOffload = graphPartialOffload