mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
create blobs in parallel (#10135)
* default max term height * error on out of tree files
This commit is contained in:
parent
7073600797
commit
d931ee8f22
6 changed files with 110 additions and 22 deletions
42
cmd/cmd.go
42
cmd/cmd.go
|
@ -31,6 +31,7 @@ import (
|
|||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
@ -41,6 +42,7 @@ import (
|
|||
"github.com/ollama/ollama/runner"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/types/syncmap"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
|
@ -106,7 +108,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
}
|
||||
spinner.Stop()
|
||||
|
||||
req.Name = args[0]
|
||||
req.Model = args[0]
|
||||
quantize, _ := cmd.Flags().GetString("quantize")
|
||||
if quantize != "" {
|
||||
req.Quantize = quantize
|
||||
|
@ -117,28 +119,44 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if len(req.Files) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
|
||||
files := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Files {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
}
|
||||
req.Files = fileMap
|
||||
}
|
||||
|
||||
if len(req.Adapters) > 0 {
|
||||
fileMap := map[string]string{}
|
||||
// TODO: this is incorrect since the file might be in a subdirectory
|
||||
// instead this should take the path relative to the model directory
|
||||
// but the current implementation does not allow this
|
||||
files.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
adapters := syncmap.NewSyncMap[string, string]()
|
||||
for f, digest := range req.Adapters {
|
||||
g.Go(func() error {
|
||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||
return err
|
||||
}
|
||||
fileMap[filepath.Base(f)] = digest
|
||||
|
||||
// TODO: same here
|
||||
adapters.Store(filepath.Base(f), digest)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
req.Adapters = fileMap
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Files = files.Items()
|
||||
req.Adapters = adapters.Items()
|
||||
|
||||
bars := make(map[string]*progress.Bar)
|
||||
fn := func(resp api.ProgressResponse) error {
|
||||
if resp.Digest != "" {
|
||||
|
@ -213,7 +231,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
|
|||
}
|
||||
}()
|
||||
|
||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return digest, nil
|
||||
|
|
|
@ -690,7 +690,7 @@ func TestCreateHandler(t *testing.T) {
|
|||
return
|
||||
}
|
||||
|
||||
if req.Name != "test-model" {
|
||||
if req.Model != "test-model" {
|
||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||
}
|
||||
|
||||
|
|
|
@ -139,10 +139,28 @@ func fileDigestMap(path string) (map[string]string, error) {
|
|||
|
||||
var files []string
|
||||
if fi.IsDir() {
|
||||
files, err = filesForModel(path)
|
||||
fs, err := filesForModel(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, f := range fs {
|
||||
f, err := filepath.EvalSymlinks(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rel, err := filepath.Rel(path, f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !filepath.IsLocal(rel) {
|
||||
return nil, fmt.Errorf("insecure path: %s", rel)
|
||||
}
|
||||
|
||||
files = append(files, f)
|
||||
}
|
||||
} else {
|
||||
files = []string{path}
|
||||
}
|
||||
|
@ -215,11 +233,11 @@ func filesForModel(path string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
for _, safetensor := range matches {
|
||||
if ct, err := detectContentType(safetensor); err != nil {
|
||||
for _, match := range matches {
|
||||
if ct, err := detectContentType(match); err != nil {
|
||||
return nil, err
|
||||
} else if ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ func formatDuration(d time.Duration) string {
|
|||
func (b *Bar) String() string {
|
||||
termWidth, _, err := term.GetSize(int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
termWidth = 80
|
||||
termWidth = defaultTermWidth
|
||||
}
|
||||
|
||||
var pre strings.Builder
|
||||
|
|
|
@ -4,8 +4,16 @@ import (
|
|||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTermWidth = 80
|
||||
defaultTermHeight = 24
|
||||
)
|
||||
|
||||
type State interface {
|
||||
|
@ -83,6 +91,11 @@ func (p *Progress) Add(key string, state State) {
|
|||
}
|
||||
|
||||
func (p *Progress) render() {
|
||||
_, termHeight, err := term.GetSize(int(os.Stderr.Fd()))
|
||||
if err != nil {
|
||||
termHeight = defaultTermHeight
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
|
@ -102,8 +115,9 @@ func (p *Progress) render() {
|
|||
fmt.Fprint(p.w, "\033[1G")
|
||||
|
||||
// render progress lines
|
||||
for i, state := range p.states {
|
||||
fmt.Fprint(p.w, state.String(), "\033[K")
|
||||
maxHeight := min(len(p.states), termHeight)
|
||||
for i := len(p.states) - maxHeight; i < len(p.states); i++ {
|
||||
fmt.Fprint(p.w, p.states[i].String(), "\033[K")
|
||||
if i < len(p.states)-1 {
|
||||
fmt.Fprint(p.w, "\n")
|
||||
}
|
||||
|
|
38
types/syncmap/syncmap.go
Normal file
38
types/syncmap/syncmap.go
Normal file
|
@ -0,0 +1,38 @@
|
|||
package syncmap
|
||||
|
||||
import (
|
||||
"maps"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SyncMap is a simple, generic thread-safe map implementation.
|
||||
type SyncMap[K comparable, V any] struct {
|
||||
mu sync.RWMutex
|
||||
m map[K]V
|
||||
}
|
||||
|
||||
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
|
||||
return &SyncMap[K, V]{
|
||||
m: make(map[K]V),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Load(key K) (V, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
val, ok := s.m[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Store(key K, value V) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.m[key] = value
|
||||
}
|
||||
|
||||
func (s *SyncMap[K, V]) Items() map[K]V {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
// shallow copy map items
|
||||
return maps.Clone(s.m)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue