image processing for llama3.2 (#6963)

Co-authored-by: jmorganca <jmorganca@gmail.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
Patrick Devine 2024-10-18 16:12:35 -07:00 committed by GitHub
parent bf4018b9ec
commit c7cb0f0602
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
35 changed files with 3851 additions and 203 deletions

View file

@ -119,20 +119,21 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
s.sched.expireRunner(model)
c.JSON(http.StatusOK, api.GenerateResponse{
@ -169,6 +170,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
// load the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@ -179,6 +181,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
isMllama := checkMllamaModelFamily(model)
if isMllama && len(req.Images) > 1 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "this model only supports one image: more than one image sent"})
return
}
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
@ -212,7 +220,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
if isMllama {
msgs = append(msgs, api.Message{Role: "user", Content: "<|image|>"})
} else {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})