mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
image processing
Co-authored-by: Patrick Devine <patrick@infrahq.com>
This commit is contained in:
parent
f0c66e6dea
commit
178761aef3
3 changed files with 494 additions and 4 deletions
|
@ -15,6 +15,7 @@ import (
|
|||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*Projector `gguf:"mm"`
|
||||
|
@ -43,8 +44,9 @@ func New(c fs.Config) (model.Model, error) {
|
|||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
|
@ -66,21 +68,42 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
f32s, aspectRatio, err := m.ProcessImage(ctx, img)
|
||||
pixelsLocal, pixelsGlobal, size, err := m.ProcessImage(img)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s))
|
||||
tilesLocal, err := ctx.Input().FromFloatSlice(pixelsLocal, size.X, size.Y, m.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ratioW, ratioH := int(size.X/m.imageSize), int(size.Y/m.imageSize)
|
||||
|
||||
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, ratioW, size.Y, m.numChannels).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW*size.Y/ratioH, ratioH, ratioW, m.numChannels).Permute(ctx, 0, 3, 2, 1).Contiguous(ctx)
|
||||
tilesLocal = tilesLocal.Reshape(ctx, size.X/ratioW, size.Y/ratioH, m.numChannels, ratioH*ratioW)
|
||||
|
||||
pixelValues := tilesLocal
|
||||
|
||||
if len(pixelsGlobal) > 0 {
|
||||
tilesGlobal, err := ctx.Input().FromFloatSlice(pixelsGlobal, m.imageSize, m.imageSize, m.numChannels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pixelValues = pixelValues.Concat(ctx, tilesGlobal, 3)
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
|
||||
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
|
||||
return m.Projector.Forward(ctx, visionOutputs), nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
|
|
167
model/models/llama4/process_image.go
Normal file
167
model/models/llama4/process_image.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package llama4
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"image"
|
||||
"math"
|
||||
"slices"
|
||||
"sort"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ImageProcessor struct {
|
||||
imageSize, patchSize, numChannels, maxUpscalingSize int
|
||||
}
|
||||
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
return ImageProcessor{
|
||||
imageSize: int(c.Uint("vision.image_size")),
|
||||
patchSize: int(c.Uint("vision.patch_size")),
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)),
|
||||
maxUpscalingSize: int(c.Uint("vision.max_upscaling_size", 448)),
|
||||
}
|
||||
}
|
||||
|
||||
func factors(n int) []int {
|
||||
var result []int
|
||||
seen := make(map[int]bool)
|
||||
|
||||
for i := 1; i <= n/2; i++ {
|
||||
if n%i == 0 && !seen[i] {
|
||||
result = append(result, i)
|
||||
seen[i] = true
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, n)
|
||||
sort.Ints(result)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (p ImageProcessor) supportedResolutions() []image.Point {
|
||||
var resolutions []image.Point
|
||||
|
||||
aspectMap := make(map[float64][]image.Point)
|
||||
for i := p.patchSize; i >= 1; i-- {
|
||||
for _, f := range factors(i) {
|
||||
x := f
|
||||
y := i / f
|
||||
k := float64(y) / float64(x)
|
||||
aspectMap[k] = append(aspectMap[k], image.Point{x, y})
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range aspectMap {
|
||||
for _, i := range v {
|
||||
resolutions = append(resolutions, image.Point{i.X * p.imageSize, i.Y * p.imageSize})
|
||||
}
|
||||
}
|
||||
|
||||
return resolutions
|
||||
}
|
||||
|
||||
func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []image.Point, resizeToMaxCanvas bool) image.Point {
|
||||
w, h := img.X, img.Y
|
||||
|
||||
scales := make([]float64, len(possibleResolutions))
|
||||
|
||||
for i, res := range possibleResolutions {
|
||||
scaleW := float64(res.X) / float64(w)
|
||||
scaleH := float64(res.Y) / float64(h)
|
||||
scale := math.Min(scaleW, scaleH)
|
||||
|
||||
scales[i] = scale
|
||||
}
|
||||
|
||||
minAboveOne := func(scales []float64) (float64, bool) {
|
||||
min := math.MaxFloat64
|
||||
found := false
|
||||
|
||||
for _, s := range scales {
|
||||
if s >= 1.0 && s < min {
|
||||
min = s
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
return min, found
|
||||
}
|
||||
|
||||
bestScale, ok := minAboveOne(scales)
|
||||
if resizeToMaxCanvas || !ok {
|
||||
bestScale = slices.Max(scales)
|
||||
}
|
||||
|
||||
var bestOptions []image.Point
|
||||
for i, scale := range scales {
|
||||
if math.Abs(scale-bestScale) < 1e-6 {
|
||||
bestOptions = append(bestOptions, possibleResolutions[i])
|
||||
}
|
||||
}
|
||||
|
||||
var chosenResolution image.Point
|
||||
if len(bestOptions) > 1 {
|
||||
chosenResolution = slices.MinFunc(bestOptions, func(a, b image.Point) int {
|
||||
return cmp.Compare(a.X*a.Y, b.X*b.Y)
|
||||
})
|
||||
} else {
|
||||
chosenResolution = bestOptions[0]
|
||||
}
|
||||
|
||||
return chosenResolution
|
||||
}
|
||||
|
||||
func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Point {
|
||||
scaleW := float64(targetRes.X) / float64(imageRes.X)
|
||||
scaleH := float64(targetRes.Y) / float64(imageRes.Y)
|
||||
|
||||
var newRes image.Point
|
||||
if scaleW < scaleH {
|
||||
newRes = image.Point{
|
||||
targetRes.X,
|
||||
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
}
|
||||
} else {
|
||||
newRes = image.Point{
|
||||
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
targetRes.Y,
|
||||
}
|
||||
}
|
||||
|
||||
return newRes
|
||||
}
|
||||
|
||||
func (p ImageProcessor) pad(src image.Image, outputSize image.Point) image.Image {
|
||||
dst := image.NewRGBA(image.Rect(0, 0, outputSize.X, outputSize.Y))
|
||||
draw.Draw(dst, src.Bounds(), src, image.Point{}, draw.Over)
|
||||
return dst
|
||||
}
|
||||
|
||||
func (p ImageProcessor) ProcessImage(img image.Image) (pixelsLocal, pixelsGlobal []float32, targetSize image.Point, _ error) {
|
||||
img = imageproc.Composite(img)
|
||||
|
||||
targetSize = p.bestResolution(img.Bounds().Max, p.supportedResolutions(), false)
|
||||
targetSizeWithoutDistortion := targetSize
|
||||
if p.maxUpscalingSize > 0 {
|
||||
targetSizeWithoutDistortion = p.maxResolution(img.Bounds().Max, targetSize)
|
||||
targetSizeWithoutDistortion.X = min(max(img.Bounds().Max.X, p.maxUpscalingSize), targetSize.X)
|
||||
targetSizeWithoutDistortion.Y = min(max(img.Bounds().Max.Y, p.maxUpscalingSize), targetSize.Y)
|
||||
}
|
||||
|
||||
newSizeWithoutDistortion := p.maxResolution(img.Bounds().Max, targetSizeWithoutDistortion)
|
||||
|
||||
padded := p.pad(imageproc.Resize(img, newSizeWithoutDistortion, imageproc.ResizeBilinear), targetSize)
|
||||
pixelsLocal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
|
||||
|
||||
if targetSize.X/p.imageSize*targetSize.Y/p.imageSize > 1 {
|
||||
padded := imageproc.Resize(img, image.Point{p.imageSize, p.imageSize}, imageproc.ResizeBilinear)
|
||||
pixelsGlobal = imageproc.Normalize(padded, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD, true, true)
|
||||
}
|
||||
|
||||
return pixelsLocal, pixelsGlobal, targetSize, nil
|
||||
}
|
300
model/models/llama4/process_image_test.go
Normal file
300
model/models/llama4/process_image_test.go
Normal file
|
@ -0,0 +1,300 @@
|
|||
package llama4
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"image"
|
||||
"image/color"
|
||||
"reflect"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
gocmp "github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestFactors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input int
|
||||
expected []int
|
||||
}{
|
||||
{
|
||||
name: "factors of 1",
|
||||
input: 1,
|
||||
expected: []int{1},
|
||||
},
|
||||
{
|
||||
name: "factors of 2",
|
||||
input: 2,
|
||||
expected: []int{1, 2},
|
||||
},
|
||||
{
|
||||
name: "factors of 6",
|
||||
input: 6,
|
||||
expected: []int{1, 2, 3, 6},
|
||||
},
|
||||
{
|
||||
name: "factors of 28",
|
||||
input: 28,
|
||||
expected: []int{1, 2, 4, 7, 14, 28},
|
||||
},
|
||||
{
|
||||
name: "factors of 49",
|
||||
input: 49,
|
||||
expected: []int{1, 7, 49},
|
||||
},
|
||||
{
|
||||
name: "factors of 97 (prime)",
|
||||
input: 97,
|
||||
expected: []int{1, 97},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := factors(tt.input)
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("factors(%d) = %v; want %v", tt.input, actual, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedResolutions(t *testing.T) {
|
||||
expectedResolutions := []image.Point{
|
||||
{X: 3360, Y: 336},
|
||||
{X: 672, Y: 2688},
|
||||
{X: 336, Y: 1344},
|
||||
{X: 336, Y: 4032},
|
||||
{X: 1008, Y: 1344},
|
||||
{X: 1344, Y: 1008},
|
||||
{X: 336, Y: 1680},
|
||||
{X: 1680, Y: 336},
|
||||
{X: 336, Y: 5040},
|
||||
{X: 4032, Y: 336},
|
||||
{X: 2352, Y: 336},
|
||||
{X: 2688, Y: 672},
|
||||
{X: 1344, Y: 336},
|
||||
{X: 5376, Y: 336},
|
||||
{X: 2352, Y: 672},
|
||||
{X: 672, Y: 1008},
|
||||
{X: 1008, Y: 672},
|
||||
{X: 336, Y: 5376},
|
||||
{X: 1680, Y: 1008},
|
||||
{X: 5040, Y: 336},
|
||||
{X: 336, Y: 3024},
|
||||
{X: 3024, Y: 336},
|
||||
{X: 336, Y: 2688},
|
||||
{X: 672, Y: 1344},
|
||||
{X: 336, Y: 672},
|
||||
{X: 336, Y: 2352},
|
||||
{X: 2016, Y: 672},
|
||||
{X: 1008, Y: 336},
|
||||
{X: 336, Y: 3360},
|
||||
{X: 336, Y: 4368},
|
||||
{X: 1008, Y: 1680},
|
||||
{X: 336, Y: 4704},
|
||||
{X: 4704, Y: 336},
|
||||
{X: 1344, Y: 672},
|
||||
{X: 672, Y: 336},
|
||||
{X: 2688, Y: 336},
|
||||
{X: 3696, Y: 336},
|
||||
{X: 2016, Y: 336},
|
||||
{X: 1344, Y: 1344},
|
||||
{X: 1008, Y: 1008},
|
||||
{X: 672, Y: 672},
|
||||
{X: 336, Y: 336},
|
||||
{X: 4368, Y: 336},
|
||||
{X: 672, Y: 2016},
|
||||
{X: 336, Y: 1008},
|
||||
{X: 336, Y: 3696},
|
||||
{X: 672, Y: 1680},
|
||||
{X: 1680, Y: 672},
|
||||
{X: 336, Y: 2016},
|
||||
{X: 672, Y: 2352},
|
||||
}
|
||||
|
||||
sortResolutionFunc := func(a, b image.Point) int {
|
||||
return cmp.Or(cmp.Compare(a.X, b.X), cmp.Compare(a.Y, b.Y))
|
||||
}
|
||||
|
||||
slices.SortStableFunc(expectedResolutions, sortResolutionFunc)
|
||||
|
||||
imgProc := ImageProcessor{
|
||||
imageSize: 336,
|
||||
patchSize: 16,
|
||||
numChannels: 3,
|
||||
maxUpscalingSize: 448,
|
||||
}
|
||||
|
||||
actualResolutions := imgProc.supportedResolutions()
|
||||
slices.SortStableFunc(actualResolutions, sortResolutionFunc)
|
||||
|
||||
if diff := gocmp.Diff(expectedResolutions, actualResolutions); diff != "" {
|
||||
t.Errorf("supportedResolutions() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBestResolution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
size image.Point
|
||||
resolutions []image.Point
|
||||
max bool
|
||||
expected image.Point
|
||||
}{
|
||||
{
|
||||
"normal",
|
||||
image.Point{800, 600},
|
||||
[]image.Point{
|
||||
{300, 200},
|
||||
{640, 480},
|
||||
{800, 600},
|
||||
{1024, 768},
|
||||
{1600, 1200},
|
||||
},
|
||||
false,
|
||||
image.Point{800, 600},
|
||||
},
|
||||
{
|
||||
"max",
|
||||
image.Point{800, 600},
|
||||
[]image.Point{
|
||||
{300, 200},
|
||||
{640, 480},
|
||||
{800, 600},
|
||||
{1024, 768},
|
||||
{1600, 1200},
|
||||
},
|
||||
true,
|
||||
image.Point{1600, 1200},
|
||||
},
|
||||
{
|
||||
"mid",
|
||||
image.Point{1000, 700},
|
||||
[]image.Point{
|
||||
{300, 200},
|
||||
{640, 480},
|
||||
{800, 600},
|
||||
{1024, 768},
|
||||
{1600, 1200},
|
||||
},
|
||||
false,
|
||||
image.Point{1024, 768},
|
||||
},
|
||||
{
|
||||
"smol",
|
||||
image.Point{100, 100},
|
||||
[]image.Point{
|
||||
{300, 200},
|
||||
{640, 480},
|
||||
{800, 600},
|
||||
{1024, 768},
|
||||
{1600, 1200},
|
||||
},
|
||||
false,
|
||||
image.Point{300, 200},
|
||||
},
|
||||
{
|
||||
"huge",
|
||||
image.Point{10000, 10000},
|
||||
[]image.Point{
|
||||
{300, 200},
|
||||
{640, 480},
|
||||
{800, 600},
|
||||
{1024, 768},
|
||||
{1600, 1200},
|
||||
},
|
||||
false,
|
||||
image.Point{1600, 1200},
|
||||
},
|
||||
}
|
||||
|
||||
p := ImageProcessor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := p.bestResolution(tt.size, tt.resolutions, tt.max)
|
||||
if diff := gocmp.Diff(tt.expected, actual); diff != "" {
|
||||
t.Errorf("best resolution mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMaxResolution(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origRes image.Point
|
||||
targetRes image.Point
|
||||
expected image.Point
|
||||
}{
|
||||
{
|
||||
"normal",
|
||||
image.Point{800, 600},
|
||||
image.Point{800, 600},
|
||||
image.Point{800, 600},
|
||||
},
|
||||
{
|
||||
"skew",
|
||||
image.Point{800, 600},
|
||||
image.Point{1100, 700},
|
||||
image.Point{933, 700},
|
||||
},
|
||||
}
|
||||
|
||||
p := ImageProcessor{}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := p.maxResolution(tt.origRes, tt.targetRes)
|
||||
if !reflect.DeepEqual(actual, tt.expected) {
|
||||
t.Errorf("max resolution; got %v want %v", actual, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessImage(t *testing.T) {
|
||||
imgProc := ImageProcessor{
|
||||
imageSize: 336,
|
||||
patchSize: 16,
|
||||
numChannels: 3,
|
||||
maxUpscalingSize: 448,
|
||||
}
|
||||
|
||||
generateImage := func(seed int) image.Image {
|
||||
width, height := 20, 10
|
||||
img := image.NewRGBA(image.Rect(0, 0, width, height))
|
||||
|
||||
for x := range width {
|
||||
// Use the seed to vary color generation
|
||||
r := uint8((seed + x*11) % 256)
|
||||
g := uint8((seed + x*17) % 256)
|
||||
b := uint8((seed + x*23) % 256)
|
||||
|
||||
c := color.RGBA{R: r, G: g, B: b, A: 255}
|
||||
for y := range height {
|
||||
img.Set(x, y, c)
|
||||
}
|
||||
}
|
||||
|
||||
return img
|
||||
}
|
||||
|
||||
pixelsLocal, pixelsGlobal, targetSize, err := imgProc.ProcessImage(generateImage(12))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if n := len(pixelsLocal); n != 336*336*3 {
|
||||
t.Errorf("unexpected size of f32s: %d", n)
|
||||
}
|
||||
|
||||
if n := len(pixelsGlobal); n > 0 {
|
||||
t.Errorf("unexpected size of f32s: %d", n)
|
||||
}
|
||||
|
||||
if !targetSize.Eq(image.Point{336, 336}) {
|
||||
t.Errorf("unexpected target size: %v", targetSize)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue