create blobs in parallel (#10135)

* default max term height
* error on out of tree files
This commit is contained in:
Michael Yang 2025-05-05 11:59:26 -07:00 committed by GitHub
parent 7073600797
commit d931ee8f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 110 additions and 22 deletions

View file

@ -31,6 +31,7 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
@ -41,6 +42,7 @@ import (
"github.com/ollama/ollama/runner"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/syncmap"
"github.com/ollama/ollama/version"
)
@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
}
spinner.Stop()
req.Name = args[0]
req.Model = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
@ -117,28 +119,44 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
var g errgroup.Group
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
files := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Files {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
// TODO: this is incorrect since the file might be in a subdirectory
// instead this should take the path relative to the model directory
// but the current implementation does not allow this
files.Store(filepath.Base(f), digest)
return nil
})
}
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
adapters := syncmap.NewSyncMap[string, string]()
for f, digest := range req.Adapters {
g.Go(func() error {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
// TODO: same here
adapters.Store(filepath.Base(f), digest)
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
req.Files = files.Items()
req.Adapters = adapters.Items()
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
@ -213,7 +231,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
}
}()
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}
return digest, nil

View file

@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
return
}
if req.Name != "test-model" {
if req.Model != "test-model" {
t.Errorf("expected model name 'test-model', got %s", req.Name)
}

View file

@ -139,10 +139,28 @@ func fileDigestMap(path string) (map[string]string, error) {
var files []string
if fi.IsDir() {
files, err = filesForModel(path)
fs, err := filesForModel(path)
if err != nil {
return nil, err
}
for _, f := range fs {
f, err := filepath.EvalSymlinks(f)
if err != nil {
return nil, err
}
rel, err := filepath.Rel(path, f)
if err != nil {
return nil, err
}
if !filepath.IsLocal(rel) {
return nil, fmt.Errorf("insecure path: %s", rel)
}
files = append(files, f)
}
} else {
files = []string{path}
}
@ -215,11 +233,11 @@ func filesForModel(path string) ([]string, error) {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}

View file

@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string {
func (b *Bar) String() string {
termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termWidth = 80
termWidth = defaultTermWidth
}
var pre strings.Builder

View file

@ -4,8 +4,16 @@ import (
"bufio"
"fmt"
"io"
"os"
"sync"
"time"
"golang.org/x/term"
)
const (
defaultTermWidth = 80
defaultTermHeight = 24
)
type State interface {
@ -83,6 +91,11 @@ func (p *Progress) Add(key string, state State) {
}
func (p *Progress) render() {
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
if err != nil {
termHeight = defaultTermHeight
}
p.mu.Lock()
defer p.mu.Unlock()
@ -102,8 +115,9 @@ func (p *Progress) render() {
fmt.Fprint(p.w, "\033[1G")
// render progress lines
for i, state := range p.states {
fmt.Fprint(p.w, state.String(), "\033[K")
maxHeight := min(len(p.states), termHeight)
for i := len(p.states) - maxHeight; i < len(p.states); i++ {
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
if i < len(p.states)-1 {
fmt.Fprint(p.w, "\n")
}

38
types/syncmap/syncmap.go Normal file
View file

@ -0,0 +1,38 @@
package syncmap
import (
"maps"
"sync"
)
// SyncMap is a simple, generic thread-safe map implementation.
type SyncMap[K comparable, V any] struct {
mu sync.RWMutex
m map[K]V
}
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
return &SyncMap[K, V]{
m: make(map[K]V),
}
}
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
val, ok := s.m[key]
return val, ok
}
func (s *SyncMap[K, V]) Store(key K, value V) {
s.mu.Lock()
defer s.mu.Unlock()
s.m[key] = value
}
func (s *SyncMap[K, V]) Items() map[K]V {
s.mu.RLock()
defer s.mu.RUnlock()
// shallow copy map items
return maps.Clone(s.m)
}