Merge pull request #1791 from git-lfs/api/redirections

lfsapi: teach `(*Client) Do()` how to handle 307 redirections
This commit is contained in:
risk danger olson 2016-12-21 10:32:13 -07:00 committed by GitHub
commit 4f9af9e708
5 changed files with 248 additions and 89 deletions

@ -129,7 +129,7 @@ func DoHttpRequestWithRedirects(cfg *config.Configuration, req *http.Request, vi
redirectedReq, err := NewHttpRequest(req.Method, redirectTo, nil)
if err != nil {
return res, errors.Wrapf(err, err.Error())
return res, err
}
via = append(via, req)
@ -152,7 +152,7 @@ func DoHttpRequestWithRedirects(cfg *config.Configuration, req *http.Request, vi
redirectedReq.ContentLength = req.ContentLength
if err = CheckRedirect(redirectedReq, via); err != nil {
return res, errors.Wrapf(err, err.Error())
return res, err
}
return DoHttpRequestWithRedirects(cfg, redirectedReq, via, useCreds)

@ -50,13 +50,14 @@ func TestDoWithAuthApprove(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
atomic.AddUint32(&called, 1)
w.Header().Set("Lfs-Authenticate", "Basic")
assert.Equal(t, "POST", req.Method)
body := &authRequest{}
err := json.NewDecoder(req.Body).Decode(body)
assert.Nil(t, err)
assert.Equal(t, "Approve", body.Test)
w.Header().Set("Lfs-Authenticate", "Basic")
actual := req.Header.Get("Authorization")
if len(actual) == 0 {
w.WriteHeader(http.StatusUnauthorized)
@ -80,7 +81,7 @@ func TestDoWithAuthApprove(t *testing.T) {
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL))
req, err := http.NewRequest("GET", srv.URL, nil)
req, err := http.NewRequest("POST", srv.URL, nil)
require.Nil(t, err)
err = MarshalToRequest(req, &authRequest{Test: "Approve"})
@ -106,7 +107,7 @@ func TestDoWithAuthReject(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
atomic.AddUint32(&called, 1)
w.Header().Set("Lfs-Authenticate", "Basic")
assert.Equal(t, "POST", req.Method)
body := &authRequest{}
err := json.NewDecoder(req.Body).Decode(body)
@ -118,6 +119,7 @@ func TestDoWithAuthReject(t *testing.T) {
base64.StdEncoding.EncodeToString([]byte("user:pass")),
)
w.Header().Set("Lfs-Authenticate", "Basic")
if actual != expected {
// Write http.StatuUnauthorized to force the credential
// helper to reject the credentials
@ -148,7 +150,7 @@ func TestDoWithAuthReject(t *testing.T) {
})),
}
req, err := http.NewRequest("GET", srv.URL, nil)
req, err := http.NewRequest("POST", srv.URL, nil)
require.Nil(t, err)
err = MarshalToRequest(req, &authRequest{Test: "Reject"})

150
lfsapi/client.go Normal file

@ -0,0 +1,150 @@
package lfsapi
import (
"crypto/tls"
"io"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/git-lfs/git-lfs/errors"
"github.com/rubyist/tracerx"
)
func (c *Client) Do(req *http.Request) (*http.Response, error) {
res, err := c.doWithRedirects(c.httpClient(req.Host), req, nil)
if err != nil {
return res, err
}
return res, c.handleResponse(res)
}
func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, via []*http.Request) (*http.Response, error) {
if seeker, ok := req.Body.(io.Seeker); ok {
seeker.Seek(0, io.SeekStart)
}
res, err := cli.Do(req)
if err != nil {
return res, err
}
if res.StatusCode != 307 {
return res, err
}
redirectTo := res.Header.Get("Location")
locurl, err := url.Parse(redirectTo)
if err == nil && !locurl.IsAbs() {
locurl = req.URL.ResolveReference(locurl)
redirectTo = locurl.String()
}
via = append(via, req)
if len(via) >= 3 {
return res, errors.New("too many redirects")
}
redirectedReq, err := newRequestForRetry(req, redirectTo)
if err != nil {
return res, err
}
return c.doWithRedirects(cli, redirectedReq, via)
}
func (c *Client) httpClient(host string) *http.Client {
c.clientMu.Lock()
defer c.clientMu.Unlock()
if c.gitEnv == nil {
c.gitEnv = make(testEnv)
}
if c.osEnv == nil {
c.osEnv = make(testEnv)
}
if c.hostClients == nil {
c.hostClients = make(map[string]*http.Client)
}
if client, ok := c.hostClients[host]; ok {
return client
}
concurrentTransfers := c.ConcurrentTransfers
if concurrentTransfers < 1 {
concurrentTransfers = 3
}
dialtime := c.DialTimeout
if dialtime < 1 {
dialtime = 30
}
keepalivetime := c.KeepaliveTimeout
if keepalivetime < 1 {
keepalivetime = 1800
}
tlstime := c.TLSTimeout
if tlstime < 1 {
tlstime = 30
}
tr := &http.Transport{
Proxy: ProxyFromClient(c),
Dial: (&net.Dialer{
Timeout: time.Duration(dialtime) * time.Second,
KeepAlive: time.Duration(keepalivetime) * time.Second,
}).Dial,
TLSHandshakeTimeout: time.Duration(tlstime) * time.Second,
MaxIdleConnsPerHost: concurrentTransfers,
}
tr.TLSClientConfig = &tls.Config{}
if isCertVerificationDisabledForHost(c, host) {
tr.TLSClientConfig.InsecureSkipVerify = true
} else {
tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host)
}
httpClient := &http.Client{
Transport: tr,
CheckRedirect: func(*http.Request, []*http.Request) error {
return http.ErrUseLastResponse
},
}
c.hostClients[host] = httpClient
return httpClient
}
func newRequestForRetry(req *http.Request, location string) (*http.Request, error) {
newReq, err := http.NewRequest(req.Method, location, nil)
if err != nil {
return nil, err
}
for key := range req.Header {
if key == "Authorization" {
if req.URL.Scheme != newReq.URL.Scheme || req.URL.Host != newReq.URL.Host {
continue
}
}
newReq.Header.Set(key, req.Header.Get(key))
}
oldestURL := strings.SplitN(req.URL.String(), "?", 2)[0]
newURL := strings.SplitN(newReq.URL.String(), "?", 2)[0]
tracerx.Printf("api: redirect %s %s to %s", req.Method, oldestURL, newURL)
newReq.Body = req.Body
newReq.ContentLength = req.ContentLength
return newReq, nil
}

@ -1,12 +1,102 @@
package lfsapi
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type redirectTest struct {
Test string
}
func TestClientRedirect(t *testing.T) {
var called1 uint32
var called2 uint32
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint32(&called2, 1)
t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
assert.Equal(t, "POST", r.Method)
switch r.URL.Path {
case "/ok":
assert.Equal(t, "", r.Header.Get("Authorization"))
assert.Equal(t, "1", r.Header.Get("A"))
body := &redirectTest{}
err := json.NewDecoder(r.Body).Decode(body)
assert.Nil(t, err)
assert.Equal(t, "External", body.Test)
w.WriteHeader(200)
default:
w.WriteHeader(404)
}
}))
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddUint32(&called1, 1)
t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
assert.Equal(t, "POST", r.Method)
switch r.URL.Path {
case "/local":
w.Header().Set("Location", "/ok")
w.WriteHeader(307)
case "/external":
w.Header().Set("Location", srv2.URL+"/ok")
w.WriteHeader(307)
case "/ok":
assert.Equal(t, "auth", r.Header.Get("Authorization"))
assert.Equal(t, "1", r.Header.Get("A"))
body := &redirectTest{}
err := json.NewDecoder(r.Body).Decode(body)
assert.Nil(t, err)
assert.Equal(t, "Local", body.Test)
w.WriteHeader(200)
default:
w.WriteHeader(404)
}
}))
defer srv1.Close()
defer srv2.Close()
c := &Client{}
// local redirect
req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "auth")
req.Header.Set("A", "1")
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
res, err := c.Do(req)
require.Nil(t, err)
assert.Equal(t, 200, res.StatusCode)
assert.EqualValues(t, 2, called1)
assert.EqualValues(t, 0, called2)
// external redirect
req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "auth")
req.Header.Set("A", "1")
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
res, err = c.Do(req)
require.Nil(t, err)
assert.Equal(t, 200, res.StatusCode)
assert.EqualValues(t, 3, called1)
assert.EqualValues(t, 1, called2)
}
func TestNewClient(t *testing.T) {
c, err := NewClient(testEnv(map[string]string{}), testEnv(map[string]string{
"lfs.dialtimeout": "151",

@ -1,16 +1,12 @@
package lfsapi
import (
"crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"github.com/git-lfs/git-lfs/errors"
)
@ -79,85 +75,6 @@ func NewClient(osEnv env, gitEnv env) (*Client, error) {
return c, nil
}
func (c *Client) Do(req *http.Request) (*http.Response, error) {
if seeker, ok := req.Body.(io.Seeker); ok {
seeker.Seek(0, io.SeekStart)
}
res, err := c.httpClient(req.Host).Do(req)
if err != nil {
return res, err
}
return res, c.handleResponse(res)
}
func (c *Client) httpClient(host string) *http.Client {
c.clientMu.Lock()
defer c.clientMu.Unlock()
if c.gitEnv == nil {
c.gitEnv = make(testEnv)
}
if c.osEnv == nil {
c.osEnv = make(testEnv)
}
if c.hostClients == nil {
c.hostClients = make(map[string]*http.Client)
}
if client, ok := c.hostClients[host]; ok {
return client
}
concurrentTransfers := c.ConcurrentTransfers
if concurrentTransfers < 1 {
concurrentTransfers = 3
}
dialtime := c.DialTimeout
if dialtime < 1 {
dialtime = 30
}
keepalivetime := c.KeepaliveTimeout
if keepalivetime < 1 {
keepalivetime = 1800
}
tlstime := c.TLSTimeout
if tlstime < 1 {
tlstime = 30
}
tr := &http.Transport{
Proxy: ProxyFromClient(c),
Dial: (&net.Dialer{
Timeout: time.Duration(dialtime) * time.Second,
KeepAlive: time.Duration(keepalivetime) * time.Second,
}).Dial,
TLSHandshakeTimeout: time.Duration(tlstime) * time.Second,
MaxIdleConnsPerHost: concurrentTransfers,
}
tr.TLSClientConfig = &tls.Config{}
if isCertVerificationDisabledForHost(c, host) {
tr.TLSClientConfig.InsecureSkipVerify = true
} else {
tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host)
}
httpClient := &http.Client{
Transport: tr,
}
c.hostClients[host] = httpClient
return httpClient
}
func decodeResponse(res *http.Response, obj interface{}) error {
ctype := res.Header.Get("Content-Type")
if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) {