diff --git a/cmd/cmd.go b/cmd/cmd.go index 79ff87ac8..58c6dbf22 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -31,6 +31,7 @@ import ( "github.com/olekukonko/tablewriter" "github.com/spf13/cobra" "golang.org/x/crypto/ssh" + "golang.org/x/sync/errgroup" "golang.org/x/term" "github.com/ollama/ollama/api" @@ -41,6 +42,7 @@ import ( "github.com/ollama/ollama/runner" "github.com/ollama/ollama/server" "github.com/ollama/ollama/types/model" + "github.com/ollama/ollama/types/syncmap" "github.com/ollama/ollama/version" ) @@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { } spinner.Stop() - req.Name = args[0] + req.Model = args[0] quantize, _ := cmd.Flags().GetString("quantize") if quantize != "" { req.Quantize = quantize @@ -117,28 +119,44 @@ func CreateHandler(cmd *cobra.Command, args []string) error { return err } - if len(req.Files) > 0 { - fileMap := map[string]string{} - for f, digest := range req.Files { + var g errgroup.Group + g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1)) + + files := syncmap.NewSyncMap[string, string]() + for f, digest := range req.Files { + g.Go(func() error { if _, err := createBlob(cmd, client, f, digest, p); err != nil { return err } - fileMap[filepath.Base(f)] = digest - } - req.Files = fileMap + + // TODO: this is incorrect since the file might be in a subdirectory + // instead this should take the path relative to the model directory + // but the current implementation does not allow this + files.Store(filepath.Base(f), digest) + return nil + }) } - if len(req.Adapters) > 0 { - fileMap := map[string]string{} - for f, digest := range req.Adapters { + adapters := syncmap.NewSyncMap[string, string]() + for f, digest := range req.Adapters { + g.Go(func() error { if _, err := createBlob(cmd, client, f, digest, p); err != nil { return err } - fileMap[filepath.Base(f)] = digest - } - req.Adapters = fileMap + + // TODO: same here + adapters.Store(filepath.Base(f), digest) + return nil + }) } + if err := g.Wait(); err != nil { + return err + } + + req.Files = files.Items() + req.Adapters = adapters.Items() + bars := make(map[string]*progress.Bar) fn := func(resp api.ProgressResponse) error { if resp.Digest != "" { @@ -213,7 +231,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri } }() - if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { + if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { return "", err } return digest, nil diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 367a35b6b..1cd6ddb40 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) { return } - if req.Name != "test-model" { + if req.Model != "test-model" { t.Errorf("expected model name 'test-model', got %s", req.Name) } diff --git a/parser/parser.go b/parser/parser.go index a14ac5ff4..0a732653c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -139,10 +139,28 @@ func fileDigestMap(path string) (map[string]string, error) { var files []string if fi.IsDir() { - files, err = filesForModel(path) + fs, err := filesForModel(path) if err != nil { return nil, err } + + for _, f := range fs { + f, err := filepath.EvalSymlinks(f) + if err != nil { + return nil, err + } + + rel, err := filepath.Rel(path, f) + if err != nil { + return nil, err + } + + if !filepath.IsLocal(rel) { + return nil, fmt.Errorf("insecure path: %s", rel) + } + + files = append(files, f) + } } else { files = []string{path} } @@ -215,11 +233,11 @@ func filesForModel(path string) ([]string, error) { return nil, err } - for _, safetensor := range matches { - if ct, err := detectContentType(safetensor); err != nil { + for _, match := range matches { + if ct, err := detectContentType(match); err != nil { return nil, err } else if ct != contentType { - return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) + return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match) } } diff --git a/progress/bar.go b/progress/bar.go index 410b6e23f..f3d21a8fd 100644 --- a/progress/bar.go +++ b/progress/bar.go @@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string { func (b *Bar) String() string { termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) if err != nil { - termWidth = 80 + termWidth = defaultTermWidth } var pre strings.Builder diff --git a/progress/progress.go b/progress/progress.go index 0cd0ea1f9..9f54275ec 100644 --- a/progress/progress.go +++ b/progress/progress.go @@ -4,8 +4,16 @@ import ( "bufio" "fmt" "io" + "os" "sync" "time" + + "golang.org/x/term" +) + +const ( + defaultTermWidth = 80 + defaultTermHeight = 24 ) type State interface { @@ -83,6 +91,11 @@ func (p *Progress) Add(key string, state State) { } func (p *Progress) render() { + _, termHeight, err := term.GetSize(int(os.Stderr.Fd())) + if err != nil { + termHeight = defaultTermHeight + } + p.mu.Lock() defer p.mu.Unlock() @@ -102,8 +115,9 @@ func (p *Progress) render() { fmt.Fprint(p.w, "\033[1G") // render progress lines - for i, state := range p.states { - fmt.Fprint(p.w, state.String(), "\033[K") + maxHeight := min(len(p.states), termHeight) + for i := len(p.states) - maxHeight; i < len(p.states); i++ { + fmt.Fprint(p.w, p.states[i].String(), "\033[K") if i < len(p.states)-1 { fmt.Fprint(p.w, "\n") } diff --git a/types/syncmap/syncmap.go b/types/syncmap/syncmap.go new file mode 100644 index 000000000..ff21cd999 --- /dev/null +++ b/types/syncmap/syncmap.go @@ -0,0 +1,38 @@ +package syncmap + +import ( + "maps" + "sync" +) + +// SyncMap is a simple, generic thread-safe map implementation. +type SyncMap[K comparable, V any] struct { + mu sync.RWMutex + m map[K]V +} + +func NewSyncMap[K comparable, V any]() *SyncMap[K, V] { + return &SyncMap[K, V]{ + m: make(map[K]V), + } +} + +func (s *SyncMap[K, V]) Load(key K) (V, bool) { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.m[key] + return val, ok +} + +func (s *SyncMap[K, V]) Store(key K, value V) { + s.mu.Lock() + defer s.mu.Unlock() + s.m[key] = value +} + +func (s *SyncMap[K, V]) Items() map[K]V { + s.mu.RLock() + defer s.mu.RUnlock() + // shallow copy map items + return maps.Clone(s.m) +}