fix: pad tensor item if ge zero

this produces a nicer output since both positive and negative values
produces the same width
This commit is contained in:
Michael Yang 2025-03-07 18:04:16 -08:00
parent 7e34f4fbfa
commit 9926eae015

View file

@ -5,6 +5,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
"slices"
"strconv" "strconv"
"strings" "strings"
) )
@ -241,16 +242,17 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
} }
shape := t.Shape() shape := t.Shape()
slices.Reverse(shape)
var sb strings.Builder var sb strings.Builder
var f func([]int, int) var f func([]int, int)
f = func(dims []int, stride int) { f = func(dims []int, stride int) {
prefix := strings.Repeat(" ", len(shape)-len(dims)+1) prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
fmt.Fprint(&sb, "[") sb.WriteString("[")
defer func() { fmt.Fprint(&sb, "]") }() defer func() { sb.WriteString("]") }()
for i := 0; i < dims[0]; i++ { for i := 0; i < dims[0]; i++ {
if i >= items && i < dims[0]-items { if i >= items && i < dims[0]-items {
fmt.Fprint(&sb, "..., ") sb.WriteString("..., ")
// skip to next printable element // skip to next printable element
skip := dims[0] - 2*items skip := dims[0] - 2*items
if len(dims) > 1 { if len(dims) > 1 {
@ -265,9 +267,14 @@ func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string)
fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix) fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
} }
} else { } else {
fmt.Fprint(&sb, fn(s[stride+i])) text := fn(s[stride+i])
if len(text) > 0 && text[0] != '-' {
sb.WriteString(" ")
}
sb.WriteString(text)
if i < dims[0]-1 { if i < dims[0]-1 {
fmt.Fprint(&sb, ", ") sb.WriteString(", ")
} }
} }
} }