diff --git a/lfs/client.go b/lfs/client.go index ceb9f1d1..b686b883 100644 --- a/lfs/client.go +++ b/lfs/client.go @@ -56,18 +56,18 @@ type objectResource struct { Error *objectError `json:"error,omitempty"` } -func (o *objectResource) NewRequest(relation, method string) (*http.Request, Creds, error) { +func (o *objectResource) NewRequest(relation, method string) (*http.Request, error) { rel, ok := o.Rel(relation) if !ok { - return nil, nil, objectRelationDoesNotExist + return nil, objectRelationDoesNotExist } - req, creds, err := newClientRequest(method, rel.Href, rel.Header) + req, err := newClientRequest(method, rel.Href, rel.Header) if err != nil { - return nil, nil, err + return nil, err } - return req, creds, nil + return req, nil } func (o *objectResource) Rel(name string) (*linkRelation, bool) { @@ -106,23 +106,22 @@ func (e *ClientError) Error() string { } func Download(oid string) (io.ReadCloser, int64, *WrappedError) { - req, creds, err := newApiRequest("GET", oid) + req, err := newApiRequest("GET", oid) if err != nil { return nil, 0, Error(err) } - res, obj, wErr := doApiRequest(req, creds) + res, obj, wErr := doApiRequest(req) if wErr != nil { return nil, 0, wErr } LogTransfer("lfs.api.download", res) - - req, creds, err = obj.NewRequest("download", "GET") + req, err = obj.NewRequest("download", "GET") if err != nil { return nil, 0, Error(err) } - res, wErr = doHttpRequest(req, creds) + res, wErr = doHttpRequest(req) if wErr != nil { return nil, 0, wErr } @@ -136,18 +135,18 @@ type byteCloser struct { } func DownloadCheck(oid string) (*objectResource, *WrappedError) { - req, creds, err := newApiRequest("GET", oid) + req, err := newApiRequest("GET", oid) if err != nil { return nil, Error(err) } - res, obj, wErr := doApiRequest(req, creds) + res, obj, wErr := doApiRequest(req) if wErr != nil { return nil, wErr } LogTransfer("lfs.api.download", res) - _, _, err = obj.NewRequest("download", "GET") + _, err = obj.NewRequest("download", "GET") if err != nil { return nil, Error(err) } @@ -156,12 +155,12 @@ func DownloadCheck(oid string) (*objectResource, *WrappedError) { } func DownloadObject(obj *objectResource) (io.ReadCloser, int64, *WrappedError) { - req, creds, err := obj.NewRequest("download", "GET") + req, err := obj.NewRequest("download", "GET") if err != nil { return nil, 0, Error(err) } - res, wErr := doHttpRequest(req, creds) + res, wErr := doHttpRequest(req) if wErr != nil { return nil, 0, wErr } @@ -186,7 +185,7 @@ func Batch(objects []*objectResource, operation string) ([]*objectResource, *Wra return nil, Error(err) } - req, creds, err := newBatchApiRequest() + req, err := newBatchApiRequest() if err != nil { return nil, Error(err) } @@ -197,7 +196,7 @@ func Batch(objects []*objectResource, operation string) ([]*objectResource, *Wra req.Body = &byteCloser{bytes.NewReader(by)} tracerx.Printf("api: batch %d files", len(objects)) - res, objs, wErr := doApiBatchRequest(req, creds) + res, objs, wErr := doApiBatchRequest(req) if wErr != nil { if res == nil { return nil, wErr @@ -242,7 +241,7 @@ func UploadCheck(oidPath string) (*objectResource, *WrappedError) { return nil, Error(err) } - req, creds, err := newApiRequest("POST", oid) + req, err := newApiRequest("POST", oid) if err != nil { return nil, Error(err) } @@ -253,7 +252,7 @@ func UploadCheck(oidPath string) (*objectResource, *WrappedError) { req.Body = &byteCloser{bytes.NewReader(by)} tracerx.Printf("api: uploading (%s)", oid) - res, obj, wErr := doApiRequest(req, creds) + res, obj, wErr := doApiRequest(req) if wErr != nil { return nil, wErr } @@ -291,7 +290,7 @@ func UploadObject(o *objectResource, cb CopyCallback) *WrappedError { Reader: file, } - req, creds, err := o.NewRequest("upload", "PUT") + req, err := o.NewRequest("upload", "PUT") if err != nil { return Error(err) } @@ -308,7 +307,7 @@ func UploadObject(o *objectResource, cb CopyCallback) *WrappedError { req.Body = ioutil.NopCloser(reader) - res, wErr := doHttpRequest(req, creds) + res, wErr := doHttpRequest(req) if wErr != nil { return wErr } @@ -321,7 +320,7 @@ func UploadObject(o *objectResource, cb CopyCallback) *WrappedError { io.Copy(ioutil.Discard, res.Body) res.Body.Close() - req, creds, err = o.NewRequest("verify", "POST") + req, err = o.NewRequest("verify", "POST") if err == objectRelationDoesNotExist { return nil } else if err != nil { @@ -337,7 +336,7 @@ func UploadObject(o *objectResource, cb CopyCallback) *WrappedError { req.Header.Set("Content-Length", strconv.Itoa(len(by))) req.ContentLength = int64(len(by)) req.Body = ioutil.NopCloser(bytes.NewReader(by)) - res, wErr = doHttpRequest(req, creds) + res, wErr = doHttpRequest(req) if wErr != nil { return wErr } @@ -349,7 +348,7 @@ func UploadObject(o *objectResource, cb CopyCallback) *WrappedError { return wErr } -func doHttpRequest(req *http.Request, creds Creds) (*http.Response, *WrappedError) { +func doHttpRequest(req *http.Request) (*http.Response, *WrappedError) { res, err := Config.HttpClient().Do(req) if res == nil { res = &http.Response{ @@ -365,7 +364,6 @@ func doHttpRequest(req *http.Request, creds Creds) (*http.Response, *WrappedErro if err != nil { wErr = Errorf(err, "Error for %s %s", res.Request.Method, res.Request.URL) } else { - saveCredentials(creds, res) wErr = handleResponse(res) } @@ -380,10 +378,21 @@ func doHttpRequest(req *http.Request, creds Creds) (*http.Response, *WrappedErro return res, wErr } -func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Request) (*http.Response, *WrappedError) { - res, wErr := doHttpRequest(req, creds) +func doApiRequestWithRedirects(req *http.Request, via []*http.Request, useCreds bool) (*http.Response, *WrappedError) { + var creds Creds + if useCreds { + c, err := getCreds(req) + if err != nil { + return nil, Error(err) + } + creds = c + } + + res, wErr := doHttpRequest(req) if wErr != nil { return res, wErr + } else { + saveCredentials(creds, res) } if res.StatusCode == 307 { @@ -394,7 +403,7 @@ func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Reque redirectTo = locurl.String() } - redirectedReq, redirectedCreds, err := newClientRequest(req.Method, redirectTo, nil) + redirectedReq, err := newClientRequest(req.Method, redirectTo, nil) if err != nil { return res, Errorf(err, err.Error()) } @@ -422,15 +431,15 @@ func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Reque return res, Errorf(err, err.Error()) } - return doApiRequestWithRedirects(redirectedReq, redirectedCreds, via) + return doApiRequestWithRedirects(redirectedReq, via, useCreds) } return res, nil } -func doApiRequest(req *http.Request, creds Creds) (*http.Response, *objectResource, *WrappedError) { +func doApiRequest(req *http.Request) (*http.Response, *objectResource, *WrappedError) { via := make([]*http.Request, 0, 4) - res, wErr := doApiRequestWithRedirects(req, creds, via) + res, wErr := doApiRequestWithRedirects(req, via, true) if wErr != nil { return res, nil, wErr } @@ -450,9 +459,9 @@ func doApiRequest(req *http.Request, creds Creds) (*http.Response, *objectResour // the repo will be marked as having private access and the request will be // re-run. When the repo is marked as having private access, credentials will // be retrieved. -func doApiBatchRequest(req *http.Request, creds Creds) (*http.Response, []*objectResource, *WrappedError) { +func doApiBatchRequest(req *http.Request) (*http.Response, []*objectResource, *WrappedError) { via := make([]*http.Request, 0, 4) - res, wErr := doApiRequestWithRedirects(req, creds, via) + res, wErr := doApiRequestWithRedirects(req, via, Config.PrivateAccess()) if wErr != nil { return res, nil, wErr @@ -535,7 +544,7 @@ func saveCredentials(creds Creds, res *http.Response) { } } -func newApiRequest(method, oid string) (*http.Request, Creds, error) { +func newApiRequest(method, oid string) (*http.Request, error) { endpoint := Config.Endpoint() objectOid := oid operation := "download" @@ -559,22 +568,22 @@ func newApiRequest(method, oid string) (*http.Request, Creds, error) { u, err := ObjectUrl(endpoint, objectOid) if err != nil { - return nil, nil, err + return nil, err } - req, creds, err := newClientRequest(method, u.String(), res.Header) + req, err := newClientRequest(method, u.String(), res.Header) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("Accept", mediaType) - return req, creds, nil + return req, nil } -func newClientRequest(method, rawurl string, header map[string]string) (*http.Request, Creds, error) { +func newClientRequest(method, rawurl string, header map[string]string) (*http.Request, error) { req, err := http.NewRequest(method, rawurl, nil) if err != nil { - return nil, nil, err + return nil, err } for key, value := range header { @@ -582,15 +591,11 @@ func newClientRequest(method, rawurl string, header map[string]string) (*http.Re } req.Header.Set("User-Agent", UserAgent) - creds, err := getCreds(req) - if err != nil { - return nil, nil, err - } - return req, creds, nil + return req, nil } -func newBatchApiRequest() (*http.Request, Creds, error) { +func newBatchApiRequest() (*http.Request, error) { endpoint := Config.Endpoint() res, err := sshAuthenticate(endpoint, "download", "") @@ -606,12 +611,12 @@ func newBatchApiRequest() (*http.Request, Creds, error) { u, err := ObjectUrl(endpoint, "batch") if err != nil { - return nil, nil, err + return nil, err } - req, creds, err := newBatchClientRequest("POST", u.String()) + req, err := newBatchClientRequest("POST", u.String()) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("Accept", mediaType) @@ -621,30 +626,18 @@ func newBatchApiRequest() (*http.Request, Creds, error) { } } - return req, creds, nil + return req, nil } -func newBatchClientRequest(method, rawurl string) (*http.Request, Creds, error) { +func newBatchClientRequest(method, rawurl string) (*http.Request, error) { req, err := http.NewRequest(method, rawurl, nil) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("User-Agent", UserAgent) - // Get the creds if we're private - if Config.PrivateAccess() { - // The PrivateAccess() check can be pushed down and this block simplified - // once everything goes through the batch endpoint. - creds, err := getCreds(req) - if err != nil { - return nil, nil, err - } - - return req, creds, nil - } - - return req, nil, nil + return req, nil } func getCreds(req *http.Request) (Creds, error) { diff --git a/lfs/download_test.go b/lfs/download_test.go index bcb7e5db..4f163e98 100644 --- a/lfs/download_test.go +++ b/lfs/download_test.go @@ -72,7 +72,7 @@ func TestSuccessfulDownload(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != expectedAuth(t, server) { + if r.Header.Get("Authorization") != "" { t.Error("Invalid Authorization") } @@ -204,7 +204,7 @@ func TestSuccessfulDownloadWithRedirects(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != expectedAuth(t, server) { + if r.Header.Get("Authorization") != "" { t.Error("Invalid Authorization") }