lfsapi: add new Access type

Add a new Access type returned by the EndpointFinder. This type
contains a URL and the AccessMode associated with it. It also
contains an `Upgrade()` function, which returns a copy of the
Access object but with a new AccessMode.
This commit is contained in:
Preben Ingvaldsen 2018-09-24 11:50:09 -07:00
parent 1719c21087
commit 25990575e7
7 changed files with 109 additions and 65 deletions

@ -25,7 +25,7 @@ func envCommand(cmd *cobra.Command, args []string) {
endpoint := getAPIClient().Endpoints.Endpoint("download", defaultRemote)
if len(endpoint.Url) > 0 {
access := getAPIClient().Endpoints.AccessFor(endpoint.Url)
Print("Endpoint=%s (auth=%s)", endpoint.Url, access)
Print("Endpoint=%s (auth=%s)", endpoint.Url, access.GetMode())
if len(endpoint.SshUserAndHost) > 0 {
Print(" SSH=%s:%s", endpoint.SshUserAndHost, endpoint.SshPath)
}
@ -38,7 +38,7 @@ func envCommand(cmd *cobra.Command, args []string) {
}
remoteEndpoint := getAPIClient().Endpoints.RemoteEndpoint("download", remote)
remoteAccess := getAPIClient().Endpoints.AccessFor(remoteEndpoint.Url)
Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess)
Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess.GetMode())
if len(remoteEndpoint.SshUserAndHost) > 0 {
Print(" SSH=%s:%s", remoteEndpoint.SshUserAndHost, remoteEndpoint.SshPath)
}

@ -56,8 +56,8 @@ func Environ(cfg *config.Configuration, manifest *tq.Manifest) []string {
fmt.Sprintf("PruneVerifyRemoteAlways=%v", fetchPruneConfig.PruneVerifyRemoteAlways),
fmt.Sprintf("PruneRemoteName=%s", fetchPruneConfig.PruneRemoteName),
fmt.Sprintf("LfsStorageDir=%s", cfg.LFSStorageDir()),
fmt.Sprintf("AccessDownload=%s", download),
fmt.Sprintf("AccessUpload=%s", upload),
fmt.Sprintf("AccessDownload=%s", download.GetMode()),
fmt.Sprintf("AccessUpload=%s", upload.GetMode()),
fmt.Sprintf("DownloadTransfers=%s", strings.Join(dltransfers, ",")),
fmt.Sprintf("UploadTransfers=%s", strings.Join(ultransfers, ",")),
)

@ -31,7 +31,7 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e
func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) {
req.Header = c.client.ExtraHeadersFor(req)
apiEndpoint, access, credHelper, credsURL, creds, err := c.getCreds(remote, req)
access, credHelper, credsURL, creds, err := c.getCreds(remote, req)
if err != nil {
return nil, err
}
@ -39,13 +39,13 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques
res, err := c.doWithCreds(req, credHelper, creds, credsURL, access, via)
if err != nil {
if errors.IsAuthError(err) {
newAccess := getAuthAccess(res)
if newAccess != access {
c.Endpoints.SetAccess(apiEndpoint.Url, newAccess)
newMode := getAuthAccess(res)
if newMode != access.GetMode() {
c.Endpoints.SetAccess(access.Upgrade(newMode))
}
if creds != nil || (access == NoneAccess && len(req.Header.Get("Authorization")) == 0) {
tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newAccess)
if creds != nil || (access.GetMode() == NoneAccess && len(req.Header.Get("Authorization")) == 0) {
tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newMode)
if creds != nil {
req.Header.Del("Authorization")
credHelper.Reject(creds)
@ -68,8 +68,8 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques
return res, err
}
func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access AccessMode, via []*http.Request) (*http.Response, error) {
if access == NTLMAccess {
func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access Access, via []*http.Request) (*http.Response, error) {
if access.GetMode() == NTLMAccess {
return c.doWithNTLM(req, credHelper, creds, credsURL)
}
@ -107,7 +107,7 @@ func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelpe
// 3. The Git Remote URL, which should be something like "https://git.com/repo.git"
// This URL is used for the Git Credential Helper. This way existing https
// Git remote credentials can be re-used for LFS.
func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, AccessMode, creds.CredentialHelper, *url.URL, creds.Creds, error) {
func (c *Client) getCreds(remote string, req *http.Request) (Access, creds.CredentialHelper, *url.URL, creds.Creds, error) {
ef := c.Endpoints
if ef == nil {
ef = defaultEndpointFinder
@ -122,18 +122,18 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
apiEndpoint := ef.Endpoint(operation, remote)
access := ef.AccessFor(apiEndpoint.Url)
if access != NTLMAccess {
if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access == NoneAccess {
return apiEndpoint, access, creds.NullCreds, nil, nil, nil
if access.GetMode() != NTLMAccess {
if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access.GetMode() == NoneAccess {
return access, creds.NullCreds, nil, nil, nil
}
credsURL, err := getCredURLForAPI(ef, operation, remote, apiEndpoint, req)
if err != nil {
return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
}
if credsURL == nil {
return apiEndpoint, access, creds.NullCreds, nil, nil, nil
return access, creds.NullCreds, nil, nil, nil
}
credHelper, creds, err := c.getGitCreds(ef, req, credsURL)
@ -141,14 +141,14 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
tracerx.Printf("Filled credentials for %s", credsURL)
setRequestAuth(req, creds["username"], creds["password"])
}
return apiEndpoint, access, credHelper, credsURL, creds, err
return access, credHelper, credsURL, creds, err
}
// NTLM ONLY
credsURL, err := url.Parse(apiEndpoint.Url)
if err != nil {
return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
}
if netrcMachine := getAuthFromNetrc(netrcFinder, req); netrcMachine != nil {
@ -160,12 +160,12 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
"source": "netrc",
}
return apiEndpoint, access, creds.NullCreds, credsURL, cred, nil
return access, creds.NullCreds, credsURL, cred, nil
}
// NTLM uses creds to create the session
credHelper, creds, err := c.getGitCreds(ef, req, credsURL)
return apiEndpoint, access, credHelper, credsURL, creds, err
return access, credHelper, credsURL, creds, err
}
func (c *Client) getGitCreds(ef EndpointFinder, req *http.Request, u *url.URL) (creds.CredentialHelper, creds.Creds, error) {

@ -81,7 +81,7 @@ func TestDoWithAuthApprove(t *testing.T) {
require.Nil(t, err)
c.Credentials = cred
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs"))
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode)
req, err := http.NewRequest("POST", srv.URL+"/repo/lfs/foo", nil)
require.Nil(t, err)
@ -99,7 +99,7 @@ func TestDoWithAuthApprove(t *testing.T) {
"protocol": "http",
"host": srv.Listener.Addr().String(),
})))
assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs"))
assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode)
assert.EqualValues(t, 2, called)
}
@ -407,7 +407,7 @@ func TestGetCreds(t *testing.T) {
},
Expected: getCredsExpected{
Access: BasicAccess,
Endpoint: "https://user@git-server.com/repo/lfs",
Endpoint: "https://git-server.com/repo/lfs",
Authorization: basicAuth("user", "monkey"),
CredsURL: "https://user@git-server.com/repo/lfs",
Creds: map[string]string{
@ -450,7 +450,7 @@ func TestGetCreds(t *testing.T) {
},
Expected: getCredsExpected{
Access: BasicAccess,
Endpoint: "https://user:pass@git-server.com/repo",
Endpoint: "https://git-server.com/repo",
Authorization: basicAuth("user", "pass"),
},
},
@ -574,12 +574,12 @@ func TestGetCreds(t *testing.T) {
client.Credentials = &fakeCredentialFiller{}
client.Netrc = &fakeNetrc{}
client.Endpoints = NewEndpointFinder(ctx)
endpoint, access, _, credsURL, creds, err := client.getCreds(test.Remote, req)
access, _, credsURL, creds, err := client.getCreds(test.Remote, req)
if !assert.Nil(t, err) {
continue
}
assert.Equal(t, test.Expected.Endpoint, endpoint.Url, "endpoint")
assert.Equal(t, test.Expected.Access, access, "access")
assert.Equal(t, test.Expected.Endpoint, access.url, "endpoint")
assert.Equal(t, test.Expected.Access, access.mode, "access")
assert.Equal(t, test.Expected.Authorization, req.Header.Get("Authorization"), "authorization")
if test.Expected.Creds != nil {

@ -26,14 +26,28 @@ const (
defaultRemote = "origin"
)
type Access struct {
mode AccessMode
url string
}
// Returns a copy of an AccessMode with the mode upgraded to newMode
func (a *Access) Upgrade(newMode AccessMode) Access {
return Access{url: a.url, mode: newMode}
}
func (a *Access) GetMode() AccessMode {
return a.mode
}
type EndpointFinder interface {
NewEndpointFromCloneURL(rawurl string) lfshttp.Endpoint
NewEndpoint(rawurl string) lfshttp.Endpoint
Endpoint(operation, remote string) lfshttp.Endpoint
RemoteEndpoint(operation, remote string) lfshttp.Endpoint
GitRemoteURL(remote string, forpush bool) string
AccessFor(rawurl string) AccessMode
SetAccess(rawurl string, access AccessMode)
AccessFor(rawurl string) Access
SetAccess(access Access)
GitProtocol() string
}
@ -202,39 +216,38 @@ func (e *endpointGitFinder) NewEndpoint(rawurl string) lfshttp.Endpoint {
}
}
func (e *endpointGitFinder) AccessFor(rawurl string) AccessMode {
if e.gitEnv == nil {
return NoneAccess
}
func (e *endpointGitFinder) AccessFor(rawurl string) Access {
accessurl := urlWithoutAuth(rawurl)
if e.gitEnv == nil {
return Access{mode: NoneAccess, url: accessurl}
}
e.accessMu.Lock()
defer e.accessMu.Unlock()
if cached, ok := e.urlAccess[accessurl]; ok {
return cached
return Access{mode: cached, url: accessurl}
}
e.urlAccess[accessurl] = e.fetchGitAccess(accessurl)
return e.urlAccess[accessurl]
return Access{mode: e.urlAccess[accessurl], url: accessurl}
}
func (e *endpointGitFinder) SetAccess(rawurl string, access AccessMode) {
accessurl := urlWithoutAuth(rawurl)
key := fmt.Sprintf("lfs.%s.access", accessurl)
tracerx.Printf("setting repository access to %s", access)
func (e *endpointGitFinder) SetAccess(access Access) {
key := fmt.Sprintf("lfs.%s.access", access.url)
tracerx.Printf("setting repository access to %s", access.mode)
e.accessMu.Lock()
defer e.accessMu.Unlock()
switch access {
switch access.mode {
case emptyAccess, NoneAccess:
e.gitConfig.UnsetLocalKey(key)
e.urlAccess[accessurl] = NoneAccess
e.urlAccess[access.url] = NoneAccess
default:
e.gitConfig.SetLocal(key, string(access))
e.urlAccess[accessurl] = access
e.gitConfig.SetLocal(key, string(access.mode))
e.urlAccess[access.url] = access.mode
}
}

@ -304,10 +304,10 @@ func TestAccessConfig(t *testing.T) {
dl := finder.Endpoint("upload", "")
ul := finder.Endpoint("download", "")
if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) {
if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) {
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
}
if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) {
if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) {
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
}
}
@ -325,10 +325,10 @@ func TestAccessConfig(t *testing.T) {
dl := finder.Endpoint("upload", "")
ul := finder.Endpoint("download", "")
if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) {
if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) {
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
}
if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) {
if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) {
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
}
}
@ -336,16 +336,23 @@ func TestAccessConfig(t *testing.T) {
func TestAccessAbsentConfig(t *testing.T) {
finder := NewEndpointFinder(nil)
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url))
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url))
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url).mode)
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url).mode)
}
func TestSetAccess(t *testing.T) {
finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{}))
url := "http://example.com"
access := finder.AccessFor(url)
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
finder.SetAccess("http://example.com", NTLMAccess)
assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com"))
assert.Equal(t, NoneAccess, access.mode)
assert.Equal(t, url, access.url)
finder.SetAccess(access.Upgrade(NTLMAccess))
newAccess := finder.AccessFor(url)
assert.Equal(t, NTLMAccess, newAccess.mode)
assert.Equal(t, url, newAccess.url)
}
func TestChangeAccess(t *testing.T) {
@ -353,9 +360,16 @@ func TestChangeAccess(t *testing.T) {
"lfs.http://example.com.access": "basic",
}))
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
finder.SetAccess("http://example.com", NTLMAccess)
assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com"))
url := "http://example.com"
access := finder.AccessFor(url)
assert.Equal(t, BasicAccess, access.mode)
assert.Equal(t, url, access.url)
finder.SetAccess(access.Upgrade(NTLMAccess))
newAccess := finder.AccessFor(url)
assert.Equal(t, NTLMAccess, newAccess.mode)
assert.Equal(t, url, newAccess.url)
}
func TestDeleteAccessWithNone(t *testing.T) {
@ -363,9 +377,17 @@ func TestDeleteAccessWithNone(t *testing.T) {
"lfs.http://example.com.access": "basic",
}))
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
finder.SetAccess("http://example.com", NoneAccess)
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
url := "http://example.com"
access := finder.AccessFor(url)
assert.Equal(t, BasicAccess, access.mode)
assert.Equal(t, url, access.url)
finder.SetAccess(access.Upgrade(NoneAccess))
newAccess := finder.AccessFor(url)
assert.Equal(t, NoneAccess, newAccess.mode)
assert.Equal(t, url, newAccess.url)
}
func TestDeleteAccessWithEmptyString(t *testing.T) {
@ -373,9 +395,17 @@ func TestDeleteAccessWithEmptyString(t *testing.T) {
"lfs.http://example.com.access": "basic",
}))
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
finder.SetAccess("http://example.com", AccessMode(""))
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
url := "http://example.com"
access := finder.AccessFor(url)
assert.Equal(t, BasicAccess, access.mode)
assert.Equal(t, url, access.url)
finder.SetAccess(access.Upgrade(AccessMode("")))
newAccess := finder.AccessFor(url)
assert.Equal(t, NoneAccess, newAccess.mode)
assert.Equal(t, url, newAccess.url)
}
type EndpointParsingTestCase struct {

@ -764,7 +764,8 @@ func (q *TransferQueue) ensureAdapterBegun(e lfshttp.Endpoint) error {
func (q *TransferQueue) toAdapterCfg(e lfshttp.Endpoint) AdapterConfig {
apiClient := q.manifest.APIClient()
concurrency := q.manifest.ConcurrentTransfers()
if apiClient.Endpoints.AccessFor(e.Url) == lfsapi.NTLMAccess {
access := apiClient.Endpoints.AccessFor(e.Url)
if access.GetMode() == lfsapi.NTLMAccess {
concurrency = 1
}