package llama4 import ( "cmp" "image" "image/color" "reflect" "slices" "testing" gocmp "github.com/google/go-cmp/cmp" ) func TestFactors(t *testing.T) { tests := []struct { name string input int expected []int }{ { name: "factors of 1", input: 1, expected: []int{1}, }, { name: "factors of 2", input: 2, expected: []int{1, 2}, }, { name: "factors of 6", input: 6, expected: []int{1, 2, 3, 6}, }, { name: "factors of 28", input: 28, expected: []int{1, 2, 4, 7, 14, 28}, }, { name: "factors of 49", input: 49, expected: []int{1, 7, 49}, }, { name: "factors of 97 (prime)", input: 97, expected: []int{1, 97}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := factors(tt.input) if !reflect.DeepEqual(actual, tt.expected) { t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected) } }) } } func TestSupportedResolutions(t *testing.T) { expectedResolutions := []image.Point{ {X: 3360, Y: 336}, {X: 672, Y: 2688}, {X: 336, Y: 1344}, {X: 336, Y: 4032}, {X: 1008, Y: 1344}, {X: 1344, Y: 1008}, {X: 336, Y: 1680}, {X: 1680, Y: 336}, {X: 336, Y: 5040}, {X: 4032, Y: 336}, {X: 2352, Y: 336}, {X: 2688, Y: 672}, {X: 1344, Y: 336}, {X: 5376, Y: 336}, {X: 2352, Y: 672}, {X: 672, Y: 1008}, {X: 1008, Y: 672}, {X: 336, Y: 5376}, {X: 1680, Y: 1008}, {X: 5040, Y: 336}, {X: 336, Y: 3024}, {X: 3024, Y: 336}, {X: 336, Y: 2688}, {X: 672, Y: 1344}, {X: 336, Y: 672}, {X: 336, Y: 2352}, {X: 2016, Y: 672}, {X: 1008, Y: 336}, {X: 336, Y: 3360}, {X: 336, Y: 4368}, {X: 1008, Y: 1680}, {X: 336, Y: 4704}, {X: 4704, Y: 336}, {X: 1344, Y: 672}, {X: 672, Y: 336}, {X: 2688, Y: 336}, {X: 3696, Y: 336}, {X: 2016, Y: 336}, {X: 1344, Y: 1344}, {X: 1008, Y: 1008}, {X: 672, Y: 672}, {X: 336, Y: 336}, {X: 4368, Y: 336}, {X: 672, Y: 2016}, {X: 336, Y: 1008}, {X: 336, Y: 3696}, {X: 672, Y: 1680}, {X: 1680, Y: 672}, {X: 336, Y: 2016}, {X: 672, Y: 2352}, } sortResolutionFunc := func(a, b image.Point) int { return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y)) } slices.SortStableFunc(expectedResolutions, sortResolutionFunc) imgProc := ImageProcessor{ imageSize: 336, patchSize: 16, numChannels: 3, maxUpscalingSize: 448, } actualResolutions := imgProc.supportedResolutions() slices.SortStableFunc(actualResolutions, sortResolutionFunc) if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" { t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff) } } func TestBestResolution(t *testing.T) { tests := []struct { name string size image.Point resolutions []image.Point max bool expected image.Point }{ { "normal", image.Point{800, 600}, []image.Point{ {300, 200}, {640, 480}, {800, 600}, {1024, 768}, {1600, 1200}, }, false, image.Point{800, 600}, }, { "max", image.Point{800, 600}, []image.Point{ {300, 200}, {640, 480}, {800, 600}, {1024, 768}, {1600, 1200}, }, true, image.Point{1600, 1200}, }, { "mid", image.Point{1000, 700}, []image.Point{ {300, 200}, {640, 480}, {800, 600}, {1024, 768}, {1600, 1200}, }, false, image.Point{1024, 768}, }, { "smol", image.Point{100, 100}, []image.Point{ {300, 200}, {640, 480}, {800, 600}, {1024, 768}, {1600, 1200}, }, false, image.Point{300, 200}, }, { "huge", image.Point{10000, 10000}, []image.Point{ {300, 200}, {640, 480}, {800, 600}, {1024, 768}, {1600, 1200}, }, false, image.Point{1600, 1200}, }, } p := ImageProcessor{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := p.bestResolution(tt.size, tt.resolutions, tt.max) if diff := gocmp.Diff(tt.expected, actual); diff != "" { t.Errorf("best resolution mismatch (-want +got):\n%s", diff) } }) } } func TestMaxResolution(t *testing.T) { tests := []struct { name string origRes image.Point targetRes image.Point expected image.Point }{ { "normal", image.Point{800, 600}, image.Point{800, 600}, image.Point{800, 600}, }, { "skew", image.Point{800, 600}, image.Point{1100, 700}, image.Point{933, 700}, }, } p := ImageProcessor{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := p.maxResolution(tt.origRes, tt.targetRes) if !reflect.DeepEqual(actual, tt.expected) { t.Errorf("max resolution; got %v want %v", actual, tt.expected) } }) } } func TestProcessImage(t *testing.T) { imgProc := ImageProcessor{ imageSize: 336, patchSize: 16, numChannels: 3, maxUpscalingSize: 448, } generateImage := func(seed int) image.Image { width, height := 20, 10 img := image.NewRGBA(image.Rect(0, 0, width, height)) for x := range width { // Use the seed to vary color generation r := uint8((seed + x*11) % 256) g := uint8((seed + x*17) % 256) b := uint8((seed + x*23) % 256) c := color.RGBA{R: r, G: g, B: b, A: 255} for y := range height { img.Set(x, y, c) } } return img } pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12)) if err != nil { t.Error(err) } if n := len(pixelsLocal); n != 336*336*3 { t.Errorf("unexpected size of f32s: %d", n) } if n := len(pixelsGlobal); n > 0 { t.Errorf("unexpected size of f32s: %d", n) } if !targetSize.Eq(image.Point{336, 336}) { t.Errorf("unexpected target size: %v", targetSize) } }