diff --git a/ml/backend.go b/ml/backend.go index 3abacbf19..641175f0f 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "fmt" "os" + "slices" "strconv" "strings" ) @@ -241,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) } shape := t.Shape() + slices.Reverse(shape) var sb strings.Builder var f func([]int, int) f = func(dims []int, stride int) { prefix := strings.Repeat(" ", len(shape)-len(dims)+1) - fmt.Fprint(&sb, "[") - defer func() { fmt.Fprint(&sb, "]") }() + sb.WriteString("[") + defer func() { sb.WriteString("]") }() for i := 0; i < dims[0]; i++ { if i >= items && i < dims[0]-items { - fmt.Fprint(&sb, "..., ") + sb.WriteString("..., ") // skip to next printable element skip := dims[0] - 2*items if len(dims) > 1 { @@ -265,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) } } else { - fmt.Fprint(&sb, fn(s[stride+i])) + text := fn(s[stride+i]) + if len(text) > 0 && text[0] != '-' { + sb.WriteString(" ") + } + + sb.WriteString(text) if i < dims[0]-1 { - fmt.Fprint(&sb, ", ") + sb.WriteString(", ") } } }