ml: Abstract attention out of model definitions

There are two benefits to doing this:
 - Provide a library function that models can use, reducing code for
   each model implementation
 - Enables a single place to drop in optimized implementations of
   attention based on the backend or other factors. One is provided for
   GGML.

On CUDA this improves token generation rate by about 3%. It does not
have a significant effect on Metal.

Co-authored-by: Daniel Hiltgen <daniel@ollama.com>
This commit is contained in:
Jesse Gross 2025-02-14 20:51:44 -08:00 committed by Jesse Gross
parent 2192a28eed
commit f53f4198c3
5 changed files with 102 additions and 22 deletions

View file

@ -111,6 +111,26 @@ type Tensor interface {
Copy(ctx Context, t2 Tensor) Tensor
}
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
}
type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |