mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
feat: add threshold to dump options
ml.Dump will preserve default values if not specified
This commit is contained in:
parent
0d6e35d3c6
commit
2b2a0d2308
1 changed files with 40 additions and 16 deletions
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
@ -214,35 +215,58 @@ func mul[T number](s ...T) T {
|
|||
return p
|
||||
}
|
||||
|
||||
type DumpOptions struct {
|
||||
// Items is the number of elements to print at the beginning and end of each dimension.
|
||||
Items int
|
||||
type DumpOptions func(*dumpOptions)
|
||||
|
||||
// Precision is the number of decimal places to print. Applies to float32 and float64.
|
||||
Precision int
|
||||
// DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
|
||||
func DumpWithPrecision(n int) DumpOptions {
|
||||
return func(opts *dumpOptions) {
|
||||
opts.Precision = n
|
||||
}
|
||||
}
|
||||
|
||||
func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
if len(opts) < 1 {
|
||||
opts = append(opts, DumpOptions{
|
||||
Items: 3,
|
||||
Precision: 4,
|
||||
})
|
||||
// DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
|
||||
// is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
|
||||
// beginning and end of each dimension will be printed.
|
||||
func DumpWithThreshold(n int) DumpOptions {
|
||||
return func(opts *dumpOptions) {
|
||||
opts.Threshold = n
|
||||
}
|
||||
}
|
||||
|
||||
// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
|
||||
func DumpWithEdgeItems(n int) DumpOptions {
|
||||
return func(opts *dumpOptions) {
|
||||
opts.EdgeItems = n
|
||||
}
|
||||
}
|
||||
|
||||
type dumpOptions struct {
|
||||
Precision, Threshold, EdgeItems int
|
||||
}
|
||||
|
||||
func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
|
||||
opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
|
||||
for _, optsFunc := range optsFuncs {
|
||||
optsFunc(&opts)
|
||||
}
|
||||
|
||||
if mul(t.Shape()...) <= opts.Threshold {
|
||||
opts.EdgeItems = math.MaxInt
|
||||
}
|
||||
|
||||
switch t.DType() {
|
||||
case DTypeF32:
|
||||
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
})
|
||||
case DTypeF16, DTypeQ80, DTypeQ40:
|
||||
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
|
||||
})
|
||||
case DTypeI32:
|
||||
return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
|
||||
return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
|
||||
return strconv.FormatInt(int64(i), 10)
|
||||
})
|
||||
default:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue