mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
image processing for llama3.2 (#6963)
Co-authored-by: jmorganca <jmorganca@gmail.com> Co-authored-by: Michael Yang <mxyng@pm.me> Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
parent
bf4018b9ec
commit
c7cb0f0602
35 changed files with 3851 additions and 203 deletions
|
@ -60,7 +60,9 @@ package llama
|
|||
#include <stdlib.h>
|
||||
#include "llama.h"
|
||||
#include "clip.h"
|
||||
#include "ggml.h"
|
||||
#include "llava.h"
|
||||
#include "mllama.h"
|
||||
#include "sampling_ext.h"
|
||||
|
||||
bool llamaProgressCallback(float progress, void *user_data);
|
||||
|
@ -410,18 +412,60 @@ func Quantize(infile, outfile string, ftype uint32) error {
|
|||
|
||||
// llava
|
||||
type ClipContext struct {
|
||||
c *C.struct_clip_ctx
|
||||
c *C.struct_clip_ctx
|
||||
m *C.struct_mllama_ctx
|
||||
IsMllama bool
|
||||
embedPin runtime.Pinner
|
||||
pinned bool
|
||||
}
|
||||
|
||||
func NewClipContext(modelPath string) *ClipContext {
|
||||
func getVisionArch(mp *C.char) (string, error) {
|
||||
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
|
||||
if gguf_ctx == nil {
|
||||
return "", errors.New("unable to load vision projector")
|
||||
}
|
||||
defer C.gguf_free(gguf_ctx)
|
||||
|
||||
arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
|
||||
if int(arch_index) < 0 {
|
||||
return "", errors.New("unknown vision model architecture")
|
||||
}
|
||||
|
||||
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
|
||||
|
||||
return C.GoString(arch), nil
|
||||
}
|
||||
|
||||
func NewClipContext(modelPath string) (*ClipContext, error) {
|
||||
mp := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
cc := C.clip_model_load(mp, 1)
|
||||
return &ClipContext{c: cc}
|
||||
|
||||
arch, err := getVisionArch(mp)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cc ClipContext
|
||||
if arch == "clip" {
|
||||
cc.c = C.clip_model_load(mp, 1)
|
||||
} else if arch == "mllama" {
|
||||
cc.m = C.mllama_model_load(mp, 1)
|
||||
cc.IsMllama = true
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
||||
}
|
||||
|
||||
// XXX: check embedding size?
|
||||
return &cc, nil
|
||||
}
|
||||
|
||||
func (c *ClipContext) Free() {
|
||||
C.clip_free(c.c)
|
||||
if c.c != nil {
|
||||
C.clip_free(c.c)
|
||||
}
|
||||
if c.m != nil {
|
||||
C.mllama_free(c.m)
|
||||
}
|
||||
}
|
||||
|
||||
func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
|
||||
|
@ -445,6 +489,48 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
|
|||
return embed
|
||||
}
|
||||
|
||||
func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
|
||||
img := C.mllama_image_init()
|
||||
defer C.mllama_image_free(img)
|
||||
|
||||
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
|
||||
|
||||
numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
|
||||
numEmbed := llamaContext.Model().NEmbd()
|
||||
|
||||
rows := make([]float32, numEmbed*numTokens)
|
||||
C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
|
||||
|
||||
embed := make([][]float32, numTokens)
|
||||
for i := range embed {
|
||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
|
||||
return embed
|
||||
}
|
||||
|
||||
// This really needs to be set on a batch instead
|
||||
func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
|
||||
if embed != nil {
|
||||
if clipContext.pinned {
|
||||
panic("Cross attention state already pinned")
|
||||
}
|
||||
|
||||
embedData := &embed[0][0]
|
||||
clipContext.embedPin.Pin(embedData)
|
||||
clipContext.pinned = true
|
||||
|
||||
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
|
||||
} else {
|
||||
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
|
||||
|
||||
if clipContext.pinned {
|
||||
clipContext.embedPin.Unpin()
|
||||
clipContext.pinned = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sampling
|
||||
// TODO: this is a temporary wrapper to allow calling C++ code from CGo
|
||||
type SamplingContext struct {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue