diff --git a/commands/command_env.go b/commands/command_env.go index 6d76528e..f0bef668 100644 --- a/commands/command_env.go +++ b/commands/command_env.go @@ -25,7 +25,7 @@ func envCommand(cmd *cobra.Command, args []string) { endpoint := getAPIClient().Endpoints.Endpoint("download", defaultRemote) if len(endpoint.Url) > 0 { access := getAPIClient().Endpoints.AccessFor(endpoint.Url) - Print("Endpoint=%s (auth=%s)", endpoint.Url, access) + Print("Endpoint=%s (auth=%s)", endpoint.Url, access.GetMode()) if len(endpoint.SshUserAndHost) > 0 { Print(" SSH=%s:%s", endpoint.SshUserAndHost, endpoint.SshPath) } @@ -38,7 +38,7 @@ func envCommand(cmd *cobra.Command, args []string) { } remoteEndpoint := getAPIClient().Endpoints.RemoteEndpoint("download", remote) remoteAccess := getAPIClient().Endpoints.AccessFor(remoteEndpoint.Url) - Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess) + Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess.GetMode()) if len(remoteEndpoint.SshUserAndHost) > 0 { Print(" SSH=%s:%s", remoteEndpoint.SshUserAndHost, remoteEndpoint.SshPath) } diff --git a/lfs/lfs.go b/lfs/lfs.go index 45b0e086..cef34c84 100644 --- a/lfs/lfs.go +++ b/lfs/lfs.go @@ -56,8 +56,8 @@ func Environ(cfg *config.Configuration, manifest *tq.Manifest) []string { fmt.Sprintf("PruneVerifyRemoteAlways=%v", fetchPruneConfig.PruneVerifyRemoteAlways), fmt.Sprintf("PruneRemoteName=%s", fetchPruneConfig.PruneRemoteName), fmt.Sprintf("LfsStorageDir=%s", cfg.LFSStorageDir()), - fmt.Sprintf("AccessDownload=%s", download), - fmt.Sprintf("AccessUpload=%s", upload), + fmt.Sprintf("AccessDownload=%s", download.GetMode()), + fmt.Sprintf("AccessUpload=%s", upload.GetMode()), fmt.Sprintf("DownloadTransfers=%s", strings.Join(dltransfers, ",")), fmt.Sprintf("UploadTransfers=%s", strings.Join(ultransfers, ",")), ) diff --git a/lfsapi/auth.go b/lfsapi/auth.go index 2c35f848..ec2a07cf 100644 --- a/lfsapi/auth.go +++ b/lfsapi/auth.go @@ -31,7 +31,7 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) { req.Header = c.client.ExtraHeadersFor(req) - apiEndpoint, access, credHelper, credsURL, creds, err := c.getCreds(remote, req) + access, credHelper, credsURL, creds, err := c.getCreds(remote, req) if err != nil { return nil, err } @@ -39,13 +39,13 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques res, err := c.doWithCreds(req, credHelper, creds, credsURL, access, via) if err != nil { if errors.IsAuthError(err) { - newAccess := getAuthAccess(res) - if newAccess != access { - c.Endpoints.SetAccess(apiEndpoint.Url, newAccess) + newMode := getAuthAccess(res) + if newMode != access.GetMode() { + c.Endpoints.SetAccess(access.Upgrade(newMode)) } - if creds != nil || (access == NoneAccess && len(req.Header.Get("Authorization")) == 0) { - tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newAccess) + if creds != nil || (access.GetMode() == NoneAccess && len(req.Header.Get("Authorization")) == 0) { + tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newMode) if creds != nil { req.Header.Del("Authorization") credHelper.Reject(creds) @@ -68,8 +68,8 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques return res, err } -func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access AccessMode, via []*http.Request) (*http.Response, error) { - if access == NTLMAccess { +func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access Access, via []*http.Request) (*http.Response, error) { + if access.GetMode() == NTLMAccess { return c.doWithNTLM(req, credHelper, creds, credsURL) } @@ -107,7 +107,7 @@ func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelpe // 3. The Git Remote URL, which should be something like "https://git.com/repo.git" // This URL is used for the Git Credential Helper. This way existing https // Git remote credentials can be re-used for LFS. -func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, AccessMode, creds.CredentialHelper, *url.URL, creds.Creds, error) { +func (c *Client) getCreds(remote string, req *http.Request) (Access, creds.CredentialHelper, *url.URL, creds.Creds, error) { ef := c.Endpoints if ef == nil { ef = defaultEndpointFinder @@ -122,18 +122,18 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A apiEndpoint := ef.Endpoint(operation, remote) access := ef.AccessFor(apiEndpoint.Url) - if access != NTLMAccess { - if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access == NoneAccess { - return apiEndpoint, access, creds.NullCreds, nil, nil, nil + if access.GetMode() != NTLMAccess { + if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access.GetMode() == NoneAccess { + return access, creds.NullCreds, nil, nil, nil } credsURL, err := getCredURLForAPI(ef, operation, remote, apiEndpoint, req) if err != nil { - return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds") + return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds") } if credsURL == nil { - return apiEndpoint, access, creds.NullCreds, nil, nil, nil + return access, creds.NullCreds, nil, nil, nil } credHelper, creds, err := c.getGitCreds(ef, req, credsURL) @@ -141,14 +141,14 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A tracerx.Printf("Filled credentials for %s", credsURL) setRequestAuth(req, creds["username"], creds["password"]) } - return apiEndpoint, access, credHelper, credsURL, creds, err + return access, credHelper, credsURL, creds, err } // NTLM ONLY credsURL, err := url.Parse(apiEndpoint.Url) if err != nil { - return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds") + return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds") } if netrcMachine := getAuthFromNetrc(netrcFinder, req); netrcMachine != nil { @@ -160,12 +160,12 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A "source": "netrc", } - return apiEndpoint, access, creds.NullCreds, credsURL, cred, nil + return access, creds.NullCreds, credsURL, cred, nil } // NTLM uses creds to create the session credHelper, creds, err := c.getGitCreds(ef, req, credsURL) - return apiEndpoint, access, credHelper, credsURL, creds, err + return access, credHelper, credsURL, creds, err } func (c *Client) getGitCreds(ef EndpointFinder, req *http.Request, u *url.URL) (creds.CredentialHelper, creds.Creds, error) { diff --git a/lfsapi/auth_test.go b/lfsapi/auth_test.go index 67e75266..198cd027 100644 --- a/lfsapi/auth_test.go +++ b/lfsapi/auth_test.go @@ -81,7 +81,7 @@ func TestDoWithAuthApprove(t *testing.T) { require.Nil(t, err) c.Credentials = cred - assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs")) + assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode) req, err := http.NewRequest("POST", srv.URL+"/repo/lfs/foo", nil) require.Nil(t, err) @@ -99,7 +99,7 @@ func TestDoWithAuthApprove(t *testing.T) { "protocol": "http", "host": srv.Listener.Addr().String(), }))) - assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs")) + assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode) assert.EqualValues(t, 2, called) } @@ -407,7 +407,7 @@ func TestGetCreds(t *testing.T) { }, Expected: getCredsExpected{ Access: BasicAccess, - Endpoint: "https://user@git-server.com/repo/lfs", + Endpoint: "https://git-server.com/repo/lfs", Authorization: basicAuth("user", "monkey"), CredsURL: "https://user@git-server.com/repo/lfs", Creds: map[string]string{ @@ -450,7 +450,7 @@ func TestGetCreds(t *testing.T) { }, Expected: getCredsExpected{ Access: BasicAccess, - Endpoint: "https://user:pass@git-server.com/repo", + Endpoint: "https://git-server.com/repo", Authorization: basicAuth("user", "pass"), }, }, @@ -574,12 +574,12 @@ func TestGetCreds(t *testing.T) { client.Credentials = &fakeCredentialFiller{} client.Netrc = &fakeNetrc{} client.Endpoints = NewEndpointFinder(ctx) - endpoint, access, _, credsURL, creds, err := client.getCreds(test.Remote, req) + access, _, credsURL, creds, err := client.getCreds(test.Remote, req) if !assert.Nil(t, err) { continue } - assert.Equal(t, test.Expected.Endpoint, endpoint.Url, "endpoint") - assert.Equal(t, test.Expected.Access, access, "access") + assert.Equal(t, test.Expected.Endpoint, access.url, "endpoint") + assert.Equal(t, test.Expected.Access, access.mode, "access") assert.Equal(t, test.Expected.Authorization, req.Header.Get("Authorization"), "authorization") if test.Expected.Creds != nil { diff --git a/lfsapi/endpoint_finder.go b/lfsapi/endpoint_finder.go index d0b346b5..c0740a22 100644 --- a/lfsapi/endpoint_finder.go +++ b/lfsapi/endpoint_finder.go @@ -26,14 +26,28 @@ const ( defaultRemote = "origin" ) +type Access struct { + mode AccessMode + url string +} + +// Returns a copy of an AccessMode with the mode upgraded to newMode +func (a *Access) Upgrade(newMode AccessMode) Access { + return Access{url: a.url, mode: newMode} +} + +func (a *Access) GetMode() AccessMode { + return a.mode +} + type EndpointFinder interface { NewEndpointFromCloneURL(rawurl string) lfshttp.Endpoint NewEndpoint(rawurl string) lfshttp.Endpoint Endpoint(operation, remote string) lfshttp.Endpoint RemoteEndpoint(operation, remote string) lfshttp.Endpoint GitRemoteURL(remote string, forpush bool) string - AccessFor(rawurl string) AccessMode - SetAccess(rawurl string, access AccessMode) + AccessFor(rawurl string) Access + SetAccess(access Access) GitProtocol() string } @@ -202,39 +216,38 @@ func (e *endpointGitFinder) NewEndpoint(rawurl string) lfshttp.Endpoint { } } -func (e *endpointGitFinder) AccessFor(rawurl string) AccessMode { - if e.gitEnv == nil { - return NoneAccess - } - +func (e *endpointGitFinder) AccessFor(rawurl string) Access { accessurl := urlWithoutAuth(rawurl) + if e.gitEnv == nil { + return Access{mode: NoneAccess, url: accessurl} + } + e.accessMu.Lock() defer e.accessMu.Unlock() if cached, ok := e.urlAccess[accessurl]; ok { - return cached + return Access{mode: cached, url: accessurl} } e.urlAccess[accessurl] = e.fetchGitAccess(accessurl) - return e.urlAccess[accessurl] + return Access{mode: e.urlAccess[accessurl], url: accessurl} } -func (e *endpointGitFinder) SetAccess(rawurl string, access AccessMode) { - accessurl := urlWithoutAuth(rawurl) - key := fmt.Sprintf("lfs.%s.access", accessurl) - tracerx.Printf("setting repository access to %s", access) +func (e *endpointGitFinder) SetAccess(access Access) { + key := fmt.Sprintf("lfs.%s.access", access.url) + tracerx.Printf("setting repository access to %s", access.mode) e.accessMu.Lock() defer e.accessMu.Unlock() - switch access { + switch access.mode { case emptyAccess, NoneAccess: e.gitConfig.UnsetLocalKey(key) - e.urlAccess[accessurl] = NoneAccess + e.urlAccess[access.url] = NoneAccess default: - e.gitConfig.SetLocal(key, string(access)) - e.urlAccess[accessurl] = access + e.gitConfig.SetLocal(key, string(access.mode)) + e.urlAccess[access.url] = access.mode } } diff --git a/lfsapi/endpoint_finder_test.go b/lfsapi/endpoint_finder_test.go index 15e722b5..0e2e0772 100644 --- a/lfsapi/endpoint_finder_test.go +++ b/lfsapi/endpoint_finder_test.go @@ -304,10 +304,10 @@ func TestAccessConfig(t *testing.T) { dl := finder.Endpoint("upload", "") ul := finder.Endpoint("download", "") - if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) { + if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) { t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access) } - if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) { + if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) { t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access) } } @@ -325,10 +325,10 @@ func TestAccessConfig(t *testing.T) { dl := finder.Endpoint("upload", "") ul := finder.Endpoint("download", "") - if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) { + if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) { t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access) } - if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) { + if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) { t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access) } } @@ -336,16 +336,23 @@ func TestAccessConfig(t *testing.T) { func TestAccessAbsentConfig(t *testing.T) { finder := NewEndpointFinder(nil) - assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url)) - assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url)) + assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url).mode) + assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url).mode) } func TestSetAccess(t *testing.T) { finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{})) + url := "http://example.com" + access := finder.AccessFor(url) - assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com")) - finder.SetAccess("http://example.com", NTLMAccess) - assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com")) + assert.Equal(t, NoneAccess, access.mode) + assert.Equal(t, url, access.url) + + finder.SetAccess(access.Upgrade(NTLMAccess)) + + newAccess := finder.AccessFor(url) + assert.Equal(t, NTLMAccess, newAccess.mode) + assert.Equal(t, url, newAccess.url) } func TestChangeAccess(t *testing.T) { @@ -353,9 +360,16 @@ func TestChangeAccess(t *testing.T) { "lfs.http://example.com.access": "basic", })) - assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com")) - finder.SetAccess("http://example.com", NTLMAccess) - assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com")) + url := "http://example.com" + access := finder.AccessFor(url) + assert.Equal(t, BasicAccess, access.mode) + assert.Equal(t, url, access.url) + + finder.SetAccess(access.Upgrade(NTLMAccess)) + + newAccess := finder.AccessFor(url) + assert.Equal(t, NTLMAccess, newAccess.mode) + assert.Equal(t, url, newAccess.url) } func TestDeleteAccessWithNone(t *testing.T) { @@ -363,9 +377,17 @@ func TestDeleteAccessWithNone(t *testing.T) { "lfs.http://example.com.access": "basic", })) - assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com")) - finder.SetAccess("http://example.com", NoneAccess) - assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com")) + url := "http://example.com" + + access := finder.AccessFor(url) + assert.Equal(t, BasicAccess, access.mode) + assert.Equal(t, url, access.url) + + finder.SetAccess(access.Upgrade(NoneAccess)) + + newAccess := finder.AccessFor(url) + assert.Equal(t, NoneAccess, newAccess.mode) + assert.Equal(t, url, newAccess.url) } func TestDeleteAccessWithEmptyString(t *testing.T) { @@ -373,9 +395,17 @@ func TestDeleteAccessWithEmptyString(t *testing.T) { "lfs.http://example.com.access": "basic", })) - assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com")) - finder.SetAccess("http://example.com", AccessMode("")) - assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com")) + url := "http://example.com" + + access := finder.AccessFor(url) + assert.Equal(t, BasicAccess, access.mode) + assert.Equal(t, url, access.url) + + finder.SetAccess(access.Upgrade(AccessMode(""))) + + newAccess := finder.AccessFor(url) + assert.Equal(t, NoneAccess, newAccess.mode) + assert.Equal(t, url, newAccess.url) } type EndpointParsingTestCase struct { diff --git a/tq/transfer_queue.go b/tq/transfer_queue.go index e53f1c50..fb88715a 100644 --- a/tq/transfer_queue.go +++ b/tq/transfer_queue.go @@ -764,7 +764,8 @@ func (q *TransferQueue) ensureAdapterBegun(e lfshttp.Endpoint) error { func (q *TransferQueue) toAdapterCfg(e lfshttp.Endpoint) AdapterConfig { apiClient := q.manifest.APIClient() concurrency := q.manifest.ConcurrentTransfers() - if apiClient.Endpoints.AccessFor(e.Url) == lfsapi.NTLMAccess { + access := apiClient.Endpoints.AccessFor(e.Url) + if access.GetMode() == lfsapi.NTLMAccess { concurrency = 1 }