server: more support for mixed-case model names (#8017)

Fixes #7944
This commit is contained in:
Blake Mizerany 2024-12-11 15:29:59 -08:00 committed by GitHub
parent 36d111e788
commit b1fd7fef86
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 123 additions and 38 deletions

View file

@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
var data [][]string
for _, m := range models.Models {
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
}
}

View file

@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
switch command {
case "model", "adapter":
if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
name, err := getExistingName(name)
if err != nil {
return err
}
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
return err

View file

@ -3,6 +3,7 @@ package server
import (
"errors"
"fmt"
"io/fs"
"net/url"
"os"
"path/filepath"
@ -10,6 +11,7 @@ import (
"strings"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model"
)
type ModelPath struct {
@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
func (mp ModelPath) GetManifestPath() (string, error) {
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
return filepath.Join(envconfig.Models(), "manifests", p), nil
name := model.Name{
Host: mp.Registry,
Namespace: mp.Namespace,
Model: mp.Repository,
Tag: mp.Tag,
}
return "", errModelPathInvalid
if !name.IsValid() {
return "", fs.ErrNotExist
}
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
}
func (mp ModelPath) BaseURL() *url.URL {

View file

@ -1,7 +1,6 @@
package server
import (
"errors"
"os"
"path/filepath"
"testing"
@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
})
}
}
func TestInsecureModelpath(t *testing.T) {
mp := ParseModelPath("../../..:something")
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
t.Errorf("expected error: %v", err)
}
}

View file

@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"math"
"net"
@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
model, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
// what the API currently returns until we can change it.
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
model, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):
case errors.Is(err, fs.ErrNotExist):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
case err.Error() == "invalid model name":
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
return
@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name, err := getExistingName(model.ParseName(req.Model))
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
return
}
var model string
var mname string
if req.Model != "" {
model = req.Model
mname = req.Model
} else if req.Name != "" {
model = req.Name
mname = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
name, err := getExistingName(model.ParseName(mname))
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
streamResponse(c, ch)
}
// getExistingName returns the original, on disk name if the input name is a
// case-insensitive match, otherwise it returns the input name.
// getExistingName searches the models directory for the longest prefix match of
// the input name and returns the input name with all existing parts replaced
// with each part found. If no parts are found, the input name is returned as
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
existing, err := Manifests(true)
if err != nil {
return zero, err
}
var set model.Name // tracks parts already canonicalized
for e := range existing {
if n.EqualFold(e) {
return e, nil
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
n.Host = e.Host
}
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
n.Namespace = e.Namespace
}
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
n.Model = e.Model
}
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
n.Tag = e.Tag
}
}
return n, nil
@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
return
}
@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
n, err := getExistingName(n)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
return
}
m, err := ParseNamedManifest(n)
if err != nil {
switch {
@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model)
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {
return nil, err
}
m, err := GetModel(name.String())
if err != nil {
return nil, err
}
@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
n := model.ParseName(req.Model)
if !n.IsValid() {
return nil, errors.New("invalid model name")
}
manifest, err := ParseNamedManifest(n)
manifest, err := ParseNamedManifest(name)
if err != nil {
return nil, err
}
@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
name := model.ParseName(req.Model)
if !name.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
name, err := getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
return

View file

@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})

View file

@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
wantStableName := name()
t.Logf("stable name: %s", wantStableName)
// checkManifestList tests that there is strictly one manifest in the
// models directory, and that the manifest is for the model under test.
checkManifestList := func() {
@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
Destination: name(),
}))
checkManifestList()
t.Logf("pushing")
rr := createRequest(t, s.PushHandler, api.PushRequest{
Model: name(),
Insecure: true,
Username: "alice",
Password: "x",
})
checkOK(rr)
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
t.Errorf("got = %q, want success", rr.Body.String())
}
}
func TestShow(t *testing.T) {

View file

@ -223,12 +223,12 @@ func (n Name) String() string {
func (n Name) DisplayShortest() string {
var sb strings.Builder
if n.Host != defaultHost {
if !strings.EqualFold(n.Host, defaultHost) {
sb.WriteString(n.Host)
sb.WriteByte('/')
sb.WriteString(n.Namespace)
sb.WriteByte('/')
} else if n.Namespace != defaultNamespace {
} else if !strings.EqualFold(n.Namespace, defaultNamespace) {
sb.WriteString(n.Namespace)
sb.WriteByte('/')
}