feat: add threshold to dump options

ml.Dump will preserve default values if not specified
This commit is contained in:
Michael Yang 2025-05-09 15:37:03 -07:00
parent 0d6e35d3c6
commit 2b2a0d2308

View file

@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math"
"os" "os"
"slices" "slices"
"strconv" "strconv"
@ -214,35 +215,58 @@ func mul[T number](s ...T) T {
return p return p
} }
type DumpOptions struct { type DumpOptions func(*dumpOptions)
// Items is the number of elements to print at the beginning and end of each dimension.
Items int
// Precision is the number of decimal places to print. Applies to float32 and float64. // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
Precision int func DumpWithPrecision(n int) DumpOptions {
return func(opts *dumpOptions) {
opts.Precision = n
}
} }
func Dump(ctx Context, t Tensor, opts ...DumpOptions) string { // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
if len(opts) < 1 { // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
opts = append(opts, DumpOptions{ // beginning and end of each dimension will be printed.
Items: 3, func DumpWithThreshold(n int) DumpOptions {
Precision: 4, return func(opts *dumpOptions) {
}) opts.Threshold = n
}
}
// DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
func DumpWithEdgeItems(n int) DumpOptions {
return func(opts *dumpOptions) {
opts.EdgeItems = n
}
}
type dumpOptions struct {
Precision, Threshold, EdgeItems int
}
func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
for _, optsFunc := range optsFuncs {
optsFunc(&opts)
}
if mul(t.Shape()...) <= opts.Threshold {
opts.EdgeItems = math.MaxInt
} }
switch t.DType() { switch t.DType() {
case DTypeF32: case DTypeF32:
return dump[[]float32](ctx, t, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
}) })
case DTypeF16, DTypeQ80, DTypeQ40: case DTypeF16, DTypeQ80, DTypeQ40:
f32 := ctx.Input().Empty(DTypeF32, t.Shape()...) f32 := ctx.Input().Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32) f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string { return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32) return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
}) })
case DTypeI32: case DTypeI32:
return dump[[]int32](ctx, t, opts[0].Items, func(i int32) string { return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
return strconv.FormatInt(int64(i), 10) return strconv.FormatInt(int64(i), 10)
}) })
default: default: