rerefactor

This commit is contained in:
Michael Yang 2024-02-14 11:29:49 -08:00 committed by jmorganca
parent 823a520266
commit e43648afe5
9 changed files with 224 additions and 251 deletions

View file

@ -18,7 +18,6 @@ import (
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
)
@ -50,7 +49,7 @@ const (
maxUploadPartSize int64 = 1000 * format.MegaByte
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
@ -122,7 +121,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *aut
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
@ -213,7 +212,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
b.done = true
}
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@ -228,7 +227,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
@ -278,9 +277,8 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(authenticate)
token, err := auth.GetAuthToken(ctx, authRedir)
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
if err != nil {
return err
}
@ -365,7 +363,7 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)