diff --git a/api/client.go b/api/client.go index 778f64ec..50e6dec6 100644 --- a/api/client.go +++ b/api/client.go @@ -1 +1,62 @@ package api + +import ( + "net/http" + + "github.com/github/git-lfs/auth" + "github.com/github/git-lfs/config" + "github.com/github/git-lfs/errutil" + "github.com/github/git-lfs/httputil" +) + +// doLegacyApiRequest runs the request to the LFS legacy API. +func DoLegacyRequest(req *http.Request) (*http.Response, *ObjectResource, error) { + via := make([]*http.Request, 0, 4) + res, err := httputil.DoHttpRequestWithRedirects(req, via, true) + if err != nil { + return res, nil, err + } + + obj := &ObjectResource{} + err = httputil.DecodeResponse(res, obj) + + if err != nil { + httputil.SetErrorResponseContext(err, res) + return nil, nil, err + } + + return res, obj, nil +} + +// doApiBatchRequest runs the request to the LFS batch API. If the API returns a +// 401, the repo will be marked as having private access and the request will be +// re-run. When the repo is marked as having private access, credentials will +// be retrieved. +func DoBatchRequest(req *http.Request) (*http.Response, []*ObjectResource, error) { + res, err := DoRequest(req, config.Config.PrivateAccess(auth.GetOperationForRequest(req))) + + if err != nil { + if res != nil && res.StatusCode == 401 { + return res, nil, errutil.NewAuthError(err) + } + return res, nil, err + } + + var objs map[string][]*ObjectResource + err = httputil.DecodeResponse(res, &objs) + + if err != nil { + httputil.SetErrorResponseContext(err, res) + } + + return res, objs["objects"], err +} + +// DoRequest runs a request to the LFS API, without parsing the response +// body. If the API returns a 401, the repo will be marked as having private +// access and the request will be re-run. When the repo is marked as having +// private access, credentials will be retrieved. +func DoRequest(req *http.Request, useCreds bool) (*http.Response, error) { + via := make([]*http.Request, 0, 4) + return httputil.DoHttpRequestWithRedirects(req, via, useCreds) +} diff --git a/auth/credentials.go b/auth/credentials.go index cfc70fd8..39d40f1f 100644 --- a/auth/credentials.go +++ b/auth/credentials.go @@ -14,7 +14,6 @@ import ( "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" - "github.com/github/git-lfs/httputil" "github.com/github/git-lfs/vendor/_nuts/github.com/rubyist/tracerx" ) @@ -51,7 +50,7 @@ func GetCreds(req *http.Request) (Creds, error) { } func getCredURLForAPI(req *http.Request) (*url.URL, error) { - operation := httputil.GetOperationForRequest(req) + operation := GetOperationForRequest(req) apiUrl, err := url.Parse(config.Config.Endpoint(operation).Url) if err != nil { return nil, err @@ -120,7 +119,7 @@ func setCredURLFromNetrc(req *http.Request) bool { } func skipCredsCheck(req *http.Request) bool { - if config.Config.NtlmAccess(httputil.GetOperationForRequest(req)) { + if config.Config.NtlmAccess(GetOperationForRequest(req)) { return false } @@ -242,7 +241,7 @@ func execCredsCommand(input Creds, subCommand string) (Creds, error) { } func setRequestAuthFromUrl(req *http.Request, u *url.URL) bool { - if !config.Config.NtlmAccess(httputil.GetOperationForRequest(req)) && u.User != nil { + if !config.Config.NtlmAccess(GetOperationForRequest(req)) && u.User != nil { if pass, ok := u.User.Password(); ok { fmt.Fprintln(os.Stderr, "warning: current Git remote contains credentials") setRequestAuth(req, u.User.Username(), pass) @@ -254,7 +253,7 @@ func setRequestAuthFromUrl(req *http.Request, u *url.URL) bool { } func setRequestAuth(req *http.Request, user, pass string) { - if config.Config.NtlmAccess(httputil.GetOperationForRequest(req)) { + if config.Config.NtlmAccess(GetOperationForRequest(req)) { return } @@ -281,3 +280,12 @@ func SetCredentialsFunc(f CredentialFunc) CredentialFunc { execCreds = f return oldf } + +// GetOperationForRequest determines the operation type for a http.Request +func GetOperationForRequest(req *http.Request) string { + operation := "download" + if req.Method == "POST" || req.Method == "PUT" { + operation = "upload" + } + return operation +} diff --git a/httputil/http.go b/httputil/http.go index 35c1923f..aab3dd42 100644 --- a/httputil/http.go +++ b/httputil/http.go @@ -102,21 +102,6 @@ func (c *HttpClient) Do(req *http.Request) (*http.Response, error) { return res, err } -func NewHttpRequest(method, rawurl string, header map[string]string) (*http.Request, error) { - req, err := http.NewRequest(method, rawurl, nil) - if err != nil { - return nil, err - } - - for key, value := range header { - req.Header.Set(key, value) - } - - req.Header.Set("User-Agent", UserAgent) - - return req, nil -} - // NewHttpClient returns a new HttpClient for the given host (which may be "host:port") func NewHttpClient(c *config.Configuration, host string) *HttpClient { httpClientsMutex.Lock() @@ -354,15 +339,6 @@ func TraceHttpReq(req *http.Request) string { return fmt.Sprintf("%s %s", req.Method, strings.SplitN(req.URL.String(), "?", 2)[0]) } -// GetOperationForRequest determines the operation type for a http.Request -func GetOperationForRequest(req *http.Request) string { - operation := "download" - if req.Method == "POST" || req.Method == "PUT" { - operation = "upload" - } - return operation -} - func init() { UserAgent = config.VersionDesc } diff --git a/auth/ntlm.go b/httputil/ntlm.go similarity index 91% rename from auth/ntlm.go rename to httputil/ntlm.go index d9f0b88b..8d081717 100644 --- a/auth/ntlm.go +++ b/httputil/ntlm.go @@ -1,4 +1,4 @@ -package auth +package httputil import ( "bytes" @@ -12,12 +12,12 @@ import ( "strings" "sync/atomic" + "github.com/github/git-lfs/auth" "github.com/github/git-lfs/config" - "github.com/github/git-lfs/httputil" "github.com/github/git-lfs/vendor/_nuts/github.com/ThomsonReutersEikon/go-ntlm/ntlm" ) -func ntlmClientSession(c *config.Configuration, creds Creds) (ntlm.ClientSession, error) { +func ntlmClientSession(c *config.Configuration, creds auth.Creds) (ntlm.ClientSession, error) { if c.NtlmSession != nil { return c.NtlmSession, nil } @@ -39,13 +39,13 @@ func ntlmClientSession(c *config.Configuration, creds Creds) (ntlm.ClientSession return session, nil } -func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { +func doNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { handReq, err := cloneRequest(request) if err != nil { return nil, err } - res, err := httputil.NewHttpClient(config.Config, handReq.Host).Do(handReq) + res, err := NewHttpClient(config.Config, handReq.Host).Do(handReq) if err != nil && res == nil { return nil, err } @@ -53,7 +53,7 @@ func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { //If the status is 401 then we need to re-authenticate, otherwise it was successful if res.StatusCode == 401 { - creds, err := GetCreds(request) + creds, err := auth.GetCreds(request) if err != nil { return nil, err } @@ -80,10 +80,10 @@ func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { //If the status is 401 then we need to re-authenticate if res.StatusCode == 401 && retry == true { - return DoNTLMRequest(challengeReq, false) + return doNTLMRequest(challengeReq, false) } - SaveCredentials(creds, res) + auth.SaveCredentials(creds, res) return res, nil } @@ -92,7 +92,7 @@ func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { func negotiate(request *http.Request, message string) ([]byte, error) { request.Header.Add("Authorization", message) - res, err := httputil.NewHttpClient(config.Config, request.Host).Do(request) + res, err := NewHttpClient(config.Config, request.Host).Do(request) if res == nil && err != nil { return nil, err @@ -109,7 +109,7 @@ func negotiate(request *http.Request, message string) ([]byte, error) { return ret, nil } -func challenge(request *http.Request, challengeBytes []byte, creds Creds) (*http.Response, error) { +func challenge(request *http.Request, challengeBytes []byte, creds auth.Creds) (*http.Response, error) { challenge, err := ntlm.ParseChallengeMessage(challengeBytes) if err != nil { return nil, err @@ -128,7 +128,7 @@ func challenge(request *http.Request, challengeBytes []byte, creds Creds) (*http authMsg := base64.StdEncoding.EncodeToString(authenticate.Bytes()) request.Header.Add("Authorization", "NTLM "+authMsg) - return httputil.NewHttpClient(config.Config, request.Host).Do(request) + return NewHttpClient(config.Config, request.Host).Do(request) } func parseChallengeResponse(response *http.Response) ([]byte, error) { diff --git a/auth/ntlm_test.go b/httputil/ntlm_test.go similarity index 99% rename from auth/ntlm_test.go rename to httputil/ntlm_test.go index 05ab23e7..1f4e51f9 100644 --- a/auth/ntlm_test.go +++ b/httputil/ntlm_test.go @@ -1,4 +1,4 @@ -package auth +package httputil import ( "bytes" diff --git a/httputil/request.go b/httputil/request.go new file mode 100644 index 00000000..9952c26d --- /dev/null +++ b/httputil/request.go @@ -0,0 +1,186 @@ +package httputil + +import ( + "bytes" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + + "github.com/github/git-lfs/auth" + "github.com/github/git-lfs/config" + "github.com/github/git-lfs/errutil" + + "github.com/github/git-lfs/vendor/_nuts/github.com/rubyist/tracerx" +) + +type ClientError struct { + Message string `json:"message"` + DocumentationUrl string `json:"documentation_url,omitempty"` + RequestId string `json:"request_id,omitempty"` +} + +func (e *ClientError) Error() string { + msg := e.Message + if len(e.DocumentationUrl) > 0 { + msg += "\nDocs: " + e.DocumentationUrl + } + if len(e.RequestId) > 0 { + msg += "\nRequest ID: " + e.RequestId + } + return msg +} + +// Internal http request management +func doHttpRequest(req *http.Request, creds auth.Creds) (*http.Response, error) { + var ( + res *http.Response + err error + ) + + if config.Config.NtlmAccess(auth.GetOperationForRequest(req)) { + res, err = doNTLMRequest(req, true) + } else { + res, err = NewHttpClient(config.Config, req.Host).Do(req) + } + + if res == nil { + res = &http.Response{ + StatusCode: 0, + Header: make(http.Header), + Request: req, + Body: ioutil.NopCloser(bytes.NewBufferString("")), + } + } + + if err != nil { + if errutil.IsAuthError(err) { + SetAuthType(req, res) + doHttpRequest(req, creds) + } else { + err = errutil.Error(err) + } + } else { + // TODO(sinbad) stop handling the response here, separate response processing to api package + err = handleResponse(res, creds) + } + + if err != nil { + if res != nil { + SetErrorResponseContext(err, res) + } else { + setErrorRequestContext(err, req) + } + } + + return res, err +} + +// DoHttpRequest performs a single HTTP request +func DoHttpRequest(req *http.Request, useCreds bool) (*http.Response, error) { + var creds auth.Creds + if useCreds { + c, err := auth.GetCreds(req) + if err != nil { + return nil, err + } + creds = c + } + + return doHttpRequest(req, creds) +} + +// DoHttpRequestWithRedirects runs a HTTP request and responds to redirects +func DoHttpRequestWithRedirects(req *http.Request, via []*http.Request, useCreds bool) (*http.Response, error) { + var creds auth.Creds + if useCreds { + c, err := auth.GetCreds(req) + if err != nil { + return nil, err + } + creds = c + } + + res, err := doHttpRequest(req, creds) + if err != nil { + return res, err + } + + if res.StatusCode == 307 { + redirectTo := res.Header.Get("Location") + locurl, err := url.Parse(redirectTo) + if err == nil && !locurl.IsAbs() { + locurl = req.URL.ResolveReference(locurl) + redirectTo = locurl.String() + } + + redirectedReq, err := NewHttpRequest(req.Method, redirectTo, nil) + if err != nil { + return res, errutil.Errorf(err, err.Error()) + } + + via = append(via, req) + + // Avoid seeking and re-wrapping the CountingReadCloser, just get the "real" body + realBody := req.Body + if wrappedBody, ok := req.Body.(*CountingReadCloser); ok { + realBody = wrappedBody.ReadCloser + } + + seeker, ok := realBody.(io.Seeker) + if !ok { + return res, errutil.Errorf(nil, "Request body needs to be an io.Seeker to handle redirects.") + } + + if _, err := seeker.Seek(0, 0); err != nil { + return res, errutil.Error(err) + } + redirectedReq.Body = realBody + redirectedReq.ContentLength = req.ContentLength + + if err = CheckRedirect(redirectedReq, via); err != nil { + return res, errutil.Errorf(err, err.Error()) + } + + return DoHttpRequestWithRedirects(redirectedReq, via, useCreds) + } + + return res, nil +} + +// NewHttpRequest creates a template request, with the given headers & UserAgent supplied +func NewHttpRequest(method, rawurl string, header map[string]string) (*http.Request, error) { + req, err := http.NewRequest(method, rawurl, nil) + if err != nil { + return nil, err + } + + for key, value := range header { + req.Header.Set(key, value) + } + + req.Header.Set("User-Agent", UserAgent) + + return req, nil +} + +func SetAuthType(req *http.Request, res *http.Response) { + authType := GetAuthType(res) + operation := auth.GetOperationForRequest(req) + config.Config.SetAccess(operation, authType) + tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", authType) +} + +func GetAuthType(res *http.Response) string { + auth := res.Header.Get("Www-Authenticate") + if len(auth) < 1 { + auth = res.Header.Get("Lfs-Authenticate") + } + + if strings.HasPrefix(strings.ToLower(auth), "ntlm") { + return "ntlm" + } + + return "basic" +} diff --git a/lfs/client_error_test.go b/httputil/request_error_test.go similarity index 99% rename from lfs/client_error_test.go rename to httputil/request_error_test.go index 3816dcb6..aec638fe 100644 --- a/lfs/client_error_test.go +++ b/httputil/request_error_test.go @@ -1,4 +1,4 @@ -package lfs +package httputil import ( "bytes" diff --git a/httputil/response.go b/httputil/response.go new file mode 100644 index 00000000..59b67c2b --- /dev/null +++ b/httputil/response.go @@ -0,0 +1,127 @@ +package httputil + +import ( + "encoding/json" + "fmt" + "io" + "io/ioutil" + "net/http" + "regexp" + + "github.com/github/git-lfs/auth" + "github.com/github/git-lfs/config" + "github.com/github/git-lfs/errutil" +) + +var ( + lfsMediaTypeRE = regexp.MustCompile(`\Aapplication/vnd\.git\-lfs\+json(;|\z)`) + jsonMediaTypeRE = regexp.MustCompile(`\Aapplication/json(;|\z)`) + hiddenHeaders = map[string]bool{ + "Authorization": true, + } + + defaultErrors = map[int]string{ + 400: "Client error: %s", + 401: "Authorization error: %s\nCheck that you have proper access to the repository", + 403: "Authorization error: %s\nCheck that you have proper access to the repository", + 404: "Repository or object not found: %s\nCheck that it exists and that you have proper access to it", + 500: "Server error: %s", + } +) + +// DecodeResponse attempts to decode the contents of the response as a JSON object +func DecodeResponse(res *http.Response, obj interface{}) error { + ctype := res.Header.Get("Content-Type") + if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) { + return nil + } + + err := json.NewDecoder(res.Body).Decode(obj) + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + if err != nil { + return errutil.Errorf(err, "Unable to parse HTTP response for %s", TraceHttpReq(res.Request)) + } + + return nil +} + +// GetDefaultError returns the default text for standard error codes (blank if none) +func GetDefaultError(code int) string { + if s, ok := defaultErrors[code]; ok { + return s + } + return "" +} + +// Check the response from a HTTP request for problems +func handleResponse(res *http.Response, creds auth.Creds) error { + auth.SaveCredentials(creds, res) + + if res.StatusCode < 400 { + return nil + } + + defer func() { + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + }() + + cliErr := &ClientError{} + err := DecodeResponse(res, cliErr) + if err == nil { + if len(cliErr.Message) == 0 { + err = defaultError(res) + } else { + err = errutil.Error(cliErr) + } + } + + if res.StatusCode == 401 { + return errutil.NewAuthError(err) + } + + if res.StatusCode > 499 && res.StatusCode != 501 && res.StatusCode != 509 { + return errutil.NewFatalError(err) + } + + return err +} + +func defaultError(res *http.Response) error { + var msgFmt string + + if f, ok := defaultErrors[res.StatusCode]; ok { + msgFmt = f + } else if res.StatusCode < 500 { + msgFmt = defaultErrors[400] + fmt.Sprintf(" from HTTP %d", res.StatusCode) + } else { + msgFmt = defaultErrors[500] + fmt.Sprintf(" from HTTP %d", res.StatusCode) + } + + return errutil.Error(fmt.Errorf(msgFmt, res.Request.URL)) +} + +func SetErrorResponseContext(err error, res *http.Response) { + errutil.ErrorSetContext(err, "Status", res.Status) + setErrorHeaderContext(err, "Request", res.Header) + setErrorRequestContext(err, res.Request) +} + +func setErrorRequestContext(err error, req *http.Request) { + errutil.ErrorSetContext(err, "Endpoint", config.Config.Endpoint(auth.GetOperationForRequest(req)).Url) + errutil.ErrorSetContext(err, "URL", TraceHttpReq(req)) + setErrorHeaderContext(err, "Response", req.Header) +} + +func setErrorHeaderContext(err error, prefix string, head http.Header) { + for key, _ := range head { + contextKey := fmt.Sprintf("%s:%s", prefix, key) + if _, skip := hiddenHeaders[key]; skip { + errutil.ErrorSetContext(err, contextKey, "--") + } else { + errutil.ErrorSetContext(err, contextKey, head.Get(key)) + } + } +} diff --git a/lfs/client.go b/lfs/client.go index 3cf1e336..016c80e1 100644 --- a/lfs/client.go +++ b/lfs/client.go @@ -11,9 +11,7 @@ import ( "os" "path" "path/filepath" - "regexp" "strconv" - "strings" "github.com/github/git-lfs/api" "github.com/github/git-lfs/auth" @@ -29,39 +27,6 @@ const ( mediaType = "application/vnd.git-lfs+json; charset=utf-8" ) -var ( - lfsMediaTypeRE = regexp.MustCompile(`\Aapplication/vnd\.git\-lfs\+json(;|\z)`) - jsonMediaTypeRE = regexp.MustCompile(`\Aapplication/json(;|\z)`) - hiddenHeaders = map[string]bool{ - "Authorization": true, - } - - defaultErrors = map[int]string{ - 400: "Client error: %s", - 401: "Authorization error: %s\nCheck that you have proper access to the repository", - 403: "Authorization error: %s\nCheck that you have proper access to the repository", - 404: "Repository or object not found: %s\nCheck that it exists and that you have proper access to it", - 500: "Server error: %s", - } -) - -type ClientError struct { - Message string `json:"message"` - DocumentationUrl string `json:"documentation_url,omitempty"` - RequestId string `json:"request_id,omitempty"` -} - -func (e *ClientError) Error() string { - msg := e.Message - if len(e.DocumentationUrl) > 0 { - msg += "\nDocs: " + e.DocumentationUrl - } - if len(e.RequestId) > 0 { - msg += "\nRequest ID: " + e.RequestId - } - return msg -} - // Download will attempt to download the object with the given oid. The batched // API will be used, but if the server does not implement the batch operations // it will fall back to the legacy API. @@ -98,7 +63,7 @@ func DownloadLegacy(oid string) (io.ReadCloser, int64, error) { return nil, 0, errutil.Error(err) } - res, obj, err := doLegacyApiRequest(req) + res, obj, err := api.DoLegacyRequest(req) if err != nil { return nil, 0, err } @@ -108,7 +73,7 @@ func DownloadLegacy(oid string) (io.ReadCloser, int64, error) { return nil, 0, errutil.Error(err) } - res, err = doHttpRequestWithCreds(req) + res, err = httputil.DoHttpRequest(req, true) if err != nil { return nil, 0, err } @@ -127,7 +92,7 @@ func DownloadCheck(oid string) (*api.ObjectResource, error) { return nil, errutil.Error(err) } - res, obj, err := doLegacyApiRequest(req) + res, obj, err := api.DoLegacyRequest(req) if err != nil { return nil, err } @@ -147,7 +112,7 @@ func DownloadObject(obj *api.ObjectResource) (io.ReadCloser, int64, error) { return nil, 0, errutil.Error(err) } - res, err := doHttpRequestWithCreds(req) + res, err := httputil.DoHttpRequest(req, true) if err != nil { return nil, 0, errutil.NewRetriableError(err) } @@ -184,7 +149,7 @@ func Batch(objects []*api.ObjectResource, operation string) ([]*api.ObjectResour tracerx.Printf("api: batch %d files", len(objects)) - res, objs, err := doApiBatchRequest(req) + res, objs, err := api.DoBatchRequest(req) if err != nil { @@ -197,7 +162,7 @@ func Batch(objects []*api.ObjectResource, operation string) ([]*api.ObjectResour } if errutil.IsAuthError(err) { - setAuthType(req, res) + httputil.SetAuthType(req, res) return Batch(objects, operation) } @@ -248,11 +213,11 @@ func UploadCheck(oidPath string) (*api.ObjectResource, error) { req.Body = &byteCloser{bytes.NewReader(by)} tracerx.Printf("api: uploading (%s)", oid) - res, obj, err := doLegacyApiRequest(req) + res, obj, err := api.DoLegacyRequest(req) if err != nil { if errutil.IsAuthError(err) { - setAuthType(req, res) + httputil.SetAuthType(req, res) return UploadCheck(oidPath) } @@ -310,7 +275,7 @@ func UploadObject(o *api.ObjectResource, cb progress.CopyCallback) error { req.ContentLength = o.Size req.Body = ioutil.NopCloser(reader) - res, err := doHttpRequestWithCreds(req) + res, err := httputil.DoHttpRequest(req, true) if err != nil { return errutil.NewRetriableError(err) } @@ -347,7 +312,7 @@ func UploadObject(o *api.ObjectResource, cb progress.CopyCallback) error { req.Header.Set("Content-Length", strconv.Itoa(len(by))) req.ContentLength = int64(len(by)) req.Body = ioutil.NopCloser(bytes.NewReader(by)) - res, err = doAPIRequest(req, true) + res, err = api.DoRequest(req, true) if err != nil { return err } @@ -359,234 +324,6 @@ func UploadObject(o *api.ObjectResource, cb progress.CopyCallback) error { return err } -// doLegacyApiRequest runs the request to the LFS legacy API. -func doLegacyApiRequest(req *http.Request) (*http.Response, *api.ObjectResource, error) { - via := make([]*http.Request, 0, 4) - res, err := doApiRequestWithRedirects(req, via, true) - if err != nil { - return res, nil, err - } - - obj := &api.ObjectResource{} - err = decodeApiResponse(res, obj) - - if err != nil { - setErrorResponseContext(err, res) - return nil, nil, err - } - - return res, obj, nil -} - -// doApiBatchRequest runs the request to the LFS batch API. If the API returns a -// 401, the repo will be marked as having private access and the request will be -// re-run. When the repo is marked as having private access, credentials will -// be retrieved. -func doApiBatchRequest(req *http.Request) (*http.Response, []*api.ObjectResource, error) { - res, err := doAPIRequest(req, config.Config.PrivateAccess(httputil.GetOperationForRequest(req))) - - if err != nil { - if res != nil && res.StatusCode == 401 { - return res, nil, errutil.NewAuthError(err) - } - return res, nil, err - } - - var objs map[string][]*api.ObjectResource - err = decodeApiResponse(res, &objs) - - if err != nil { - setErrorResponseContext(err, res) - } - - return res, objs["objects"], err -} - -// doHttpRequestWithCreds performs doHttpRequest with creds added -func doHttpRequestWithCreds(req *http.Request) (*http.Response, error) { - creds, err := auth.GetCreds(req) - if err != nil { - return nil, err - } - - return doHttpRequest(req, creds) -} - -// doAPIRequest runs the request to the LFS API, without parsing the response -// body. If the API returns a 401, the repo will be marked as having private -// access and the request will be re-run. When the repo is marked as having -// private access, credentials will be retrieved. -func doAPIRequest(req *http.Request, useCreds bool) (*http.Response, error) { - via := make([]*http.Request, 0, 4) - return doApiRequestWithRedirects(req, via, useCreds) -} - -// doHttpRequest runs the given HTTP request. LFS or Storage API requests should -// use doApiBatchRequest() or doHttpRequestWithCreds() instead. -func doHttpRequest(req *http.Request, creds auth.Creds) (*http.Response, error) { - var ( - res *http.Response - err error - ) - - if config.Config.NtlmAccess(httputil.GetOperationForRequest(req)) { - res, err = auth.DoNTLMRequest(req, true) - } else { - res, err = httputil.NewHttpClient(config.Config, req.Host).Do(req) - } - - if res == nil { - res = &http.Response{ - StatusCode: 0, - Header: make(http.Header), - Request: req, - Body: ioutil.NopCloser(bytes.NewBufferString("")), - } - } - - if err != nil { - if errutil.IsAuthError(err) { - setAuthType(req, res) - doHttpRequest(req, creds) - } else { - err = errutil.Error(err) - } - } else { - err = handleResponse(res, creds) - } - - if err != nil { - if res != nil { - setErrorResponseContext(err, res) - } else { - setErrorRequestContext(err, req) - } - } - - return res, err -} - -func doApiRequestWithRedirects(req *http.Request, via []*http.Request, useCreds bool) (*http.Response, error) { - var creds auth.Creds - if useCreds { - c, err := auth.GetCreds(req) - if err != nil { - return nil, err - } - creds = c - } - - res, err := doHttpRequest(req, creds) - if err != nil { - return res, err - } - - if res.StatusCode == 307 { - redirectTo := res.Header.Get("Location") - locurl, err := url.Parse(redirectTo) - if err == nil && !locurl.IsAbs() { - locurl = req.URL.ResolveReference(locurl) - redirectTo = locurl.String() - } - - redirectedReq, err := newClientRequest(req.Method, redirectTo, nil) - if err != nil { - return res, errutil.Errorf(err, err.Error()) - } - - via = append(via, req) - - // Avoid seeking and re-wrapping the CountingReadCloser, just get the "real" body - realBody := req.Body - if wrappedBody, ok := req.Body.(*httputil.CountingReadCloser); ok { - realBody = wrappedBody.ReadCloser - } - - seeker, ok := realBody.(io.Seeker) - if !ok { - return res, errutil.Errorf(nil, "Request body needs to be an io.Seeker to handle redirects.") - } - - if _, err := seeker.Seek(0, 0); err != nil { - return res, errutil.Error(err) - } - redirectedReq.Body = realBody - redirectedReq.ContentLength = req.ContentLength - - if err = httputil.CheckRedirect(redirectedReq, via); err != nil { - return res, errutil.Errorf(err, err.Error()) - } - - return doApiRequestWithRedirects(redirectedReq, via, useCreds) - } - - return res, nil -} - -func handleResponse(res *http.Response, creds auth.Creds) error { - auth.SaveCredentials(creds, res) - - if res.StatusCode < 400 { - return nil - } - - defer func() { - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - }() - - cliErr := &ClientError{} - err := decodeApiResponse(res, cliErr) - if err == nil { - if len(cliErr.Message) == 0 { - err = defaultError(res) - } else { - err = errutil.Error(cliErr) - } - } - - if res.StatusCode == 401 { - return errutil.NewAuthError(err) - } - - if res.StatusCode > 499 && res.StatusCode != 501 && res.StatusCode != 509 { - return errutil.NewFatalError(err) - } - - return err -} - -func decodeApiResponse(res *http.Response, obj interface{}) error { - ctype := res.Header.Get("Content-Type") - if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) { - return nil - } - - err := json.NewDecoder(res.Body).Decode(obj) - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - - if err != nil { - return errutil.Errorf(err, "Unable to parse HTTP response for %s", httputil.TraceHttpReq(res.Request)) - } - - return nil -} - -func defaultError(res *http.Response) error { - var msgFmt string - - if f, ok := defaultErrors[res.StatusCode]; ok { - msgFmt = f - } else if res.StatusCode < 500 { - msgFmt = defaultErrors[400] + fmt.Sprintf(" from HTTP %d", res.StatusCode) - } else { - msgFmt = defaultErrors[500] + fmt.Sprintf(" from HTTP %d", res.StatusCode) - } - - return errutil.Error(fmt.Errorf(msgFmt, res.Request.URL)) -} - func newApiRequest(method, oid string) (*http.Request, error) { objectOid := oid operation := "download" @@ -615,7 +352,7 @@ func newApiRequest(method, oid string) (*http.Request, error) { return nil, err } - req, err := newClientRequest(method, u.String(), res.Header) + req, err := httputil.NewHttpRequest(method, u.String(), res.Header) if err != nil { return nil, err } @@ -624,21 +361,6 @@ func newApiRequest(method, oid string) (*http.Request, error) { return req, nil } -func newClientRequest(method, rawurl string, header map[string]string) (*http.Request, error) { - req, err := http.NewRequest(method, rawurl, nil) - if err != nil { - return nil, err - } - - for key, value := range header { - req.Header.Set(key, value) - } - - req.Header.Set("User-Agent", httputil.UserAgent) - - return req, nil -} - func newBatchApiRequest(operation string) (*http.Request, error) { endpoint := config.Config.Endpoint(operation) @@ -659,7 +381,7 @@ func newBatchApiRequest(operation string) (*http.Request, error) { return nil, err } - req, err := newBatchClientRequest("POST", u.String()) + req, err := httputil.NewHttpRequest("POST", u.String(), nil) if err != nil { return nil, err } @@ -674,60 +396,6 @@ func newBatchApiRequest(operation string) (*http.Request, error) { return req, nil } -func newBatchClientRequest(method, rawurl string) (*http.Request, error) { - req, err := http.NewRequest(method, rawurl, nil) - if err != nil { - return nil, err - } - - req.Header.Set("User-Agent", httputil.UserAgent) - - return req, nil -} - -func setAuthType(req *http.Request, res *http.Response) { - authType := getAuthType(res) - operation := httputil.GetOperationForRequest(req) - config.Config.SetAccess(operation, authType) - tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", authType) -} - -func getAuthType(res *http.Response) string { - auth := res.Header.Get("Www-Authenticate") - if len(auth) < 1 { - auth = res.Header.Get("Lfs-Authenticate") - } - - if strings.HasPrefix(strings.ToLower(auth), "ntlm") { - return "ntlm" - } - - return "basic" -} - -func setErrorResponseContext(err error, res *http.Response) { - errutil.ErrorSetContext(err, "Status", res.Status) - setErrorHeaderContext(err, "Request", res.Header) - setErrorRequestContext(err, res.Request) -} - -func setErrorRequestContext(err error, req *http.Request) { - errutil.ErrorSetContext(err, "Endpoint", config.Config.Endpoint(httputil.GetOperationForRequest(req)).Url) - errutil.ErrorSetContext(err, "URL", httputil.TraceHttpReq(req)) - setErrorHeaderContext(err, "Response", req.Header) -} - -func setErrorHeaderContext(err error, prefix string, head http.Header) { - for key, _ := range head { - contextKey := fmt.Sprintf("%s:%s", prefix, key) - if _, skip := hiddenHeaders[key]; skip { - errutil.ErrorSetContext(err, contextKey, "--") - } else { - errutil.ErrorSetContext(err, contextKey, head.Get(key)) - } - } -} - func ObjectUrl(endpoint config.Endpoint, oid string) (*url.URL, error) { u, err := url.Parse(endpoint.Url) if err != nil { diff --git a/lfs/download_test.go b/lfs/download_test.go index 9df019a4..b9cb7c7e 100644 --- a/lfs/download_test.go +++ b/lfs/download_test.go @@ -14,6 +14,7 @@ import ( "github.com/github/git-lfs/api" "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" + "github.com/github/git-lfs/httputil" ) func TestSuccessfulDownload(t *testing.T) { @@ -650,7 +651,7 @@ func TestDownloadAPIError(t *testing.T) { return } - if err.Error() != fmt.Sprintf(defaultErrors[404], server.URL+"/media/objects/oid") { + if err.Error() != fmt.Sprintf(httputil.GetDefaultError(404), server.URL+"/media/objects/oid") { t.Fatalf("Unexpected error: %s", err.Error()) } @@ -727,7 +728,7 @@ func TestDownloadStorageError(t *testing.T) { t.Fatal("should panic") } - if err.Error() != fmt.Sprintf(defaultErrors[500], server.URL+"/download") { + if err.Error() != fmt.Sprintf(httputil.GetDefaultError(500), server.URL+"/download") { t.Fatalf("Unexpected error: %s", err.Error()) } diff --git a/lfs/upload_test.go b/lfs/upload_test.go index 4378a920..bda271e6 100644 --- a/lfs/upload_test.go +++ b/lfs/upload_test.go @@ -15,6 +15,7 @@ import ( "github.com/github/git-lfs/api" "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" + "github.com/github/git-lfs/httputil" ) func TestExistingUpload(t *testing.T) { @@ -662,7 +663,7 @@ func TestUploadApiError(t *testing.T) { return } - if err.Error() != fmt.Sprintf(defaultErrors[404], server.URL+"/media/objects") { + if err.Error() != fmt.Sprintf(httputil.GetDefaultError(404), server.URL+"/media/objects") { t.Fatalf("Unexpected error: %s", err.Error()) } @@ -781,7 +782,7 @@ func TestUploadStorageError(t *testing.T) { t.Fatal("should not panic") } - if err.Error() != fmt.Sprintf(defaultErrors[404], server.URL+"/upload") { + if err.Error() != fmt.Sprintf(httputil.GetDefaultError(404), server.URL+"/upload") { t.Fatalf("Unexpected error: %s", err.Error()) } @@ -937,7 +938,7 @@ func TestUploadVerifyError(t *testing.T) { t.Fatal("should not panic") } - if err.Error() != fmt.Sprintf(defaultErrors[404], server.URL+"/verify") { + if err.Error() != fmt.Sprintf(httputil.GetDefaultError(404), server.URL+"/verify") { t.Fatalf("Unexpected error: %s", err.Error()) }