mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 02:16:36 +02:00
feat: add threshold to dump options
ml.Dump will preserve default values if not specified
This commit is contained in:
parent
0d6e35d3c6
commit
2b2a0d2308
1 changed files with 40 additions and 16 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
@ -214,35 +215,58 @@ func mul[T number](s ...T) T {
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
type DumpOptions struct {
|
type DumpOptions func(*dumpOptions)
|
||||||
// Items is the number of elements to print at the beginning and end of each dimension.
|
|
||||||
Items int
|
|
||||||
|
|
||||||
// Precision is the number of decimal places to print. Applies to float32 and float64.
|
// DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||||
Precision int
|
func DumpWithPrecision(n int) DumpOptions {
|
||||||
|
return func(opts *dumpOptions) {
|
||||||
|
opts.Precision = n
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
// DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||||
if len(opts) < 1 {
|
// is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||||
opts = append(opts, DumpOptions{
|
// beginning and end of each dimension will be printed.
|
||||||
Items: 3,
|
func DumpWithThreshold(n int) DumpOptions {
|
||||||
Precision: 4,
|
return func(opts *dumpOptions) {
|
||||||
})
|
opts.Threshold = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||||
|
func DumpWithEdgeItems(n int) DumpOptions {
|
||||||
|
return func(opts *dumpOptions) {
|
||||||
|
opts.EdgeItems = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type dumpOptions struct {
|
||||||
|
Precision, Threshold, EdgeItems int
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||||
|
opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||||
|
for _, optsFunc := range optsFuncs {
|
||||||
|
optsFunc(&opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
if mul(t.Shape()...) <= opts.Threshold {
|
||||||
|
opts.EdgeItems = math.MaxInt
|
||||||
}
|
}
|
||||||
|
|
||||||
switch t.DType() {
|
switch t.DType() {
|
||||||
case DTypeF32:
|
case DTypeF32:
|
||||||
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||||
})
|
})
|
||||||
case DTypeF16, DTypeQ80, DTypeQ40:
|
case DTypeF16, DTypeQ80, DTypeQ40:
|
||||||
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
|
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
|
||||||
f32 = t.Copy(ctx, f32)
|
f32 = t.Copy(ctx, f32)
|
||||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||||
})
|
})
|
||||||
case DTypeI32:
|
case DTypeI32:
|
||||||
return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
|
return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||||
return strconv.FormatInt(int64(i), 10)
|
return strconv.FormatInt(int64(i), 10)
|
||||||
})
|
})
|
||||||
default:
|
default:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue