mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
ml: Abstract attention out of model definitions
There are two benefits to doing this: - Provide a library function that models can use, reducing code for each model implementation - Enables a single place to drop in optimized implementations of attention based on the backend or other factors. One is provided for GGML. On CUDA this improves token generation rate by about 3%. It does not have a significant effect on Metal. Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
parent
2192a28eed
commit
f53f4198c3
5 changed files with 102 additions and 22 deletions
|
@ -111,6 +111,26 @@ type Tensor interface {
|
|||
Copy(ctx Context, t2 Tensor) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
//
|
||||
// if mask != nil {
|
||||
// kq = kq.Add(ctx, mask)
|
||||
// }
|
||||
//
|
||||
// kq = kq.Softmax(ctx)
|
||||
//
|
||||
// kqv := value.Mulmat(ctx, kq)
|
||||
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
type ScaledDotProductAttention interface {
|
||||
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
|
||||
}
|
||||
|
||||
type number interface {
|
||||
~int | ~int8 | ~int16 | ~int32 | ~int64 |
|
||||
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue