diff --git a/lfs/client.go b/lfs/client.go index 767568e1..4ae02b9e 100644 --- a/lfs/client.go +++ b/lfs/client.go @@ -55,18 +55,18 @@ type objectResource struct { Error *objectError `json:"error,omitempty"` } -func (o *objectResource) NewRequest(relation, method string) (*http.Request, Creds, error) { +func (o *objectResource) NewRequest(relation, method string) (*http.Request, error) { rel, ok := o.Rel(relation) if !ok { - return nil, nil, errors.New("relation does not exist") + return nil, errors.New("relation does not exist") } - req, creds, err := newClientRequest(method, rel.Href, rel.Header) + req, err := newClientRequest(method, rel.Href, rel.Header) if err != nil { - return nil, nil, err + return nil, err } - return req, creds, nil + return req, nil } func (o *objectResource) Rel(name string) (*linkRelation, bool) { @@ -105,23 +105,22 @@ func (e *ClientError) Error() string { } func Download(oid string) (io.ReadCloser, int64, error) { - req, creds, err := newApiRequest("GET", oid) + req, err := newApiRequest("GET", oid) if err != nil { return nil, 0, Error(err) } - res, obj, err := doApiRequest(req, creds) + res, obj, err := doLegacyApiRequest(req) if err != nil { return nil, 0, err } LogTransfer("lfs.api.download", res) - - req, creds, err = obj.NewRequest("download", "GET") + req, err = obj.NewRequest("download", "GET") if err != nil { return nil, 0, Error(err) } - res, err = doHttpRequest(req, creds) + res, err = doStorageRequest(req) if err != nil { return nil, 0, err } @@ -135,18 +134,18 @@ type byteCloser struct { } func DownloadCheck(oid string) (*objectResource, error) { - req, creds, err := newApiRequest("GET", oid) + req, err := newApiRequest("GET", oid) if err != nil { return nil, Error(err) } - res, obj, err := doApiRequest(req, creds) + res, obj, err := doLegacyApiRequest(req) if err != nil { return nil, err } LogTransfer("lfs.api.download", res) - _, _, err = obj.NewRequest("download", "GET") + _, err = obj.NewRequest("download", "GET") if err != nil { return nil, Error(err) } @@ -155,12 +154,12 @@ func DownloadCheck(oid string) (*objectResource, error) { } func DownloadObject(obj *objectResource) (io.ReadCloser, int64, error) { - req, creds, err := obj.NewRequest("download", "GET") + req, err := obj.NewRequest("download", "GET") if err != nil { return nil, 0, Error(err) } - res, err := doHttpRequest(req, creds) + res, err := doStorageRequest(req) if err != nil { return nil, 0, err } @@ -185,7 +184,7 @@ func Batch(objects []*objectResource, operation string) ([]*objectResource, erro return nil, Error(err) } - req, creds, err := newBatchApiRequest() + req, err := newBatchApiRequest() if err != nil { return nil, Error(err) } @@ -196,7 +195,8 @@ func Batch(objects []*objectResource, operation string) ([]*objectResource, erro req.Body = &byteCloser{bytes.NewReader(by)} tracerx.Printf("api: batch %d files", len(objects)) - res, objs, err := doApiBatchRequest(req, creds) + + res, objs, err := doApiBatchRequest(req) if err != nil { if res == nil { return nil, err @@ -241,7 +241,7 @@ func UploadCheck(oidPath string) (*objectResource, error) { return nil, Error(err) } - req, creds, err := newApiRequest("POST", oid) + req, err := newApiRequest("POST", oid) if err != nil { return nil, Error(err) } @@ -252,7 +252,7 @@ func UploadCheck(oidPath string) (*objectResource, error) { req.Body = &byteCloser{bytes.NewReader(by)} tracerx.Printf("api: uploading (%s)", oid) - res, obj, err := doApiRequest(req, creds) + res, obj, err := doLegacyApiRequest(req) if err != nil { return nil, err } @@ -290,7 +290,7 @@ func UploadObject(o *objectResource, cb CopyCallback) error { Reader: file, } - req, creds, err := o.NewRequest("upload", "PUT") + req, err := o.NewRequest("upload", "PUT") if err != nil { return Error(err) } @@ -298,16 +298,17 @@ func UploadObject(o *objectResource, cb CopyCallback) error { if len(req.Header.Get("Content-Type")) == 0 { req.Header.Set("Content-Type", "application/octet-stream") } + if req.Header.Get("Transfer-Encoding") == "chunked" { req.TransferEncoding = []string{"chunked"} } else { req.Header.Set("Content-Length", strconv.FormatInt(o.Size, 10)) } - req.ContentLength = o.Size + req.ContentLength = o.Size req.Body = ioutil.NopCloser(reader) - res, err := doHttpRequest(req, creds) + res, err := doStorageRequest(req) if err != nil { return err } @@ -324,7 +325,7 @@ func UploadObject(o *objectResource, cb CopyCallback) error { return nil } - req, creds, err = o.NewRequest("verify", "POST") + req, err = o.NewRequest("verify", "POST") if err != nil { return Error(err) } @@ -338,7 +339,7 @@ func UploadObject(o *objectResource, cb CopyCallback) error { req.Header.Set("Content-Length", strconv.Itoa(len(by))) req.ContentLength = int64(len(by)) req.Body = ioutil.NopCloser(bytes.NewReader(by)) - res, err = doHttpRequest(req, creds) + res, err = doAPIRequest(req) if err != nil { return err } @@ -350,6 +351,72 @@ func UploadObject(o *objectResource, cb CopyCallback) error { return err } +// doLegacyApiRequest runs the request to the LFS legacy API. +func doLegacyApiRequest(req *http.Request) (*http.Response, *objectResource, error) { + via := make([]*http.Request, 0, 4) + res, err := doApiRequestWithRedirects(req, via, true) + if err != nil { + return res, nil, err + } + + obj := &objectResource{} + err = decodeApiResponse(res, obj) + + if err != nil { + setErrorResponseContext(err, res) + return nil, nil, err + } + + return res, obj, nil +} + +// doApiBatchRequest runs the request to the LFS batch API. If the API returns a +// 401, the repo will be marked as having private access and the request will be +// re-run. When the repo is marked as having private access, credentials will +// be retrieved. +func doApiBatchRequest(req *http.Request) (*http.Response, []*objectResource, error) { + res, err := doAPIRequest(req) + + if err != nil { + return res, nil, err + } + + var objs map[string][]*objectResource + err = decodeApiResponse(res, &objs) + + if err != nil { + setErrorResponseContext(err, res) + } + + return res, objs["objects"], err +} + +// doStorageREquest runs the request to the storage API from a link provided by +// the "actions" or "_links" properties an LFS API response. +func doStorageRequest(req *http.Request) (*http.Response, error) { + creds, err := getCreds(req) + if err != nil { + return nil, err + } + + return doHttpRequest(req, creds) +} + +// doAPIRequest runs the request to the LFS API, without parsing the response +// body. If the API returns a 401, the repo will be marked as having private +// access and the request will be re-run. When the repo is marked as having +// private access, credentials will be retrieved. +func doAPIRequest(req *http.Request) (*http.Response, error) { + via := make([]*http.Request, 0, 4) + useCreds := true + if req.Method == "GET" || req.Method == "HEAD" { + useCreds = Config.PrivateAccess() + } + return doApiRequestWithRedirects(req, via, useCreds) +} + +// doHttpRequest runs the given HTTP request. LFS or Storage API requests should +// use doApiBatchRequest() or doStorageRequest() instead. func doHttpRequest(req *http.Request, creds Creds) (*http.Response, error) { res, err := Config.HttpClient().Do(req) if res == nil { @@ -364,8 +431,7 @@ func doHttpRequest(req *http.Request, creds Creds) (*http.Response, error) { if err != nil { err = Errorf(err, "Error for %s %s", res.Request.Method, res.Request.URL) } else { - saveCredentials(creds, res) - err = handleResponse(res) + err = handleResponse(res, creds) } if err != nil { @@ -379,7 +445,16 @@ func doHttpRequest(req *http.Request, creds Creds) (*http.Response, error) { return res, err } -func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Request) (*http.Response, error) { +func doApiRequestWithRedirects(req *http.Request, via []*http.Request, useCreds bool) (*http.Response, error) { + var creds Creds + if useCreds { + c, err := getCredsForAPI(req) + if err != nil { + return nil, err + } + creds = c + } + res, err := doHttpRequest(req, creds) if err != nil { return res, err @@ -393,7 +468,7 @@ func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Reque redirectTo = locurl.String() } - redirectedReq, redirectedCreds, err := newClientRequest(req.Method, redirectTo, nil) + redirectedReq, err := newClientRequest(req.Method, redirectTo, nil) if err != nil { return res, Errorf(err, err.Error()) } @@ -421,53 +496,15 @@ func doApiRequestWithRedirects(req *http.Request, creds Creds, via []*http.Reque return res, Errorf(err, err.Error()) } - return doApiRequestWithRedirects(redirectedReq, redirectedCreds, via) + return doApiRequestWithRedirects(redirectedReq, via, useCreds) } return res, nil } -func doApiRequest(req *http.Request, creds Creds) (*http.Response, *objectResource, error) { - via := make([]*http.Request, 0, 4) - res, err := doApiRequestWithRedirects(req, creds, via) - if err != nil { - return res, nil, err - } +func handleResponse(res *http.Response, creds Creds) error { + saveCredentials(creds, res) - obj := &objectResource{} - err = decodeApiResponse(res, obj) - - if err != nil { - setErrorResponseContext(err, res) - return nil, nil, err - } - - return res, obj, nil -} - -// doApiBatchRequest runs the request to the batch API. If the API returns a 401, -// the repo will be marked as having private access and the request will be -// re-run. When the repo is marked as having private access, credentials will -// be retrieved. -func doApiBatchRequest(req *http.Request, creds Creds) (*http.Response, []*objectResource, error) { - via := make([]*http.Request, 0, 4) - res, err := doApiRequestWithRedirects(req, creds, via) - - if err != nil { - return res, nil, err - } - - var objs map[string][]*objectResource - err = decodeApiResponse(res, &objs) - - if err != nil { - setErrorResponseContext(err, res) - } - - return res, objs["objects"], err -} - -func handleResponse(res *http.Response) error { if res.StatusCode < 400 { return nil } @@ -525,19 +562,7 @@ func defaultError(res *http.Response) error { return Error(fmt.Errorf(msgFmt, res.Request.URL)) } -func saveCredentials(creds Creds, res *http.Response) { - if creds == nil { - return - } - - if res.StatusCode < 300 { - execCreds(creds, "approve") - } else if res.StatusCode == 401 { - execCreds(creds, "reject") - } -} - -func newApiRequest(method, oid string) (*http.Request, Creds, error) { +func newApiRequest(method, oid string) (*http.Request, error) { endpoint := Config.Endpoint() objectOid := oid operation := "download" @@ -561,22 +586,22 @@ func newApiRequest(method, oid string) (*http.Request, Creds, error) { u, err := ObjectUrl(endpoint, objectOid) if err != nil { - return nil, nil, err + return nil, err } - req, creds, err := newClientRequest(method, u.String(), res.Header) + req, err := newClientRequest(method, u.String(), res.Header) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("Accept", mediaType) - return req, creds, nil + return req, nil } -func newClientRequest(method, rawurl string, header map[string]string) (*http.Request, Creds, error) { +func newClientRequest(method, rawurl string, header map[string]string) (*http.Request, error) { req, err := http.NewRequest(method, rawurl, nil) if err != nil { - return nil, nil, err + return nil, err } for key, value := range header { @@ -584,15 +609,11 @@ func newClientRequest(method, rawurl string, header map[string]string) (*http.Re } req.Header.Set("User-Agent", UserAgent) - creds, err := getCreds(req) - if err != nil { - return nil, nil, err - } - return req, creds, nil + return req, nil } -func newBatchApiRequest() (*http.Request, Creds, error) { +func newBatchApiRequest() (*http.Request, error) { endpoint := Config.Endpoint() res, err := sshAuthenticate(endpoint, "download", "") @@ -608,12 +629,12 @@ func newBatchApiRequest() (*http.Request, Creds, error) { u, err := ObjectUrl(endpoint, "batch") if err != nil { - return nil, nil, err + return nil, err } - req, creds, err := newBatchClientRequest("POST", u.String()) + req, err := newBatchClientRequest("POST", u.String()) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("Accept", mediaType) @@ -623,78 +644,18 @@ func newBatchApiRequest() (*http.Request, Creds, error) { } } - return req, creds, nil + return req, nil } -func newBatchClientRequest(method, rawurl string) (*http.Request, Creds, error) { +func newBatchClientRequest(method, rawurl string) (*http.Request, error) { req, err := http.NewRequest(method, rawurl, nil) if err != nil { - return nil, nil, err + return nil, err } req.Header.Set("User-Agent", UserAgent) - // Get the creds if we're private - if Config.PrivateAccess() { - // The PrivateAccess() check can be pushed down and this block simplified - // once everything goes through the batch endpoint. - creds, err := getCreds(req) - if err != nil { - return nil, nil, err - } - - return req, creds, nil - } - - return req, nil, nil -} - -func getCreds(req *http.Request) (Creds, error) { - if len(req.Header.Get("Authorization")) > 0 { - return nil, nil - } - - apiUrl, err := Config.ObjectUrl("") - if err != nil { - return nil, err - } - - if req.URL.Scheme != apiUrl.Scheme || - req.URL.Host != apiUrl.Host { - return nil, nil - } - - if setRequestAuthFromUrl(req, apiUrl) { - return nil, nil - } - - credsUrl := apiUrl - if len(Config.CurrentRemote) > 0 { - if u, ok := Config.GitConfig("remote." + Config.CurrentRemote + ".url"); ok { - gitRemoteUrl, err := url.Parse(u) - if err != nil { - return nil, err - } - - if gitRemoteUrl.Scheme == apiUrl.Scheme && - gitRemoteUrl.Host == apiUrl.Host { - - if setRequestAuthFromUrl(req, gitRemoteUrl) { - return nil, nil - } - - credsUrl = gitRemoteUrl - } - } - } - - creds, err := credentials(credsUrl) - if err != nil { - return nil, err - } - - setRequestAuth(req, creds["username"], creds["password"]) - return creds, nil + return req, nil } func setRequestAuthFromUrl(req *http.Request, u *url.URL) bool { @@ -710,6 +671,10 @@ func setRequestAuthFromUrl(req *http.Request, u *url.URL) bool { } func setRequestAuth(req *http.Request, user, pass string) { + if len(user) == 0 && len(pass) == 0 { + return + } + token := fmt.Sprintf("%s:%s", user, pass) auth := "Basic " + base64.URLEncoding.EncodeToString([]byte(token)) req.Header.Set("Authorization", auth) diff --git a/lfs/client_error_test.go b/lfs/client_error_test.go index e4091666..cdf47599 100644 --- a/lfs/client_error_test.go +++ b/lfs/client_error_test.go @@ -13,7 +13,7 @@ import ( func TestSuccessStatus(t *testing.T) { for _, status := range []int{200, 201, 202} { res := &http.Response{StatusCode: status} - if err := handleResponse(res); err != nil { + if err := handleResponse(res, nil); err != nil { t.Errorf("Unexpected error for HTTP %d: %s", status, err.Error()) } } @@ -59,7 +59,7 @@ func TestErrorStatusWithCustomMessage(t *testing.T) { } res.Header.Set("Content-Type", "application/vnd.git-lfs+json; charset=utf-8") - err = handleResponse(res) + err = handleResponse(res, nil) if err == nil { t.Errorf("No error from HTTP %d", status) continue @@ -121,7 +121,7 @@ func TestErrorStatusWithDefaultMessage(t *testing.T) { // purposely wrong content type so it falls back to default res.Header.Set("Content-Type", "application/vnd.git-lfs+json2") - err = handleResponse(res) + err = handleResponse(res, nil) if err == nil { t.Errorf("No error from HTTP %d", status) continue diff --git a/lfs/config.go b/lfs/config.go index b1739b68..fcfb9a97 100644 --- a/lfs/config.go +++ b/lfs/config.go @@ -53,6 +53,7 @@ type Configuration struct { loading sync.Mutex // guards initialization of gitConfig and remotes gitConfig map[string]string + origConfig map[string]string remotes []string extensions map[string]Extension fetchIncludePaths []string @@ -234,11 +235,6 @@ func (c *Configuration) GitConfig(key string) (string, bool) { return value, ok } -func (c *Configuration) SetConfig(key, value string) { - c.loadGitConfig() - c.gitConfig[key] = value -} - func (c *Configuration) ObjectUrl(oid string) (*url.URL, error) { return ObjectUrl(c.Endpoint(), oid) } @@ -290,12 +286,12 @@ func (c *Configuration) FetchPruneConfig() *FetchPruneConfig { return c.fetchPruneConfig } -func (c *Configuration) loadGitConfig() { +func (c *Configuration) loadGitConfig() bool { c.loading.Lock() defer c.loading.Unlock() if c.gitConfig != nil { - return + return false } uniqRemotes := make(map[string]bool) @@ -373,4 +369,6 @@ func (c *Configuration) loadGitConfig() { } c.remotes = append(c.remotes, remote) } + + return true } diff --git a/lfs/config_test.go b/lfs/config_test.go index 580faabb..057486c6 100644 --- a/lfs/config_test.go +++ b/lfs/config_test.go @@ -184,6 +184,7 @@ func TestBareHTTPEndpointAddsLfsSuffix(t *testing.T) { } func TestObjectUrl(t *testing.T) { + defer Config.ResetConfig() tests := map[string]string{ "http://example.com": "http://example.com/objects/oid", "http://example.com/": "http://example.com/objects/oid", @@ -205,6 +206,8 @@ func TestObjectUrl(t *testing.T) { } func TestObjectsUrl(t *testing.T) { + defer Config.ResetConfig() + tests := map[string]string{ "http://example.com": "http://example.com/objects", "http://example.com/": "http://example.com/objects", @@ -391,6 +394,7 @@ func TestLoadInvalidExtension(t *testing.T) { assert.Equal(t, "", ext.Smudge) assert.Equal(t, 0, ext.Priority) } + func TestFetchPruneConfigDefault(t *testing.T) { config := &Configuration{} fp := config.FetchPruneConfig() @@ -416,5 +420,27 @@ func TestFetchPruneConfigCustom(t *testing.T) { assert.Equal(t, 9, fp.FetchRecentCommitsDays) assert.Equal(t, 30, fp.PruneOffsetDays) assert.Equal(t, false, fp.FetchRecentRefsIncludeRemotes) - +} + +// only used for tests +func (c *Configuration) SetConfig(key, value string) { + if c.loadGitConfig() { + c.loading.Lock() + c.origConfig = make(map[string]string) + for k, v := range c.gitConfig { + c.origConfig[k] = v + } + c.loading.Unlock() + } + + c.gitConfig[key] = value +} + +func (c *Configuration) ResetConfig() { + c.loading.Lock() + c.gitConfig = make(map[string]string) + for k, v := range c.origConfig { + c.gitConfig[k] = v + } + c.loading.Unlock() } diff --git a/lfs/credentials.go b/lfs/credentials.go index 7f595a6c..382906d7 100644 --- a/lfs/credentials.go +++ b/lfs/credentials.go @@ -3,69 +3,124 @@ package lfs import ( "bytes" "fmt" + "net/http" "net/url" "os/exec" "strings" ) -type credentialFetcher interface { - Credentials() Creds +// getCreds gets the credentials for the given request's URL, and sets its +// Authorization header with them using Basic Authentication. This is like +// getCredsForAPI(), but skips checking the LFS url or git remote. +func getCreds(req *http.Request) (Creds, error) { + if len(req.Header.Get("Authorization")) > 0 { + return nil, nil + } + + creds, err := fillCredentials(req.URL) + if err != nil { + return nil, Error(err) + } + + setRequestAuth(req, creds["username"], creds["password"]) + return creds, nil } -type credentialFunc func(Creds, string) (credentialFetcher, error) +// getCredsForAPI gets the credentials for LFS API requests and sets the given +// request's Authorization header with them using Basic Authentication. +// 1. Check the LFS URL for authentication. Ex: http://user:pass@example.com +// 2. Check the Git remote URL for authentication IF it's the same scheme and +// host of the LFS URL. +// 3. Ask 'git credential' to fill in the password from one of the above URLs. +// +// This prefers the Git remote URL for checking credentials so that users only +// have to enter their passwords once for Git and Git LFS. It uses the same +// URL path that Git does, in case 'useHttpPath' is enabled in the Git config. +func getCredsForAPI(req *http.Request) (Creds, error) { + if len(req.Header.Get("Authorization")) > 0 { + return nil, nil + } -var execCreds credentialFunc + credsUrl, err := getCredURLForAPI(req) + if err != nil { + return nil, Error(err) + } -func credentials(u *url.URL) (Creds, error) { - path := strings.TrimPrefix(u.Path, "/") - creds := Creds{"protocol": u.Scheme, "host": u.Host, "path": path} - cmd, err := execCreds(creds, "fill") + if credsUrl == nil { + return nil, nil + } + + creds, err := fillCredentials(credsUrl) + if err != nil { + return nil, Error(err) + } + + if creds != nil { + setRequestAuth(req, creds["username"], creds["password"]) + } + + return creds, nil +} + +func getCredURLForAPI(req *http.Request) (*url.URL, error) { + apiUrl, err := Config.ObjectUrl("") if err != nil { return nil, err } - return cmd.Credentials(), nil -} -type CredentialCmd struct { - output *bytes.Buffer - SubCommand string - *exec.Cmd -} - -func NewCommand(input Creds, subCommand string) *CredentialCmd { - buf1 := new(bytes.Buffer) - cmd := exec.Command("git", "credential", subCommand) - - cmd.Stdin = input.Buffer() - cmd.Stdout = buf1 - /* - There is a reason we don't hook up stderr here: - Git's credential cache daemon helper does not close its stderr, so if this - process is the process that fires up the daemon, it will wait forever - (until the daemon exits, really) trying to read from stderr. - - See https://github.com/github/git-lfs/issues/117 for more details. - */ - - return &CredentialCmd{buf1, subCommand, cmd} -} - -func (c *CredentialCmd) StdoutString() string { - return c.output.String() -} - -func (c *CredentialCmd) Credentials() Creds { - creds := make(Creds) - - for _, line := range strings.Split(c.StdoutString(), "\n") { - pieces := strings.SplitN(line, "=", 2) - if len(pieces) < 2 { - continue - } - creds[pieces[0]] = pieces[1] + // if the LFS request doesn't match the current LFS url, don't bother + // attempting to set the Authorization header from the LFS or Git remote URLs. + if req.URL.Scheme != apiUrl.Scheme || + req.URL.Host != apiUrl.Host { + return req.URL, nil } - return creds + if setRequestAuthFromUrl(req, apiUrl) { + return nil, nil + } + + credsUrl := apiUrl + if len(Config.CurrentRemote) > 0 { + if u, ok := Config.GitConfig("remote." + Config.CurrentRemote + ".url"); ok { + gitRemoteUrl, err := url.Parse(u) + if err != nil { + return nil, err + } + + if gitRemoteUrl.Scheme == apiUrl.Scheme && + gitRemoteUrl.Host == apiUrl.Host { + + if setRequestAuthFromUrl(req, gitRemoteUrl) { + return nil, nil + } + + credsUrl = gitRemoteUrl + } + } + } + + return credsUrl, nil +} + +func fillCredentials(u *url.URL) (Creds, error) { + path := strings.TrimPrefix(u.Path, "/") + creds := Creds{"protocol": u.Scheme, "host": u.Host, "path": path} + return execCreds(creds, "fill") +} + +func saveCredentials(creds Creds, res *http.Response) { + if creds == nil { + return + } + + switch res.StatusCode { + case 401, 403: + execCreds(creds, "reject") + default: + if res.StatusCode < 300 { + execCreds(creds, "approve") + } + } } type Creds map[string]string @@ -83,25 +138,54 @@ func (c Creds) Buffer() *bytes.Buffer { return buf } -func init() { - execCreds = func(input Creds, subCommand string) (credentialFetcher, error) { - cmd := NewCommand(input, subCommand) - err := cmd.Start() - if err == nil { - err = cmd.Wait() - } +type credentialFunc func(Creds, string) (Creds, error) - if exitErr, ok := err.(*exec.ExitError); ok { - if exitErr.ProcessState.Success() == false && !Config.GetenvBool("GIT_TERMINAL_PROMPT", true) { - return nil, fmt.Errorf("Change the GIT_TERMINAL_PROMPT env var to be prompted to enter your credentials for %s://%s.", - input["protocol"], input["host"]) - } - } +func execCredsCommand(input Creds, subCommand string) (Creds, error) { + output := new(bytes.Buffer) + cmd := exec.Command("git", "credential", subCommand) + cmd.Stdin = input.Buffer() + cmd.Stdout = output + /* + There is a reason we don't hook up stderr here: + Git's credential cache daemon helper does not close its stderr, so if this + process is the process that fires up the daemon, it will wait forever + (until the daemon exits, really) trying to read from stderr. - if err != nil { - return cmd, fmt.Errorf("'git credential %s' error: %s\n", cmd.SubCommand, err.Error()) - } + See https://github.com/github/git-lfs/issues/117 for more details. + */ - return cmd, nil + err := cmd.Start() + if err == nil { + err = cmd.Wait() } + + if _, ok := err.(*exec.ExitError); ok { + if !Config.GetenvBool("GIT_TERMINAL_PROMPT", true) { + return nil, fmt.Errorf("Change the GIT_TERMINAL_PROMPT env var to be prompted to enter your credentials for %s://%s.", + input["protocol"], input["host"]) + } + + // 'git credential' exits with 128 if the helper doesn't fill the username + // and password values. + if subCommand == "fill" && err.Error() == "exit status 128" { + return input, nil + } + } + + if err != nil { + return nil, fmt.Errorf("'git credential %s' error: %s\n", subCommand, err.Error()) + } + + creds := make(Creds) + for _, line := range strings.Split(output.String(), "\n") { + pieces := strings.SplitN(line, "=", 2) + if len(pieces) < 2 { + continue + } + creds[pieces[0]] = pieces[1] + } + + return creds, nil } + +var execCreds credentialFunc = execCredsCommand diff --git a/lfs/credentials_test.go b/lfs/credentials_test.go index c2ec7617..e5df1a23 100644 --- a/lfs/credentials_test.go +++ b/lfs/credentials_test.go @@ -2,156 +2,222 @@ package lfs import ( "encoding/base64" + "fmt" "net/http" "testing" ) +func TestGetCredentialsForApi(t *testing.T) { + checkGetCredentials(t, getCredsForAPI, []*getCredentialCheck{ + { + Desc: "simple", + Config: map[string]string{"lfs.url": "https://git-server.com"}, + Method: "GET", + Href: "https://git-server.com/foo", + Protocol: "https", + Host: "git-server.com", + Username: "git-server.com", + Password: "monkey", + }, + { + Desc: "auth header", + Config: map[string]string{"lfs.url": "https://git-server.com"}, + Header: map[string]string{"Authorization": "Test monkey"}, + Method: "GET", + Href: "https://git-server.com/foo", + Authorization: "Test monkey", + }, + { + Desc: "scheme mismatch", + Config: map[string]string{"lfs.url": "https://git-server.com"}, + Method: "GET", + Href: "http://git-server.com/foo", + Protocol: "http", + Host: "git-server.com", + Path: "foo", + Username: "git-server.com", + Password: "monkey", + }, + { + Desc: "host mismatch", + Config: map[string]string{"lfs.url": "https://git-server.com"}, + Method: "GET", + Href: "https://git-server2.com/foo", + Protocol: "https", + Host: "git-server2.com", + Path: "foo", + Username: "git-server2.com", + Password: "monkey", + }, + { + Desc: "port mismatch", + Config: map[string]string{"lfs.url": "https://git-server.com"}, + Method: "GET", + Href: "https://git-server.com:8080/foo", + Protocol: "https", + Host: "git-server.com:8080", + Path: "foo", + Username: "git-server.com:8080", + Password: "monkey", + }, + { + Desc: "api url auth", + Config: map[string]string{"lfs.url": "https://testuser:testpass@git-server.com"}, + Method: "GET", + Href: "https://git-server.com/foo", + Authorization: "Basic " + base64.URLEncoding.EncodeToString([]byte("testuser:testpass")), + }, + { + Desc: "git url auth", + CurrentRemote: "origin", + Config: map[string]string{ + "lfs.url": "https://git-server.com", + "remote.origin.url": "https://gituser:gitpass@git-server.com", + }, + Method: "GET", + Href: "https://git-server.com/foo", + Authorization: "Basic " + base64.URLEncoding.EncodeToString([]byte("gituser:gitpass")), + }, + }) +} + func TestGetCredentials(t *testing.T) { - Config.SetConfig("lfs.url", "https://lfs-server.com") - req, err := http.NewRequest("GET", "https://lfs-server.com/foo", nil) - if err != nil { - t.Fatal(err) + checks := []*getCredentialCheck{ + { + Desc: "git server", + Method: "GET", + Href: "https://git-server.com/foo", + Protocol: "https", + Host: "git-server.com", + Username: "git-server.com", + Password: "monkey", + }, + { + Desc: "separate lfs server", + Method: "GET", + Href: "https://lfs-server.com/foo", + Protocol: "https", + Host: "lfs-server.com", + Username: "lfs-server.com", + Password: "monkey", + }, } - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) + // these properties should not change the outcome + for _, check := range checks { + check.CurrentRemote = "origin" + check.Config = map[string]string{ + "lfs.url": "https://testuser:testuser@git-server.com", + "remote.origin.url": "https://gituser:gitpass@git-server.com", + } } - if value := creds["username"]; value != "lfs-server.com" { - t.Errorf("bad username: %s", value) - } + checkGetCredentials(t, getCreds, checks) +} - if value := creds["password"]; value != "monkey" { - t.Errorf("bad password: %s", value) - } +func checkGetCredentials(t *testing.T, getCredsFunc func(*http.Request) (Creds, error), checks []*getCredentialCheck) { + existingRemote := Config.CurrentRemote + for _, check := range checks { + t.Logf("Checking %q", check.Desc) + Config.CurrentRemote = check.CurrentRemote - expected := "Basic " + base64.URLEncoding.EncodeToString([]byte("lfs-server.com:monkey")) - if value := req.Header.Get("Authorization"); value != expected { - t.Errorf("Bad Authorization. Expected '%s', got '%s'", expected, value) + for key, value := range check.Config { + Config.SetConfig(key, value) + } + + req, err := http.NewRequest(check.Method, check.Href, nil) + if err != nil { + t.Errorf("[%s] %s", check.Desc, err) + continue + } + + for key, value := range check.Header { + req.Header.Set(key, value) + } + + creds, err := getCredsFunc(req) + if err != nil { + t.Errorf("[%s] %s", check.Desc, err) + continue + } + + if check.ExpectCreds() { + if creds == nil { + t.Errorf("[%s], no credentials returned", check.Desc) + continue + } + + if value := creds["protocol"]; len(check.Protocol) > 0 && value != check.Protocol { + t.Errorf("[%s] bad protocol: %q, expected: %q", check.Desc, value, check.Protocol) + } + + if value := creds["host"]; len(check.Host) > 0 && value != check.Host { + t.Errorf("[%s] bad host: %q, expected: %q", check.Desc, value, check.Host) + } + + if value := creds["username"]; len(check.Username) > 0 && value != check.Username { + t.Errorf("[%s] bad username: %q, expected: %q", check.Desc, value, check.Username) + } + + if value := creds["password"]; len(check.Password) > 0 && value != check.Password { + t.Errorf("[%s] bad password: %q, expected: %q", check.Desc, value, check.Password) + } + + if value := creds["path"]; len(check.Path) > 0 && value != check.Path { + t.Errorf("[%s] bad path: %q, expected: %q", check.Desc, value, check.Path) + } + } else { + if creds != nil { + t.Errorf("[%s], unexpected credentials: %v // %v", check.Desc, creds, check) + continue + } + } + + if len(check.Authorization) > 0 { + if actual := req.Header.Get("Authorization"); actual != check.Authorization { + t.Errorf("[%s] Unexpected Authorization header: %s", check.Desc, actual) + } + } else { + rawtoken := fmt.Sprintf("%s:%s", check.Username, check.Password) + expected := "Basic " + base64.URLEncoding.EncodeToString([]byte(rawtoken)) + if value := req.Header.Get("Authorization"); value != expected { + t.Errorf("[%s] Bad Authorization. Expected '%s', got '%s'", check.Desc, expected, value) + } + } + + Config.ResetConfig() + Config.CurrentRemote = existingRemote } } -func TestGetCredentialsWithExistingAuthorization(t *testing.T) { - Config.SetConfig("lfs.url", "https://lfs-server.com") - req, err := http.NewRequest("GET", "http://lfs-server.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - req.Header.Set("Authorization", "Test monkey") - - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) - } - - if creds != nil { - t.Errorf("Unexpected creds: %v", creds) - } - - if actual := req.Header.Get("Authorization"); actual != "Test monkey" { - t.Errorf("Unexpected Authorization header: %s", actual) - } +type getCredentialCheck struct { + Desc string + Config map[string]string + Header map[string]string + Method string + Href string + Protocol string + Host string + Username string + Password string + Path string + Authorization string + CurrentRemote string } -func TestGetCredentialsWithSchemeMismatch(t *testing.T) { - Config.SetConfig("lfs.url", "https://lfs-server.com") - req, err := http.NewRequest("GET", "http://lfs-server.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) - } - - if creds != nil { - t.Errorf("Unexpected creds: %v", creds) - } - - if actual := req.Header.Get("Authorization"); actual != "" { - t.Errorf("Unexpected Authorization header: %s", actual) - } -} - -func TestGetCredentialsWithHostMismatch(t *testing.T) { - Config.SetConfig("lfs.url", "https://lfs-server.com") - req, err := http.NewRequest("GET", "https://lfs-server2.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) - } - - if creds != nil { - t.Errorf("Unexpected creds: %v", creds) - } - - if actual := req.Header.Get("Authorization"); actual != "" { - t.Errorf("Unexpected Authorization header: %s", actual) - } -} - -func TestGetCredentialsWithPortMismatch(t *testing.T) { - Config.SetConfig("lfs.url", "https://lfs-server.com") - req, err := http.NewRequest("GET", "https://lfs-server:8080.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) - } - - if creds != nil { - t.Errorf("Unexpected creds: %v", creds) - } - - if actual := req.Header.Get("Authorization"); actual != "" { - t.Errorf("Unexpected Authorization header: %s", actual) - } -} - -func TestGetCredentialsWithRfc1738UsernameAndPassword(t *testing.T) { - Config.SetConfig("lfs.url", "https://testuser:testpass@lfs-server.com") - req, err := http.NewRequest("GET", "https://lfs-server.com/foo", nil) - if err != nil { - t.Fatal(err) - } - - creds, err := getCreds(req) - if err != nil { - t.Fatal(err) - } - - if creds != nil { - t.Errorf("unexpected creds: %v", creds) - } - - expected := "Basic " + base64.URLEncoding.EncodeToString([]byte("testuser:testpass")) - if value := req.Header.Get("Authorization"); value != expected { - t.Errorf("Bad Authorization. Expected '%s', got '%s'", expected, value) - } +func (c *getCredentialCheck) ExpectCreds() bool { + return len(c.Protocol) > 0 || len(c.Host) > 0 || len(c.Username) > 0 || + len(c.Password) > 0 || len(c.Path) > 0 } func init() { - execCreds = func(input Creds, subCommand string) (credentialFetcher, error) { - return &testCredentialFetcher{input}, nil + execCreds = func(input Creds, subCommand string) (Creds, error) { + output := make(Creds) + for key, value := range input { + output[key] = value + } + output["username"] = input["host"] + output["password"] = "monkey" + return output, nil } } - -type testCredentialFetcher struct { - Creds Creds -} - -func (c *testCredentialFetcher) Credentials() Creds { - c.Creds["username"] = c.Creds["host"] - c.Creds["password"] = "monkey" - return c.Creds -} diff --git a/lfs/download_test.go b/lfs/download_test.go index 69440841..092947ef 100644 --- a/lfs/download_test.go +++ b/lfs/download_test.go @@ -72,10 +72,6 @@ func TestSuccessfulDownload(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != expectedAuth(t, server) { - t.Error("Invalid Authorization") - } - if r.Header.Get("A") != "1" { t.Error("invalid A") } @@ -87,6 +83,7 @@ func TestSuccessfulDownload(t *testing.T) { w.Write([]byte("test")) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") reader, size, err := Download("oid") if err != nil { @@ -204,10 +201,6 @@ func TestSuccessfulDownloadWithRedirects(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != expectedAuth(t, server) { - t.Error("Invalid Authorization") - } - if r.Header.Get("A") != "1" { t.Error("invalid A") } @@ -219,6 +212,7 @@ func TestSuccessfulDownloadWithRedirects(t *testing.T) { w.Write([]byte("test")) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/redirect") for _, redirect := range redirectCodes { @@ -324,6 +318,7 @@ func TestSuccessfulDownloadWithAuthorization(t *testing.T) { w.Write([]byte("test")) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") reader, size, err := Download("oid") if err != nil { @@ -412,10 +407,6 @@ func TestSuccessfulDownloadFromSeparateHost(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != "" { - t.Error("Invalid Authorization") - } - if r.Header.Get("A") != "1" { t.Error("invalid A") } @@ -427,6 +418,7 @@ func TestSuccessfulDownloadFromSeparateHost(t *testing.T) { w.Write([]byte("test")) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") reader, size, err := Download("oid") if err != nil { @@ -546,10 +538,6 @@ func TestSuccessfulDownloadFromSeparateRedirectedHost(t *testing.T) { t.Error("Invalid Accept") } - if r.Header.Get("Authorization") != "" { - t.Error("Invalid Authorization") - } - if r.Header.Get("A") != "1" { t.Error("invalid A") } @@ -561,6 +549,7 @@ func TestSuccessfulDownloadFromSeparateRedirectedHost(t *testing.T) { w.Write([]byte("test")) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") for _, redirect := range redirectCodes { @@ -597,6 +586,7 @@ func TestDownloadAPIError(t *testing.T) { w.WriteHeader(404) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") _, _, err := Download("oid") if err == nil { @@ -664,6 +654,7 @@ func TestDownloadStorageError(t *testing.T) { w.WriteHeader(500) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") _, _, err := Download("oid") if err == nil { diff --git a/lfs/upload_test.go b/lfs/upload_test.go index e0fe1c49..74a4af00 100644 --- a/lfs/upload_test.go +++ b/lfs/upload_test.go @@ -103,6 +103,7 @@ func TestExistingUpload(t *testing.T) { w.WriteHeader(200) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -226,6 +227,7 @@ func TestUploadWithRedirect(t *testing.T) { w.Write(by) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/redirect") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -399,6 +401,7 @@ func TestSuccessfulUploadWithVerify(t *testing.T) { w.WriteHeader(200) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -559,6 +562,7 @@ func TestSuccessfulUploadWithoutVerify(t *testing.T) { w.WriteHeader(200) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -603,6 +607,7 @@ func TestUploadApiError(t *testing.T) { w.WriteHeader(404) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -710,6 +715,7 @@ func TestUploadStorageError(t *testing.T) { w.WriteHeader(404) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") @@ -858,6 +864,7 @@ func TestUploadVerifyError(t *testing.T) { w.WriteHeader(404) }) + defer Config.ResetConfig() Config.SetConfig("lfs.url", server.URL+"/media") oidPath, _ := LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") diff --git a/test/cmd/git-credential-lfstest.go b/test/cmd/git-credential-lfstest.go index 4bcddd03..2301123a 100644 --- a/test/cmd/git-credential-lfstest.go +++ b/test/cmd/git-credential-lfstest.go @@ -6,7 +6,6 @@ import ( "io/ioutil" "os" "path/filepath" - "regexp" "strings" ) @@ -17,10 +16,7 @@ var ( "erase": noop, } - delim = '\n' - - hostRE = regexp.MustCompile(`\A127.0.0.1:\d+\z`) - + delim = '\n' credsDir = "" ) @@ -74,12 +70,14 @@ func fill() { os.Exit(1) } - if _, ok := creds["username"]; !ok { - creds["username"] = user - } + if user != "skip" { + if _, ok := creds["username"]; !ok { + creds["username"] = user + } - if _, ok := creds["password"]; !ok { - creds["password"] = pass + if _, ok := creds["password"]; !ok { + creds["password"] = pass + } } for key, value := range creds { diff --git a/test/test-credentials.sh b/test/test-credentials.sh index c562148a..d705c6a2 100755 --- a/test/test-credentials.sh +++ b/test/test-credentials.sh @@ -2,54 +2,7 @@ . "test/testlib.sh" -begin_test "git credential" -( - set -e - - printf "git:server" > "$CREDSDIR/credential-test.com" - printf "git:path" > "$CREDSDIR/credential-test.com--some-path" - - mkdir empty - cd empty - git init - - echo "protocol=http -host=credential-test.com" | GIT_TERMINAL_PROMPT=0 git credential fill | tee cred.log - - expected="protocol=http -host=credential-test.com -username=git -password=server" - [ "$expected" = "$(cat cred.log)" ] - - echo "protocol=http -host=credential-test.com -path=some/path" | GIT_TERMINAL_PROMPT=0 git credential fill | tee cred.log - - expected="protocol=http -host=credential-test.com -username=git -password=server" - - [ "$expected" = "$(cat cred.log)" ] - - git config credential.useHttpPath true - - echo "protocol=http -host=credential-test.com -path=some/path" | GIT_TERMINAL_PROMPT=0 git credential fill | tee cred.log - - expected="protocol=http -host=credential-test.com -path=some/path -username=git -password=path" - - [ "$expected" = "$(cat cred.log)" ] -) -end_test - -begin_test "credentials without useHttpPath, with wrong password" +begin_test "credentials without useHttpPath, with wrong path password" ( set -e @@ -136,3 +89,53 @@ begin_test "credentials with useHttpPath, with correct password" grep "(1 of 1 files)" push.log ) end_test + +begin_test "git credential" +( + set -e + + printf "git:server" > "$CREDSDIR/credential-test.com" + printf "git:path" > "$CREDSDIR/credential-test.com--some-path" + + mkdir empty + cd empty + git init + + echo "protocol=http +host=credential-test.com" | GIT_TERMINAL_PROMPT=0 git credential fill > cred.log + cat cred.log + + expected="protocol=http +host=credential-test.com +username=git +password=server" + [ "$expected" = "$(cat cred.log)" ] + + echo "protocol=http +host=credential-test.com +path=some/path" | GIT_TERMINAL_PROMPT=0 git credential fill > cred.log + cat cred.log + + expected="protocol=http +host=credential-test.com +username=git +password=server" + + [ "$expected" = "$(cat cred.log)" ] + + git config credential.useHttpPath true + + echo "protocol=http +host=credential-test.com +path=some/path" | GIT_TERMINAL_PROMPT=0 git credential fill > cred.log + cat cred.log + + expected="protocol=http +host=credential-test.com +path=some/path +username=git +password=path" + + [ "$expected" = "$(cat cred.log)" ] +) +end_test