mirror of
https://github.com/ollama/ollama.git
synced 2025-05-11 18:36:41 +02:00
Previously, using a Registry required a DiskCache to be passed in for use in various methods. This was a bit cumbersome, as the DiskCache is required for most operations, and the DefaultCache is used in most of those cases. This change makes the DiskCache an optional field on the Registry struct. This also changes DefaultCache to initialize on first use. This is to not burden clients with the cost of creating a new cache per use, or having to hold onto a cache for the lifetime of the Registry. Also, slip in some minor docs updates for Trace.
375 lines
8.2 KiB
Go
375 lines
8.2 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"mime"
|
|
"net/http"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
"github.com/ollama/ollama/server/internal/client/ollama"
|
|
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
var stdout io.Writer = os.Stdout
|
|
|
|
const usage = `Opp is a tool for pushing and pulling Ollama models.
|
|
|
|
Usage:
|
|
|
|
opp [flags] <push|pull|import>
|
|
|
|
Commands:
|
|
|
|
push Upload a model to the Ollama server.
|
|
pull Download a model from the Ollama server.
|
|
import Import a model from a local safetensor directory.
|
|
|
|
Examples:
|
|
|
|
# Pull a model from the Ollama server.
|
|
opp pull library/llama3.2:latest
|
|
|
|
# Push a model to the Ollama server.
|
|
opp push username/my_model:8b
|
|
|
|
# Import a model from a local safetensor directory.
|
|
opp import /path/to/safetensor
|
|
|
|
Envionment Variables:
|
|
|
|
OLLAMA_MODELS
|
|
The directory where models are pushed and pulled from
|
|
(default ~/.ollama/models).
|
|
`
|
|
|
|
func main() {
|
|
flag.Usage = func() {
|
|
fmt.Fprint(os.Stderr, usage)
|
|
}
|
|
flag.Parse()
|
|
|
|
ctx := context.Background()
|
|
|
|
err := func() error {
|
|
switch cmd := flag.Arg(0); cmd {
|
|
case "pull":
|
|
rc, err := ollama.DefaultRegistry()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return cmdPull(ctx, rc)
|
|
case "push":
|
|
rc, err := ollama.DefaultRegistry()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return cmdPush(ctx, rc)
|
|
case "import":
|
|
c, err := ollama.DefaultCache()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
return cmdImport(ctx, c)
|
|
default:
|
|
if cmd == "" {
|
|
flag.Usage()
|
|
} else {
|
|
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
|
|
}
|
|
os.Exit(2)
|
|
return errors.New("unreachable")
|
|
}
|
|
}()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
|
model := flag.Arg(1)
|
|
if model == "" {
|
|
flag.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
tr := http.DefaultTransport.(*http.Transport).Clone()
|
|
// TODO(bmizerany): configure transport?
|
|
rc.HTTPClient = &http.Client{Transport: tr}
|
|
|
|
var mu sync.Mutex
|
|
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
|
|
|
|
var pb bytes.Buffer
|
|
printProgress := func() {
|
|
pb.Reset()
|
|
mu.Lock()
|
|
for d, s := range p {
|
|
// Write progress to a buffer first to avoid blocking
|
|
// on stdout while holding the lock.
|
|
stamp := time.Now().Format("2006/01/02 15:04:05")
|
|
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
|
|
if s[0] == s[1] {
|
|
delete(p, d)
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
io.Copy(stdout, &pb)
|
|
}
|
|
|
|
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
|
Update: func(l *ollama.Layer, n int64, err error) {
|
|
if err != nil && !errors.Is(err, ollama.ErrCached) {
|
|
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
|
|
return
|
|
}
|
|
|
|
mu.Lock()
|
|
p[l.Digest] = [2]int64{l.Size, n}
|
|
mu.Unlock()
|
|
},
|
|
})
|
|
|
|
errc := make(chan error)
|
|
go func() {
|
|
errc <- rc.Pull(ctx, model)
|
|
}()
|
|
|
|
t := time.NewTicker(time.Second)
|
|
defer t.Stop()
|
|
for {
|
|
select {
|
|
case <-t.C:
|
|
printProgress()
|
|
case err := <-errc:
|
|
printProgress()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
|
args := flag.Args()[1:]
|
|
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
|
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
|
|
flag.PrintDefaults()
|
|
}
|
|
flag.Parse(args)
|
|
|
|
model := flag.Arg(0)
|
|
if model == "" {
|
|
return fmt.Errorf("missing model argument")
|
|
}
|
|
|
|
from := cmp.Or(*flagFrom, model)
|
|
m, err := rc.ResolveLocal(from)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
|
Update: func(l *ollama.Layer, n int64, err error) {
|
|
switch {
|
|
case errors.Is(err, ollama.ErrCached):
|
|
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
|
|
case err != nil:
|
|
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
|
|
case n == 0:
|
|
l := m.Layer(l.Digest)
|
|
mt, p, _ := mime.ParseMediaType(l.MediaType)
|
|
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
|
|
switch mt {
|
|
case "tensor":
|
|
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
|
|
default:
|
|
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
|
|
}
|
|
}
|
|
},
|
|
})
|
|
|
|
return rc.Push(ctx, model, &ollama.PushParams{
|
|
From: from,
|
|
})
|
|
}
|
|
|
|
type trackingReader struct {
|
|
io.Reader
|
|
n *atomic.Int64
|
|
}
|
|
|
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|
n, err = r.Reader.Read(p)
|
|
r.n.Add(int64(n))
|
|
return n, err
|
|
}
|
|
|
|
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
|
args := flag.Args()[1:]
|
|
flag := flag.NewFlagSet("import", flag.ExitOnError)
|
|
flagAs := flag.String("as", "", "Import using the provided name.")
|
|
flag.Usage = func() {
|
|
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
|
|
flag.PrintDefaults()
|
|
}
|
|
flag.Parse(args)
|
|
if *flagAs == "" {
|
|
return fmt.Errorf("missing -as flag")
|
|
}
|
|
as := ollama.CompleteName(*flagAs)
|
|
|
|
dir := cmp.Or(flag.Arg(0), ".")
|
|
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
|
|
|
m, err := safetensors.Read(os.DirFS(dir))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var total int64
|
|
var tt []*safetensors.Tensor
|
|
for t, err := range m.Tensors() {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
tt = append(tt, t)
|
|
total += t.Size()
|
|
}
|
|
|
|
var n atomic.Int64
|
|
done := make(chan error)
|
|
go func() {
|
|
layers := make([]*ollama.Layer, len(tt))
|
|
var g errgroup.Group
|
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
|
var ctxErr error
|
|
for i, t := range tt {
|
|
if ctx.Err() != nil {
|
|
// The context may cancel AFTER we exit the
|
|
// loop, and so if we use ctx.Err() after the
|
|
// loop we may report it as the error that
|
|
// broke the loop, when it was not. This can
|
|
// manifest as a false-negative, leading the
|
|
// user to think their import failed when it
|
|
// did not, so capture it if and only if we
|
|
// exit the loop because of a ctx.Err() and
|
|
// report it.
|
|
ctxErr = ctx.Err()
|
|
break
|
|
}
|
|
g.Go(func() (err error) {
|
|
rc, err := t.Reader()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rc.Close()
|
|
tr := &trackingReader{rc, &n}
|
|
d, err := c.Import(tr, t.Size())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := rc.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
layers[i] = &ollama.Layer{
|
|
Digest: d,
|
|
Size: t.Size(),
|
|
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
|
|
"name": t.Name(),
|
|
"dtype": t.DataType(),
|
|
"shape": t.Shape().String(),
|
|
}),
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
done <- func() error {
|
|
if err := errors.Join(g.Wait(), ctxErr); err != nil {
|
|
return err
|
|
}
|
|
m := &ollama.Manifest{Layers: layers}
|
|
data, err := json.MarshalIndent(m, "", " ")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
d := blob.DigestFromBytes(data)
|
|
err = blob.PutBytes(c, d, data)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return c.Link(as, d)
|
|
}()
|
|
}()
|
|
|
|
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
|
|
|
|
csiHideCursor(stdout)
|
|
defer csiShowCursor(stdout)
|
|
|
|
csiSavePos(stdout)
|
|
writeProgress := func() {
|
|
csiRestorePos(stdout)
|
|
nn := n.Load()
|
|
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
|
|
formatNatural(nn),
|
|
formatNatural(total),
|
|
nn*100/total,
|
|
ansiClearToEndOfLine,
|
|
)
|
|
}
|
|
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
writeProgress()
|
|
case err := <-done:
|
|
writeProgress()
|
|
fmt.Println()
|
|
fmt.Println("Successfully imported", as)
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
func formatNatural(n int64) string {
|
|
switch {
|
|
case n < 1024:
|
|
return fmt.Sprintf("%d B", n)
|
|
case n < 1024*1024:
|
|
return fmt.Sprintf("%.1f KB", float64(n)/1024)
|
|
case n < 1024*1024*1024:
|
|
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
|
|
default:
|
|
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
|
|
}
|
|
}
|
|
|
|
const ansiClearToEndOfLine = "\033[K"
|
|
|
|
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
|
|
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
|
|
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
|
|
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }
|