cmd: add retry/backoff (#10069)

This commit adds retry/backoff to the registry client for pull requests.

Also, revert progress indication to match original client's until we can
"get it right."

Also, make WithTrace wrap existing traces instead of clobbering them.
This allows clients to compose traces.
This commit is contained in:
Blake Mizerany 2025-04-15 23:24:44 -07:00 committed by GitHub
parent ccb7eb8135
commit 1e7f62cb42
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 233 additions and 158 deletions

View file

@ -808,13 +808,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
if resp.Completed == 0 {
// This is the initial status update for the
// layer, which the server sends before
// beginning the download, for clients to
// compute total size and prepare for
// downloads, if needed.
//
// Skipping this here to avoid showing a 0%
// progress bar, which *should* clue the user
// into the fact that many things are being
// downloaded and that the current active
// download is not that last. However, in rare
// cases it seems to be triggering to some, and
// it isn't worth explaining, so just ignore
// and regress to the old UI that keeps giving
// you the "But wait, there is more!" after
// each "100% done" bar, which is "better."
return nil
}
if spinner != nil {
spinner.Stop()
}
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
name = strings.TrimSpace(name)
if isDigest {
name = name[:min(12, len(name))]
}
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
bars[resp.Digest] = bar
p.Add(resp.Digest, bar)
}
@ -834,11 +859,7 @@ func PullHandler(cmd *cobra.Command, args []string) error {
}
request := api.PullRequest{Name: args[0], Insecure: insecure}
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
return err
}
return nil
return client.Pull(cmd.Context(), &request, fn)
}
type generateContextKey string

View file

@ -107,15 +107,20 @@ func DefaultCache() (*blob.DiskCache, error) {
//
// In both cases, the code field is optional and may be empty.
type Error struct {
Status int `json:"-"` // TODO(bmizerany): remove this
status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
// Temporary reports if the error is temporary (e.g. 5xx status code).
func (e *Error) Temporary() bool {
return e.status >= 500
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("registry responded with status ")
b.WriteString(strconv.Itoa(e.Status))
b.WriteString(strconv.Itoa(e.status))
if e.Code != "" {
b.WriteString(": code ")
b.WriteString(e.Code)
@ -129,7 +134,7 @@ func (e *Error) Error() string {
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.Int("status", e.status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
@ -428,12 +433,12 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
type trackingReader struct {
l *Layer
r io.Reader
update func(l *Layer, n int64, err error)
update func(n int64)
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(r.l, int64(n), nil)
r.update(int64(n))
return
}
@ -478,111 +483,120 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
expected += l.Size
}
var received atomic.Int64
var completed atomic.Int64
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range layers {
var received atomic.Int64
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
received.Add(l.Size)
completed.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue
}
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
func() {
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
t.update(l, received.Load(), err)
return
}
defer func() {
// Close the chunked writer when all chunks are
// downloaded.
//
// This is done as a background task in the
// group to allow the next layer to start while
// we wait for the final chunk in this layer to
// complete. It also ensures this is done
// before we exit Pull.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}()
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, received.Load(), err)
break
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
recv := received.Add(cs.Chunk.Size())
completed.Add(cs.Chunk.Size())
t.update(l, recv, ErrCached)
continue
}
received.Add(cs.Chunk.Size())
} else {
t.update(l, 0, err)
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
} else {
t.update(l, received.Load(), err)
}
wg.Done()
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
wg.Done()
}()
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
tr := &trackingReader{
l: l,
r: res.Body,
update: func(n int64) {
completed.Add(n)
recv := received.Add(n)
t.update(l, recv, nil)
},
}
return chunked.Put(cs.Chunk, cs.Digest, tr)
})
}
}()
}
if err := g.Wait(); err != nil {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
if recv := completed.Load(); recv != expected {
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, recv, expected)
}
md := blob.DigestFromBytes(m.Data)
@ -973,7 +987,7 @@ func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error)
return nil, ErrModelNotFound
}
re.Status = res.StatusCode
re.status = res.StatusCode
return nil, &re
}
return res, nil

View file

@ -154,7 +154,7 @@ func okHandler(w http.ResponseWriter, r *http.Request) {
func checkErrCode(t *testing.T, err error, status int, code string) {
t.Helper()
var e *Error
if !errors.As(err, &e) || e.Status != status || e.Code != code {
if !errors.As(err, &e) || e.status != status || e.Code != code {
t.Errorf("err = %v; want %v %v", err, status, code)
}
}
@ -860,8 +860,8 @@ func TestPullChunksumStreaming(t *testing.T) {
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 {
t.Fatalf("got %d, want 1", g)
if g := <-update; g != 3 {
t.Fatalf("got %d, want 3", g)
}
csw.Close()
testutil.Check(t, <-errc)
@ -944,10 +944,10 @@ func TestPullChunksumsCached(t *testing.T) {
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 3 {
if g := written.Load(); g != 5 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g)
t.Fatalf("cached %d bytes, want 5", g)
}
}

View file

@ -34,10 +34,27 @@ func (t *Trace) update(l *Layer, n int64, err error) {
type traceKey struct{}
// WithTrace returns a context derived from ctx that uses t to report trace
// events.
// WithTrace adds a trace to the context for transfer progress reporting.
func WithTrace(ctx context.Context, t *Trace) context.Context {
return context.WithValue(ctx, traceKey{}, t)
old := traceFromContext(ctx)
if old == t {
// No change, return the original context. This also prevents
// infinite recursion below, if the caller passes the same
// Trace.
return ctx
}
// Create a new Trace that wraps the old one, if any. If we used the
// same pointer t, we end up with a recursive structure.
composed := &Trace{
Update: func(l *Layer, n int64, err error) {
if old != nil {
old.update(l, n, err)
}
t.update(l, n, err)
},
}
return context.WithValue(ctx, traceKey{}, composed)
}
var emptyTrace = &Trace{}

View file

@ -9,13 +9,14 @@ import (
"fmt"
"io"
"log/slog"
"maps"
"net/http"
"slices"
"sync"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/internal/backoff"
)
// Local implements an http.Handler for handling local Ollama API model
@ -265,68 +266,81 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
}
return err
}
return enc.Encode(progressUpdateJSON{Status: "success"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil
}
maybeFlush := func() {
var mu sync.Mutex
var progress []progressUpdateJSON
flushProgress := func() {
mu.Lock()
progress := slices.Clone(progress) // make a copy and release lock before encoding to the wire
mu.Unlock()
for _, p := range progress {
enc.Encode(p)
}
fl, _ := w.(http.Flusher)
if fl != nil {
fl.Flush()
}
}
defer maybeFlush()
var mu sync.Mutex
progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progressCopy {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
Completed: n,
})
}
}
defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
t := time.NewTicker(1<<63 - 1) // "unstarted" timer
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
if err != nil && !errors.Is(err, ollama.ErrCached) {
s.Logger.Error("pulling", "model", p.model(), "error", err)
return
}
mu.Lock()
progress[l] += n
mu.Unlock()
func() {
mu.Lock()
defer mu.Unlock()
for i, p := range progress {
if p.Digest == l.Digest {
progress[i].Completed = n
return
}
}
progress = append(progress, progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
})
}()
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
},
})
done := make(chan error, 1)
go func() {
done <- s.Client.Pull(ctx, p.model())
go func() (err error) {
defer func() { done <- err }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
err := s.Client.Pull(ctx, p.model())
var oe *ollama.Error
if errors.As(err, &oe) && oe.Temporary() {
continue // retry
}
return err
}
return nil
}()
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
for {
select {
case <-t.C:
@ -341,7 +355,13 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
status = fmt.Sprintf("error: %v", err)
}
enc.Encode(progressUpdateJSON{Status: status})
return nil
}
// Emulate old client pull progress (for now):
enc.Encode(progressUpdateJSON{Status: "verifying sha256 digest"})
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
enc.Encode(progressUpdateJSON{Status: "success"})
return nil
}
}

View file

@ -78,7 +78,12 @@ func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
ctx := ollama.WithTrace(t.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
t.Logf("update: %s %d %v", l.Digest, n, err)
},
})
req := httptest.NewRequestWithContext(ctx, method, path, strings.NewReader(body))
return s.sendRequest(t, req)
}
@ -184,36 +189,34 @@ func TestServerPull(t *testing.T) {
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Errorf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
if strings.TrimSpace(gotlines) == "" {
gotlines = "<empty>"
}
t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Errorf("! missing %q in body", want)
t.Errorf("\t! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Errorf("! unexpected %q in body", want)
t.Errorf("\t! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
got := s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling manifest"}
{"digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying sha256 digest"}
{"status":"writing manifest"}
{"status":"success"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)