diff --git a/ml/backend.go b/ml/backend.go index 0cd33bd8a..ba24ecb45 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "fmt" + "math" "os" "slices" "strconv" @@ -214,35 +215,58 @@ func mul[T number](s ...T) T { return p } -type DumpOptions struct { - // Items is the number of elements to print at the beginning and end of each dimension. - Items int +type DumpOptions func(*dumpOptions) - // Precision is the number of decimal places to print. Applies to float32 and float64. - Precision int +// DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64. +func DumpWithPrecision(n int) DumpOptions { + return func(opts *dumpOptions) { + opts.Precision = n + } } -func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { - if len(opts) < 1 { - opts = append(opts, DumpOptions{ - Items: 3, - Precision: 4, - }) +// DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements +// is less than or equal to this value, the entire tensor will be printed. Otherwise, only the +// beginning and end of each dimension will be printed. +func DumpWithThreshold(n int) DumpOptions { + return func(opts *dumpOptions) { + opts.Threshold = n + } +} + +// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension. +func DumpWithEdgeItems(n int) DumpOptions { + return func(opts *dumpOptions) { + opts.EdgeItems = n + } +} + +type dumpOptions struct { + Precision, Threshold, EdgeItems int +} + +func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string { + opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3} + for _, optsFunc := range optsFuncs { + optsFunc(&opts) + } + + if mul(t.Shape()...) <= opts.Threshold { + opts.EdgeItems = math.MaxInt } switch t.DType() { case DTypeF32: - return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { - return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) + return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string { + return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) }) case DTypeF16, DTypeQ80, DTypeQ40: f32 := ctx.Input().Empty(DTypeF32, t.Shape()...) f32 = t.Copy(ctx, f32) - return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { - return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) + return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string { + return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32) }) case DTypeI32: - return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string { + return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string { return strconv.FormatInt(int64(i), 10) }) default: