create blobs in parallel (#10135)

* default max term height
* error on out of tree files
This commit is contained in:
Michael Yang 2025-05-05 11:59:26 -07:00 committed by GitHub
parent 7073600797
commit d931ee8f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 110 additions and 22 deletions

View file

@ -31,6 +31,7 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@ -41,6 +42,7 @@ import (
"github.com/ollama/ollama/runner" "github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
spinner.Stop() spinner.Stop()
req.Name = args[0] req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize") quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" { if quantize != "" {
req.Quantize = quantize req.Quantize = quantize
@ -117,28 +119,44 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err return err
} }
if len(req.Files) > 0 { var g errgroup.Group
fileMap := map[string]string{} g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
for f, digest := range req.Files {
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
} // TODO: this is incorrect since the file might be in a subdirectory
req.Files = fileMap // instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
} }
if len(req.Adapters) > 0 { adapters := syncmap.NewSyncMap[string, string]()
fileMap := map[string]string{} for f, digest := range req.Adapters {
for f, digest := range req.Adapters { g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil { if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err return err
} }
fileMap[filepath.Base(f)] = digest
} // TODO: same here
req.Adapters = fileMap adapters.Store(filepath.Base(f), digest)
return nil
})
} }
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
@ -213,7 +231,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
} }
}() }()
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil { if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil

View file

@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return return
} }
if req.Name != "test-model" { if req.Model != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name) t.Errorf("expected model name 'test-model', got %s", req.Name)
} }

View file

@ -139,10 +139,28 @@ func fileDigestMap(path string) (map[string]string, error) {
var files []string var files []string
if fi.IsDir() { if fi.IsDir() {
files, err = filesForModel(path) fs, err := filesForModel(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else { } else {
files = []string{path} files = []string{path}
} }
@ -215,11 +233,11 @@ func filesForModel(path string) ([]string, error) {
return nil, err return nil, err
} }
for _, safetensor := range matches { for _, match := range matches {
if ct, err := detectContentType(safetensor); err != nil { if ct, err := detectContentType(match); err != nil {
return nil, err return nil, err
} else if ct != contentType { } else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor) return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
} }
} }

View file

@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string {
func (b *Bar) String() string { func (b *Bar) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil { if err != nil {
termWidth = 80 termWidth = defaultTermWidth
} }
var pre strings.Builder var pre strings.Builder

View file

@ -4,8 +4,16 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"io" "io"
"os"
"sync" "sync"
"time" "time"
"golang.org/x/term"
)
const (
defaultTermWidth = 80
defaultTermHeight = 24
) )
type State interface { type State interface {
@ -83,6 +91,11 @@ func (p *Progress) Add(key string, state State) {
} }
func (p *Progress) render() { func (p *Progress) render() {
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termHeight = defaultTermHeight
}
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock() defer p.mu.Unlock()
@ -102,8 +115,9 @@ func (p *Progress) render() {
fmt.Fprint(p.w, "\033[1G") fmt.Fprint(p.w, "\033[1G")
// render progress lines // render progress lines
for i, state := range p.states { maxHeight := min(len(p.states), termHeight)
fmt.Fprint(p.w, state.String(), "\033[K") for i := len(p.states) - maxHeight; i < len(p.states); i++ {
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
if i < len(p.states)-1 { if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n") fmt.Fprint(p.w, "\n")
} }

38
types/syncmap/syncmap.go Normal file
View file

@ -0,0 +1,38 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}