From 54d47159f7c9cc6ad3475c5c68b55b54fe2d3b49 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Fri, 25 Apr 2025 16:47:27 -0700 Subject: [PATCH] remove mllama integration, use ollama engine --- llama/llama.go | 57 -------------------------------- model/models/mllama/model.go | 5 --- runner/llamarunner/image.go | 36 ++------------------ runner/llamarunner/runner.go | 7 ---- server/prompt.go | 64 +++--------------------------------- server/routes.go | 36 +------------------- server/sched.go | 9 ++--- 7 files changed, 14 insertions(+), 200 deletions(-) diff --git a/llama/llama.go b/llama/llama.go index 278022cc1..961393c3e 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -509,63 +509,6 @@ func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, return embed, nil } -type MllamaContext struct { - c *C.struct_mllama_ctx -} - -func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) { - mp := C.CString(modelPath) - defer C.free(unsafe.Pointer(mp)) - c := C.mllama_model_load(mp, 1) - if c == nil { - return nil, fmt.Errorf("unable to load mllama model: %v", modelPath) - } - - projEmbedSize := int(C.mllama_n_embd(c)) - modelEmbedSize := llamaContext.Model().NEmbd() - if projEmbedSize != modelEmbedSize { - return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize) - } - - return &MllamaContext{c: c}, nil -} - -func (m *MllamaContext) Free() { - C.mllama_free(m.c) -} - -func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) { - img := C.mllama_image_init() - defer C.mllama_image_free(img) - - ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)) - if !ok { - return nil, errors.New("unable to load mllama image data") - } - - rows := make([]float32, m.EmbedSize(llamaContext)) - ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))) - if !ok { - return nil, errors.New("unable to make mllama embedding from image") - } - - embed := make([][]float32, 1) - embed[0] = rows - - return embed, nil -} - -func (m *MllamaContext) EmbedSize(llamaContext *Context) int { - numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c)) - numEmbed := llamaContext.Model().NEmbd() - - return numTokens * numEmbed -} - -func (c *Context) SetCrossAttention(state bool) { - C.llama_set_cross_attention(c.c, C.bool(state)) -} - func (c *Context) Synchronize() { C.llama_synchronize(c.c) } diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 149876c9c..62e21cf39 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -3,7 +3,6 @@ package mllama import ( "bytes" "encoding/binary" - "fmt" "hash/fnv" "image" "slices" @@ -34,10 +33,6 @@ const ( ) func New(c fs.Config) (model.Model, error) { - // Verify unified config - if c.Uint("vision.block_count") == 0 { - return nil, fmt.Errorf("non-unified vision model not supported") - } m := Model{ BytePairEncoding: model.NewBytePairEncoding( c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), diff --git a/runner/llamarunner/image.go b/runner/llamarunner/image.go index e7e30a4d8..09d95565e 100644 --- a/runner/llamarunner/image.go +++ b/runner/llamarunner/image.go @@ -5,7 +5,6 @@ import ( "fmt" "hash/maphash" "log/slog" - "slices" "sync" "time" @@ -18,8 +17,7 @@ type ImageContext struct { // mu is required to be held when generating embeddings or accessing the cache mu sync.Mutex - clip *llama.ClipContext - mllama *llama.MllamaContext + clip *llama.ClipContext // cache of images to embeddings images []imageCache @@ -35,8 +33,6 @@ func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageConte var c ImageContext if arch == "clip" { c.clip, err = llama.NewClipContext(llamaContext, modelPath) - } else if arch == "mllama" { - c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath) } else { return nil, fmt.Errorf("unknown vision model architecture: %s", arch) } @@ -58,9 +54,6 @@ func (c *ImageContext) Free(modelPath string) { if c.clip != nil { c.clip.Free() } - if c.mllama != nil { - c.mllama.Free() - } } func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) { @@ -79,12 +72,7 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspect embed, err := c.findImage(hash) if err != nil { - if c.mllama != nil { - embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId) - if err != nil { - return nil, err - } - } else if c.clip != nil { + if c.clip != nil { embed, err = c.clip.NewEmbed(llamaContext, data) if err != nil { return nil, err @@ -109,29 +97,11 @@ func (c *ImageContext) BatchSize(configuredBatchSize int) int { // and doesn't support more than a single image per request. // The embeddings are large (100 MB), so allocating a big batch can fail // on some systems - if c.mllama != nil { - return 1 - } - return configuredBatchSize } func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { - if c != nil && c.mllama != nil { - return c.mllama.EmbedSize(llamaContext) - } else { - return llamaContext.Model().NEmbd() - } -} - -func (c *ImageContext) NeedCrossAttention(inputs ...input) bool { - if c == nil || c.mllama == nil { - return false - } - - return slices.ContainsFunc(inputs, func(input input) bool { - return input.embed != nil - }) + return llamaContext.Model().NEmbd() } type imageCache struct { diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 5341d4fb1..791d75ba6 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -413,9 +413,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) if batch == nil { if !embedding { batch = tokenBatch - } else { - batch = embedBatch - seq.crossAttention = s.image.NeedCrossAttention(input) } } else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention { s.nextSeq = seqIdx @@ -439,8 +436,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch) return nil } - s.lc.SetCrossAttention(crossAttention) - err := s.lc.Decode(batch) if err != nil { return fmt.Errorf("failed to decode batch: %w", err) @@ -621,8 +616,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...) - s.seqs[i] = seq s.cond.Signal() found = true diff --git a/server/prompt.go b/server/prompt.go index 5b5b958f1..a492ae733 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -3,7 +3,6 @@ package server import ( "bytes" "context" - "encoding/binary" "errors" "fmt" "log/slog" @@ -11,7 +10,6 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/template" ) @@ -25,25 +23,14 @@ var errTooManyImages = errors.New("vision model only supports a single image per func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { var system []api.Message - isMllama := checkMllamaModelFamily(m) - var imageNumTokens int // TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent - if isMllama { - // Our mllama implementation packs all of the embeddings into a single token - imageNumTokens = 1 - } else { - // Clip images are represented as 768 tokens, each an embedding - imageNumTokens = 768 - } + // Clip images are represented as 768 tokens, each an embedding + imageNumTokens = 768 n := len(msgs) - 1 // in reverse, find all messages that fit into context window for i := n; i >= 0; i-- { - if isMllama && len(msgs[i].Images) > 1 { - return "", nil, errTooManyImages - } - // always include the last message if i == n { continue @@ -91,41 +78,9 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. for _, i := range msg.Images { var imgData llm.ImageData - if isMllama { - if len(m.ProjectorPaths) == 0 { - imgData = llm.ImageData{ - ID: len(images), - Data: i, - } - } else { - data, opts, err := mllama.Preprocess(bytes.NewReader(i)) - if err != nil { - return "", nil, err - } - - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - return "", nil, err - } - - ar, ok := opts["aspectRatioIndex"].(int) - if !ok { - return "", nil, fmt.Errorf("missing aspect ratio for image") - } - - imgData = llm.ImageData{ - ID: len(images), - Data: buf.Bytes(), - AspectRatioID: ar, - } - } - imgPrompt = "<|image|>" - } else { - imgData = llm.ImageData{ - ID: len(images), - Data: i, - } + imgData = llm.ImageData{ + ID: len(images), + Data: i, } imgTag := fmt.Sprintf("[img-%d]", imgData.ID) @@ -148,12 +103,3 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api. return b.String(), images, nil } - -func checkMllamaModelFamily(m *Model) bool { - for _, arch := range m.Config.ModelFamilies { - if arch == "mllama" { - return true - } - } - return false -} diff --git a/server/routes.go b/server/routes.go index 8b0c7aca5..6b3da97b7 100644 --- a/server/routes.go +++ b/server/routes.go @@ -4,7 +4,6 @@ import ( "bytes" "cmp" "context" - "encoding/binary" "encoding/json" "errors" "fmt" @@ -33,7 +32,6 @@ import ( "github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llm" - "github.com/ollama/ollama/model/models/mllama" "github.com/ollama/ollama/openai" "github.com/ollama/ollama/server/internal/client/ollama" "github.com/ollama/ollama/server/internal/registry" @@ -204,38 +202,9 @@ func (s *Server) GenerateHandler(c *gin.Context) { return } - isMllama := checkMllamaModelFamily(m) - if isMllama && len(req.Images) > 1 { - c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"}) - return - } - images := make([]llm.ImageData, len(req.Images)) for i := range req.Images { - if isMllama && len(m.ProjectorPaths) > 0 { - data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i])) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) - return - } - - ar, ok := opts["aspectRatioIndex"].(int) - if !ok { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) - return - } - - buf := new(bytes.Buffer) - err = binary.Write(buf, binary.LittleEndian, data) - if err != nil { - c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"}) - return - } - - images[i] = llm.ImageData{ID: i, Data: buf.Bytes(), AspectRatioID: ar} - } else { - images[i] = llm.ImageData{ID: i, Data: req.Images[i]} - } + images[i] = llm.ImageData{ID: i, Data: req.Images[i]} } prompt := req.Prompt @@ -267,9 +236,6 @@ func (s *Server) GenerateHandler(c *gin.Context) { for _, i := range images { imgPrompt := "" - if isMllama { - imgPrompt = "<|image|>" - } msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)}) } diff --git a/server/sched.go b/server/sched.go index 43da138e2..3fc54e55a 100644 --- a/server/sched.go +++ b/server/sched.go @@ -8,6 +8,7 @@ import ( "os" "reflect" "runtime" + "slices" "sort" "strconv" "strings" @@ -132,11 +133,11 @@ func (s *Scheduler) processPending(ctx context.Context) { continue } numParallel := int(envconfig.NumParallel()) - // TODO (jmorganca): mllama doesn't support parallel yet - // see https://github.com/ollama/ollama/issues/4165 - if checkMllamaModelFamily(pending.model) && numParallel != 1 { + // `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1 + // ref: https://github.com/ollama/ollama/issues/4165 + if slices.Contains(pending.model.Config.ModelFamilies, "mllama") && numParallel != 1 { numParallel = 1 - slog.Warn("mllama doesn't support parallel requests yet") + slog.Warn("mllama does not currently support parallel requests") } for {