mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
ggml: more accurate estimates for head count array case
Also standardized the approach by always treatting `HeadCount()` and `HeadCountKV()` as arrays by filling them with the same value when they're a scalar in the original GGUF
This commit is contained in:
parent
0188c74c41
commit
7c94471d38
2 changed files with 208 additions and 62 deletions
268
fs/ggml/ggml.go
268
fs/ggml/ggml.go
|
@ -6,6 +6,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
|
@ -52,32 +53,80 @@ func (kv KV) EmbeddingLength() uint64 {
|
|||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.UintOrFirstArrayValue("attention.head_count"))
|
||||
func (kv KV) HeadCounts() []uint64 {
|
||||
return kv.UintOrArrayAsArray("attention.head_count", kv.BlockCount(), 1)
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.UintOrFirstArrayValue("attention.head_count_kv", 1))
|
||||
func (kv KV) HeadCountKVs() []uint64 {
|
||||
return kv.UintOrArrayAsArray("attention.head_count_kv", kv.BlockCount(), 1)
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
func (kv KV) EmbeddingHeadCount() []uint64 {
|
||||
headCount := kv.HeadCounts()
|
||||
embeddingHeadCount := make([]uint64, len(headCount))
|
||||
for i, heads := range headCount {
|
||||
if heads == 0 {
|
||||
embeddingHeadCount[i] = 0
|
||||
} else {
|
||||
embeddingHeadCount[i] = kv.EmbeddingLength() / heads
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
return embeddingHeadCount
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
func (kv KV) FillArrayOrDefault(key string, defaultValue []uint64) []uint64 {
|
||||
length := len(defaultValue)
|
||||
if v, ok := keyValueUntyped(kv, key); ok {
|
||||
switch v := v.(type) {
|
||||
case uint32:
|
||||
return FillArray(uint64(v), length)
|
||||
case uint64:
|
||||
return FillArray(v, length)
|
||||
case int32:
|
||||
return FillArray(uint64(v), length)
|
||||
default:
|
||||
slog.Warn("unsupported type", "key", key, "type", reflect.TypeOf(v))
|
||||
}
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
func (kv KV) EmbeddingHeadCountK() []uint64 {
|
||||
return kv.FillArrayOrDefault("attention.key_length", kv.EmbeddingHeadCount())
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
func (kv KV) EmbeddingHeadCountV() []uint64 {
|
||||
return kv.FillArrayOrDefault("attention.value_length", kv.EmbeddingHeadCount())
|
||||
}
|
||||
|
||||
func (kv KV) GQAMax() uint64 {
|
||||
heads := kv.HeadCounts()
|
||||
headsKV := kv.HeadCountKVs()
|
||||
if len(heads) != len(headsKV) {
|
||||
slog.Warn("head count and head count kv are not the same length")
|
||||
return 0
|
||||
}
|
||||
if len(heads) == 0 {
|
||||
slog.Warn("head count is empty")
|
||||
return 0
|
||||
}
|
||||
|
||||
maxGQA := uint64(0)
|
||||
for i := range heads {
|
||||
head := heads[i]
|
||||
headKV := headsKV[i]
|
||||
if head == 0 || headKV == 0 {
|
||||
return 0
|
||||
}
|
||||
gqa := head / headKV
|
||||
if gqa > maxGQA {
|
||||
maxGQA = gqa
|
||||
}
|
||||
}
|
||||
|
||||
return maxGQA
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
|
@ -104,20 +153,39 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
|||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrFirstArrayValue(key string, defaultValue ...uint32) uint32 {
|
||||
func (kv KV) UintOrArrayAsArray(key string, n uint64, defaultSingleValue ...uint64) []uint64 {
|
||||
var singleValue *uint64
|
||||
if v, ok := keyValueUntyped(kv, key); ok {
|
||||
if a, ok := v.(*array); ok {
|
||||
signed := a.values[0].(int32)
|
||||
if signed >= 0 {
|
||||
return uint32(signed)
|
||||
switch v := v.(type) {
|
||||
case *array:
|
||||
switch v.values[0].(type) {
|
||||
case int32, uint32, uint64:
|
||||
values, ok := AsUint64Array(v.values)
|
||||
if ok {
|
||||
return values
|
||||
}
|
||||
default:
|
||||
slog.Warn("unexpected array value type", "key", key, "type", reflect.TypeOf(v))
|
||||
}
|
||||
// TODO(drifkin): indicate unexpected data somehow?
|
||||
return defaultValue[0]
|
||||
} else if v, ok := v.(uint32); ok {
|
||||
return v
|
||||
case uint32:
|
||||
val := uint64(v)
|
||||
singleValue = &val
|
||||
case int32:
|
||||
val := uint64(v)
|
||||
singleValue = &val
|
||||
}
|
||||
}
|
||||
return defaultValue[0]
|
||||
if singleValue == nil {
|
||||
slog.Warn("falling back to default")
|
||||
singleValue = &defaultSingleValue[0]
|
||||
}
|
||||
|
||||
values := make([]uint64, n)
|
||||
for i := range values {
|
||||
values[i] = *singleValue
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
|
@ -442,12 +510,22 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
heads := f.KV().HeadCounts()
|
||||
headsKV := f.KV().HeadCountKVs()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
maxEmbeddingHeads, ok := MaxValue(embeddingHeads)
|
||||
if !ok {
|
||||
maxEmbeddingHeads = 1
|
||||
slog.Warn("failed to get max embedding heads")
|
||||
}
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
maxEmbeddingHeadsK, ok := MaxValue(embeddingHeadsK)
|
||||
if !ok {
|
||||
maxEmbeddingHeadsK = 1
|
||||
slog.Warn("failed to get max embedding headsK")
|
||||
}
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
@ -455,19 +533,30 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||
}
|
||||
|
||||
maxHeads, ok := MaxValue(heads)
|
||||
if !ok {
|
||||
maxHeads = 1
|
||||
slog.Warn("failed to get max heads")
|
||||
}
|
||||
maxHeadsKV, ok := MaxValue(headsKV)
|
||||
if !ok {
|
||||
maxHeadsKV = 1
|
||||
slog.Warn("failed to get max headsKV")
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = max(
|
||||
4*batch*(1+4*embedding+context*(1+heads)),
|
||||
4*batch*(1+4*embedding+context*(1+maxHeads)),
|
||||
4*batch*(embedding+vocab),
|
||||
)
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*maxHeads+maxEmbeddingHeads*maxHeadsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
|
@ -475,16 +564,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
// mixtral 8x22b
|
||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||
partialOffload = max(
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+maxHeadsKV+embedding+context+maxEmbeddingHeads*maxHeadsKV),
|
||||
4*(context*batch*maxHeads+context*maxEmbeddingHeads*maxHeadsKV+batch*1024+maxEmbeddingHeads*maxHeadsKV*batch),
|
||||
)
|
||||
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
||||
// mixtral 8x7b
|
||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+maxHeads) + 2*maxHeadsKV + ffnGateWeight1)
|
||||
partialOffload = max(
|
||||
4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||
4*batch*(3+maxEmbeddingHeads*maxHeadsKV+embedding+context*(1+maxHeads)+ffnGateWeight1)+(embedding*embedding+3*embedding*maxHeadsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+maxHeads))+embedding*(6*context*maxHeadsKV/maxHeads+embedding*9/16),
|
||||
)
|
||||
}
|
||||
case "mllama":
|
||||
|
@ -493,7 +582,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
kv[i] = headsKV[i] * (embeddingHeadsK[i] + embeddingHeadsV[i]) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
|
@ -501,7 +590,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
}
|
||||
|
||||
fullOffload = max(
|
||||
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
||||
4*batch*(2+3*embedding+maxEmbeddingHeadsK*maxHeads+context*(1+maxHeads)),
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab),
|
||||
)
|
||||
|
@ -515,23 +604,23 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
|
||||
partialOffload = max(
|
||||
4*(batch*
|
||||
(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
||||
(2*embedding+1+context*(1+maxHeads)+maxEmbeddingHeadsK*maxHeads)+
|
||||
ropeFreqsCount+
|
||||
embeddingHeadsK*context*headsKV),
|
||||
maxEmbeddingHeadsK*context*maxHeadsKV),
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
4*batch*(2+context+context*maxHeads+2*embedding+2*maxEmbeddingHeadsK*maxHeads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeads+context+context*maxHeads)+
|
||||
4*maxEmbeddingHeadsK*context*8+
|
||||
embedding*embedding*maxEmbeddingHeadsK*maxHeads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
|
@ -543,42 +632,42 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
4*batch*(2+4*embedding+context*(1+maxHeads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+maxHeads))+4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
4*batch*(1+2*embedding+context+context*maxHeads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
4*(batch*(1+2*embedding+context*(1+maxHeads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
4*batch*(1+4*embedding+context+context*maxHeads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2+3*embedding+context+context*heads),
|
||||
4*batch*(2+3*embedding+context+context*maxHeads),
|
||||
)
|
||||
case "stablelm":
|
||||
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
|
||||
fullOffload = 4 * batch * (context*(1+maxHeads) + 3*embedding + 2)
|
||||
partialOffload = max(
|
||||
4*batch*(vocab+2*embedding),
|
||||
fullOffload,
|
||||
|
@ -586,12 +675,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
case "deepseek2":
|
||||
fullOffload = max(
|
||||
4*batch*(3*embedding+vocab),
|
||||
4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
|
||||
4*batch*(3*embedding+2+context*(1+maxHeadsKV)+2*maxEmbeddingHeadsK*maxHeadsKV),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeadsKV+context+context*maxHeadsKV)+4*maxEmbeddingHeadsK*context*maxHeadsKV+embedding*embedding*maxEmbeddingHeadsK*maxHeadsKV*9/16,
|
||||
)
|
||||
case "chatglm":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
|
@ -602,8 +691,8 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
4*batch*(2+
|
||||
2*embedding+
|
||||
context+
|
||||
context*heads+
|
||||
embeddingHeadsK*heads+
|
||||
context*maxHeads+
|
||||
maxEmbeddingHeadsK*maxHeads+
|
||||
qkvBias.Shape[0]),
|
||||
)
|
||||
|
||||
|
@ -611,11 +700,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
partialOffload,
|
||||
4*batch*(1+
|
||||
2*embedding+
|
||||
embeddingHeadsK*heads+
|
||||
maxEmbeddingHeadsK*maxHeads+
|
||||
context+
|
||||
context*heads)+
|
||||
4*embeddingHeadsK*context+
|
||||
4*context*embeddingHeadsK+
|
||||
context*maxHeads)+
|
||||
4*maxEmbeddingHeadsK*context+
|
||||
4*context*maxEmbeddingHeadsK+
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
|
@ -687,9 +776,15 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
headCount := f.KV().HeadCounts()
|
||||
embeddingHeadCountK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadCountV := f.KV().EmbeddingHeadCountV()
|
||||
for i := range headCount {
|
||||
if embeddingHeadCountK[i] != embeddingHeadCountV[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
|
@ -703,3 +798,54 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
|||
return 2 // f16 (default)
|
||||
}
|
||||
}
|
||||
|
||||
func AsUint64Array(v []any) ([]uint64, bool) {
|
||||
switch v[0].(type) {
|
||||
case uint32:
|
||||
values := make([]uint64, len(v))
|
||||
for i, v := range v {
|
||||
values[i] = uint64(v.(uint32))
|
||||
}
|
||||
return values, true
|
||||
case uint64:
|
||||
values := make([]uint64, len(v))
|
||||
for i, v := range v {
|
||||
values[i] = v.(uint64)
|
||||
}
|
||||
return values, true
|
||||
case int32:
|
||||
values := make([]uint64, len(v))
|
||||
for i, val := range v {
|
||||
val := val.(int32)
|
||||
if val < 0 {
|
||||
slog.Warn("negative value in int32 array", "value", val)
|
||||
return nil, false
|
||||
}
|
||||
values[i] = uint64(val)
|
||||
}
|
||||
return values, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func MaxValue(values []uint64) (uint64, bool) {
|
||||
if len(values) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
max := values[0]
|
||||
for _, v := range values {
|
||||
if v > max {
|
||||
max = v
|
||||
}
|
||||
}
|
||||
return max, true
|
||||
}
|
||||
|
||||
func FillArray[T any](value T, n int) []T {
|
||||
values := make([]T, n)
|
||||
for i := range values {
|
||||
values[i] = value
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
|
|
@ -149,7 +149,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
graphPartialOffload = f.KV().GQAMax() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue