feat: add threshold to dump options

ml.Dump will preserve default values if not specified
This commit is contained in:
Michael Yang 2025-05-09 15:37:03 -07:00
parent 0d6e35d3c6
commit 2b2a0d2308

View file

@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"fmt"
"math"
"os"
"slices"
"strconv"
@ -214,35 +215,58 @@ func mul[T number](s ...T) T {
return p
}
type DumpOptions struct {
// Items is the number of elements to print at the beginning and end of each dimension.
Items int
type DumpOptions func(*dumpOptions)
// Precision is the number of decimal places to print. Applies to float32 and float64.
Precision int
// DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
func DumpWithPrecision(n int) DumpOptions {
return func(opts *dumpOptions) {
opts.Precision = n
}
}
func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
if len(opts) < 1 {
opts = append(opts, DumpOptions{
Items: 3,
Precision: 4,
})
// DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// beginning and end of each dimension will be printed.
func DumpWithThreshold(n int) DumpOptions {
return func(opts *dumpOptions) {
opts.Threshold = n
}
}
// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
func DumpWithEdgeItems(n int) DumpOptions {
return func(opts *dumpOptions) {
opts.EdgeItems = n
}
}
type dumpOptions struct {
Precision, Threshold, EdgeItems int
}
func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
for _, optsFunc := range optsFuncs {
optsFunc(&opts)
}
if mul(t.Shape()...) <= opts.Threshold {
opts.EdgeItems = math.MaxInt
}
switch t.DType() {
case DTypeF32:
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
})
case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
})
case DTypeI32:
return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string {
return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
return strconv.FormatInt(int64(i), 10)
})
default: