From d101bdb605f314dc66e73e08237f570de1287b8d Mon Sep 17 00:00:00 2001 From: Preben Ingvaldsen Date: Thu, 6 Sep 2018 14:42:41 -0700 Subject: [PATCH 1/2] lfsapi: extract new lfshttp package Extract more basic http-related functionality out of lfsapi and into a new package, lfshttp. Everything is currently functional aside from authorization. --- Makefile | 1 + commands/command_version.go | 4 +- commands/lockverifier.go | 4 +- lfs/lfs.go | 2 +- lfsapi/auth.go | 7 +- lfsapi/auth_test.go | 81 +++- lfsapi/client.go | 433 ++-------------------- lfsapi/creds.go | 4 +- lfsapi/endpoint_finder.go | 50 +-- lfsapi/endpoint_finder_test.go | 91 +++-- lfsapi/endpoint_test.go | 23 -- lfsapi/lfsapi.go | 178 +-------- lfsapi/ntlm_test.go | 3 +- lfshttp/auth.go | 9 + lfshttp/body.go | 39 ++ {lfsapi => lfshttp}/certs.go | 2 +- {lfsapi => lfshttp}/certs_darwin.go | 2 +- {lfsapi => lfshttp}/certs_freebsd.go | 2 +- {lfsapi => lfshttp}/certs_linux.go | 2 +- {lfsapi => lfshttp}/certs_openbsd.go | 2 +- {lfsapi => lfshttp}/certs_test.go | 2 +- {lfsapi => lfshttp}/certs_windows.go | 2 +- lfshttp/client.go | 533 +++++++++++++++++++++++++++ {lfsapi => lfshttp}/client_test.go | 90 +---- {lfsapi => lfshttp}/endpoint.go | 77 ++-- {lfsapi => lfshttp}/errors.go | 2 +- lfshttp/lfshttp.go | 91 +++++ {lfsapi => lfshttp}/proxy.go | 5 +- {lfsapi => lfshttp}/proxy_test.go | 2 +- {lfsapi => lfshttp}/retries.go | 2 +- {lfsapi => lfshttp}/retries_test.go | 2 +- {lfsapi => lfshttp}/ssh.go | 2 +- {lfsapi => lfshttp}/ssh_test.go | 52 +-- {lfsapi => lfshttp}/stats.go | 2 +- {lfsapi => lfshttp}/stats_test.go | 2 +- {lfsapi => lfshttp}/verbose.go | 2 +- {lfsapi => lfshttp}/verbose_test.go | 2 +- locking/api.go | 17 +- locking/api_test.go | 23 +- locking/locks_test.go | 5 +- tq/api.go | 11 +- tq/api_test.go | 5 +- tq/custom_test.go | 9 +- tq/manifest_test.go | 9 +- tq/transfer_queue.go | 7 +- tq/transfer_test.go | 3 +- tq/verify_test.go | 3 +- 47 files changed, 1020 insertions(+), 881 deletions(-) delete mode 100644 lfsapi/endpoint_test.go create mode 100644 lfshttp/auth.go create mode 100644 lfshttp/body.go rename {lfsapi => lfshttp}/certs.go (99%) rename {lfsapi => lfshttp}/certs_darwin.go (99%) rename {lfsapi => lfshttp}/certs_freebsd.go (91%) rename {lfsapi => lfshttp}/certs_linux.go (91%) rename {lfsapi => lfshttp}/certs_openbsd.go (91%) rename {lfsapi => lfshttp}/certs_test.go (99%) rename {lfsapi => lfshttp}/certs_windows.go (92%) create mode 100644 lfshttp/client.go rename {lfsapi => lfshttp}/client_test.go (77%) rename {lfsapi => lfshttp}/endpoint.go (82%) rename {lfsapi => lfshttp}/errors.go (99%) create mode 100644 lfshttp/lfshttp.go rename {lfsapi => lfshttp}/proxy.go (99%) rename {lfsapi => lfshttp}/proxy_test.go (99%) rename {lfsapi => lfshttp}/retries.go (98%) rename {lfsapi => lfshttp}/retries_test.go (99%) rename {lfsapi => lfshttp}/ssh.go (99%) rename {lfsapi => lfshttp}/ssh_test.go (93%) rename {lfsapi => lfshttp}/stats.go (99%) rename {lfsapi => lfshttp}/stats_test.go (99%) rename {lfsapi => lfshttp}/verbose.go (99%) rename {lfsapi => lfshttp}/verbose_test.go (99%) diff --git a/Makefile b/Makefile index b89f643e..666f38e5 100644 --- a/Makefile +++ b/Makefile @@ -89,6 +89,7 @@ PKGS += git/githistory PKGS += git PKGS += lfs PKGS += lfsapi +PKGS += lfshttp PKGS += locking PKGS += subprocess PKGS += tasklog diff --git a/commands/command_version.go b/commands/command_version.go index bde27cd1..c59aa892 100644 --- a/commands/command_version.go +++ b/commands/command_version.go @@ -1,7 +1,7 @@ package commands import ( - "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/spf13/cobra" ) @@ -10,7 +10,7 @@ var ( ) func versionCommand(cmd *cobra.Command, args []string) { - Print(lfsapi.UserAgent) + Print(lfshttp.UserAgent) if lovesComics { Print("Nothing may see Gah Lak Tus and survive!") diff --git a/commands/lockverifier.go b/commands/lockverifier.go index b402f9ca..60a322ce 100644 --- a/commands/lockverifier.go +++ b/commands/lockverifier.go @@ -9,7 +9,7 @@ import ( "github.com/git-lfs/git-lfs/config" "github.com/git-lfs/git-lfs/errors" "github.com/git-lfs/git-lfs/git" - "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/git-lfs/git-lfs/locking" "github.com/git-lfs/git-lfs/tq" ) @@ -30,7 +30,7 @@ func verifyLocksForUpdates(lv *lockVerifier, updates []*git.RefUpdate) { // lockVerifier verifies locked files before updating one or more refs. type lockVerifier struct { - endpoint lfsapi.Endpoint + endpoint lfshttp.Endpoint verifyState verifyState verifiedRefs map[string]bool diff --git a/lfs/lfs.go b/lfs/lfs.go index 9b4b43dd..45b0e086 100644 --- a/lfs/lfs.go +++ b/lfs/lfs.go @@ -44,7 +44,7 @@ func Environ(cfg *config.Configuration, manifest *tq.Manifest) []string { fmt.Sprintf("LocalMediaDir=%s", cfg.LFSObjectDir()), fmt.Sprintf("LocalReferenceDirs=%s", references), fmt.Sprintf("TempDir=%s", cfg.TempDir()), - fmt.Sprintf("ConcurrentTransfers=%d", api.ConcurrentTransfers), + fmt.Sprintf("ConcurrentTransfers=%d", api.ConcurrentTransfers()), fmt.Sprintf("TusTransfers=%v", cfg.TusTransfersAllowed()), fmt.Sprintf("BasicTransfersOnly=%v", cfg.BasicTransfersOnly()), fmt.Sprintf("SkipDownloadErrors=%v", cfg.SkipDownloadErrors()), diff --git a/lfsapi/auth.go b/lfsapi/auth.go index 7d53e49c..10f2a47d 100644 --- a/lfsapi/auth.go +++ b/lfsapi/auth.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/git-lfs/git-lfs/errors" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/git-lfs/go-netrc/netrc" "github.com/rubyist/tracerx" ) @@ -28,7 +29,7 @@ func (c *Client) DoWithAuth(remote string, req *http.Request) (*http.Response, e } func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) { - req.Header = c.extraHeadersFor(req) + req.Header = c.client.ExtraHeadersFor(req) apiEndpoint, access, credHelper, credsURL, creds, err := c.getCreds(remote, req) if err != nil { @@ -94,7 +95,7 @@ func (c *Client) doWithCreds(req *http.Request, credHelper CredentialHelper, cre // 3. The Git Remote URL, which should be something like "https://git.com/repo.git" // This URL is used for the Git Credential Helper. This way existing https // Git remote credentials can be re-used for LFS. -func (c *Client) getCreds(remote string, req *http.Request) (Endpoint, Access, CredentialHelper, *url.URL, Creds, error) { +func (c *Client) getCreds(remote string, req *http.Request) (lfshttp.Endpoint, Access, CredentialHelper, *url.URL, Creds, error) { ef := c.Endpoints if ef == nil { ef = defaultEndpointFinder @@ -198,7 +199,7 @@ func setAuthFromNetrc(netrcFinder NetrcFinder, req *http.Request) bool { return false } -func getCredURLForAPI(ef EndpointFinder, operation, remote string, apiEndpoint Endpoint, req *http.Request) (*url.URL, error) { +func getCredURLForAPI(ef EndpointFinder, operation, remote string, apiEndpoint lfshttp.Endpoint, req *http.Request) (*url.URL, error) { apiURL, err := url.Parse(apiEndpoint.Url) if err != nil { return nil, err diff --git a/lfsapi/auth_test.go b/lfsapi/auth_test.go index 3dfb1ca1..39253d13 100644 --- a/lfsapi/auth_test.go +++ b/lfsapi/auth_test.go @@ -12,6 +12,7 @@ import ( "github.com/git-lfs/git-lfs/errors" "github.com/git-lfs/git-lfs/git" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -73,7 +74,7 @@ func TestDoWithAuthApprove(t *testing.T) { defer srv.Close() creds := newMockCredentialHelper() - c, err := NewClient(NewContext(nil, nil, map[string]string{ + c, err := NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/repo/lfs", })) require.Nil(t, err) @@ -143,7 +144,7 @@ func TestDoWithAuthReject(t *testing.T) { c, _ := NewClient(nil) c.Credentials = creds - c.Endpoints = NewEndpointFinder(NewContext(nil, nil, map[string]string{ + c.Endpoints = NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL, })) @@ -567,7 +568,7 @@ func TestGetCreds(t *testing.T) { req.Header.Set(key, value) } - ctx := NewContext(git.NewConfig("", ""), nil, test.Config) + ctx := lfshttp.NewContext(git.NewConfig("", ""), nil, test.Config) client, _ := NewClient(ctx) client.Credentials = &fakeCredentialFiller{} client.Netrc = &fakeNetrc{} @@ -617,3 +618,77 @@ func (f *fakeCredentialFiller) Approve(creds Creds) error { func (f *fakeCredentialFiller) Reject(creds Creds) error { return errors.New("Not implemented") } + +// func TestClientRedirectReauthenticate(t *testing.T) { +// var srv1, srv2 *httptest.Server +// var called1, called2 uint32 +// var creds1, creds2 Creds + +// srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// atomic.AddUint32(&called1, 1) + +// if hdr := r.Header.Get("Authorization"); len(hdr) > 0 { +// parts := strings.SplitN(hdr, " ", 2) +// typ, b64 := parts[0], parts[1] + +// auth, err := base64.URLEncoding.DecodeString(b64) +// assert.Nil(t, err) +// assert.Equal(t, "Basic", typ) +// assert.Equal(t, "user1:pass1", string(auth)) + +// http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently) +// return +// } +// w.WriteHeader(http.StatusUnauthorized) +// })) + +// srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// atomic.AddUint32(&called2, 1) + +// parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) +// typ, b64 := parts[0], parts[1] + +// auth, err := base64.URLEncoding.DecodeString(b64) +// assert.Nil(t, err) +// assert.Equal(t, "Basic", typ) +// assert.Equal(t, "user2:pass2", string(auth)) +// })) + +// // Change the URL of srv2 to make it appears as if it is a different +// // host. +// srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1) + +// creds1 = Creds(map[string]string{ +// "protocol": "http", +// "host": strings.TrimPrefix(srv1.URL, "http://"), + +// "username": "user1", +// "password": "pass1", +// }) +// creds2 = Creds(map[string]string{ +// "protocol": "http", +// "host": strings.TrimPrefix(srv2.URL, "http://"), + +// "username": "user2", +// "password": "pass2", +// }) + +// defer srv1.Close() +// defer srv2.Close() + +// c, err := NewClient(lfshttp.NewContext(nil, nil, nil)) +// creds := newCredentialCacher() +// creds.Approve(creds1) +// creds.Approve(creds2) +// c.Credentials = creds + +// req, err := http.NewRequest("GET", srv1.URL, nil) +// require.Nil(t, err) + +// _, err = c.DoWithAuth("", req) +// assert.Nil(t, err) + +// // called1 is 2 since LFS tries an unauthenticated request first +// assert.EqualValues(t, 2, called1) +// assert.EqualValues(t, 1, called2) +// } diff --git a/lfsapi/client.go b/lfsapi/client.go index ede1f3b3..fa5f31bc 100644 --- a/lfsapi/client.go +++ b/lfsapi/client.go @@ -1,431 +1,50 @@ package lfsapi import ( - "context" - "crypto/tls" - "fmt" "io" - "net" "net/http" - "net/textproto" - "net/url" - "os" - "regexp" - "strconv" - "strings" - "time" "github.com/git-lfs/git-lfs/config" - "github.com/git-lfs/git-lfs/errors" - "github.com/git-lfs/git-lfs/tools" - "github.com/rubyist/tracerx" + "github.com/git-lfs/git-lfs/lfshttp" ) -const MediaType = "application/vnd.git-lfs+json; charset=utf-8" - -var ( - UserAgent = "git-lfs" - httpRE = regexp.MustCompile(`\Ahttps?://`) -) - -var hintFileUrl = strings.TrimSpace(` -hint: The remote resolves to a file:// URL, which can only work with a -hint: standalone transfer agent. See section "Using a Custom Transfer Type -hint: without the API server" in custom-transfers.md for details. -`) - -func (c *Client) NewRequest(method string, e Endpoint, suffix string, body interface{}) (*http.Request, error) { - if strings.HasPrefix(e.Url, "file://") { - // Initial `\n` to avoid overprinting `Downloading LFS...`. - fmt.Fprintf(os.Stderr, "\n%s\n", hintFileUrl) - } - - sshRes, err := c.sshResolveWithRetries(e, method) - if err != nil { - return nil, err - } - - prefix := e.Url - if len(sshRes.Href) > 0 { - prefix = sshRes.Href - } - - if !httpRE.MatchString(prefix) { - urlfragment := strings.SplitN(prefix, "?", 2)[0] - return nil, fmt.Errorf("missing protocol: %q", urlfragment) - } - - req, err := http.NewRequest(method, joinURL(prefix, suffix), nil) - if err != nil { - return req, err - } - - for key, value := range sshRes.Header { - req.Header.Set(key, value) - } - req.Header.Set("Accept", MediaType) - - if body != nil { - if merr := MarshalToRequest(req, body); merr != nil { - return req, merr - } - req.Header.Set("Content-Type", MediaType) - } - - return req, err -} - -const slash = "/" - -func joinURL(prefix, suffix string) string { - if strings.HasSuffix(prefix, slash) { - return prefix + suffix - } - return prefix + slash + suffix +func (c *Client) NewRequest(method string, e lfshttp.Endpoint, suffix string, body interface{}) (*http.Request, error) { + return c.client.NewRequest(method, e, suffix, body) } // Do sends an HTTP request to get an HTTP response. It wraps net/http, adding // extra headers, redirection handling, and error reporting. func (c *Client) Do(req *http.Request) (*http.Response, error) { - req.Header = c.extraHeadersFor(req) - - return c.do(req, "", nil) + return c.client.Do(req) } // do performs an *http.Request respecting redirects, and handles the response // as defined in c.handleResponse. Notably, it does not alter the headers for // the request argument in any way. func (c *Client) do(req *http.Request, remote string, via []*http.Request) (*http.Response, error) { - req.Header.Set("User-Agent", UserAgent) - - res, err := c.doWithRedirects(c.httpClient(req.Host), req, remote, via) - if err != nil { - return res, err - } - - return res, c.handleResponse(res) + return c.client.Do(req) +} + +func (c *Client) LogRequest(r *http.Request, reqKey string) *http.Request { + return c.client.LogRequest(r, reqKey) +} + +func (c *Client) GitEnv() config.Environment { + return c.client.GitEnv() +} + +func (c *Client) OSEnv() config.Environment { + return c.client.OSEnv() +} + +func (c *Client) ConcurrentTransfers() int { + return c.client.ConcurrentTransfers +} + +func (c *Client) LogHTTPStats(w io.WriteCloser) { + c.client.LogHTTPStats(w) } -// Close closes any resources that this client opened. func (c *Client) Close() error { - return c.httpLogger.Close() -} - -func (c *Client) sshResolveWithRetries(e Endpoint, method string) (*sshAuthResponse, error) { - var sshRes sshAuthResponse - var err error - - requests := tools.MaxInt(0, c.sshTries) + 1 - for i := 0; i < requests; i++ { - sshRes, err = c.SSH.Resolve(e, method) - if err == nil { - return &sshRes, nil - } - - tracerx.Printf( - "ssh: %s failed, error: %s, message: %s (try: %d/%d)", - e.SshUserAndHost, err.Error(), sshRes.Message, i, - requests, - ) - } - - if len(sshRes.Message) > 0 { - return nil, errors.Wrap(err, sshRes.Message) - } - return nil, err -} - -func (c *Client) extraHeadersFor(req *http.Request) http.Header { - extraHeaders := c.extraHeaders(req.URL) - if len(extraHeaders) == 0 { - return req.Header - } - - copy := make(http.Header, len(req.Header)) - for k, vs := range req.Header { - copy[k] = vs - } - - for k, vs := range extraHeaders { - for _, v := range vs { - copy[k] = append(copy[k], v) - } - } - return copy -} - -func (c *Client) extraHeaders(u *url.URL) map[string][]string { - hdrs := c.uc.GetAll("http", u.String(), "extraHeader") - m := make(map[string][]string, len(hdrs)) - - for _, hdr := range hdrs { - parts := strings.SplitN(hdr, ":", 2) - if len(parts) < 2 { - continue - } - - k, v := parts[0], strings.TrimSpace(parts[1]) - // If header keys are given in non-canonicalized form (e.g., - // "AUTHORIZATION" as opposed to "Authorization") they will not - // be returned in calls to net/http.Header.Get(). - // - // So, we avoid this problem by first canonicalizing header keys - // for extra headers. - k = textproto.CanonicalMIMEHeaderKey(k) - - m[k] = append(m[k], v) - } - return m -} - -func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Response, error) { - tracedReq, err := c.traceRequest(req) - if err != nil { - return nil, err - } - - var retries int - if n, ok := Retries(req); ok { - retries = n - } else { - retries = defaultRequestRetries - } - - var res *http.Response - - requests := tools.MaxInt(0, retries) + 1 - for i := 0; i < requests; i++ { - res, err = cli.Do(req) - if err == nil { - break - } - - if seek, ok := req.Body.(io.Seeker); ok { - seek.Seek(0, io.SeekStart) - } - - c.traceResponse(req, tracedReq, nil) - } - - if err != nil { - c.traceResponse(req, tracedReq, nil) - return nil, err - } - - if res == nil { - return nil, nil - } - - c.traceResponse(req, tracedReq, res) - - if res.StatusCode != 301 && - res.StatusCode != 302 && - res.StatusCode != 303 && - res.StatusCode != 307 && - res.StatusCode != 308 { - - // Above are the list of 3xx status codes that we know - // how to handle below. If the status code contained in - // the HTTP response was none of them, return the (res, - // err) tuple as-is, otherwise handle the redirect. - return res, err - } - - redirectTo := res.Header.Get("Location") - locurl, err := url.Parse(redirectTo) - if err == nil && !locurl.IsAbs() { - locurl = req.URL.ResolveReference(locurl) - redirectTo = locurl.String() - } - - via = append(via, req) - if len(via) >= 3 { - return res, errors.New("too many redirects") - } - - redirectedReq, err := newRequestForRetry(req, redirectTo) - if err != nil { - return res, err - } - - if len(req.Header.Get("Authorization")) > 0 { - // If the original request was authenticated (noted by the - // presence of the Authorization header), then recur through - // doWithAuth, retaining the requests via but only after - // authenticating the redirected request. - return c.doWithAuth(remote, redirectedReq, via) - } - return c.doWithRedirects(cli, redirectedReq, remote, via) -} - -func (c *Client) httpClient(host string) *http.Client { - c.clientMu.Lock() - defer c.clientMu.Unlock() - - if c.gitEnv == nil { - c.gitEnv = make(testEnv) - } - - if c.osEnv == nil { - c.osEnv = make(testEnv) - } - - if c.hostClients == nil { - c.hostClients = make(map[string]*http.Client) - } - - if client, ok := c.hostClients[host]; ok { - return client - } - - concurrentTransfers := c.ConcurrentTransfers - if concurrentTransfers < 1 { - concurrentTransfers = 8 - } - - dialtime := c.DialTimeout - if dialtime < 1 { - dialtime = 30 - } - - keepalivetime := c.KeepaliveTimeout - if keepalivetime < 1 { - keepalivetime = 1800 - } - - tlstime := c.TLSTimeout - if tlstime < 1 { - tlstime = 30 - } - - tr := &http.Transport{ - Proxy: proxyFromClient(c), - TLSHandshakeTimeout: time.Duration(tlstime) * time.Second, - MaxIdleConnsPerHost: concurrentTransfers, - } - - activityTimeout := 30 - if v, ok := c.uc.Get("lfs", fmt.Sprintf("https://%v", host), "activitytimeout"); ok { - if i, err := strconv.Atoi(v); err == nil { - activityTimeout = i - } else { - activityTimeout = 0 - } - } - - dialer := &net.Dialer{ - Timeout: time.Duration(dialtime) * time.Second, - KeepAlive: time.Duration(keepalivetime) * time.Second, - DualStack: true, - } - - if activityTimeout > 0 { - activityDuration := time.Duration(activityTimeout) * time.Second - tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - c, err := dialer.DialContext(ctx, network, addr) - if c == nil { - return c, err - } - if tc, ok := c.(*net.TCPConn); ok { - tc.SetKeepAlive(true) - tc.SetKeepAlivePeriod(dialer.KeepAlive) - } - return &deadlineConn{Timeout: activityDuration, Conn: c}, err - } - } else { - tr.DialContext = dialer.DialContext - } - - tr.TLSClientConfig = &tls.Config{} - - if isClientCertEnabledForHost(c, host) { - tracerx.Printf("http: client cert for %s", host) - tr.TLSClientConfig.Certificates = []tls.Certificate{getClientCertForHost(c, host)} - tr.TLSClientConfig.BuildNameToCertificate() - } - - if isCertVerificationDisabledForHost(c, host) { - tr.TLSClientConfig.InsecureSkipVerify = true - } else { - tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host) - } - - httpClient := &http.Client{ - Transport: tr, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, - } - - c.hostClients[host] = httpClient - if c.VerboseOut == nil { - c.VerboseOut = os.Stderr - } - - return httpClient -} - -func (c *Client) CurrentUser() (string, string) { - userName, _ := c.gitEnv.Get("user.name") - userEmail, _ := c.gitEnv.Get("user.email") - return userName, userEmail -} - -func newRequestForRetry(req *http.Request, location string) (*http.Request, error) { - newReq, err := http.NewRequest(req.Method, location, nil) - if err != nil { - return nil, err - } - - if req.URL.Scheme == "https" && newReq.URL.Scheme == "http" { - return nil, errors.New("lfsapi/client: refusing insecure redirect, https->http") - } - - sameHost := req.URL.Host == newReq.URL.Host - for key := range req.Header { - if key == "Authorization" { - if !sameHost { - continue - } - } - newReq.Header.Set(key, req.Header.Get(key)) - } - - oldestURL := strings.SplitN(req.URL.String(), "?", 2)[0] - newURL := strings.SplitN(newReq.URL.String(), "?", 2)[0] - tracerx.Printf("api: redirect %s %s to %s", req.Method, oldestURL, newURL) - - // This body will have already been rewound from a call to - // lfsapi.Client.traceRequest(). - newReq.Body = req.Body - newReq.ContentLength = req.ContentLength - - // Copy the request's context.Context, if any. - newReq = newReq.WithContext(req.Context()) - - return newReq, nil -} - -type deadlineConn struct { - Timeout time.Duration - net.Conn -} - -func (c *deadlineConn) Read(b []byte) (int, error) { - if err := c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { - return 0, err - } - return c.Conn.Read(b) -} - -func (c *deadlineConn) Write(b []byte) (int, error) { - if err := c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { - return 0, err - } - - return c.Conn.Write(b) -} - -func init() { - UserAgent = config.VersionDesc + return c.client.Close() } diff --git a/lfsapi/creds.go b/lfsapi/creds.go index 2cd3aa0f..76496e60 100644 --- a/lfsapi/creds.go +++ b/lfsapi/creds.go @@ -49,7 +49,7 @@ func (c *Client) getCredentialHelper(u *url.URL) (CredentialHelper, Creds) { if u.User != nil && u.User.Username() != "" { input["username"] = u.User.Username() } - if c.uc.Bool("credential", rawurl, "usehttppath", false) { + if c.client.URLConfig().Bool("credential", rawurl, "usehttppath", false) { input["path"] = strings.TrimPrefix(u.Path, "/") } @@ -62,7 +62,7 @@ func (c *Client) getCredentialHelper(u *url.URL) (CredentialHelper, Creds) { helpers = append(helpers, c.cachingCredHelper) } if c.askpassCredHelper != nil { - helper, _ := c.uc.Get("credential", rawurl, "helper") + helper, _ := c.client.URLConfig().Get("credential", rawurl, "helper") if len(helper) == 0 { helpers = append(helpers, c.askpassCredHelper) } diff --git a/lfsapi/endpoint_finder.go b/lfsapi/endpoint_finder.go index 11f93aa2..dbea5dd0 100644 --- a/lfsapi/endpoint_finder.go +++ b/lfsapi/endpoint_finder.go @@ -10,6 +10,7 @@ import ( "github.com/git-lfs/git-lfs/config" "github.com/git-lfs/git-lfs/git" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/rubyist/tracerx" ) @@ -26,10 +27,10 @@ const ( ) type EndpointFinder interface { - NewEndpointFromCloneURL(rawurl string) Endpoint - NewEndpoint(rawurl string) Endpoint - Endpoint(operation, remote string) Endpoint - RemoteEndpoint(operation, remote string) Endpoint + NewEndpointFromCloneURL(rawurl string) lfshttp.Endpoint + NewEndpoint(rawurl string) lfshttp.Endpoint + Endpoint(operation, remote string) lfshttp.Endpoint + RemoteEndpoint(operation, remote string) lfshttp.Endpoint GitRemoteURL(remote string, forpush bool) string AccessFor(rawurl string) Access SetAccess(rawurl string, access Access) @@ -49,9 +50,9 @@ type endpointGitFinder struct { urlConfig *config.URLConfig } -func NewEndpointFinder(ctx Context) EndpointFinder { +func NewEndpointFinder(ctx lfshttp.Context) EndpointFinder { if ctx == nil { - ctx = NewContext(nil, nil, nil) + ctx = lfshttp.NewContext(nil, nil, nil) } e := &endpointGitFinder{ @@ -71,15 +72,15 @@ func NewEndpointFinder(ctx Context) EndpointFinder { return e } -func (e *endpointGitFinder) Endpoint(operation, remote string) Endpoint { +func (e *endpointGitFinder) Endpoint(operation, remote string) lfshttp.Endpoint { ep := e.getEndpoint(operation, remote) ep.Operation = operation return ep } -func (e *endpointGitFinder) getEndpoint(operation, remote string) Endpoint { +func (e *endpointGitFinder) getEndpoint(operation, remote string) lfshttp.Endpoint { if e.gitEnv == nil { - return Endpoint{} + return lfshttp.Endpoint{} } if operation == "upload" { @@ -101,9 +102,9 @@ func (e *endpointGitFinder) getEndpoint(operation, remote string) Endpoint { return e.RemoteEndpoint(operation, defaultRemote) } -func (e *endpointGitFinder) RemoteEndpoint(operation, remote string) Endpoint { +func (e *endpointGitFinder) RemoteEndpoint(operation, remote string) lfshttp.Endpoint { if e.gitEnv == nil { - return Endpoint{} + return lfshttp.Endpoint{} } if len(remote) == 0 { @@ -125,7 +126,7 @@ func (e *endpointGitFinder) RemoteEndpoint(operation, remote string) Endpoint { return e.NewEndpointFromCloneURL(url) } - return Endpoint{} + return lfshttp.Endpoint{} } func (e *endpointGitFinder) GitRemoteURL(remote string, forpush bool) string { @@ -148,9 +149,9 @@ func (e *endpointGitFinder) GitRemoteURL(remote string, forpush bool) string { return "" } -func (e *endpointGitFinder) NewEndpointFromCloneURL(rawurl string) Endpoint { +func (e *endpointGitFinder) NewEndpointFromCloneURL(rawurl string) lfshttp.Endpoint { ep := e.NewEndpoint(rawurl) - if ep.Url == UrlUnknown { + if ep.Url == lfshttp.UrlUnknown { return ep } @@ -168,36 +169,36 @@ func (e *endpointGitFinder) NewEndpointFromCloneURL(rawurl string) Endpoint { return ep } -func (e *endpointGitFinder) NewEndpoint(rawurl string) Endpoint { +func (e *endpointGitFinder) NewEndpoint(rawurl string) lfshttp.Endpoint { rawurl = e.ReplaceUrlAlias(rawurl) if strings.HasPrefix(rawurl, "/") { - return endpointFromLocalPath(rawurl) + return lfshttp.EndpointFromLocalPath(rawurl) } u, err := url.Parse(rawurl) if err != nil { - return endpointFromBareSshUrl(rawurl) + return lfshttp.EndpointFromBareSshUrl(rawurl) } switch u.Scheme { case "ssh": - return endpointFromSshUrl(u) + return lfshttp.EndpointFromSshUrl(u) case "http", "https": - return endpointFromHttpUrl(u) + return lfshttp.EndpointFromHttpUrl(u) case "git": return endpointFromGitUrl(u, e) case "": - return endpointFromBareSshUrl(u.String()) + return lfshttp.EndpointFromBareSshUrl(u.String()) default: if strings.HasPrefix(rawurl, u.Scheme+"::") { // Looks like a remote helper; just pass it through. - return Endpoint{Url: rawurl} + return lfshttp.Endpoint{Url: rawurl} } // We probably got here because the "scheme" that was parsed is // a hostname (whether FQDN or single word) and the URL parser // didn't know what to do with it. Do what Git does and treat // it as an SSH URL. This ensures we handle SSH config aliases // properly. - return endpointFromBareSshUrl(u.String()) + return lfshttp.EndpointFromBareSshUrl(u.String()) } } @@ -305,3 +306,8 @@ func initAliases(e *endpointGitFinder, git config.Environment) { e.aliases[gitval[len(gitval)-1]] = gitkey[len(prefix) : len(gitkey)-len(suffix)] } } + +func endpointFromGitUrl(u *url.URL, e *endpointGitFinder) lfshttp.Endpoint { + u.Scheme = e.gitProtocol + return lfshttp.Endpoint{Url: u.String()} +} diff --git a/lfsapi/endpoint_finder_test.go b/lfsapi/endpoint_finder_test.go index 478c84d2..3283ce7a 100644 --- a/lfsapi/endpoint_finder_test.go +++ b/lfsapi/endpoint_finder_test.go @@ -3,11 +3,12 @@ package lfsapi import ( "testing" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" ) func TestEndpointDefaultsToOrigin(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.lfsurl": "abc", })) @@ -18,7 +19,7 @@ func TestEndpointDefaultsToOrigin(t *testing.T) { } func TestEndpointOverridesOrigin(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "abc", "remote.origin.lfsurl": "def", })) @@ -30,7 +31,7 @@ func TestEndpointOverridesOrigin(t *testing.T) { } func TestEndpointNoOverrideDefaultRemote(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.lfsurl": "abc", "remote.other.lfsurl": "def", })) @@ -42,7 +43,7 @@ func TestEndpointNoOverrideDefaultRemote(t *testing.T) { } func TestEndpointUseAlternateRemote(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.lfsurl": "abc", "remote.other.lfsurl": "def", })) @@ -54,7 +55,7 @@ func TestEndpointUseAlternateRemote(t *testing.T) { } func TestEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "https://example.com/foo/bar", })) @@ -65,7 +66,7 @@ func TestEndpointAddsLfsSuffix(t *testing.T) { } func TestBareEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "https://example.com/foo/bar.git", })) @@ -76,7 +77,7 @@ func TestBareEndpointAddsLfsSuffix(t *testing.T) { } func TestEndpointSeparateClonePushUrl(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "https://example.com/foo/bar.git", "remote.origin.pushurl": "https://readwrite.com/foo/bar.git", })) @@ -93,7 +94,7 @@ func TestEndpointSeparateClonePushUrl(t *testing.T) { } func TestEndpointOverriddenSeparateClonePushLfsUrl(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "https://example.com/foo/bar.git", "remote.origin.pushurl": "https://readwrite.com/foo/bar.git", "remote.origin.lfsurl": "https://examplelfs.com/foo/bar", @@ -112,7 +113,7 @@ func TestEndpointOverriddenSeparateClonePushLfsUrl(t *testing.T) { } func TestEndpointGlobalSeparateLfsPush(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "https://readonly.com/foo/bar", "lfs.pushurl": "https://write.com/foo/bar", })) @@ -129,7 +130,7 @@ func TestEndpointGlobalSeparateLfsPush(t *testing.T) { } func TestSSHEndpointOverridden(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "git@example.com:foo/bar", "remote.origin.lfsurl": "lfs", })) @@ -142,7 +143,7 @@ func TestSSHEndpointOverridden(t *testing.T) { } func TestSSHEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "ssh://git@example.com/foo/bar", })) @@ -154,7 +155,7 @@ func TestSSHEndpointAddsLfsSuffix(t *testing.T) { } func TestSSHCustomPortEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "ssh://git@example.com:9000/foo/bar", })) @@ -166,7 +167,7 @@ func TestSSHCustomPortEndpointAddsLfsSuffix(t *testing.T) { } func TestBareSSHEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "git@example.com:foo/bar.git", })) @@ -178,7 +179,7 @@ func TestBareSSHEndpointAddsLfsSuffix(t *testing.T) { } func TestBareSSSHEndpointWithCustomPortInBrackets(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "[git@example.com:2222]:foo/bar.git", })) @@ -190,7 +191,7 @@ func TestBareSSSHEndpointWithCustomPortInBrackets(t *testing.T) { } func TestSSHEndpointFromGlobalLfsUrl(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "git@example.com:foo/bar.git", })) @@ -202,7 +203,7 @@ func TestSSHEndpointFromGlobalLfsUrl(t *testing.T) { } func TestHTTPEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "http://example.com/foo/bar", })) @@ -214,7 +215,7 @@ func TestHTTPEndpointAddsLfsSuffix(t *testing.T) { } func TestBareHTTPEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "http://example.com/foo/bar.git", })) @@ -226,7 +227,7 @@ func TestBareHTTPEndpointAddsLfsSuffix(t *testing.T) { } func TestGitEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "git://example.com/foo/bar", })) @@ -238,7 +239,7 @@ func TestGitEndpointAddsLfsSuffix(t *testing.T) { } func TestGitEndpointAddsLfsSuffixWithCustomProtocol(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "git://example.com/foo/bar", "lfs.gitprotocol": "http", })) @@ -251,7 +252,7 @@ func TestGitEndpointAddsLfsSuffixWithCustomProtocol(t *testing.T) { } func TestBareGitEndpointAddsLfsSuffix(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "git://example.com/foo/bar.git", })) @@ -263,7 +264,7 @@ func TestBareGitEndpointAddsLfsSuffix(t *testing.T) { } func TestLocalPathEndpointAddsDotGitDir(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "/local/path", })) e := finder.Endpoint("download", "") @@ -271,7 +272,7 @@ func TestLocalPathEndpointAddsDotGitDir(t *testing.T) { } func TestLocalPathEndpointPreservesDotGit(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "remote.origin.url": "/local/path.git", })) e := finder.Endpoint("download", "") @@ -294,7 +295,7 @@ func TestAccessConfig(t *testing.T) { } for value, expected := range tests { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "http://example.com", "lfs.http://example.com.access": value, "lfs.https://example.com.access": "bad", @@ -313,7 +314,7 @@ func TestAccessConfig(t *testing.T) { // Test again but with separate push url for value, expected := range tests { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "http://example.com", "lfs.pushurl": "http://examplepush.com", "lfs.http://example.com.access": value, @@ -340,7 +341,7 @@ func TestAccessAbsentConfig(t *testing.T) { } func TestSetAccess(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{})) + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{})) assert.Equal(t, NoneAccess, finder.AccessFor("http://example.com")) finder.SetAccess("http://example.com", NTLMAccess) @@ -348,7 +349,7 @@ func TestSetAccess(t *testing.T) { } func TestChangeAccess(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.http://example.com.access": "basic", })) @@ -358,7 +359,7 @@ func TestChangeAccess(t *testing.T) { } func TestDeleteAccessWithNone(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.http://example.com.access": "basic", })) @@ -368,7 +369,7 @@ func TestDeleteAccessWithNone(t *testing.T) { } func TestDeleteAccessWithEmptyString(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.http://example.com.access": "basic", })) @@ -379,11 +380,11 @@ func TestDeleteAccessWithEmptyString(t *testing.T) { type EndpointParsingTestCase struct { Given string - Expected Endpoint + Expected lfshttp.Endpoint } func (c *EndpointParsingTestCase) Assert(t *testing.T) { - finder := NewEndpointFinder(NewContext(nil, nil, map[string]string{ + finder := NewEndpointFinder(lfshttp.NewContext(nil, nil, map[string]string{ "url.https://github.com/.insteadof": "gh:", })) actual := finder.NewEndpoint(c.Given) @@ -396,7 +397,7 @@ func TestEndpointParsing(t *testing.T) { for desc, c := range map[string]EndpointParsingTestCase{ "simple bare ssh": { "git@github.com:git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "https://github.com/git-lfs/git-lfs.git", SshUserAndHost: "git@github.com", SshPath: "git-lfs/git-lfs.git", @@ -406,7 +407,7 @@ func TestEndpointParsing(t *testing.T) { }, "port bare ssh": { "[git@ssh.github.com:443]:git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "https://ssh.github.com/git-lfs/git-lfs.git", SshUserAndHost: "git@ssh.github.com", SshPath: "git-lfs/git-lfs.git", @@ -416,7 +417,7 @@ func TestEndpointParsing(t *testing.T) { }, "no user bare ssh": { "github.com:git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "https://github.com/git-lfs/git-lfs.git", SshUserAndHost: "github.com", SshPath: "git-lfs/git-lfs.git", @@ -426,7 +427,7 @@ func TestEndpointParsing(t *testing.T) { }, "bare word bare ssh": { "github:git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "https://github/git-lfs/git-lfs.git", SshUserAndHost: "github", SshPath: "git-lfs/git-lfs.git", @@ -436,7 +437,7 @@ func TestEndpointParsing(t *testing.T) { }, "insteadof alias": { "gh:git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "https://github.com/git-lfs/git-lfs.git", SshUserAndHost: "", SshPath: "", @@ -446,7 +447,7 @@ func TestEndpointParsing(t *testing.T) { }, "remote helper": { "remote::git-lfs/git-lfs.git", - Endpoint{ + lfshttp.Endpoint{ Url: "remote::git-lfs/git-lfs.git", SshUserAndHost: "", SshPath: "", @@ -458,3 +459,21 @@ func TestEndpointParsing(t *testing.T) { t.Run(desc, c.Assert) } } + +func TestNewEndpointFromCloneURLWithConfig(t *testing.T) { + expected := "https://foo/bar.git/info/lfs" + tests := []string{ + "https://foo/bar", + "https://foo/bar/", + "https://foo/bar.git", + "https://foo/bar.git/", + } + + finder := NewEndpointFinder(nil) + for _, actual := range tests { + e := finder.NewEndpointFromCloneURL(actual) + if e.Url != expected { + t.Errorf("%s returned bad endpoint url %s", actual, e.Url) + } + } +} diff --git a/lfsapi/endpoint_test.go b/lfsapi/endpoint_test.go deleted file mode 100644 index 5bf208bc..00000000 --- a/lfsapi/endpoint_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package lfsapi - -import ( - "testing" -) - -func TestNewEndpointFromCloneURLWithConfig(t *testing.T) { - expected := "https://foo/bar.git/info/lfs" - tests := []string{ - "https://foo/bar", - "https://foo/bar/", - "https://foo/bar.git", - "https://foo/bar.git/", - } - - finder := NewEndpointFinder(nil) - for _, actual := range tests { - e := finder.NewEndpointFromCloneURL(actual) - if e.Url != expected { - t.Errorf("%s returned bad endpoint url %s", actual, e.Url) - } - } -} diff --git a/lfsapi/lfsapi.go b/lfsapi/lfsapi.go index f512fd16..2c4daed3 100644 --- a/lfsapi/lfsapi.go +++ b/lfsapi/lfsapi.go @@ -1,69 +1,32 @@ package lfsapi import ( - "encoding/json" "fmt" - "io" - "net/http" - "regexp" "sync" "github.com/ThomsonReutersEikon/go-ntlm/ntlm" - "github.com/git-lfs/git-lfs/config" "github.com/git-lfs/git-lfs/errors" - "github.com/git-lfs/git-lfs/git" -) - -var ( - lfsMediaTypeRE = regexp.MustCompile(`\Aapplication/vnd\.git\-lfs\+json(;|\z)`) - jsonMediaTypeRE = regexp.MustCompile(`\Aapplication/json(;|\z)`) + "github.com/git-lfs/git-lfs/lfshttp" ) type Client struct { Endpoints EndpointFinder Credentials CredentialHelper - SSH SSHResolver Netrc NetrcFinder - DialTimeout int - KeepaliveTimeout int - TLSTimeout int - ConcurrentTransfers int - SkipSSLVerify bool - - Verbose bool - DebuggingVerbose bool - VerboseOut io.Writer - - hostClients map[string]*http.Client - clientMu sync.Mutex - ntlmSessions map[string]ntlm.ClientSession ntlmMu sync.Mutex - httpLogger *syncLogger - - LoggingStats bool // DEPRECATED - commandCredHelper *commandCredentialHelper askpassCredHelper *AskPassCredentialHelper cachingCredHelper *credentialCacher - gitEnv config.Environment - osEnv config.Environment - uc *config.URLConfig - sshTries int + client *lfshttp.Client } -type Context interface { - GitConfig() *git.Configuration - OSEnv() config.Environment - GitEnv() config.Environment -} - -func NewClient(ctx Context) (*Client, error) { +func NewClient(ctx lfshttp.Context) (*Client, error) { if ctx == nil { - ctx = NewContext(nil, nil, nil) + ctx = lfshttp.NewContext(nil, nil, nil) } gitEnv := ctx.GitEnv() @@ -73,30 +36,18 @@ func NewClient(ctx Context) (*Client, error) { return nil, errors.Wrap(err, fmt.Sprintf("bad netrc file %s", netrcfile)) } - cacheCreds := gitEnv.Bool("lfs.cachecredentials", true) - var sshResolver SSHResolver = &sshAuthClient{os: osEnv, git: gitEnv} - if cacheCreds { - sshResolver = withSSHCache(sshResolver) + httpClient, err := lfshttp.NewClient(ctx) + if err != nil { + return nil, errors.Wrap(err, fmt.Sprintf("error creating http client")) } c := &Client{ - Endpoints: NewEndpointFinder(ctx), - SSH: sshResolver, - Netrc: netrc, - DialTimeout: gitEnv.Int("lfs.dialtimeout", 0), - KeepaliveTimeout: gitEnv.Int("lfs.keepalive", 0), - TLSTimeout: gitEnv.Int("lfs.tlstimeout", 0), - ConcurrentTransfers: gitEnv.Int("lfs.concurrenttransfers", 3), - SkipSSLVerify: !gitEnv.Bool("http.sslverify", true) || osEnv.Bool("GIT_SSL_NO_VERIFY", false), - Verbose: osEnv.Bool("GIT_CURL_VERBOSE", false), - DebuggingVerbose: osEnv.Bool("LFS_DEBUG_HTTP", false), + Endpoints: NewEndpointFinder(ctx), + Netrc: netrc, commandCredHelper: &commandCredentialHelper{ SkipPrompt: osEnv.Bool("GIT_TERMINAL_PROMPT", false), }, - gitEnv: gitEnv, - osEnv: osEnv, - uc: config.NewURLConfig(gitEnv), - sshTries: gitEnv.Int("lfs.ssh.retries", 5), + client: httpClient, } askpass, ok := osEnv.Get("GIT_ASKPASS") @@ -112,117 +63,10 @@ func NewClient(ctx Context) (*Client, error) { } } + cacheCreds := gitEnv.Bool("lfs.cachecredentials", true) if cacheCreds { c.cachingCredHelper = newCredentialCacher() } return c, nil } - -func (c *Client) GitEnv() config.Environment { - return c.gitEnv -} - -func (c *Client) OSEnv() config.Environment { - return c.osEnv -} - -func IsDecodeTypeError(err error) bool { - _, ok := err.(*decodeTypeError) - return ok -} - -type decodeTypeError struct { - Type string -} - -func (e *decodeTypeError) TypeError() {} - -func (e *decodeTypeError) Error() string { - return fmt.Sprintf("Expected json type, got: %q", e.Type) -} - -func DecodeJSON(res *http.Response, obj interface{}) error { - ctype := res.Header.Get("Content-Type") - if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) { - return &decodeTypeError{Type: ctype} - } - - err := json.NewDecoder(res.Body).Decode(obj) - res.Body.Close() - - if err != nil { - return errors.Wrapf(err, "Unable to parse HTTP response for %s %s", res.Request.Method, res.Request.URL) - } - - return nil -} - -type testContext struct { - gitConfig *git.Configuration - osEnv config.Environment - gitEnv config.Environment -} - -func (c *testContext) GitConfig() *git.Configuration { - return c.gitConfig -} - -func (c *testContext) OSEnv() config.Environment { - return c.osEnv -} - -func (c *testContext) GitEnv() config.Environment { - return c.gitEnv -} - -func NewContext(gitConf *git.Configuration, osEnv, gitEnv map[string]string) Context { - c := &testContext{gitConfig: gitConf} - if c.gitConfig == nil { - c.gitConfig = git.NewConfig("", "") - } - if osEnv != nil { - c.osEnv = testEnv(osEnv) - } else { - c.osEnv = make(testEnv) - } - - if gitEnv != nil { - c.gitEnv = testEnv(gitEnv) - } else { - c.gitEnv = make(testEnv) - } - return c -} - -type testEnv map[string]string - -func (e testEnv) Get(key string) (v string, ok bool) { - v, ok = e[key] - return -} - -func (e testEnv) GetAll(key string) []string { - if v, ok := e.Get(key); ok { - return []string{v} - } - return make([]string, 0) -} - -func (e testEnv) Int(key string, def int) int { - s, _ := e.Get(key) - return config.Int(s, def) -} - -func (e testEnv) Bool(key string, def bool) bool { - s, _ := e.Get(key) - return config.Bool(s, def) -} - -func (e testEnv) All() map[string][]string { - m := make(map[string][]string) - for k, _ := range e { - m[k] = e.GetAll(k) - } - return m -} diff --git a/lfsapi/ntlm_test.go b/lfsapi/ntlm_test.go index 8dced99c..5884c2ce 100644 --- a/lfsapi/ntlm_test.go +++ b/lfsapi/ntlm_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/ThomsonReutersEikon/go-ntlm/ntlm" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -94,7 +95,7 @@ func TestNtlmAuth(t *testing.T) { require.Nil(t, err) credHelper := newMockCredentialHelper() - cli, err := NewClient(NewContext(nil, nil, map[string]string{ + cli, err := NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/ntlm", "lfs." + srv.URL + "/ntlm.access": "ntlm", })) diff --git a/lfshttp/auth.go b/lfshttp/auth.go new file mode 100644 index 00000000..c84f7843 --- /dev/null +++ b/lfshttp/auth.go @@ -0,0 +1,9 @@ +package lfshttp + +import ( + "net/http" +) + +func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) { + return c.do(req, remote, via) +} diff --git a/lfshttp/body.go b/lfshttp/body.go new file mode 100644 index 00000000..6666377c --- /dev/null +++ b/lfshttp/body.go @@ -0,0 +1,39 @@ +package lfshttp + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "strconv" +) + +type ReadSeekCloser interface { + io.Seeker + io.ReadCloser +} + +func MarshalToRequest(req *http.Request, obj interface{}) error { + by, err := json.Marshal(obj) + if err != nil { + return err + } + + clen := len(by) + req.Header.Set("Content-Length", strconv.Itoa(clen)) + req.ContentLength = int64(clen) + req.Body = NewByteBody(by) + return nil +} + +func NewByteBody(by []byte) ReadSeekCloser { + return &closingByteReader{Reader: bytes.NewReader(by)} +} + +type closingByteReader struct { + *bytes.Reader +} + +func (r *closingByteReader) Close() error { + return nil +} diff --git a/lfsapi/certs.go b/lfshttp/certs.go similarity index 99% rename from lfsapi/certs.go rename to lfshttp/certs.go index d245ea77..940d0205 100644 --- a/lfsapi/certs.go +++ b/lfshttp/certs.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "crypto/tls" diff --git a/lfsapi/certs_darwin.go b/lfshttp/certs_darwin.go similarity index 99% rename from lfsapi/certs_darwin.go rename to lfshttp/certs_darwin.go index 513678fc..68b0efe2 100644 --- a/lfsapi/certs_darwin.go +++ b/lfshttp/certs_darwin.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "crypto/x509" diff --git a/lfsapi/certs_freebsd.go b/lfshttp/certs_freebsd.go similarity index 91% rename from lfsapi/certs_freebsd.go rename to lfshttp/certs_freebsd.go index 989f4ee9..4c8e80a5 100644 --- a/lfsapi/certs_freebsd.go +++ b/lfshttp/certs_freebsd.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import "crypto/x509" diff --git a/lfsapi/certs_linux.go b/lfshttp/certs_linux.go similarity index 91% rename from lfsapi/certs_linux.go rename to lfshttp/certs_linux.go index 989f4ee9..4c8e80a5 100644 --- a/lfsapi/certs_linux.go +++ b/lfshttp/certs_linux.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import "crypto/x509" diff --git a/lfsapi/certs_openbsd.go b/lfshttp/certs_openbsd.go similarity index 91% rename from lfsapi/certs_openbsd.go rename to lfshttp/certs_openbsd.go index 989f4ee9..4c8e80a5 100644 --- a/lfsapi/certs_openbsd.go +++ b/lfshttp/certs_openbsd.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import "crypto/x509" diff --git a/lfsapi/certs_test.go b/lfshttp/certs_test.go similarity index 99% rename from lfsapi/certs_test.go rename to lfshttp/certs_test.go index 2551723b..eb3f0374 100644 --- a/lfsapi/certs_test.go +++ b/lfshttp/certs_test.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "fmt" diff --git a/lfsapi/certs_windows.go b/lfshttp/certs_windows.go similarity index 92% rename from lfsapi/certs_windows.go rename to lfshttp/certs_windows.go index d2f8a890..5fb3b808 100644 --- a/lfsapi/certs_windows.go +++ b/lfshttp/certs_windows.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import "crypto/x509" diff --git a/lfshttp/client.go b/lfshttp/client.go new file mode 100644 index 00000000..42ce6c84 --- /dev/null +++ b/lfshttp/client.go @@ -0,0 +1,533 @@ +package lfshttp + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/textproto" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/git-lfs/git-lfs/config" + "github.com/git-lfs/git-lfs/errors" + "github.com/git-lfs/git-lfs/tools" + "github.com/rubyist/tracerx" +) + +const MediaType = "application/vnd.git-lfs+json; charset=utf-8" + +var ( + UserAgent = "git-lfs" + httpRE = regexp.MustCompile(`\Ahttps?://`) +) + +var hintFileUrl = strings.TrimSpace(` +hint: The remote resolves to a file:// URL, which can only work with a +hint: standalone transfer agent. See section "Using a Custom Transfer Type +hint: without the API server" in custom-transfers.md for details. +`) + +type Client struct { + SSH SSHResolver + + DialTimeout int + KeepaliveTimeout int + TLSTimeout int + ConcurrentTransfers int + SkipSSLVerify bool + + Verbose bool + DebuggingVerbose bool + VerboseOut io.Writer + + hostClients map[string]*http.Client + clientMu sync.Mutex + + httpLogger *syncLogger + + gitEnv config.Environment + osEnv config.Environment + uc *config.URLConfig + + sshTries int +} + +func NewClient(ctx Context) (*Client, error) { + if ctx == nil { + ctx = NewContext(nil, nil, nil) + } + + gitEnv := ctx.GitEnv() + osEnv := ctx.OSEnv() + + cacheCreds := gitEnv.Bool("lfs.cachecredentials", true) + var sshResolver SSHResolver = &sshAuthClient{os: osEnv, git: gitEnv} + if cacheCreds { + sshResolver = withSSHCache(sshResolver) + } + + c := &Client{ + SSH: sshResolver, + DialTimeout: gitEnv.Int("lfs.dialtimeout", 0), + KeepaliveTimeout: gitEnv.Int("lfs.keepalive", 0), + TLSTimeout: gitEnv.Int("lfs.tlstimeout", 0), + ConcurrentTransfers: gitEnv.Int("lfs.concurrenttransfers", 3), + SkipSSLVerify: !gitEnv.Bool("http.sslverify", true) || osEnv.Bool("GIT_SSL_NO_VERIFY", false), + Verbose: osEnv.Bool("GIT_CURL_VERBOSE", false), + DebuggingVerbose: osEnv.Bool("LFS_DEBUG_HTTP", false), + gitEnv: gitEnv, + osEnv: osEnv, + uc: config.NewURLConfig(gitEnv), + sshTries: gitEnv.Int("lfs.ssh.retries", 5), + } + + return c, nil +} + +func (c *Client) GitEnv() config.Environment { + return c.gitEnv +} + +func (c *Client) OSEnv() config.Environment { + return c.osEnv +} + +func (c *Client) URLConfig() *config.URLConfig { + return c.uc +} + +func (c *Client) NewRequest(method string, e Endpoint, suffix string, body interface{}) (*http.Request, error) { + if strings.HasPrefix(e.Url, "file://") { + // Initial `\n` to avoid overprinting `Downloading LFS...`. + fmt.Fprintf(os.Stderr, "\n%s\n", hintFileUrl) + } + + sshRes, err := c.sshResolveWithRetries(e, method) + if err != nil { + return nil, err + } + + prefix := e.Url + if len(sshRes.Href) > 0 { + prefix = sshRes.Href + } + + if !httpRE.MatchString(prefix) { + urlfragment := strings.SplitN(prefix, "?", 2)[0] + return nil, fmt.Errorf("missing protocol: %q", urlfragment) + } + + req, err := http.NewRequest(method, joinURL(prefix, suffix), nil) + if err != nil { + return req, err + } + + for key, value := range sshRes.Header { + req.Header.Set(key, value) + } + req.Header.Set("Accept", MediaType) + + if body != nil { + if merr := MarshalToRequest(req, body); merr != nil { + return req, merr + } + req.Header.Set("Content-Type", MediaType) + } + + return req, err +} + +const slash = "/" + +func joinURL(prefix, suffix string) string { + if strings.HasSuffix(prefix, slash) { + return prefix + suffix + } + return prefix + slash + suffix +} + +// Do sends an HTTP request to get an HTTP response. It wraps net/http, adding +// extra headers, redirection handling, and error reporting. +func (c *Client) Do(req *http.Request) (*http.Response, error) { + req.Header = c.ExtraHeadersFor(req) + + return c.do(req, "", nil) +} + +// do performs an *http.Request respecting redirects, and handles the response +// as defined in c.handleResponse. Notably, it does not alter the headers for +// the request argument in any way. +func (c *Client) do(req *http.Request, remote string, via []*http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", UserAgent) + + res, err := c.doWithRedirects(c.httpClient(req.Host), req, remote, via) + if err != nil { + return res, err + } + + return res, c.handleResponse(res) +} + +// Close closes any resources that this client opened. +func (c *Client) Close() error { + return c.httpLogger.Close() +} + +func (c *Client) sshResolveWithRetries(e Endpoint, method string) (*sshAuthResponse, error) { + var sshRes sshAuthResponse + var err error + + requests := tools.MaxInt(0, c.sshTries) + 1 + for i := 0; i < requests; i++ { + sshRes, err = c.SSH.Resolve(e, method) + if err == nil { + return &sshRes, nil + } + + tracerx.Printf( + "ssh: %s failed, error: %s, message: %s (try: %d/%d)", + e.SshUserAndHost, err.Error(), sshRes.Message, i, + requests, + ) + } + + if len(sshRes.Message) > 0 { + return nil, errors.Wrap(err, sshRes.Message) + } + return nil, err +} + +func (c *Client) ExtraHeadersFor(req *http.Request) http.Header { + extraHeaders := c.extraHeaders(req.URL) + if len(extraHeaders) == 0 { + return req.Header + } + + copy := make(http.Header, len(req.Header)) + for k, vs := range req.Header { + copy[k] = vs + } + + for k, vs := range extraHeaders { + for _, v := range vs { + copy[k] = append(copy[k], v) + } + } + return copy +} + +func (c *Client) extraHeaders(u *url.URL) map[string][]string { + hdrs := c.uc.GetAll("http", u.String(), "extraHeader") + m := make(map[string][]string, len(hdrs)) + + for _, hdr := range hdrs { + parts := strings.SplitN(hdr, ":", 2) + if len(parts) < 2 { + continue + } + + k, v := parts[0], strings.TrimSpace(parts[1]) + // If header keys are given in non-canonicalized form (e.g., + // "AUTHORIZATION" as opposed to "Authorization") they will not + // be returned in calls to net/http.Header.Get(). + // + // So, we avoid this problem by first canonicalizing header keys + // for extra headers. + k = textproto.CanonicalMIMEHeaderKey(k) + + m[k] = append(m[k], v) + } + return m +} + +func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Response, error) { + tracedReq, err := c.traceRequest(req) + if err != nil { + return nil, err + } + + var retries int + if n, ok := Retries(req); ok { + retries = n + } else { + retries = defaultRequestRetries + } + + var res *http.Response + + requests := tools.MaxInt(0, retries) + 1 + for i := 0; i < requests; i++ { + res, err = cli.Do(req) + if err == nil { + break + } + + if seek, ok := req.Body.(io.Seeker); ok { + seek.Seek(0, io.SeekStart) + } + + c.traceResponse(req, tracedReq, nil) + } + + if err != nil { + c.traceResponse(req, tracedReq, nil) + return nil, err + } + + if res == nil { + return nil, nil + } + + c.traceResponse(req, tracedReq, res) + + if res.StatusCode != 301 && + res.StatusCode != 302 && + res.StatusCode != 303 && + res.StatusCode != 307 && + res.StatusCode != 308 { + + // Above are the list of 3xx status codes that we know + // how to handle below. If the status code contained in + // the HTTP response was none of them, return the (res, + // err) tuple as-is, otherwise handle the redirect. + return res, err + } + + redirectTo := res.Header.Get("Location") + locurl, err := url.Parse(redirectTo) + if err == nil && !locurl.IsAbs() { + locurl = req.URL.ResolveReference(locurl) + redirectTo = locurl.String() + } + + via = append(via, req) + if len(via) >= 3 { + return res, errors.New("too many redirects") + } + + redirectedReq, err := newRequestForRetry(req, redirectTo) + if err != nil { + return res, err + } + + if len(req.Header.Get("Authorization")) > 0 { + // If the original request was authenticated (noted by the + // presence of the Authorization header), then recur through + // doWithAuth, retaining the requests via but only after + // authenticating the redirected request. + return c.doWithAuth(remote, redirectedReq, via) + } + return c.doWithRedirects(cli, redirectedReq, remote, via) +} + +func (c *Client) httpClient(host string) *http.Client { + c.clientMu.Lock() + defer c.clientMu.Unlock() + + if c.gitEnv == nil { + c.gitEnv = make(testEnv) + } + + if c.osEnv == nil { + c.osEnv = make(testEnv) + } + + if c.hostClients == nil { + c.hostClients = make(map[string]*http.Client) + } + + if client, ok := c.hostClients[host]; ok { + return client + } + + concurrentTransfers := c.ConcurrentTransfers + if concurrentTransfers < 1 { + concurrentTransfers = 8 + } + + dialtime := c.DialTimeout + if dialtime < 1 { + dialtime = 30 + } + + keepalivetime := c.KeepaliveTimeout + if keepalivetime < 1 { + keepalivetime = 1800 + } + + tlstime := c.TLSTimeout + if tlstime < 1 { + tlstime = 30 + } + + tr := &http.Transport{ + Proxy: proxyFromClient(c), + TLSHandshakeTimeout: time.Duration(tlstime) * time.Second, + MaxIdleConnsPerHost: concurrentTransfers, + } + + activityTimeout := 30 + if v, ok := c.uc.Get("lfs", fmt.Sprintf("https://%v", host), "activitytimeout"); ok { + if i, err := strconv.Atoi(v); err == nil { + activityTimeout = i + } else { + activityTimeout = 0 + } + } + + dialer := &net.Dialer{ + Timeout: time.Duration(dialtime) * time.Second, + KeepAlive: time.Duration(keepalivetime) * time.Second, + DualStack: true, + } + + if activityTimeout > 0 { + activityDuration := time.Duration(activityTimeout) * time.Second + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + c, err := dialer.DialContext(ctx, network, addr) + if c == nil { + return c, err + } + if tc, ok := c.(*net.TCPConn); ok { + tc.SetKeepAlive(true) + tc.SetKeepAlivePeriod(dialer.KeepAlive) + } + return &deadlineConn{Timeout: activityDuration, Conn: c}, err + } + } else { + tr.DialContext = dialer.DialContext + } + + tr.TLSClientConfig = &tls.Config{} + + if isClientCertEnabledForHost(c, host) { + tracerx.Printf("http: client cert for %s", host) + tr.TLSClientConfig.Certificates = []tls.Certificate{getClientCertForHost(c, host)} + tr.TLSClientConfig.BuildNameToCertificate() + } + + if isCertVerificationDisabledForHost(c, host) { + tr.TLSClientConfig.InsecureSkipVerify = true + } else { + tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host) + } + + httpClient := &http.Client{ + Transport: tr, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } + + c.hostClients[host] = httpClient + if c.VerboseOut == nil { + c.VerboseOut = os.Stderr + } + + return httpClient +} + +func (c *Client) CurrentUser() (string, string) { + userName, _ := c.gitEnv.Get("user.name") + userEmail, _ := c.gitEnv.Get("user.email") + return userName, userEmail +} + +func newRequestForRetry(req *http.Request, location string) (*http.Request, error) { + newReq, err := http.NewRequest(req.Method, location, nil) + if err != nil { + return nil, err + } + + if req.URL.Scheme == "https" && newReq.URL.Scheme == "http" { + return nil, errors.New("lfsapi/client: refusing insecure redirect, https->http") + } + + sameHost := req.URL.Host == newReq.URL.Host + for key := range req.Header { + if key == "Authorization" { + if !sameHost { + continue + } + } + newReq.Header.Set(key, req.Header.Get(key)) + } + + oldestURL := strings.SplitN(req.URL.String(), "?", 2)[0] + newURL := strings.SplitN(newReq.URL.String(), "?", 2)[0] + tracerx.Printf("api: redirect %s %s to %s", req.Method, oldestURL, newURL) + + // This body will have already been rewound from a call to + // lfsapi.Client.traceRequest(). + newReq.Body = req.Body + newReq.ContentLength = req.ContentLength + + // Copy the request's context.Context, if any. + newReq = newReq.WithContext(req.Context()) + + return newReq, nil +} + +type deadlineConn struct { + Timeout time.Duration + net.Conn +} + +func (c *deadlineConn) Read(b []byte) (int, error) { + if err := c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { + return 0, err + } + return c.Conn.Read(b) +} + +func (c *deadlineConn) Write(b []byte) (int, error) { + if err := c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { + return 0, err + } + + return c.Conn.Write(b) +} + +func init() { + UserAgent = config.VersionDesc +} + +type testEnv map[string]string + +func (e testEnv) Get(key string) (v string, ok bool) { + v, ok = e[key] + return +} + +func (e testEnv) GetAll(key string) []string { + if v, ok := e.Get(key); ok { + return []string{v} + } + return make([]string, 0) +} + +func (e testEnv) Int(key string, def int) int { + s, _ := e.Get(key) + return config.Int(s, def) +} + +func (e testEnv) Bool(key string, def bool) bool { + s, _ := e.Get(key) + return config.Bool(s, def) +} + +func (e testEnv) All() map[string][]string { + m := make(map[string][]string) + for k, _ := range e { + m[k] = e.GetAll(k) + } + return m +} diff --git a/lfsapi/client_test.go b/lfshttp/client_test.go similarity index 77% rename from lfsapi/client_test.go rename to lfshttp/client_test.go index f5500f80..e51e32b9 100644 --- a/lfsapi/client_test.go +++ b/lfshttp/client_test.go @@ -1,13 +1,11 @@ -package lfsapi +package lfshttp import ( - "encoding/base64" "encoding/json" "fmt" "net" "net/http" "net/http/httptest" - "strings" "sync/atomic" "testing" @@ -177,80 +175,6 @@ func TestClientRedirect(t *testing.T) { assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http") } -func TestClientRedirectReauthenticate(t *testing.T) { - var srv1, srv2 *httptest.Server - var called1, called2 uint32 - var creds1, creds2 Creds - - srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddUint32(&called1, 1) - - if hdr := r.Header.Get("Authorization"); len(hdr) > 0 { - parts := strings.SplitN(hdr, " ", 2) - typ, b64 := parts[0], parts[1] - - auth, err := base64.URLEncoding.DecodeString(b64) - assert.Nil(t, err) - assert.Equal(t, "Basic", typ) - assert.Equal(t, "user1:pass1", string(auth)) - - http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently) - return - } - w.WriteHeader(http.StatusUnauthorized) - })) - - srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddUint32(&called2, 1) - - parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) - typ, b64 := parts[0], parts[1] - - auth, err := base64.URLEncoding.DecodeString(b64) - assert.Nil(t, err) - assert.Equal(t, "Basic", typ) - assert.Equal(t, "user2:pass2", string(auth)) - })) - - // Change the URL of srv2 to make it appears as if it is a different - // host. - srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1) - - creds1 = Creds(map[string]string{ - "protocol": "http", - "host": strings.TrimPrefix(srv1.URL, "http://"), - - "username": "user1", - "password": "pass1", - }) - creds2 = Creds(map[string]string{ - "protocol": "http", - "host": strings.TrimPrefix(srv2.URL, "http://"), - - "username": "user2", - "password": "pass2", - }) - - defer srv1.Close() - defer srv2.Close() - - c, err := NewClient(NewContext(nil, nil, nil)) - creds := newCredentialCacher() - creds.Approve(creds1) - creds.Approve(creds2) - c.Credentials = creds - - req, err := http.NewRequest("GET", srv1.URL, nil) - require.Nil(t, err) - - _, err = c.DoWithAuth("", req) - assert.Nil(t, err) - - // called1 is 2 since LFS tries an unauthenticated request first - assert.EqualValues(t, 2, called1) - assert.EqualValues(t, 1, called2) -} - func TestNewClient(t *testing.T) { c, err := NewClient(NewContext(nil, nil, map[string]string{ "lfs.dialtimeout": "151", @@ -323,12 +247,10 @@ func TestNewRequest(t *testing.T) { } for _, test := range tests { - c, err := NewClient(NewContext(nil, nil, map[string]string{ - "lfs.url": test[0], - })) + c, err := NewClient(NewContext(nil, nil, nil)) require.Nil(t, err) - req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), test[1], nil) + req, err := c.NewRequest("POST", Endpoint{Url: test[0]}, test[1], nil) require.Nil(t, err) assert.Equal(t, "POST", req.Method) assert.Equal(t, test[2], req.URL.String(), fmt.Sprintf("endpoint: %s, suffix: %s, expected: %s", test[0], test[1], test[2])) @@ -336,15 +258,13 @@ func TestNewRequest(t *testing.T) { } func TestNewRequestWithBody(t *testing.T) { - c, err := NewClient(NewContext(nil, nil, map[string]string{ - "lfs.url": "https://example.com", - })) + c, err := NewClient(NewContext(nil, nil, nil)) require.Nil(t, err) body := struct { Test string }{Test: "test"} - req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), "body", body) + req, err := c.NewRequest("POST", Endpoint{Url: "https://example.com"}, "body", body) require.Nil(t, err) assert.NotNil(t, req.Body) diff --git a/lfsapi/endpoint.go b/lfshttp/endpoint.go similarity index 82% rename from lfsapi/endpoint.go rename to lfshttp/endpoint.go index e961a09b..4b96c9d5 100644 --- a/lfsapi/endpoint.go +++ b/lfshttp/endpoint.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "fmt" @@ -31,39 +31,8 @@ func endpointOperation(e Endpoint, method string) string { } } -// endpointFromBareSshUrl constructs a new endpoint from a bare SSH URL: -// -// user@host.com:path/to/repo.git or -// [user@host.com:port]:path/to/repo.git -// -func endpointFromBareSshUrl(rawurl string) Endpoint { - parts := strings.Split(rawurl, ":") - partsLen := len(parts) - if partsLen < 2 { - return Endpoint{Url: rawurl} - } - - // Treat presence of ':' as a bare URL - var newPath string - if len(parts) > 2 { // port included; really should only ever be 3 parts - // Correctly handle [host:port]:path URLs - parts[0] = strings.TrimPrefix(parts[0], "[") - parts[1] = strings.TrimSuffix(parts[1], "]") - newPath = fmt.Sprintf("%v:%v", parts[0], strings.Join(parts[1:], "/")) - } else { - newPath = strings.Join(parts, "/") - } - newrawurl := fmt.Sprintf("ssh://%v", newPath) - newu, err := url.Parse(newrawurl) - if err != nil { - return Endpoint{Url: UrlUnknown} - } - - return endpointFromSshUrl(newu) -} - -// endpointFromSshUrl constructs a new endpoint from an ssh:// URL -func endpointFromSshUrl(u *url.URL) Endpoint { +// EndpointFromSshUrl constructs a new endpoint from an ssh:// URL +func EndpointFromSshUrl(u *url.URL) Endpoint { var endpoint Endpoint // Pull out port now, we need it separately for SSH regex := regexp.MustCompile(`^([^\:]+)(?:\:(\d+))?$`) @@ -100,18 +69,44 @@ func endpointFromSshUrl(u *url.URL) Endpoint { return endpoint } +// EndpointFromBareSshUrl constructs a new endpoint from a bare SSH URL: +// +// user@host.com:path/to/repo.git or +// [user@host.com:port]:path/to/repo.git +// +func EndpointFromBareSshUrl(rawurl string) Endpoint { + parts := strings.Split(rawurl, ":") + partsLen := len(parts) + if partsLen < 2 { + return Endpoint{Url: rawurl} + } + + // Treat presence of ':' as a bare URL + var newPath string + if len(parts) > 2 { // port included; really should only ever be 3 parts + // Correctly handle [host:port]:path URLs + parts[0] = strings.TrimPrefix(parts[0], "[") + parts[1] = strings.TrimSuffix(parts[1], "]") + newPath = fmt.Sprintf("%v:%v", parts[0], strings.Join(parts[1:], "/")) + } else { + newPath = strings.Join(parts, "/") + } + newrawurl := fmt.Sprintf("ssh://%v", newPath) + newu, err := url.Parse(newrawurl) + if err != nil { + return Endpoint{Url: UrlUnknown} + } + + return EndpointFromSshUrl(newu) +} + // Construct a new endpoint from a HTTP URL -func endpointFromHttpUrl(u *url.URL) Endpoint { +func EndpointFromHttpUrl(u *url.URL) Endpoint { // just pass this straight through return Endpoint{Url: u.String()} } -func endpointFromGitUrl(u *url.URL, e *endpointGitFinder) Endpoint { - u.Scheme = e.gitProtocol - return Endpoint{Url: u.String()} -} - -func endpointFromLocalPath(path string) Endpoint { +func EndpointFromLocalPath(path string) Endpoint { if !strings.HasSuffix(path, ".git") { path = fmt.Sprintf("%s/.git", path) } diff --git a/lfsapi/errors.go b/lfshttp/errors.go similarity index 99% rename from lfsapi/errors.go rename to lfshttp/errors.go index 916633b1..cedb03b5 100644 --- a/lfsapi/errors.go +++ b/lfshttp/errors.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "fmt" diff --git a/lfshttp/lfshttp.go b/lfshttp/lfshttp.go new file mode 100644 index 00000000..cf241212 --- /dev/null +++ b/lfshttp/lfshttp.go @@ -0,0 +1,91 @@ +package lfshttp + +import ( + "encoding/json" + "fmt" + "net/http" + "regexp" + + "github.com/git-lfs/git-lfs/config" + "github.com/git-lfs/git-lfs/errors" + "github.com/git-lfs/git-lfs/git" +) + +var ( + lfsMediaTypeRE = regexp.MustCompile(`\Aapplication/vnd\.git\-lfs\+json(;|\z)`) + jsonMediaTypeRE = regexp.MustCompile(`\Aapplication/json(;|\z)`) +) + +type Context interface { + GitConfig() *git.Configuration + OSEnv() config.Environment + GitEnv() config.Environment +} + +func NewContext(gitConf *git.Configuration, osEnv, gitEnv map[string]string) Context { + c := &testContext{gitConfig: gitConf} + if c.gitConfig == nil { + c.gitConfig = git.NewConfig("", "") + } + if osEnv != nil { + c.osEnv = testEnv(osEnv) + } else { + c.osEnv = make(testEnv) + } + + if gitEnv != nil { + c.gitEnv = testEnv(gitEnv) + } else { + c.gitEnv = make(testEnv) + } + return c +} + +type testContext struct { + gitConfig *git.Configuration + osEnv config.Environment + gitEnv config.Environment +} + +func (c *testContext) GitConfig() *git.Configuration { + return c.gitConfig +} + +func (c *testContext) OSEnv() config.Environment { + return c.osEnv +} + +func (c *testContext) GitEnv() config.Environment { + return c.gitEnv +} + +func IsDecodeTypeError(err error) bool { + _, ok := err.(*decodeTypeError) + return ok +} + +type decodeTypeError struct { + Type string +} + +func (e *decodeTypeError) TypeError() {} + +func (e *decodeTypeError) Error() string { + return fmt.Sprintf("Expected json type, got: %q", e.Type) +} + +func DecodeJSON(res *http.Response, obj interface{}) error { + ctype := res.Header.Get("Content-Type") + if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) { + return &decodeTypeError{Type: ctype} + } + + err := json.NewDecoder(res.Body).Decode(obj) + res.Body.Close() + + if err != nil { + return errors.Wrapf(err, "Unable to parse HTTP response for %s %s", res.Request.Method, res.Request.URL) + } + + return nil +} diff --git a/lfsapi/proxy.go b/lfshttp/proxy.go similarity index 99% rename from lfsapi/proxy.go rename to lfshttp/proxy.go index c7410774..e697269f 100644 --- a/lfsapi/proxy.go +++ b/lfshttp/proxy.go @@ -1,13 +1,12 @@ -package lfsapi +package lfshttp import ( + "fmt" "net/http" "net/url" "strings" "github.com/git-lfs/git-lfs/config" - - "fmt" ) // Logic is copied, with small changes, from "net/http".ProxyFromEnvironment in the go std lib. diff --git a/lfsapi/proxy_test.go b/lfshttp/proxy_test.go similarity index 99% rename from lfsapi/proxy_test.go rename to lfshttp/proxy_test.go index ad98419b..dac61168 100644 --- a/lfsapi/proxy_test.go +++ b/lfshttp/proxy_test.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "net/http" diff --git a/lfsapi/retries.go b/lfshttp/retries.go similarity index 98% rename from lfsapi/retries.go rename to lfshttp/retries.go index df4f2503..8ee83178 100644 --- a/lfsapi/retries.go +++ b/lfshttp/retries.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "context" diff --git a/lfsapi/retries_test.go b/lfshttp/retries_test.go similarity index 99% rename from lfsapi/retries_test.go rename to lfshttp/retries_test.go index 129fb857..fd3cea8a 100644 --- a/lfsapi/retries_test.go +++ b/lfshttp/retries_test.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "encoding/json" diff --git a/lfsapi/ssh.go b/lfshttp/ssh.go similarity index 99% rename from lfsapi/ssh.go rename to lfshttp/ssh.go index 35495d7d..5fe718a3 100644 --- a/lfsapi/ssh.go +++ b/lfshttp/ssh.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "bytes" diff --git a/lfsapi/ssh_test.go b/lfshttp/ssh_test.go similarity index 93% rename from lfsapi/ssh_test.go rename to lfshttp/ssh_test.go index 86576733..18a0ee0a 100644 --- a/lfsapi/ssh_test.go +++ b/lfshttp/ssh_test.go @@ -1,12 +1,12 @@ -package lfsapi +package lfshttp import ( - "errors" "net/url" "path/filepath" "testing" "time" + "github.com/git-lfs/git-lfs/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -216,7 +216,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) { cli, err := NewClient(nil) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPath = "user/repo" @@ -262,7 +262,7 @@ func TestSSHGetExeAndArgsSsh(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -277,7 +277,7 @@ func TestSSHGetExeAndArgsSshCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" @@ -295,7 +295,7 @@ func TestSSHGetExeAndArgsPlink(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -312,7 +312,7 @@ func TestSSHGetExeAndArgsPlinkCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" @@ -330,7 +330,7 @@ func TestSSHGetExeAndArgsTortoisePlink(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -347,7 +347,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" @@ -363,7 +363,7 @@ func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -377,7 +377,7 @@ func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -391,7 +391,7 @@ func TestSSHGetExeAndArgsSshCommandArgsWithMixedQuotes(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -405,7 +405,7 @@ func TestSSHGetExeAndArgsSshCommandCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" @@ -423,7 +423,7 @@ func TestSSHGetLFSExeAndArgsWithCustomSSH(t *testing.T) { u, err := url.Parse("ssh://git@host.com:12345/repo") require.Nil(t, err) - e := endpointFromSshUrl(u) + e := EndpointFromSshUrl(u) t.Logf("ENDPOINT: %+v", e) assert.Equal(t, "12345", e.SshPort) assert.Equal(t, "git@host.com", e.SshUserAndHost) @@ -442,7 +442,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHost(t *testing.T) { require.Nil(t, err) assert.Equal(t, "-oProxyCommand=gnome-calculator", u.Host) - e := endpointFromSshUrl(u) + e := EndpointFromSshUrl(u) t.Logf("ENDPOINT: %+v", e) assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "repo", e.SshPath) @@ -462,7 +462,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHostWithCustomSSH(t *testing.T) { require.Nil(t, err) assert.Equal(t, "--oProxyCommand=gnome-calculator", u.Host) - e := endpointFromSshUrl(u) + e := EndpointFromSshUrl(u) t.Logf("ENDPOINT: %+v", e) assert.Equal(t, "--oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "repo", e.SshPath) @@ -480,7 +480,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsHost(t *testing.T) { require.Nil(t, err) assert.Equal(t, "-oProxyCommand=gnome-calculator", u.Host) - e := endpointFromSshUrl(u) + e := EndpointFromSshUrl(u) t.Logf("ENDPOINT: %+v", e) assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "", e.SshPath) @@ -499,7 +499,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsPath(t *testing.T) { require.Nil(t, err) assert.Equal(t, "git-host.com", u.Host) - e := endpointFromSshUrl(u) + e := EndpointFromSshUrl(u) t.Logf("ENDPOINT: %+v", e) assert.Equal(t, "git@git-host.com", e.SshUserAndHost) assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshPath) @@ -511,22 +511,22 @@ func TestSSHGetExeAndArgsInvalidOptionsAsPath(t *testing.T) { } func TestParseBareSSHUrl(t *testing.T) { - e := endpointFromBareSshUrl("git@git-host.com:repo.git") + e := EndpointFromBareSshUrl("git@git-host.com:repo.git") t.Logf("endpoint: %+v", e) assert.Equal(t, "git@git-host.com", e.SshUserAndHost) assert.Equal(t, "repo.git", e.SshPath) - e = endpointFromBareSshUrl("git@git-host.com/should-be-a-colon.git") + e = EndpointFromBareSshUrl("git@git-host.com/should-be-a-colon.git") t.Logf("endpoint: %+v", e) assert.Equal(t, "", e.SshUserAndHost) assert.Equal(t, "", e.SshPath) - e = endpointFromBareSshUrl("-oProxyCommand=gnome-calculator") + e = EndpointFromBareSshUrl("-oProxyCommand=gnome-calculator") t.Logf("endpoint: %+v", e) assert.Equal(t, "", e.SshUserAndHost) assert.Equal(t, "", e.SshPath) - e = endpointFromBareSshUrl("git@git-host.com:-oProxyCommand=gnome-calculator") + e = EndpointFromBareSshUrl("git@git-host.com:-oProxyCommand=gnome-calculator") t.Logf("endpoint: %+v", e) assert.Equal(t, "git@git-host.com", e.SshUserAndHost) assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshPath) @@ -540,7 +540,7 @@ func TestSSHGetExeAndArgsPlinkCommand(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -556,7 +556,7 @@ func TestSSHGetExeAndArgsPlinkCommandCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" @@ -573,7 +573,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommand(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) @@ -589,7 +589,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommandCustomPort(t *testing.T) { }, nil)) require.Nil(t, err) - endpoint := cli.Endpoints.Endpoint("download", "") + endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" diff --git a/lfsapi/stats.go b/lfshttp/stats.go similarity index 99% rename from lfsapi/stats.go rename to lfshttp/stats.go index 9d0b42d8..34afa1b7 100644 --- a/lfsapi/stats.go +++ b/lfshttp/stats.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "context" diff --git a/lfsapi/stats_test.go b/lfshttp/stats_test.go similarity index 99% rename from lfsapi/stats_test.go rename to lfshttp/stats_test.go index e8aca1eb..f1d9fd5f 100644 --- a/lfsapi/stats_test.go +++ b/lfshttp/stats_test.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "bytes" diff --git a/lfsapi/verbose.go b/lfshttp/verbose.go similarity index 99% rename from lfsapi/verbose.go rename to lfshttp/verbose.go index 2dcab78e..9ab59a69 100644 --- a/lfsapi/verbose.go +++ b/lfshttp/verbose.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "bufio" diff --git a/lfsapi/verbose_test.go b/lfshttp/verbose_test.go similarity index 99% rename from lfsapi/verbose_test.go rename to lfshttp/verbose_test.go index f0ed3129..c44b745a 100644 --- a/lfsapi/verbose_test.go +++ b/lfshttp/verbose_test.go @@ -1,4 +1,4 @@ -package lfsapi +package lfshttp import ( "bytes" diff --git a/locking/api.go b/locking/api.go index a0fac45a..80cb7c61 100644 --- a/locking/api.go +++ b/locking/api.go @@ -7,6 +7,7 @@ import ( "github.com/git-lfs/git-lfs/git" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" ) type lockClient struct { @@ -56,14 +57,14 @@ func (c *lockClient) Lock(remote string, lockReq *lockRequest) (*lockResponse, * return nil, nil, err } - req = c.LogRequest(req, "lfs.locks.lock") + req = c.Client.LogRequest(req, "lfs.locks.lock") res, err := c.DoWithAuth(remote, req) if err != nil { return nil, res, err } lockRes := &lockResponse{} - return lockRes, res, lfsapi.DecodeJSON(res, lockRes) + return lockRes, res, lfshttp.DecodeJSON(res, lockRes) } // UnlockRequest encapsulates the data sent in an API request to remove a lock. @@ -101,14 +102,14 @@ func (c *lockClient) Unlock(ref *git.Ref, remote, id string, force bool) (*unloc return nil, nil, err } - req = c.LogRequest(req, "lfs.locks.unlock") + req = c.Client.LogRequest(req, "lfs.locks.unlock") res, err := c.DoWithAuth(remote, req) if err != nil { return nil, res, err } unlockRes := &unlockResponse{} - err = lfsapi.DecodeJSON(res, unlockRes) + err = lfshttp.DecodeJSON(res, unlockRes) return unlockRes, res, err } @@ -191,7 +192,7 @@ func (c *lockClient) Search(remote string, searchReq *lockSearchRequest) (*lockL } req.URL.RawQuery = q.Encode() - req = c.LogRequest(req, "lfs.locks.search") + req = c.Client.LogRequest(req, "lfs.locks.search") res, err := c.DoWithAuth(remote, req) if err != nil { return nil, res, err @@ -199,7 +200,7 @@ func (c *lockClient) Search(remote string, searchReq *lockSearchRequest) (*lockL locks := &lockList{} if res.StatusCode == http.StatusOK { - err = lfsapi.DecodeJSON(res, locks) + err = lfshttp.DecodeJSON(res, locks) } return locks, res, err @@ -251,7 +252,7 @@ func (c *lockClient) SearchVerifiable(remote string, vreq *lockVerifiableRequest return nil, nil, err } - req = c.LogRequest(req, "lfs.locks.verify") + req = c.Client.LogRequest(req, "lfs.locks.verify") res, err := c.DoWithAuth(remote, req) if err != nil { return nil, res, err @@ -259,7 +260,7 @@ func (c *lockClient) SearchVerifiable(remote string, vreq *lockVerifiableRequest locks := &lockVerifiableList{} if res.StatusCode == http.StatusOK { - err = lfsapi.DecodeJSON(res, locks) + err = lfshttp.DecodeJSON(res, locks) } return locks, res, err diff --git a/locking/api_test.go b/locking/api_test.go index 3a2fbd24..8ae32f58 100644 --- a/locking/api_test.go +++ b/locking/api_test.go @@ -12,6 +12,7 @@ import ( "github.com/git-lfs/git-lfs/git" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xeipuuv/gojsonschema" @@ -28,8 +29,8 @@ func TestAPILock(t *testing.T) { } assert.Equal(t, "POST", r.Method) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Accept")) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Content-Type")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Accept")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Content-Type")) assert.Equal(t, "53", r.Header.Get("Content-Length")) reqLoader, body := gojsonschema.NewReaderLoader(r.Body) @@ -54,7 +55,7 @@ func TestAPILock(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) @@ -78,8 +79,8 @@ func TestAPIUnlock(t *testing.T) { } assert.Equal(t, "POST", r.Method) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Accept")) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Content-Type")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Accept")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Content-Type")) reqLoader, body := gojsonschema.NewReaderLoader(r.Body) unlockReq := &unlockRequest{} @@ -102,7 +103,7 @@ func TestAPIUnlock(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) @@ -129,7 +130,7 @@ func TestAPISearch(t *testing.T) { } assert.Equal(t, "GET", r.Method) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Accept")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Accept")) assert.Equal(t, "", r.Header.Get("Content-Type")) q := r.URL.Query() @@ -150,7 +151,7 @@ func TestAPISearch(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) @@ -180,8 +181,8 @@ func TestAPIVerifiableLocks(t *testing.T) { } assert.Equal(t, "POST", r.Method) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Accept")) - assert.Equal(t, lfsapi.MediaType, r.Header.Get("Content-Type")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Accept")) + assert.Equal(t, lfshttp.MediaType, r.Header.Get("Content-Type")) body := lockVerifiableRequest{} if assert.Nil(t, json.NewDecoder(r.Body).Decode(&body)) { @@ -205,7 +206,7 @@ func TestAPIVerifiableLocks(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) diff --git a/locking/locks_test.go b/locking/locks_test.go index 6cf55ad6..128afffe 100644 --- a/locking/locks_test.go +++ b/locking/locks_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -48,7 +49,7 @@ func TestRefreshCache(t *testing.T) { srv.Close() }() - lfsclient, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + lfsclient, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", "user.name": "Fred", "user.email": "fred@bloggs.com", @@ -120,7 +121,7 @@ func TestGetVerifiableLocks(t *testing.T) { defer srv.Close() - lfsclient, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + lfsclient, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", "user.name": "Fred", "user.email": "fred@bloggs.com", diff --git a/tq/api.go b/tq/api.go index cde86052..57fc688a 100644 --- a/tq/api.go +++ b/tq/api.go @@ -6,6 +6,7 @@ import ( "github.com/git-lfs/git-lfs/errors" "github.com/git-lfs/git-lfs/git" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/rubyist/tracerx" ) @@ -28,7 +29,7 @@ type batchRequest struct { type BatchResponse struct { Objects []*Transfer `json:"objects"` TransferAdapterName string `json:"transfer"` - endpoint lfsapi.Endpoint + endpoint lfshttp.Endpoint } func Batch(m *Manifest, dir Direction, remote string, remoteRef *git.Ref, objects []*Transfer) (*BatchResponse, error) { @@ -64,19 +65,19 @@ func (c *tqClient) Batch(remote string, bReq *batchRequest) (*BatchResponse, err tracerx.Printf("api: batch %d files", len(bReq.Objects)) - req = c.LogRequest(req, "lfs.batch") - res, err := c.DoWithAuth(remote, lfsapi.WithRetries(req, c.MaxRetries)) + req = c.Client.LogRequest(req, "lfs.batch") + res, err := c.DoWithAuth(remote, lfshttp.WithRetries(req, c.MaxRetries)) if err != nil { tracerx.Printf("api error: %s", err) return nil, errors.Wrap(err, "batch response") } - if err := lfsapi.DecodeJSON(res, bRes); err != nil { + if err := lfshttp.DecodeJSON(res, bRes); err != nil { return bRes, errors.Wrap(err, "batch response") } if res.StatusCode != 200 { - return nil, lfsapi.NewStatusCodeError(res) + return nil, lfshttp.NewStatusCodeError(res) } for _, obj := range bRes.Objects { diff --git a/tq/api_test.go b/tq/api_test.go index 6e20bd5b..62bd1f00 100644 --- a/tq/api_test.go +++ b/tq/api_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xeipuuv/gojsonschema" @@ -54,7 +55,7 @@ func TestAPIBatch(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) @@ -110,7 +111,7 @@ func TestAPIBatchOnlyBasic(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": srv.URL + "/api", })) require.Nil(t, err) diff --git a/tq/custom_test.go b/tq/custom_test.go index 00f88279..b94c41ff 100644 --- a/tq/custom_test.go +++ b/tq/custom_test.go @@ -4,13 +4,14 @@ import ( "testing" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCustomTransferBasicConfig(t *testing.T) { path := "/path/to/binary" - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.customtransfer.testsimple.path": path, })) require.Nil(t, err) @@ -36,7 +37,7 @@ func TestCustomTransferBasicConfig(t *testing.T) { func TestCustomTransferDownloadConfig(t *testing.T) { path := "/path/to/binary" args := "-c 1 --whatever" - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.customtransfer.testdownload.path": path, "lfs.customtransfer.testdownload.args": args, "lfs.customtransfer.testdownload.concurrent": "false", @@ -62,7 +63,7 @@ func TestCustomTransferDownloadConfig(t *testing.T) { func TestCustomTransferUploadConfig(t *testing.T) { path := "/path/to/binary" args := "-c 1 --whatever" - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.customtransfer.testupload.path": path, "lfs.customtransfer.testupload.args": args, "lfs.customtransfer.testupload.concurrent": "false", @@ -88,7 +89,7 @@ func TestCustomTransferUploadConfig(t *testing.T) { func TestCustomTransferBothConfig(t *testing.T) { path := "/path/to/binary" args := "-c 1 --whatever --yeah" - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.customtransfer.testboth.path": path, "lfs.customtransfer.testboth.args": args, "lfs.customtransfer.testboth.concurrent": "yes", diff --git a/tq/manifest_test.go b/tq/manifest_test.go index 85056e8e..680feb29 100644 --- a/tq/manifest_test.go +++ b/tq/manifest_test.go @@ -4,12 +4,13 @@ import ( "testing" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestManifestIsConfigurable(t *testing.T) { - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.transfer.maxretries": "3", })) require.Nil(t, err) @@ -19,7 +20,7 @@ func TestManifestIsConfigurable(t *testing.T) { } func TestManifestChecksNTLM(t *testing.T) { - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.url": "http://foo", "lfs.http://foo.access": "ntlm", "lfs.concurrenttransfers": "3", @@ -31,7 +32,7 @@ func TestManifestChecksNTLM(t *testing.T) { } func TestManifestClampsValidValues(t *testing.T) { - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.transfer.maxretries": "-1", })) require.Nil(t, err) @@ -41,7 +42,7 @@ func TestManifestClampsValidValues(t *testing.T) { } func TestManifestIgnoresNonInts(t *testing.T) { - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.transfer.maxretries": "not_an_int", })) require.Nil(t, err) diff --git a/tq/transfer_queue.go b/tq/transfer_queue.go index eba9f677..e53f1c50 100644 --- a/tq/transfer_queue.go +++ b/tq/transfer_queue.go @@ -9,6 +9,7 @@ import ( "github.com/git-lfs/git-lfs/errors" "github.com/git-lfs/git-lfs/git" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/git-lfs/git-lfs/tools" "github.com/rubyist/tracerx" ) @@ -537,7 +538,7 @@ func (q *TransferQueue) makeBatch() batch { return make(batch, 0, q.batchSize) } // closed. // // addToAdapter returns immediately, and does not block. -func (q *TransferQueue) addToAdapter(e lfsapi.Endpoint, pending []*Transfer) <-chan *objectTuple { +func (q *TransferQueue) addToAdapter(e lfshttp.Endpoint, pending []*Transfer) <-chan *objectTuple { retries := make(chan *objectTuple, len(pending)) if err := q.ensureAdapterBegun(e); err != nil { @@ -729,7 +730,7 @@ func (q *TransferQueue) Skip(size int64) { q.meter.Skip(size) } -func (q *TransferQueue) ensureAdapterBegun(e lfsapi.Endpoint) error { +func (q *TransferQueue) ensureAdapterBegun(e lfshttp.Endpoint) error { q.adapterInitMutex.Lock() defer q.adapterInitMutex.Unlock() @@ -760,7 +761,7 @@ func (q *TransferQueue) ensureAdapterBegun(e lfsapi.Endpoint) error { return nil } -func (q *TransferQueue) toAdapterCfg(e lfsapi.Endpoint) AdapterConfig { +func (q *TransferQueue) toAdapterCfg(e lfshttp.Endpoint) AdapterConfig { apiClient := q.manifest.APIClient() concurrency := q.manifest.ConcurrentTransfers() if apiClient.Endpoints.AccessFor(e.Url) == lfsapi.NTLMAccess { diff --git a/tq/transfer_test.go b/tq/transfer_test.go index 48ba763b..6e488397 100644 --- a/tq/transfer_test.go +++ b/tq/transfer_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -118,7 +119,7 @@ func testAdapterRegAndOverride(t *testing.T) { } func testAdapterRegButBasicOnly(t *testing.T) { - cli, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + cli, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.basictransfersonly": "yes", })) require.Nil(t, err) diff --git a/tq/verify_test.go b/tq/verify_test.go index 41ef8fb0..0475c668 100644 --- a/tq/verify_test.go +++ b/tq/verify_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/git-lfs/git-lfs/lfsapi" + "github.com/git-lfs/git-lfs/lfshttp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -44,7 +45,7 @@ func TestVerifySuccess(t *testing.T) { })) defer srv.Close() - c, err := lfsapi.NewClient(lfsapi.NewContext(nil, nil, map[string]string{ + c, err := lfsapi.NewClient(lfshttp.NewContext(nil, nil, map[string]string{ "lfs.transfer.maxverifies": "1", })) require.Nil(t, err) From e6c4c6386ba8cce6fb3cc31f105d0eefa43b444e Mon Sep 17 00:00:00 2001 From: Preben Ingvaldsen Date: Tue, 11 Sep 2018 14:14:37 -0700 Subject: [PATCH 2/2] lfsapi: implement auth redirects Refactor the client redirect code to allow lfsapi to re-authenticate redirected requests --- lfsapi/auth.go | 14 +++++- lfsapi/auth_test.go | 114 +++++++++++++++++++++--------------------- lfshttp/auth.go | 9 ---- lfshttp/certs_test.go | 16 +++--- lfshttp/client.go | 34 +++++++------ 5 files changed, 98 insertions(+), 89 deletions(-) delete mode 100644 lfshttp/auth.go diff --git a/lfsapi/auth.go b/lfsapi/auth.go index 10f2a47d..ec4c0d29 100644 --- a/lfsapi/auth.go +++ b/lfsapi/auth.go @@ -72,7 +72,19 @@ func (c *Client) doWithCreds(req *http.Request, credHelper CredentialHelper, cre if access == NTLMAccess { return c.doWithNTLM(req, credHelper, creds, credsURL) } - return c.do(req, "", via) + + req.Header.Set("User-Agent", lfshttp.UserAgent) + + redirectedReq, res, err := c.client.DoWithRedirect(c.client.HttpClient(req.Host), req, "", via) + if err != nil || res != nil { + return res, err + } + + if redirectedReq == nil { + return res, errors.New("failed to redirect request") + } + + return c.doWithAuth("", redirectedReq, via) } // getCreds fills the authorization header for the given request if possible, diff --git a/lfsapi/auth_test.go b/lfsapi/auth_test.go index 39253d13..68b485dd 100644 --- a/lfsapi/auth_test.go +++ b/lfsapi/auth_test.go @@ -619,76 +619,76 @@ func (f *fakeCredentialFiller) Reject(creds Creds) error { return errors.New("Not implemented") } -// func TestClientRedirectReauthenticate(t *testing.T) { -// var srv1, srv2 *httptest.Server -// var called1, called2 uint32 -// var creds1, creds2 Creds +func TestClientRedirectReauthenticate(t *testing.T) { + var srv1, srv2 *httptest.Server + var called1, called2 uint32 + var creds1, creds2 Creds -// srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// atomic.AddUint32(&called1, 1) + srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&called1, 1) -// if hdr := r.Header.Get("Authorization"); len(hdr) > 0 { -// parts := strings.SplitN(hdr, " ", 2) -// typ, b64 := parts[0], parts[1] + if hdr := r.Header.Get("Authorization"); len(hdr) > 0 { + parts := strings.SplitN(hdr, " ", 2) + typ, b64 := parts[0], parts[1] -// auth, err := base64.URLEncoding.DecodeString(b64) -// assert.Nil(t, err) -// assert.Equal(t, "Basic", typ) -// assert.Equal(t, "user1:pass1", string(auth)) + auth, err := base64.URLEncoding.DecodeString(b64) + assert.Nil(t, err) + assert.Equal(t, "Basic", typ) + assert.Equal(t, "user1:pass1", string(auth)) -// http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently) -// return -// } -// w.WriteHeader(http.StatusUnauthorized) -// })) + http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently) + return + } + w.WriteHeader(http.StatusUnauthorized) + })) -// srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { -// atomic.AddUint32(&called2, 1) + srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&called2, 1) -// parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) -// typ, b64 := parts[0], parts[1] + parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2) + typ, b64 := parts[0], parts[1] -// auth, err := base64.URLEncoding.DecodeString(b64) -// assert.Nil(t, err) -// assert.Equal(t, "Basic", typ) -// assert.Equal(t, "user2:pass2", string(auth)) -// })) + auth, err := base64.URLEncoding.DecodeString(b64) + assert.Nil(t, err) + assert.Equal(t, "Basic", typ) + assert.Equal(t, "user2:pass2", string(auth)) + })) -// // Change the URL of srv2 to make it appears as if it is a different -// // host. -// srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1) + // Change the URL of srv2 to make it appears as if it is a different + // host. + srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1) -// creds1 = Creds(map[string]string{ -// "protocol": "http", -// "host": strings.TrimPrefix(srv1.URL, "http://"), + creds1 = Creds(map[string]string{ + "protocol": "http", + "host": strings.TrimPrefix(srv1.URL, "http://"), -// "username": "user1", -// "password": "pass1", -// }) -// creds2 = Creds(map[string]string{ -// "protocol": "http", -// "host": strings.TrimPrefix(srv2.URL, "http://"), + "username": "user1", + "password": "pass1", + }) + creds2 = Creds(map[string]string{ + "protocol": "http", + "host": strings.TrimPrefix(srv2.URL, "http://"), -// "username": "user2", -// "password": "pass2", -// }) + "username": "user2", + "password": "pass2", + }) -// defer srv1.Close() -// defer srv2.Close() + defer srv1.Close() + defer srv2.Close() -// c, err := NewClient(lfshttp.NewContext(nil, nil, nil)) -// creds := newCredentialCacher() -// creds.Approve(creds1) -// creds.Approve(creds2) -// c.Credentials = creds + c, err := NewClient(lfshttp.NewContext(nil, nil, nil)) + creds := newCredentialCacher() + creds.Approve(creds1) + creds.Approve(creds2) + c.Credentials = creds -// req, err := http.NewRequest("GET", srv1.URL, nil) -// require.Nil(t, err) + req, err := http.NewRequest("GET", srv1.URL, nil) + require.Nil(t, err) -// _, err = c.DoWithAuth("", req) -// assert.Nil(t, err) + _, err = c.DoWithAuth("", req) + assert.Nil(t, err) -// // called1 is 2 since LFS tries an unauthenticated request first -// assert.EqualValues(t, 2, called1) -// assert.EqualValues(t, 1, called2) -// } + // called1 is 2 since LFS tries an unauthenticated request first + assert.EqualValues(t, 2, called1) + assert.EqualValues(t, 1, called2) +} diff --git a/lfshttp/auth.go b/lfshttp/auth.go deleted file mode 100644 index c84f7843..00000000 --- a/lfshttp/auth.go +++ /dev/null @@ -1,9 +0,0 @@ -package lfshttp - -import ( - "net/http" -) - -func (c *Client) doWithAuth(remote string, req *http.Request, via []*http.Request) (*http.Response, error) { - return c.do(req, remote, via) -} diff --git a/lfshttp/certs_test.go b/lfshttp/certs_test.go index eb3f0374..e32e9aee 100644 --- a/lfshttp/certs_test.go +++ b/lfshttp/certs_test.go @@ -158,7 +158,7 @@ func TestCertFromSSLCAPathEnv(t *testing.T) { func TestCertVerifyDisabledGlobalEnv(t *testing.T) { empty, _ := NewClient(nil) - httpClient := empty.httpClient("anyhost.com") + httpClient := empty.HttpClient("anyhost.com") tr, ok := httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) @@ -170,7 +170,7 @@ func TestCertVerifyDisabledGlobalEnv(t *testing.T) { assert.Nil(t, err) - httpClient = c.httpClient("anyhost.com") + httpClient = c.HttpClient("anyhost.com") tr, ok = httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.True(t, tr.TLSClientConfig.InsecureSkipVerify) @@ -179,7 +179,7 @@ func TestCertVerifyDisabledGlobalEnv(t *testing.T) { func TestCertVerifyDisabledGlobalConfig(t *testing.T) { def, _ := NewClient(nil) - httpClient := def.httpClient("anyhost.com") + httpClient := def.HttpClient("anyhost.com") tr, ok := httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) @@ -190,7 +190,7 @@ func TestCertVerifyDisabledGlobalConfig(t *testing.T) { })) assert.Nil(t, err) - httpClient = c.httpClient("anyhost.com") + httpClient = c.HttpClient("anyhost.com") tr, ok = httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.True(t, tr.TLSClientConfig.InsecureSkipVerify) @@ -199,13 +199,13 @@ func TestCertVerifyDisabledGlobalConfig(t *testing.T) { func TestCertVerifyDisabledHostConfig(t *testing.T) { def, _ := NewClient(nil) - httpClient := def.httpClient("specifichost.com") + httpClient := def.HttpClient("specifichost.com") tr, ok := httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) } - httpClient = def.httpClient("otherhost.com") + httpClient = def.HttpClient("otherhost.com") tr, ok = httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) @@ -216,13 +216,13 @@ func TestCertVerifyDisabledHostConfig(t *testing.T) { })) assert.Nil(t, err) - httpClient = c.httpClient("specifichost.com") + httpClient = c.HttpClient("specifichost.com") tr, ok = httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.True(t, tr.TLSClientConfig.InsecureSkipVerify) } - httpClient = c.httpClient("otherhost.com") + httpClient = c.HttpClient("otherhost.com") tr, ok = httpClient.Transport.(*http.Transport) if assert.True(t, ok) { assert.False(t, tr.TLSClientConfig.InsecureSkipVerify) diff --git a/lfshttp/client.go b/lfshttp/client.go index 42ce6c84..b58d0c24 100644 --- a/lfshttp/client.go +++ b/lfshttp/client.go @@ -168,7 +168,7 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) { func (c *Client) do(req *http.Request, remote string, via []*http.Request) (*http.Response, error) { req.Header.Set("User-Agent", UserAgent) - res, err := c.doWithRedirects(c.httpClient(req.Host), req, remote, via) + res, err := c.doWithRedirects(c.HttpClient(req.Host), req, remote, via) if err != nil { return res, err } @@ -248,10 +248,10 @@ func (c *Client) extraHeaders(u *url.URL) map[string][]string { return m } -func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Response, error) { +func (c *Client) DoWithRedirect(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Request, *http.Response, error) { tracedReq, err := c.traceRequest(req) if err != nil { - return nil, err + return nil, nil, err } var retries int @@ -279,11 +279,11 @@ func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote str if err != nil { c.traceResponse(req, tracedReq, nil) - return nil, err + return nil, nil, err } if res == nil { - return nil, nil + return nil, nil, nil } c.traceResponse(req, tracedReq, res) @@ -298,7 +298,7 @@ func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote str // how to handle below. If the status code contained in // the HTTP response was none of them, return the (res, // err) tuple as-is, otherwise handle the redirect. - return res, err + return nil, res, c.handleResponse(res) } redirectTo := res.Header.Get("Location") @@ -310,25 +310,31 @@ func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote str via = append(via, req) if len(via) >= 3 { - return res, errors.New("too many redirects") + return nil, res, errors.New("too many redirects") } redirectedReq, err := newRequestForRetry(req, redirectTo) if err != nil { + return nil, res, err + } + + return redirectedReq, nil, nil +} + +func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, remote string, via []*http.Request) (*http.Response, error) { + redirectedReq, res, err := c.DoWithRedirect(cli, req, remote, via) + if err != nil || res != nil { return res, err } - if len(req.Header.Get("Authorization")) > 0 { - // If the original request was authenticated (noted by the - // presence of the Authorization header), then recur through - // doWithAuth, retaining the requests via but only after - // authenticating the redirected request. - return c.doWithAuth(remote, redirectedReq, via) + if redirectedReq == nil { + return nil, errors.New("failed to redirect request") } + return c.doWithRedirects(cli, redirectedReq, remote, via) } -func (c *Client) httpClient(host string) *http.Client { +func (c *Client) HttpClient(host string) *http.Client { c.clientMu.Lock() defer c.clientMu.Unlock()