lfsapi: add new Access type
Add a new Access type returned by the EndpointFinder. This type contains a URL and the AccessMode associated with it. It also contains an `Upgrade()` function, which returns a copy of the Access object but with a new AccessMode.
This commit is contained in:
parent
1719c21087
commit
25990575e7
@ -25,7 +25,7 @@ func envCommand(cmd *cobra.Command, args []string) {
|
||||
endpoint := getAPIClient().Endpoints.Endpoint("download", defaultRemote)
|
||||
if len(endpoint.Url) > 0 {
|
||||
access := getAPIClient().Endpoints.AccessFor(endpoint.Url)
|
||||
Print("Endpoint=%s (auth=%s)", endpoint.Url, access)
|
||||
Print("Endpoint=%s (auth=%s)", endpoint.Url, access.GetMode())
|
||||
if len(endpoint.SshUserAndHost) > 0 {
|
||||
Print(" SSH=%s:%s", endpoint.SshUserAndHost, endpoint.SshPath)
|
||||
}
|
||||
@ -38,7 +38,7 @@ func envCommand(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
remoteEndpoint := getAPIClient().Endpoints.RemoteEndpoint("download", remote)
|
||||
remoteAccess := getAPIClient().Endpoints.AccessFor(remoteEndpoint.Url)
|
||||
Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess)
|
||||
Print("Endpoint (%s)=%s (auth=%s)", remote, remoteEndpoint.Url, remoteAccess.GetMode())
|
||||
if len(remoteEndpoint.SshUserAndHost) > 0 {
|
||||
Print(" SSH=%s:%s", remoteEndpoint.SshUserAndHost, remoteEndpoint.SshPath)
|
||||
}
|
||||
|
@ -56,8 +56,8 @@ func Environ(cfg *config.Configuration, manifest *tq.Manifest) []string {
|
||||
fmt.Sprintf("PruneVerifyRemoteAlways=%v", fetchPruneConfig.PruneVerifyRemoteAlways),
|
||||
fmt.Sprintf("PruneRemoteName=%s", fetchPruneConfig.PruneRemoteName),
|
||||
fmt.Sprintf("LfsStorageDir=%s", cfg.LFSStorageDir()),
|
||||
fmt.Sprintf("AccessDownload=%s", download),
|
||||
fmt.Sprintf("AccessUpload=%s", upload),
|
||||
fmt.Sprintf("AccessDownload=%s", download.GetMode()),
|
||||
fmt.Sprintf("AccessUpload=%s", upload.GetMode()),
|
||||
fmt.Sprintf("DownloadTransfers=%s", strings.Join(dltransfers, ",")),
|
||||
fmt.Sprintf("UploadTransfers=%s", strings.Join(ultransfers, ",")),
|
||||
)
|
||||
|
@ -31,7 +31,7 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e
|
||||
func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) {
|
||||
req.Header = c.client.ExtraHeadersFor(req)
|
||||
|
||||
apiEndpoint, access, credHelper, credsURL, creds, err := c.getCreds(remote, req)
|
||||
access, credHelper, credsURL, creds, err := c.getCreds(remote, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -39,13 +39,13 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques
|
||||
res, err := c.doWithCreds(req, credHelper, creds, credsURL, access, via)
|
||||
if err != nil {
|
||||
if errors.IsAuthError(err) {
|
||||
newAccess := getAuthAccess(res)
|
||||
if newAccess != access {
|
||||
c.Endpoints.SetAccess(apiEndpoint.Url, newAccess)
|
||||
newMode := getAuthAccess(res)
|
||||
if newMode != access.GetMode() {
|
||||
c.Endpoints.SetAccess(access.Upgrade(newMode))
|
||||
}
|
||||
|
||||
if creds != nil || (access == NoneAccess && len(req.Header.Get("Authorization")) == 0) {
|
||||
tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newAccess)
|
||||
if creds != nil || (access.GetMode() == NoneAccess && len(req.Header.Get("Authorization")) == 0) {
|
||||
tracerx.Printf("api: http response indicates %q authentication. Resubmitting...", newMode)
|
||||
if creds != nil {
|
||||
req.Header.Del("Authorization")
|
||||
credHelper.Reject(creds)
|
||||
@ -68,8 +68,8 @@ func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Reques
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access AccessMode, via []*http.Request) (*http.Response, error) {
|
||||
if access == NTLMAccess {
|
||||
func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelper, creds creds.Creds, credsURL *url.URL, access Access, via []*http.Request) (*http.Response, error) {
|
||||
if access.GetMode() == NTLMAccess {
|
||||
return c.doWithNTLM(req, credHelper, creds, credsURL)
|
||||
}
|
||||
|
||||
@ -107,7 +107,7 @@ func (c *Client) doWithCreds(req *http.Request, credHelper creds.CredentialHelpe
|
||||
// 3. The Git Remote URL, which should be something like "https://git.com/repo.git"
|
||||
// This URL is used for the Git Credential Helper. This way existing https
|
||||
// Git remote credentials can be re-used for LFS.
|
||||
func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, AccessMode, creds.CredentialHelper, *url.URL, creds.Creds, error) {
|
||||
func (c *Client) getCreds(remote string, req *http.Request) (Access, creds.CredentialHelper, *url.URL, creds.Creds, error) {
|
||||
ef := c.Endpoints
|
||||
if ef == nil {
|
||||
ef = defaultEndpointFinder
|
||||
@ -122,18 +122,18 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
|
||||
apiEndpoint := ef.Endpoint(operation, remote)
|
||||
access := ef.AccessFor(apiEndpoint.Url)
|
||||
|
||||
if access != NTLMAccess {
|
||||
if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access == NoneAccess {
|
||||
return apiEndpoint, access, creds.NullCreds, nil, nil, nil
|
||||
if access.GetMode() != NTLMAccess {
|
||||
if requestHasAuth(req) || setAuthFromNetrc(netrcFinder, req) || access.GetMode() == NoneAccess {
|
||||
return access, creds.NullCreds, nil, nil, nil
|
||||
}
|
||||
|
||||
credsURL, err := getCredURLForAPI(ef, operation, remote, apiEndpoint, req)
|
||||
if err != nil {
|
||||
return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
|
||||
return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
|
||||
}
|
||||
|
||||
if credsURL == nil {
|
||||
return apiEndpoint, access, creds.NullCreds, nil, nil, nil
|
||||
return access, creds.NullCreds, nil, nil, nil
|
||||
}
|
||||
|
||||
credHelper, creds, err := c.getGitCreds(ef, req, credsURL)
|
||||
@ -141,14 +141,14 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
|
||||
tracerx.Printf("Filled credentials for %s", credsURL)
|
||||
setRequestAuth(req, creds["username"], creds["password"])
|
||||
}
|
||||
return apiEndpoint, access, credHelper, credsURL, creds, err
|
||||
return access, credHelper, credsURL, creds, err
|
||||
}
|
||||
|
||||
// NTLM ONLY
|
||||
|
||||
credsURL, err := url.Parse(apiEndpoint.Url)
|
||||
if err != nil {
|
||||
return apiEndpoint, access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
|
||||
return access, creds.NullCreds, nil, nil, errors.Wrap(err, "creds")
|
||||
}
|
||||
|
||||
if netrcMachine := getAuthFromNetrc(netrcFinder, req); netrcMachine != nil {
|
||||
@ -160,12 +160,12 @@ func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, A
|
||||
"source": "netrc",
|
||||
}
|
||||
|
||||
return apiEndpoint, access, creds.NullCreds, credsURL, cred, nil
|
||||
return access, creds.NullCreds, credsURL, cred, nil
|
||||
}
|
||||
|
||||
// NTLM uses creds to create the session
|
||||
credHelper, creds, err := c.getGitCreds(ef, req, credsURL)
|
||||
return apiEndpoint, access, credHelper, credsURL, creds, err
|
||||
return access, credHelper, credsURL, creds, err
|
||||
}
|
||||
|
||||
func (c *Client) getGitCreds(ef EndpointFinder, req *http.Request, u *url.URL) (creds.CredentialHelper, creds.Creds, error) {
|
||||
|
@ -81,7 +81,7 @@ func TestDoWithAuthApprove(t *testing.T) {
|
||||
require.Nil(t, err)
|
||||
c.Credentials = cred
|
||||
|
||||
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs"))
|
||||
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode)
|
||||
|
||||
req, err := http.NewRequest("POST", srv.URL+"/repo/lfs/foo", nil)
|
||||
require.Nil(t, err)
|
||||
@ -99,7 +99,7 @@ func TestDoWithAuthApprove(t *testing.T) {
|
||||
"protocol": "http",
|
||||
"host": srv.Listener.Addr().String(),
|
||||
})))
|
||||
assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs"))
|
||||
assert.Equal(t, BasicAccess, c.Endpoints.AccessFor(srv.URL+"/repo/lfs").mode)
|
||||
assert.EqualValues(t, 2, called)
|
||||
}
|
||||
|
||||
@ -407,7 +407,7 @@ func TestGetCreds(t *testing.T) {
|
||||
},
|
||||
Expected: getCredsExpected{
|
||||
Access: BasicAccess,
|
||||
Endpoint: "https://user@git-server.com/repo/lfs",
|
||||
Endpoint: "https://git-server.com/repo/lfs",
|
||||
Authorization: basicAuth("user", "monkey"),
|
||||
CredsURL: "https://user@git-server.com/repo/lfs",
|
||||
Creds: map[string]string{
|
||||
@ -450,7 +450,7 @@ func TestGetCreds(t *testing.T) {
|
||||
},
|
||||
Expected: getCredsExpected{
|
||||
Access: BasicAccess,
|
||||
Endpoint: "https://user:pass@git-server.com/repo",
|
||||
Endpoint: "https://git-server.com/repo",
|
||||
Authorization: basicAuth("user", "pass"),
|
||||
},
|
||||
},
|
||||
@ -574,12 +574,12 @@ func TestGetCreds(t *testing.T) {
|
||||
client.Credentials = &fakeCredentialFiller{}
|
||||
client.Netrc = &fakeNetrc{}
|
||||
client.Endpoints = NewEndpointFinder(ctx)
|
||||
endpoint, access, _, credsURL, creds, err := client.getCreds(test.Remote, req)
|
||||
access, _, credsURL, creds, err := client.getCreds(test.Remote, req)
|
||||
if !assert.Nil(t, err) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, test.Expected.Endpoint, endpoint.Url, "endpoint")
|
||||
assert.Equal(t, test.Expected.Access, access, "access")
|
||||
assert.Equal(t, test.Expected.Endpoint, access.url, "endpoint")
|
||||
assert.Equal(t, test.Expected.Access, access.mode, "access")
|
||||
assert.Equal(t, test.Expected.Authorization, req.Header.Get("Authorization"), "authorization")
|
||||
|
||||
if test.Expected.Creds != nil {
|
||||
|
@ -26,14 +26,28 @@ const (
|
||||
defaultRemote = "origin"
|
||||
)
|
||||
|
||||
type Access struct {
|
||||
mode AccessMode
|
||||
url string
|
||||
}
|
||||
|
||||
// Returns a copy of an AccessMode with the mode upgraded to newMode
|
||||
func (a *Access) Upgrade(newMode AccessMode) Access {
|
||||
return Access{url: a.url, mode: newMode}
|
||||
}
|
||||
|
||||
func (a *Access) GetMode() AccessMode {
|
||||
return a.mode
|
||||
}
|
||||
|
||||
type EndpointFinder interface {
|
||||
NewEndpointFromCloneURL(rawurl string) lfshttp.Endpoint
|
||||
NewEndpoint(rawurl string) lfshttp.Endpoint
|
||||
Endpoint(operation, remote string) lfshttp.Endpoint
|
||||
RemoteEndpoint(operation, remote string) lfshttp.Endpoint
|
||||
GitRemoteURL(remote string, forpush bool) string
|
||||
AccessFor(rawurl string) AccessMode
|
||||
SetAccess(rawurl string, access AccessMode)
|
||||
AccessFor(rawurl string) Access
|
||||
SetAccess(access Access)
|
||||
GitProtocol() string
|
||||
}
|
||||
|
||||
@ -202,39 +216,38 @@ func (e *endpointGitFinder) NewEndpoint(rawurl string) lfshttp.Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
func (e *endpointGitFinder) AccessFor(rawurl string) AccessMode {
|
||||
if e.gitEnv == nil {
|
||||
return NoneAccess
|
||||
}
|
||||
|
||||
func (e *endpointGitFinder) AccessFor(rawurl string) Access {
|
||||
accessurl := urlWithoutAuth(rawurl)
|
||||
|
||||
if e.gitEnv == nil {
|
||||
return Access{mode: NoneAccess, url: accessurl}
|
||||
}
|
||||
|
||||
e.accessMu.Lock()
|
||||
defer e.accessMu.Unlock()
|
||||
|
||||
if cached, ok := e.urlAccess[accessurl]; ok {
|
||||
return cached
|
||||
return Access{mode: cached, url: accessurl}
|
||||
}
|
||||
|
||||
e.urlAccess[accessurl] = e.fetchGitAccess(accessurl)
|
||||
return e.urlAccess[accessurl]
|
||||
return Access{mode: e.urlAccess[accessurl], url: accessurl}
|
||||
}
|
||||
|
||||
func (e *endpointGitFinder) SetAccess(rawurl string, access AccessMode) {
|
||||
accessurl := urlWithoutAuth(rawurl)
|
||||
key := fmt.Sprintf("lfs.%s.access", accessurl)
|
||||
tracerx.Printf("setting repository access to %s", access)
|
||||
func (e *endpointGitFinder) SetAccess(access Access) {
|
||||
key := fmt.Sprintf("lfs.%s.access", access.url)
|
||||
tracerx.Printf("setting repository access to %s", access.mode)
|
||||
|
||||
e.accessMu.Lock()
|
||||
defer e.accessMu.Unlock()
|
||||
|
||||
switch access {
|
||||
switch access.mode {
|
||||
case emptyAccess, NoneAccess:
|
||||
e.gitConfig.UnsetLocalKey(key)
|
||||
e.urlAccess[accessurl] = NoneAccess
|
||||
e.urlAccess[access.url] = NoneAccess
|
||||
default:
|
||||
e.gitConfig.SetLocal(key, string(access))
|
||||
e.urlAccess[accessurl] = access
|
||||
e.gitConfig.SetLocal(key, string(access.mode))
|
||||
e.urlAccess[access.url] = access.mode
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -304,10 +304,10 @@ func TestAccessConfig(t *testing.T) {
|
||||
dl := finder.Endpoint("upload", "")
|
||||
ul := finder.Endpoint("download", "")
|
||||
|
||||
if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) {
|
||||
if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) {
|
||||
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
|
||||
}
|
||||
if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) {
|
||||
if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) {
|
||||
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
|
||||
}
|
||||
}
|
||||
@ -325,10 +325,10 @@ func TestAccessConfig(t *testing.T) {
|
||||
dl := finder.Endpoint("upload", "")
|
||||
ul := finder.Endpoint("download", "")
|
||||
|
||||
if access := finder.AccessFor(dl.Url); access != AccessMode(expected.AccessMode) {
|
||||
if access := finder.AccessFor(dl.Url); access.mode != AccessMode(expected.AccessMode) {
|
||||
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
|
||||
}
|
||||
if access := finder.AccessFor(ul.Url); access != AccessMode(expected.AccessMode) {
|
||||
if access := finder.AccessFor(ul.Url); access.mode != AccessMode(expected.AccessMode) {
|
||||
t.Errorf("Expected AccessMode() with value %q to be %v, got %v", value, expected.AccessMode, access)
|
||||
}
|
||||
}
|
||||
@ -336,16 +336,23 @@ func TestAccessConfig(t *testing.T) {
|
||||
|
||||
func TestAccessAbsentConfig(t *testing.T) {
|
||||
finder := NewEndpointFinder(nil)
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url))
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url))
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("download", "").Url).mode)
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor(finder.Endpoint("upload", "").Url).mode)
|
||||
}
|
||||
|
||||
func TestSetAccess(t *testing.T) {
|
||||
finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{}))
|
||||
url := "http://example.com"
|
||||
access := finder.AccessFor(url)
|
||||
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
|
||||
finder.SetAccess("http://example.com", NTLMAccess)
|
||||
assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com"))
|
||||
assert.Equal(t, NoneAccess, access.mode)
|
||||
assert.Equal(t, url, access.url)
|
||||
|
||||
finder.SetAccess(access.Upgrade(NTLMAccess))
|
||||
|
||||
newAccess := finder.AccessFor(url)
|
||||
assert.Equal(t, NTLMAccess, newAccess.mode)
|
||||
assert.Equal(t, url, newAccess.url)
|
||||
}
|
||||
|
||||
func TestChangeAccess(t *testing.T) {
|
||||
@ -353,9 +360,16 @@ func TestChangeAccess(t *testing.T) {
|
||||
"lfs.http://example.com.access": "basic",
|
||||
}))
|
||||
|
||||
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
|
||||
finder.SetAccess("http://example.com", NTLMAccess)
|
||||
assert.Equal(t, NTLMAccess, finder.AccessFor("http://example.com"))
|
||||
url := "http://example.com"
|
||||
access := finder.AccessFor(url)
|
||||
assert.Equal(t, BasicAccess, access.mode)
|
||||
assert.Equal(t, url, access.url)
|
||||
|
||||
finder.SetAccess(access.Upgrade(NTLMAccess))
|
||||
|
||||
newAccess := finder.AccessFor(url)
|
||||
assert.Equal(t, NTLMAccess, newAccess.mode)
|
||||
assert.Equal(t, url, newAccess.url)
|
||||
}
|
||||
|
||||
func TestDeleteAccessWithNone(t *testing.T) {
|
||||
@ -363,9 +377,17 @@ func TestDeleteAccessWithNone(t *testing.T) {
|
||||
"lfs.http://example.com.access": "basic",
|
||||
}))
|
||||
|
||||
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
|
||||
finder.SetAccess("http://example.com", NoneAccess)
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
|
||||
url := "http://example.com"
|
||||
|
||||
access := finder.AccessFor(url)
|
||||
assert.Equal(t, BasicAccess, access.mode)
|
||||
assert.Equal(t, url, access.url)
|
||||
|
||||
finder.SetAccess(access.Upgrade(NoneAccess))
|
||||
|
||||
newAccess := finder.AccessFor(url)
|
||||
assert.Equal(t, NoneAccess, newAccess.mode)
|
||||
assert.Equal(t, url, newAccess.url)
|
||||
}
|
||||
|
||||
func TestDeleteAccessWithEmptyString(t *testing.T) {
|
||||
@ -373,9 +395,17 @@ func TestDeleteAccessWithEmptyString(t *testing.T) {
|
||||
"lfs.http://example.com.access": "basic",
|
||||
}))
|
||||
|
||||
assert.Equal(t, BasicAccess, finder.AccessFor("http://example.com"))
|
||||
finder.SetAccess("http://example.com", AccessMode(""))
|
||||
assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com"))
|
||||
url := "http://example.com"
|
||||
|
||||
access := finder.AccessFor(url)
|
||||
assert.Equal(t, BasicAccess, access.mode)
|
||||
assert.Equal(t, url, access.url)
|
||||
|
||||
finder.SetAccess(access.Upgrade(AccessMode("")))
|
||||
|
||||
newAccess := finder.AccessFor(url)
|
||||
assert.Equal(t, NoneAccess, newAccess.mode)
|
||||
assert.Equal(t, url, newAccess.url)
|
||||
}
|
||||
|
||||
type EndpointParsingTestCase struct {
|
||||
|
@ -764,7 +764,8 @@ func (q *TransferQueue) ensureAdapterBegun(e lfshttp.Endpoint) error {
|
||||
func (q *TransferQueue) toAdapterCfg(e lfshttp.Endpoint) AdapterConfig {
|
||||
apiClient := q.manifest.APIClient()
|
||||
concurrency := q.manifest.ConcurrentTransfers()
|
||||
if apiClient.Endpoints.AccessFor(e.Url) == lfsapi.NTLMAccess {
|
||||
access := apiClient.Endpoints.AccessFor(e.Url)
|
||||
if access.GetMode() == lfsapi.NTLMAccess {
|
||||
concurrency = 1
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user