diff --git a/lfsapi/auth.go b/lfsapi/auth.go index 0070a2f6..7d53e49c 100644 --- a/lfsapi/auth.go +++ b/lfsapi/auth.go @@ -24,6 +24,10 @@ var ( // authentication from netrc or git's credential helpers if necessary, // supporting basic and ntlm authentication. func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, error) { + return c.doWithAuth(remote, req, nil) +} + +func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) { req.Header = c.extraHeadersFor(req) apiEndpoint, access, credHelper, credsURL, creds, err := c.getCreds(remote, req) @@ -31,7 +35,7 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e return nil, err } - res, err := c.doWithCreds(req, credHelper, creds, credsURL, access) + res, err := c.doWithCreds(req, credHelper, creds, credsURL, access, via) if err != nil { if errors.IsAuthError(err) { newAccess := getAuthAccess(res) @@ -45,6 +49,12 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e req.Header.Del("Authorization") credHelper.Reject(creds) } + + // This case represents a rejected request that + // should have been authenticated but wasn't. Do + // not count this against our redirection + // maximum, so do not recur through doWithAuth + // and instead call DoWithAuth. return c.DoWithAuth(remote, req) } } @@ -57,11 +67,11 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e return res, err } -func (c *Client) doWithCreds(req *http.Request, credHelper CredentialHelper, creds Creds, credsURL *url.URL, access Access) (*http.Response, error) { +func (c *Client) doWithCreds(req *http.Request, credHelper CredentialHelper, creds Creds, credsURL *url.URL, access Access, via []*http.Request) (*http.Response, error) { if access == NTLMAccess { return c.doWithNTLM(req, credHelper, creds, credsURL) } - return c.do(req) + return c.do(req, "", via) } // getCreds fills the authorization header for the given request if possible, diff --git a/lfsapi/client.go b/lfsapi/client.go index 222a9a5c..bea88a0b 100644 --- a/lfsapi/client.go +++ b/lfsapi/client.go @@ -96,16 +96,16 @@ func joinURL(prefix, suffix string) string { func (c *Client) Do(req *http.Request) (*http.Response, error) { req.Header = c.extraHeadersFor(req) - return c.do(req) + return c.do(req, "", nil) } // do performs an *http.Request respecting redirects, and handles the response // as defined in c.handleResponse. Notably, it does not alter the headers for // the request argument in any way. -func (c *Client) do(req *http.Request) (*http.Response, error) { +func (c *Client) do(req *http.Request, remote string, via []*http.Request) (*http.Response, error) { req.Header.Set("User-Agent", UserAgent) - res, err := c.doWithRedirects(c.httpClient(req.Host), req, nil) + res, err := c.doWithRedirects(c.httpClient(req.Host), req, remote, via) if err != nil { return res, err } @@ -161,7 +161,7 @@ func (c *Client) extraHeaders(u *url.URL) map[string][]string { return m } -func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, via []*http.Request) (*http.Response, error) { +func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Response, error) { tracedReq, err := c.traceRequest(req) if err != nil { return nil, err @@ -231,7 +231,14 @@ func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, via []*htt return res, err } - return c.doWithRedirects(cli, redirectedReq, via) + if len(req.Header.Get("Authorization")) > 0 { + // If the original request was authenticated (noted by the + // presence of the Authorization header), then recur through + // doWithAuth, retaining the requests via but only after + // authenticating the redirected request. + return c.doWithAuth(remote, redirectedReq, via) + } + return c.doWithRedirects(cli, redirectedReq, remote, via) } func (c *Client) httpClient(host string) *http.Client { diff --git a/lfsapi/client_test.go b/lfsapi/client_test.go index e31a7e92..f5500f80 100644 --- a/lfsapi/client_test.go +++ b/lfsapi/client_test.go @@ -1,11 +1,13 @@ package lfsapi import ( + "encoding/base64" "encoding/json" "fmt" "net" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" @@ -175,6 +177,80 @@ func TestClientRedirect(t *testing.T) { assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http") } +func TestClientRedirectReauthenticate(t *testing.T) { + var srv1, srv2 *httptest.Server + var called1, called2 uint32 + var creds1, creds2 Creds + + srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&called1, 1) + + if hdr := r.Header.Get("Authorization"); len(hdr) > 0 { + parts := strings.SplitN(hdr, " ", 2) + typ, b64 := parts[0], parts[1] + + auth, err := base64.URLEncoding.DecodeString(b64) + assert.Nil(t, err) + assert.Equal(t, "Basic", typ) + assert.Equal(t, "user1:pass1", string(auth)) + + http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) + + srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&called2, 1) + + parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + typ, b64 := parts[0], parts[1] + + auth, err := base64.URLEncoding.DecodeString(b64) + assert.Nil(t, err) + assert.Equal(t, "Basic", typ) + assert.Equal(t, "user2:pass2", string(auth)) + })) + + // Change the URL of srv2 to make it appears as if it is a different + // host. + srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1) + + creds1 = Creds(map[string]string{ + "protocol": "http", + "host": strings.TrimPrefix(srv1.URL, "http://"), + + "username": "user1", + "password": "pass1", + }) + creds2 = Creds(map[string]string{ + "protocol": "http", + "host": strings.TrimPrefix(srv2.URL, "http://"), + + "username": "user2", + "password": "pass2", + }) + + defer srv1.Close() + defer srv2.Close() + + c, err := NewClient(NewContext(nil, nil, nil)) + creds := newCredentialCacher() + creds.Approve(creds1) + creds.Approve(creds2) + c.Credentials = creds + + req, err := http.NewRequest("GET", srv1.URL, nil) + require.Nil(t, err) + + _, err = c.DoWithAuth("", req) + assert.Nil(t, err) + + // called1 is 2 since LFS tries an unauthenticated request first + assert.EqualValues(t, 2, called1) + assert.EqualValues(t, 1, called2) +} + func TestNewClient(t *testing.T) { c, err := NewClient(NewContext(nil, nil, map[string]string{ "lfs.dialtimeout": "151", diff --git a/lfsapi/ntlm.go b/lfsapi/ntlm.go index 10b1ae7c..0170bb93 100644 --- a/lfsapi/ntlm.go +++ b/lfsapi/ntlm.go @@ -19,7 +19,7 @@ type ntmlCredentials struct { } func (c *Client) doWithNTLM(req *http.Request, credHelper CredentialHelper, creds Creds, credsURL *url.URL) (*http.Response, error) { - res, err := c.do(req) + res, err := c.do(req, "", nil) if err != nil && !errors.IsAuthError(err) { return res, err } @@ -86,7 +86,7 @@ func (c *Client) ntlmSendMessage(req *http.Request, message []byte) (*http.Respo msg := base64.StdEncoding.EncodeToString(message) req.Header.Set("Authorization", "NTLM "+msg) - return c.do(req) + return c.do(req, "", nil) } func parseChallengeResponse(res *http.Response) ([]byte, error) {