mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 10:26:53 +02:00
ggml: more accurate estimates for head count array case
Also standardized the approach by always treatting `HeadCount()` and `HeadCountKV()` as arrays by filling them with the same value when they're a scalar in the original GGUF
This commit is contained in:
parent
0188c74c41
commit
7c94471d38
2 changed files with 208 additions and 62 deletions
268
fs/ggml/ggml.go
268
fs/ggml/ggml.go
|
@ -6,6 +6,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -52,32 +53,80 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||||
return uint64(kv.Uint("embedding_length"))
|
return uint64(kv.Uint("embedding_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCount() uint64 {
|
func (kv KV) HeadCounts() []uint64 {
|
||||||
return uint64(kv.UintOrFirstArrayValue("attention.head_count"))
|
return kv.UintOrArrayAsArray("attention.head_count", kv.BlockCount(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountKV() uint64 {
|
func (kv KV) HeadCountKVs() []uint64 {
|
||||||
return uint64(kv.UintOrFirstArrayValue("attention.head_count_kv", 1))
|
return kv.UintOrArrayAsArray("attention.head_count_kv", kv.BlockCount(), 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
func (kv KV) EmbeddingHeadCount() []uint64 {
|
||||||
if heads := kv.HeadCount(); heads > 0 {
|
headCount := kv.HeadCounts()
|
||||||
return kv.EmbeddingLength() / heads
|
embeddingHeadCount := make([]uint64, len(headCount))
|
||||||
|
for i, heads := range headCount {
|
||||||
|
if heads == 0 {
|
||||||
|
embeddingHeadCount[i] = 0
|
||||||
|
} else {
|
||||||
|
embeddingHeadCount[i] = kv.EmbeddingLength() / heads
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return embeddingHeadCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
func (kv KV) FillArrayOrDefault(key string, defaultValue []uint64) []uint64 {
|
||||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
length := len(defaultValue)
|
||||||
|
if v, ok := keyValueUntyped(kv, key); ok {
|
||||||
|
switch v := v.(type) {
|
||||||
|
case uint32:
|
||||||
|
return FillArray(uint64(v), length)
|
||||||
|
case uint64:
|
||||||
|
return FillArray(v, length)
|
||||||
|
case int32:
|
||||||
|
return FillArray(uint64(v), length)
|
||||||
|
default:
|
||||||
|
slog.Warn("unsupported type", "key", key, "type", reflect.TypeOf(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
func (kv KV) EmbeddingHeadCountK() []uint64 {
|
||||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
return kv.FillArrayOrDefault("attention.key_length", kv.EmbeddingHeadCount())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) GQA() uint64 {
|
func (kv KV) EmbeddingHeadCountV() []uint64 {
|
||||||
return kv.HeadCount() / kv.HeadCountKV()
|
return kv.FillArrayOrDefault("attention.value_length", kv.EmbeddingHeadCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) GQAMax() uint64 {
|
||||||
|
heads := kv.HeadCounts()
|
||||||
|
headsKV := kv.HeadCountKVs()
|
||||||
|
if len(heads) != len(headsKV) {
|
||||||
|
slog.Warn("head count and head count kv are not the same length")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if len(heads) == 0 {
|
||||||
|
slog.Warn("head count is empty")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
maxGQA := uint64(0)
|
||||||
|
for i := range heads {
|
||||||
|
head := heads[i]
|
||||||
|
headKV := headsKV[i]
|
||||||
|
if head == 0 || headKV == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
gqa := head / headKV
|
||||||
|
if gqa > maxGQA {
|
||||||
|
maxGQA = gqa
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxGQA
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ContextLength() uint64 {
|
func (kv KV) ContextLength() uint64 {
|
||||||
|
@ -104,20 +153,39 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||||
return keyValue(kv, key, append(defaultValue, false)...)
|
return keyValue(kv, key, append(defaultValue, false)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) UintOrFirstArrayValue(key string, defaultValue ...uint32) uint32 {
|
func (kv KV) UintOrArrayAsArray(key string, n uint64, defaultSingleValue ...uint64) []uint64 {
|
||||||
|
var singleValue *uint64
|
||||||
if v, ok := keyValueUntyped(kv, key); ok {
|
if v, ok := keyValueUntyped(kv, key); ok {
|
||||||
if a, ok := v.(*array); ok {
|
switch v := v.(type) {
|
||||||
signed := a.values[0].(int32)
|
case *array:
|
||||||
if signed >= 0 {
|
switch v.values[0].(type) {
|
||||||
return uint32(signed)
|
case int32, uint32, uint64:
|
||||||
|
values, ok := AsUint64Array(v.values)
|
||||||
|
if ok {
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
slog.Warn("unexpected array value type", "key", key, "type", reflect.TypeOf(v))
|
||||||
}
|
}
|
||||||
// TODO(drifkin): indicate unexpected data somehow?
|
case uint32:
|
||||||
return defaultValue[0]
|
val := uint64(v)
|
||||||
} else if v, ok := v.(uint32); ok {
|
singleValue = &val
|
||||||
return v
|
case int32:
|
||||||
|
val := uint64(v)
|
||||||
|
singleValue = &val
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return defaultValue[0]
|
if singleValue == nil {
|
||||||
|
slog.Warn("falling back to default")
|
||||||
|
singleValue = &defaultSingleValue[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
values := make([]uint64, n)
|
||||||
|
for i := range values {
|
||||||
|
values[i] = *singleValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
|
@ -442,12 +510,22 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := f.KV().EmbeddingLength()
|
embedding := f.KV().EmbeddingLength()
|
||||||
heads := f.KV().HeadCount()
|
heads := f.KV().HeadCounts()
|
||||||
headsKV := f.KV().HeadCountKV()
|
headsKV := f.KV().HeadCountKVs()
|
||||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||||
|
|
||||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||||
|
maxEmbeddingHeads, ok := MaxValue(embeddingHeads)
|
||||||
|
if !ok {
|
||||||
|
maxEmbeddingHeads = 1
|
||||||
|
slog.Warn("failed to get max embedding heads")
|
||||||
|
}
|
||||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||||
|
maxEmbeddingHeadsK, ok := MaxValue(embeddingHeadsK)
|
||||||
|
if !ok {
|
||||||
|
maxEmbeddingHeadsK = 1
|
||||||
|
slog.Warn("failed to get max embedding headsK")
|
||||||
|
}
|
||||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||||
|
|
||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
@ -455,19 +533,30 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = make([]uint64, f.KV().BlockCount())
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
for i := range kv {
|
for i := range kv {
|
||||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv[i] = uint64(float64(context*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxHeads, ok := MaxValue(heads)
|
||||||
|
if !ok {
|
||||||
|
maxHeads = 1
|
||||||
|
slog.Warn("failed to get max heads")
|
||||||
|
}
|
||||||
|
maxHeadsKV, ok := MaxValue(headsKV)
|
||||||
|
if !ok {
|
||||||
|
maxHeadsKV = 1
|
||||||
|
slog.Warn("failed to get max headsKV")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(1+4*embedding+context*(1+heads)),
|
4*batch*(1+4*embedding+context*(1+maxHeads)),
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = 4 * batch * embedding
|
partialOffload = 4 * batch * embedding
|
||||||
partialOffload += max(
|
partialOffload += max(
|
||||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
|
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*maxHeads+maxEmbeddingHeads*maxHeadsKV),
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -475,16 +564,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
// mixtral 8x22b
|
// mixtral 8x22b
|
||||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+maxHeadsKV+embedding+context+maxEmbeddingHeads*maxHeadsKV),
|
||||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
4*(context*batch*maxHeads+context*maxEmbeddingHeads*maxHeadsKV+batch*1024+maxEmbeddingHeads*maxHeadsKV*batch),
|
||||||
)
|
)
|
||||||
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
||||||
// mixtral 8x7b
|
// mixtral 8x7b
|
||||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+maxHeads) + 2*maxHeadsKV + ffnGateWeight1)
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
4*batch*(3+maxEmbeddingHeads*maxHeadsKV+embedding+context*(1+maxHeads)+ffnGateWeight1)+(embedding*embedding+3*embedding*maxHeadsKV*ffnGateWeight1)*9/16,
|
||||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
4*batch*(1+2*embedding+context*(1+maxHeads))+embedding*(6*context*maxHeadsKV/maxHeads+embedding*9/16),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
case "mllama":
|
case "mllama":
|
||||||
|
@ -493,7 +582,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||||
for i := range kv {
|
for i := range kv {
|
||||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
kv[i] = headsKV[i] * (embeddingHeadsK[i] + embeddingHeadsV[i]) *
|
||||||
4 * // sizeof(float32)
|
4 * // sizeof(float32)
|
||||||
visionTokens *
|
visionTokens *
|
||||||
tiles
|
tiles
|
||||||
|
@ -501,7 +590,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
4*batch*(2+3*embedding+maxEmbeddingHeadsK*maxHeads+context*(1+maxHeads)),
|
||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
)
|
)
|
||||||
|
@ -515,23 +604,23 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*(batch*
|
4*(batch*
|
||||||
(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
(2*embedding+1+context*(1+maxHeads)+maxEmbeddingHeadsK*maxHeads)+
|
||||||
ropeFreqsCount+
|
ropeFreqsCount+
|
||||||
embeddingHeadsK*context*headsKV),
|
maxEmbeddingHeadsK*context*maxHeadsKV),
|
||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
case "gemma", "gemma2", "gemma3":
|
case "gemma", "gemma2", "gemma3":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
4*batch*(2+context+context*maxHeads+2*embedding+2*maxEmbeddingHeadsK*maxHeads),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
||||||
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
|
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeads+context+context*maxHeads)+
|
||||||
4*embeddingHeadsK*context*8+
|
4*maxEmbeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embedding*maxEmbeddingHeadsK*maxHeads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
@ -543,42 +632,42 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
// layers are the smaller local (sliding) layers.
|
// layers are the smaller local (sliding) layers.
|
||||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+4*embedding+context*(1+heads)),
|
4*batch*(2+4*embedding+context*(1+maxHeads)),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
|
4*batch*(1+2*embedding+context*(1+maxHeads))+4*embedding*context+embedding*embedding*9/16,
|
||||||
)
|
)
|
||||||
case "qwen2":
|
case "qwen2":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(1+2*embedding+context+context*heads),
|
4*batch*(1+2*embedding+context+context*maxHeads),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
4*(batch*(1+2*embedding+context*(1+maxHeads))+embedding*(1+context)),
|
||||||
)
|
)
|
||||||
case "phi2":
|
case "phi2":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(1+4*embedding+context+context*heads),
|
4*batch*(1+4*embedding+context+context*maxHeads),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*batch*(2+3*embedding+context+context*heads),
|
4*batch*(2+3*embedding+context+context*maxHeads),
|
||||||
)
|
)
|
||||||
case "stablelm":
|
case "stablelm":
|
||||||
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
|
fullOffload = 4 * batch * (context*(1+maxHeads) + 3*embedding + 2)
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(vocab+2*embedding),
|
4*batch*(vocab+2*embedding),
|
||||||
fullOffload,
|
fullOffload,
|
||||||
|
@ -586,12 +675,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
case "deepseek2":
|
case "deepseek2":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(3*embedding+vocab),
|
4*batch*(3*embedding+vocab),
|
||||||
4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
|
4*batch*(3*embedding+2+context*(1+maxHeadsKV)+2*maxEmbeddingHeadsK*maxHeadsKV),
|
||||||
)
|
)
|
||||||
|
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeadsKV+context+context*maxHeadsKV)+4*maxEmbeddingHeadsK*context*maxHeadsKV+embedding*embedding*maxEmbeddingHeadsK*maxHeadsKV*9/16,
|
||||||
)
|
)
|
||||||
case "chatglm":
|
case "chatglm":
|
||||||
fullOffload = 4 * batch * (embedding + vocab)
|
fullOffload = 4 * batch * (embedding + vocab)
|
||||||
|
@ -602,8 +691,8 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
4*batch*(2+
|
4*batch*(2+
|
||||||
2*embedding+
|
2*embedding+
|
||||||
context+
|
context+
|
||||||
context*heads+
|
context*maxHeads+
|
||||||
embeddingHeadsK*heads+
|
maxEmbeddingHeadsK*maxHeads+
|
||||||
qkvBias.Shape[0]),
|
qkvBias.Shape[0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -611,11 +700,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||||
partialOffload,
|
partialOffload,
|
||||||
4*batch*(1+
|
4*batch*(1+
|
||||||
2*embedding+
|
2*embedding+
|
||||||
embeddingHeadsK*heads+
|
maxEmbeddingHeadsK*maxHeads+
|
||||||
context+
|
context+
|
||||||
context*heads)+
|
context*maxHeads)+
|
||||||
4*embeddingHeadsK*context+
|
4*maxEmbeddingHeadsK*context+
|
||||||
4*context*embeddingHeadsK+
|
4*context*maxEmbeddingHeadsK+
|
||||||
4*qkvBias.Shape[0],
|
4*qkvBias.Shape[0],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -687,9 +776,15 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check head counts match and are non-zero
|
// Check head counts match and are non-zero
|
||||||
headCountK := f.KV().EmbeddingHeadCountK()
|
headCount := f.KV().HeadCounts()
|
||||||
headCountV := f.KV().EmbeddingHeadCountV()
|
embeddingHeadCountK := f.KV().EmbeddingHeadCountK()
|
||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
embeddingHeadCountV := f.KV().EmbeddingHeadCountV()
|
||||||
|
for i := range headCount {
|
||||||
|
if embeddingHeadCountK[i] != embeddingHeadCountV[i] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||||
|
@ -703,3 +798,54 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||||
return 2 // f16 (default)
|
return 2 // f16 (default)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AsUint64Array(v []any) ([]uint64, bool) {
|
||||||
|
switch v[0].(type) {
|
||||||
|
case uint32:
|
||||||
|
values := make([]uint64, len(v))
|
||||||
|
for i, v := range v {
|
||||||
|
values[i] = uint64(v.(uint32))
|
||||||
|
}
|
||||||
|
return values, true
|
||||||
|
case uint64:
|
||||||
|
values := make([]uint64, len(v))
|
||||||
|
for i, v := range v {
|
||||||
|
values[i] = v.(uint64)
|
||||||
|
}
|
||||||
|
return values, true
|
||||||
|
case int32:
|
||||||
|
values := make([]uint64, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
val := val.(int32)
|
||||||
|
if val < 0 {
|
||||||
|
slog.Warn("negative value in int32 array", "value", val)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
values[i] = uint64(val)
|
||||||
|
}
|
||||||
|
return values, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func MaxValue(values []uint64) (uint64, bool) {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
max := values[0]
|
||||||
|
for _, v := range values {
|
||||||
|
if v > max {
|
||||||
|
max = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func FillArray[T any](value T, n int) []T {
|
||||||
|
values := make([]T, n)
|
||||||
|
for i := range values {
|
||||||
|
values[i] = value
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
|
@ -149,7 +149,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||||
}
|
}
|
||||||
|
|
||||||
if graphPartialOffload == 0 {
|
if graphPartialOffload == 0 {
|
||||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
graphPartialOffload = f.KV().GQAMax() * kvTotal / 6
|
||||||
}
|
}
|
||||||
if graphFullOffload == 0 {
|
if graphFullOffload == 0 {
|
||||||
graphFullOffload = graphPartialOffload
|
graphFullOffload = graphPartialOffload
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue