mirror of
https://github.com/ollama/ollama.git
synced 2025-05-10 18:06:33 +02:00
server: remove internal cmd (#10595)
This commit is contained in:
parent
424810450f
commit
92ce438de0
2 changed files with 0 additions and 599 deletions
|
@ -1,224 +0,0 @@
|
|||
// safetensors provides a reader for the safetensor directories and files.
|
||||
package safetensors
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Tensor represents a single tensor in a safetensors file.
|
||||
//
|
||||
// It's zero value is not valid. Use [Model.Tensors] to get valid tensors.
|
||||
//
|
||||
// It is not safe for use across multiple goroutines.
|
||||
type Tensor struct {
|
||||
name string
|
||||
dataType string
|
||||
shape []int64
|
||||
|
||||
fsys fs.FS
|
||||
fname string // entry name in fsys
|
||||
offset int64
|
||||
size int64
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
fsys fs.FS
|
||||
}
|
||||
|
||||
func Read(fsys fs.FS) (*Model, error) {
|
||||
return &Model{fsys: fsys}, nil
|
||||
}
|
||||
|
||||
func (m *Model) Tensors() iter.Seq2[*Tensor, error] {
|
||||
return func(yield func(*Tensor, error) bool) {
|
||||
entries, err := fs.Glob(m.fsys, "*.safetensors")
|
||||
if err != nil {
|
||||
yield(nil, err)
|
||||
return
|
||||
}
|
||||
for _, e := range entries {
|
||||
tt, err := m.readTensors(e)
|
||||
if err != nil {
|
||||
yield(nil, err)
|
||||
return
|
||||
}
|
||||
for _, t := range tt {
|
||||
if !yield(t, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
f, err := m.fsys.Open(fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
finfo, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headerSize, err := readInt64(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := make([]byte, headerSize)
|
||||
_, err = io.ReadFull(f, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var raws map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raws); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||
|
||||
// TODO(bmizerany): do something with metadata? This could be another
|
||||
// header read if needed. We also need to figure out if the metadata is
|
||||
// present in only one .safetensors file or if each file may have their
|
||||
// own and if it needs to follow each tensor. Currently, I (bmizerany)
|
||||
// am only seeing them show up with one entry for file type which is
|
||||
// always "pt".
|
||||
|
||||
tt := make([]*Tensor, 0, len(raws))
|
||||
for name, raw := range raws {
|
||||
if name == "__metadata__" {
|
||||
// TODO(bmizerany): do something with metadata?
|
||||
continue
|
||||
}
|
||||
var v struct {
|
||||
DataType string `json:"dtype"`
|
||||
Shape []int64 `json:"shape"`
|
||||
Offsets []int64 `json:"data_offsets"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &v); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshalling layer %q: %w", name, err)
|
||||
}
|
||||
if len(v.Offsets) != 2 {
|
||||
return nil, fmt.Errorf("invalid offsets for %q: %v", name, v.Offsets)
|
||||
}
|
||||
|
||||
// TODO(bmizerany): after collecting, validate all offests make
|
||||
// tensors contiguous?
|
||||
begin := endOfHeader + v.Offsets[0]
|
||||
end := endOfHeader + v.Offsets[1]
|
||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): just yield.. don't be silly and make a slice :)
|
||||
tt = append(tt, &Tensor{
|
||||
name: name,
|
||||
dataType: v.DataType,
|
||||
shape: v.Shape,
|
||||
fsys: m.fsys,
|
||||
fname: fname,
|
||||
offset: begin,
|
||||
size: end - begin,
|
||||
})
|
||||
}
|
||||
return tt, nil
|
||||
}
|
||||
|
||||
func checkBeginEnd(size, begin, end int64) error {
|
||||
if begin < 0 {
|
||||
return fmt.Errorf("begin must not be negative: %d", begin)
|
||||
}
|
||||
if end < 0 {
|
||||
return fmt.Errorf("end must not be negative: %d", end)
|
||||
}
|
||||
if end < begin {
|
||||
return fmt.Errorf("end must be >= begin: %d < %d", end, begin)
|
||||
}
|
||||
if end > size {
|
||||
return fmt.Errorf("end must be <= size: %d > %d", end, size)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readInt64(r io.Reader) (int64, error) {
|
||||
var v uint64
|
||||
var buf [8]byte
|
||||
if _, err := io.ReadFull(r, buf[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for i := range buf {
|
||||
v |= uint64(buf[i]) << (8 * i)
|
||||
}
|
||||
return int64(v), nil
|
||||
}
|
||||
|
||||
type Shape []int64
|
||||
|
||||
func (s Shape) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
for i, v := range s {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(strconv.FormatInt(v, 10))
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *Tensor) Name() string { return t.name }
|
||||
func (t *Tensor) DataType() string { return t.dataType }
|
||||
func (t *Tensor) Size() int64 { return t.size }
|
||||
func (t *Tensor) Shape() Shape { return slices.Clone(t.shape) }
|
||||
|
||||
func (t *Tensor) Reader() (io.ReadCloser, error) {
|
||||
f, err := t.fsys.Open(t.fname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := newSectionReader(f, t.offset, t.size)
|
||||
rc := struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}{r, f}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
// newSectionReader returns a new io.Reader that reads from r starting at
|
||||
// offset. It is a convenience function for creating a io.SectionReader when r
|
||||
// may not be an io.ReaderAt.
|
||||
//
|
||||
// If r is already a ReaderAt, it is returned directly, otherwise if r is an
|
||||
// io.Seeker, a new io.ReaderAt is returned that wraps r after seeking to the
|
||||
// beginning of the file.
|
||||
//
|
||||
// If r is an io.Seeker,
|
||||
// or slow path. The slow path is used when r does not implement io.ReaderAt,
|
||||
// in which case it must discard the data it reads.
|
||||
func newSectionReader(r io.Reader, offset, n int64) io.Reader {
|
||||
if r, ok := r.(io.ReaderAt); ok {
|
||||
return io.NewSectionReader(r, offset, n)
|
||||
}
|
||||
if r, ok := r.(io.ReadSeeker); ok {
|
||||
r.Seek(offset, io.SeekStart)
|
||||
return io.LimitReader(r, n)
|
||||
}
|
||||
// Discard to offset and return a limited reader.
|
||||
_, err := io.CopyN(io.Discard, r, offset)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return io.LimitReader(r, n)
|
||||
}
|
|
@ -1,375 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/cmd/opp/internal/safetensors"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
var stdout io.Writer = os.Stdout
|
||||
|
||||
const usage = `Opp is a tool for pushing and pulling Ollama models.
|
||||
|
||||
Usage:
|
||||
|
||||
opp [flags] <push|pull|import>
|
||||
|
||||
Commands:
|
||||
|
||||
push Upload a model to the Ollama server.
|
||||
pull Download a model from the Ollama server.
|
||||
import Import a model from a local safetensor directory.
|
||||
|
||||
Examples:
|
||||
|
||||
# Pull a model from the Ollama server.
|
||||
opp pull library/llama3.2:latest
|
||||
|
||||
# Push a model to the Ollama server.
|
||||
opp push username/my_model:8b
|
||||
|
||||
# Import a model from a local safetensor directory.
|
||||
opp import /path/to/safetensor
|
||||
|
||||
Envionment Variables:
|
||||
|
||||
OLLAMA_MODELS
|
||||
The directory where models are pushed and pulled from
|
||||
(default ~/.ollama/models).
|
||||
`
|
||||
|
||||
func main() {
|
||||
flag.Usage = func() {
|
||||
fmt.Fprint(os.Stderr, usage)
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
flag.Usage()
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "unknown command %q\n", cmd)
|
||||
}
|
||||
os.Exit(2)
|
||||
return errors.New("unreachable")
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "opp: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
// TODO(bmizerany): configure transport?
|
||||
rc.HTTPClient = &http.Client{Transport: tr}
|
||||
|
||||
var mu sync.Mutex
|
||||
p := make(map[blob.Digest][2]int64) // digest -> [total, downloaded]
|
||||
|
||||
var pb bytes.Buffer
|
||||
printProgress := func() {
|
||||
pb.Reset()
|
||||
mu.Lock()
|
||||
for d, s := range p {
|
||||
// Write progress to a buffer first to avoid blocking
|
||||
// on stdout while holding the lock.
|
||||
stamp := time.Now().Format("2006/01/02 15:04:05")
|
||||
fmt.Fprintf(&pb, "%s %s pulling %d/%d (%.1f%%)\n", stamp, d.Short(), s[1], s[0], 100*float64(s[1])/float64(s[0]))
|
||||
if s[0] == s[1] {
|
||||
delete(p, d)
|
||||
}
|
||||
}
|
||||
mu.Unlock()
|
||||
io.Copy(stdout, &pb)
|
||||
}
|
||||
|
||||
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
if err != nil && !errors.Is(err, ollama.ErrCached) {
|
||||
fmt.Fprintf(stdout, "opp: pull %s ! %v\n", l.Digest.Short(), err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
p[l.Digest] = [2]int64{l.Size, n}
|
||||
mu.Unlock()
|
||||
},
|
||||
})
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
printProgress()
|
||||
case err := <-errc:
|
||||
printProgress()
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: opp push <model>\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
|
||||
model := flag.Arg(0)
|
||||
if model == "" {
|
||||
return fmt.Errorf("missing model argument")
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx = ollama.WithTrace(ctx, &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
switch {
|
||||
case errors.Is(err, ollama.ErrCached):
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %d (existed)", l.Digest.Short(), n)
|
||||
case err != nil:
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %d ! %v\n", l.Digest.Short(), n, err)
|
||||
case n == 0:
|
||||
l := m.Layer(l.Digest)
|
||||
mt, p, _ := mime.ParseMediaType(l.MediaType)
|
||||
mt, _ = strings.CutPrefix(mt, "application/vnd.ollama.image.")
|
||||
switch mt {
|
||||
case "tensor":
|
||||
fmt.Fprintf(stdout, "opp: uploading tensor %s %s\n", l.Digest.Short(), p["name"])
|
||||
default:
|
||||
fmt.Fprintf(stdout, "opp: uploading %s %s\n", l.Digest.Short(), l.MediaType)
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
|
||||
type trackingReader struct {
|
||||
io.Reader
|
||||
n *atomic.Int64
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.Reader.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
return n, err
|
||||
}
|
||||
|
||||
func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("import", flag.ExitOnError)
|
||||
flagAs := flag.String("as", "", "Import using the provided name.")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: opp import <SafetensorDir>\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
if *flagAs == "" {
|
||||
return fmt.Errorf("missing -as flag")
|
||||
}
|
||||
as := ollama.CompleteName(*flagAs)
|
||||
|
||||
dir := cmp.Or(flag.Arg(0), ".")
|
||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||
|
||||
m, err := safetensors.Read(os.DirFS(dir))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var total int64
|
||||
var tt []*safetensors.Tensor
|
||||
for t, err := range m.Tensors() {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tt = append(tt, t)
|
||||
total += t.Size()
|
||||
}
|
||||
|
||||
var n atomic.Int64
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
layers := make([]*ollama.Layer, len(tt))
|
||||
var g errgroup.Group
|
||||
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||
var ctxErr error
|
||||
for i, t := range tt {
|
||||
if ctx.Err() != nil {
|
||||
// The context may cancel AFTER we exit the
|
||||
// loop, and so if we use ctx.Err() after the
|
||||
// loop we may report it as the error that
|
||||
// broke the loop, when it was not. This can
|
||||
// manifest as a false-negative, leading the
|
||||
// user to think their import failed when it
|
||||
// did not, so capture it if and only if we
|
||||
// exit the loop because of a ctx.Err() and
|
||||
// report it.
|
||||
ctxErr = ctx.Err()
|
||||
break
|
||||
}
|
||||
g.Go(func() (err error) {
|
||||
rc, err := t.Reader()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rc.Close()
|
||||
tr := &trackingReader{rc, &n}
|
||||
d, err := c.Import(tr, t.Size())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := rc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
layers[i] = &ollama.Layer{
|
||||
Digest: d,
|
||||
Size: t.Size(),
|
||||
MediaType: mime.FormatMediaType("application/vnd.ollama.image.tensor", map[string]string{
|
||||
"name": t.Name(),
|
||||
"dtype": t.DataType(),
|
||||
"shape": t.Shape().String(),
|
||||
}),
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
done <- func() error {
|
||||
if err := errors.Join(g.Wait(), ctxErr); err != nil {
|
||||
return err
|
||||
}
|
||||
m := &ollama.Manifest{Layers: layers}
|
||||
data, err := json.MarshalIndent(m, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d := blob.DigestFromBytes(data)
|
||||
err = blob.PutBytes(c, d, data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Link(as, d)
|
||||
}()
|
||||
}()
|
||||
|
||||
fmt.Fprintf(stdout, "Importing %d tensors from %s\n", len(tt), dir)
|
||||
|
||||
csiHideCursor(stdout)
|
||||
defer csiShowCursor(stdout)
|
||||
|
||||
csiSavePos(stdout)
|
||||
writeProgress := func() {
|
||||
csiRestorePos(stdout)
|
||||
nn := n.Load()
|
||||
fmt.Fprintf(stdout, "Imported %s/%s bytes (%d%%)%s\n",
|
||||
formatNatural(nn),
|
||||
formatNatural(total),
|
||||
nn*100/total,
|
||||
ansiClearToEndOfLine,
|
||||
)
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
writeProgress()
|
||||
case err := <-done:
|
||||
writeProgress()
|
||||
fmt.Println()
|
||||
fmt.Println("Successfully imported", as)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func formatNatural(n int64) string {
|
||||
switch {
|
||||
case n < 1024:
|
||||
return fmt.Sprintf("%d B", n)
|
||||
case n < 1024*1024:
|
||||
return fmt.Sprintf("%.1f KB", float64(n)/1024)
|
||||
case n < 1024*1024*1024:
|
||||
return fmt.Sprintf("%.1f MB", float64(n)/(1024*1024))
|
||||
default:
|
||||
return fmt.Sprintf("%.1f GB", float64(n)/(1024*1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
const ansiClearToEndOfLine = "\033[K"
|
||||
|
||||
func csiSavePos(w io.Writer) { fmt.Fprint(w, "\033[s") }
|
||||
func csiRestorePos(w io.Writer) { fmt.Fprint(w, "\033[u") }
|
||||
func csiHideCursor(w io.Writer) { fmt.Fprint(w, "\033[?25l") }
|
||||
func csiShowCursor(w io.Writer) { fmt.Fprint(w, "\033[?25h") }
|
Loading…
Add table
Add a link
Reference in a new issue