mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
parent
36d111e788
commit
b1fd7fef86
8 changed files with 123 additions and 38 deletions
|
@ -601,7 +601,7 @@ func ListHandler(cmd *cobra.Command, args []string) error {
|
|||
var data [][]string
|
||||
|
||||
for _, m := range models.Models {
|
||||
if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
|
||||
if len(args) == 0 || strings.HasPrefix(strings.ToLower(m.Name), strings.ToLower(args[0])) {
|
||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), format.HumanTime(m.ModifiedAt, "Never")})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -376,6 +376,10 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
|||
switch command {
|
||||
case "model", "adapter":
|
||||
if name := model.ParseName(c.Args); name.IsValid() && command == "model" {
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
baseLayers, err = parseFromModel(ctx, name, fn)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
@ -10,6 +11,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
|
@ -93,11 +95,16 @@ func (mp ModelPath) GetShortTagname() string {
|
|||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
func (mp ModelPath) GetManifestPath() (string, error) {
|
||||
if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
|
||||
return filepath.Join(envconfig.Models(), "manifests", p), nil
|
||||
name := model.Name{
|
||||
Host: mp.Registry,
|
||||
Namespace: mp.Namespace,
|
||||
Model: mp.Repository,
|
||||
Tag: mp.Tag,
|
||||
}
|
||||
|
||||
return "", errModelPathInvalid
|
||||
if !name.IsValid() {
|
||||
return "", fs.ErrNotExist
|
||||
}
|
||||
return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
|
||||
}
|
||||
|
||||
func (mp ModelPath) BaseURL() *url.URL {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
@ -155,10 +154,3 @@ func TestParseModelPath(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsecureModelpath(t *testing.T) {
|
||||
mp := ParseModelPath("../../..:something")
|
||||
if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
|
||||
t.Errorf("expected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
112
server/routes.go
112
server/routes.go
|
@ -9,6 +9,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
|
@ -120,10 +121,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(req.Model)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// We cannot currently consolidate this into GetModel because all we'll
|
||||
// induce infinite recursion given the current code structure.
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
model, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
case errors.Is(err, fs.ErrNotExist):
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
case err.Error() == "invalid model name":
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
@ -157,7 +174,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
caps = append(caps, CapabilityInsert)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
return
|
||||
|
@ -386,7 +403,13 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
|
@ -489,7 +512,13 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
|
@ -582,11 +611,11 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
var model string
|
||||
var mname string
|
||||
if req.Model != "" {
|
||||
model = req.Model
|
||||
mname = req.Model
|
||||
} else if req.Name != "" {
|
||||
model = req.Name
|
||||
mname = req.Name
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
|
@ -606,7 +635,13 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := PushModel(ctx, model, regOpts, fn); err != nil {
|
||||
name, err := getExistingName(model.ParseName(mname))
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
if err := PushModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
@ -619,17 +654,29 @@ func (s *Server) PushHandler(c *gin.Context) {
|
|||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// getExistingName returns the original, on disk name if the input name is a
|
||||
// case-insensitive match, otherwise it returns the input name.
|
||||
// getExistingName searches the models directory for the longest prefix match of
|
||||
// the input name and returns the input name with all existing parts replaced
|
||||
// with each part found. If no parts are found, the input name is returned as
|
||||
// is.
|
||||
func getExistingName(n model.Name) (model.Name, error) {
|
||||
var zero model.Name
|
||||
existing, err := Manifests(true)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
var set model.Name // tracks parts already canonicalized
|
||||
for e := range existing {
|
||||
if n.EqualFold(e) {
|
||||
return e, nil
|
||||
if set.Host == "" && strings.EqualFold(e.Host, n.Host) {
|
||||
n.Host = e.Host
|
||||
}
|
||||
if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) {
|
||||
n.Namespace = e.Namespace
|
||||
}
|
||||
if set.Model == "" && strings.EqualFold(e.Model, n.Model) {
|
||||
n.Model = e.Model
|
||||
}
|
||||
if set.Tag == "" && strings.EqualFold(e.Tag, n.Tag) {
|
||||
n.Tag = e.Tag
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
|
@ -658,7 +705,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
if r.Path == "" && r.Modelfile == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or Modelfile are required"})
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -722,6 +769,12 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
}
|
||||
|
||||
m, err := ParseNamedManifest(n)
|
||||
if err != nil {
|
||||
switch {
|
||||
|
@ -782,7 +835,16 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
m, err := GetModel(req.Model)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
return nil, errModelPathInvalid
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err := GetModel(name.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -805,12 +867,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
|
||||
}
|
||||
|
||||
n := model.ParseName(req.Model)
|
||||
if !n.IsValid() {
|
||||
return nil, errors.New("invalid model name")
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(n)
|
||||
manifest, err := ParseNamedManifest(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1431,7 +1488,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
caps = append(caps, CapabilityTools)
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
name, err := getExistingName(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
|
||||
return
|
||||
|
|
|
@ -719,7 +719,7 @@ func TestGenerate(t *testing.T) {
|
|||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support insert"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
|
|
@ -514,6 +514,8 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
|||
|
||||
wantStableName := name()
|
||||
|
||||
t.Logf("stable name: %s", wantStableName)
|
||||
|
||||
// checkManifestList tests that there is strictly one manifest in the
|
||||
// models directory, and that the manifest is for the model under test.
|
||||
checkManifestList := func() {
|
||||
|
@ -601,6 +603,18 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
|||
Destination: name(),
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("pushing")
|
||||
rr := createRequest(t, s.PushHandler, api.PushRequest{
|
||||
Model: name(),
|
||||
Insecure: true,
|
||||
Username: "alice",
|
||||
Password: "x",
|
||||
})
|
||||
checkOK(rr)
|
||||
if !strings.Contains(rr.Body.String(), `"status":"success"`) {
|
||||
t.Errorf("got = %q, want success", rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestShow(t *testing.T) {
|
||||
|
|
|
@ -223,12 +223,12 @@ func (n Name) String() string {
|
|||
func (n Name) DisplayShortest() string {
|
||||
var sb strings.Builder
|
||||
|
||||
if n.Host != defaultHost {
|
||||
if !strings.EqualFold(n.Host, defaultHost) {
|
||||
sb.WriteString(n.Host)
|
||||
sb.WriteByte('/')
|
||||
sb.WriteString(n.Namespace)
|
||||
sb.WriteByte('/')
|
||||
} else if n.Namespace != defaultNamespace {
|
||||
} else if !strings.EqualFold(n.Namespace, defaultNamespace) {
|
||||
sb.WriteString(n.Namespace)
|
||||
sb.WriteByte('/')
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue