Merge pull request #1791 from git-lfs/api/redirections
lfsapi: teach `(*Client) Do()` how to handle 307 redirections
This commit is contained in:
commit
4f9af9e708
@ -129,7 +129,7 @@ func DoHttpRequestWithRedirects(cfg *config.Configuration, req *http.Request, vi
|
||||
|
||||
redirectedReq, err := NewHttpRequest(req.Method, redirectTo, nil)
|
||||
if err != nil {
|
||||
return res, errors.Wrapf(err, err.Error())
|
||||
return res, err
|
||||
}
|
||||
|
||||
via = append(via, req)
|
||||
@ -152,7 +152,7 @@ func DoHttpRequestWithRedirects(cfg *config.Configuration, req *http.Request, vi
|
||||
redirectedReq.ContentLength = req.ContentLength
|
||||
|
||||
if err = CheckRedirect(redirectedReq, via); err != nil {
|
||||
return res, errors.Wrapf(err, err.Error())
|
||||
return res, err
|
||||
}
|
||||
|
||||
return DoHttpRequestWithRedirects(cfg, redirectedReq, via, useCreds)
|
||||
|
@ -50,13 +50,14 @@ func TestDoWithAuthApprove(t *testing.T) {
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
atomic.AddUint32(&called, 1)
|
||||
w.Header().Set("Lfs-Authenticate", "Basic")
|
||||
assert.Equal(t, "POST", req.Method)
|
||||
|
||||
body := &authRequest{}
|
||||
err := json.NewDecoder(req.Body).Decode(body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "Approve", body.Test)
|
||||
|
||||
w.Header().Set("Lfs-Authenticate", "Basic")
|
||||
actual := req.Header.Get("Authorization")
|
||||
if len(actual) == 0 {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
@ -80,7 +81,7 @@ func TestDoWithAuthApprove(t *testing.T) {
|
||||
|
||||
assert.Equal(t, NoneAccess, c.Endpoints.AccessFor(srv.URL))
|
||||
|
||||
req, err := http.NewRequest("GET", srv.URL, nil)
|
||||
req, err := http.NewRequest("POST", srv.URL, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = MarshalToRequest(req, &authRequest{Test: "Approve"})
|
||||
@ -106,7 +107,7 @@ func TestDoWithAuthReject(t *testing.T) {
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
||||
atomic.AddUint32(&called, 1)
|
||||
w.Header().Set("Lfs-Authenticate", "Basic")
|
||||
assert.Equal(t, "POST", req.Method)
|
||||
|
||||
body := &authRequest{}
|
||||
err := json.NewDecoder(req.Body).Decode(body)
|
||||
@ -118,6 +119,7 @@ func TestDoWithAuthReject(t *testing.T) {
|
||||
base64.StdEncoding.EncodeToString([]byte("user:pass")),
|
||||
)
|
||||
|
||||
w.Header().Set("Lfs-Authenticate", "Basic")
|
||||
if actual != expected {
|
||||
// Write http.StatuUnauthorized to force the credential
|
||||
// helper to reject the credentials
|
||||
@ -148,7 +150,7 @@ func TestDoWithAuthReject(t *testing.T) {
|
||||
})),
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", srv.URL, nil)
|
||||
req, err := http.NewRequest("POST", srv.URL, nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
err = MarshalToRequest(req, &authRequest{Test: "Reject"})
|
||||
|
150
lfsapi/client.go
Normal file
150
lfsapi/client.go
Normal file
@ -0,0 +1,150 @@
|
||||
package lfsapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/git-lfs/git-lfs/errors"
|
||||
"github.com/rubyist/tracerx"
|
||||
)
|
||||
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
res, err := c.doWithRedirects(c.httpClient(req.Host), req, nil)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, c.handleResponse(res)
|
||||
}
|
||||
|
||||
func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, via []*http.Request) (*http.Response, error) {
|
||||
if seeker, ok := req.Body.(io.Seeker); ok {
|
||||
seeker.Seek(0, io.SeekStart)
|
||||
}
|
||||
|
||||
res, err := cli.Do(req)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
if res.StatusCode != 307 {
|
||||
return res, err
|
||||
}
|
||||
|
||||
redirectTo := res.Header.Get("Location")
|
||||
locurl, err := url.Parse(redirectTo)
|
||||
if err == nil && !locurl.IsAbs() {
|
||||
locurl = req.URL.ResolveReference(locurl)
|
||||
redirectTo = locurl.String()
|
||||
}
|
||||
|
||||
via = append(via, req)
|
||||
if len(via) >= 3 {
|
||||
return res, errors.New("too many redirects")
|
||||
}
|
||||
|
||||
redirectedReq, err := newRequestForRetry(req, redirectTo)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return c.doWithRedirects(cli, redirectedReq, via)
|
||||
}
|
||||
|
||||
func (c *Client) httpClient(host string) *http.Client {
|
||||
c.clientMu.Lock()
|
||||
defer c.clientMu.Unlock()
|
||||
|
||||
if c.gitEnv == nil {
|
||||
c.gitEnv = make(testEnv)
|
||||
}
|
||||
|
||||
if c.osEnv == nil {
|
||||
c.osEnv = make(testEnv)
|
||||
}
|
||||
|
||||
if c.hostClients == nil {
|
||||
c.hostClients = make(map[string]*http.Client)
|
||||
}
|
||||
|
||||
if client, ok := c.hostClients[host]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
concurrentTransfers := c.ConcurrentTransfers
|
||||
if concurrentTransfers < 1 {
|
||||
concurrentTransfers = 3
|
||||
}
|
||||
|
||||
dialtime := c.DialTimeout
|
||||
if dialtime < 1 {
|
||||
dialtime = 30
|
||||
}
|
||||
|
||||
keepalivetime := c.KeepaliveTimeout
|
||||
if keepalivetime < 1 {
|
||||
keepalivetime = 1800
|
||||
}
|
||||
|
||||
tlstime := c.TLSTimeout
|
||||
if tlstime < 1 {
|
||||
tlstime = 30
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
Proxy: ProxyFromClient(c),
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: time.Duration(dialtime) * time.Second,
|
||||
KeepAlive: time.Duration(keepalivetime) * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: time.Duration(tlstime) * time.Second,
|
||||
MaxIdleConnsPerHost: concurrentTransfers,
|
||||
}
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{}
|
||||
if isCertVerificationDisabledForHost(c, host) {
|
||||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||
} else {
|
||||
tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: tr,
|
||||
CheckRedirect: func(*http.Request, []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
|
||||
c.hostClients[host] = httpClient
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
func newRequestForRetry(req *http.Request, location string) (*http.Request, error) {
|
||||
newReq, err := http.NewRequest(req.Method, location, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for key := range req.Header {
|
||||
if key == "Authorization" {
|
||||
if req.URL.Scheme != newReq.URL.Scheme || req.URL.Host != newReq.URL.Host {
|
||||
continue
|
||||
}
|
||||
}
|
||||
newReq.Header.Set(key, req.Header.Get(key))
|
||||
}
|
||||
|
||||
oldestURL := strings.SplitN(req.URL.String(), "?", 2)[0]
|
||||
newURL := strings.SplitN(newReq.URL.String(), "?", 2)[0]
|
||||
tracerx.Printf("api: redirect %s %s to %s", req.Method, oldestURL, newURL)
|
||||
|
||||
newReq.Body = req.Body
|
||||
newReq.ContentLength = req.ContentLength
|
||||
return newReq, nil
|
||||
}
|
@ -1,12 +1,102 @@
|
||||
package lfsapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type redirectTest struct {
|
||||
Test string
|
||||
}
|
||||
|
||||
func TestClientRedirect(t *testing.T) {
|
||||
var called1 uint32
|
||||
var called2 uint32
|
||||
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddUint32(&called2, 1)
|
||||
t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/ok":
|
||||
assert.Equal(t, "", r.Header.Get("Authorization"))
|
||||
assert.Equal(t, "1", r.Header.Get("A"))
|
||||
body := &redirectTest{}
|
||||
err := json.NewDecoder(r.Body).Decode(body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "External", body.Test)
|
||||
|
||||
w.WriteHeader(200)
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
|
||||
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
atomic.AddUint32(&called1, 1)
|
||||
t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
|
||||
assert.Equal(t, "POST", r.Method)
|
||||
|
||||
switch r.URL.Path {
|
||||
case "/local":
|
||||
w.Header().Set("Location", "/ok")
|
||||
w.WriteHeader(307)
|
||||
case "/external":
|
||||
w.Header().Set("Location", srv2.URL+"/ok")
|
||||
w.WriteHeader(307)
|
||||
case "/ok":
|
||||
assert.Equal(t, "auth", r.Header.Get("Authorization"))
|
||||
assert.Equal(t, "1", r.Header.Get("A"))
|
||||
body := &redirectTest{}
|
||||
err := json.NewDecoder(r.Body).Decode(body)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "Local", body.Test)
|
||||
|
||||
w.WriteHeader(200)
|
||||
default:
|
||||
w.WriteHeader(404)
|
||||
}
|
||||
}))
|
||||
defer srv1.Close()
|
||||
defer srv2.Close()
|
||||
|
||||
c := &Client{}
|
||||
|
||||
// local redirect
|
||||
req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("Authorization", "auth")
|
||||
req.Header.Set("A", "1")
|
||||
|
||||
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
|
||||
|
||||
res, err := c.Do(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 200, res.StatusCode)
|
||||
assert.EqualValues(t, 2, called1)
|
||||
assert.EqualValues(t, 0, called2)
|
||||
|
||||
// external redirect
|
||||
req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("Authorization", "auth")
|
||||
req.Header.Set("A", "1")
|
||||
|
||||
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
|
||||
|
||||
res, err = c.Do(req)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, 200, res.StatusCode)
|
||||
assert.EqualValues(t, 3, called1)
|
||||
assert.EqualValues(t, 1, called2)
|
||||
}
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
c, err := NewClient(testEnv(map[string]string{}), testEnv(map[string]string{
|
||||
"lfs.dialtimeout": "151",
|
||||
|
@ -1,16 +1,12 @@
|
||||
package lfsapi
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/git-lfs/git-lfs/errors"
|
||||
)
|
||||
@ -79,85 +75,6 @@ func NewClient(osEnv env, gitEnv env) (*Client, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
if seeker, ok := req.Body.(io.Seeker); ok {
|
||||
seeker.Seek(0, io.SeekStart)
|
||||
}
|
||||
|
||||
res, err := c.httpClient(req.Host).Do(req)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
|
||||
return res, c.handleResponse(res)
|
||||
}
|
||||
|
||||
func (c *Client) httpClient(host string) *http.Client {
|
||||
c.clientMu.Lock()
|
||||
defer c.clientMu.Unlock()
|
||||
|
||||
if c.gitEnv == nil {
|
||||
c.gitEnv = make(testEnv)
|
||||
}
|
||||
|
||||
if c.osEnv == nil {
|
||||
c.osEnv = make(testEnv)
|
||||
}
|
||||
|
||||
if c.hostClients == nil {
|
||||
c.hostClients = make(map[string]*http.Client)
|
||||
}
|
||||
|
||||
if client, ok := c.hostClients[host]; ok {
|
||||
return client
|
||||
}
|
||||
|
||||
concurrentTransfers := c.ConcurrentTransfers
|
||||
if concurrentTransfers < 1 {
|
||||
concurrentTransfers = 3
|
||||
}
|
||||
|
||||
dialtime := c.DialTimeout
|
||||
if dialtime < 1 {
|
||||
dialtime = 30
|
||||
}
|
||||
|
||||
keepalivetime := c.KeepaliveTimeout
|
||||
if keepalivetime < 1 {
|
||||
keepalivetime = 1800
|
||||
}
|
||||
|
||||
tlstime := c.TLSTimeout
|
||||
if tlstime < 1 {
|
||||
tlstime = 30
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
Proxy: ProxyFromClient(c),
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: time.Duration(dialtime) * time.Second,
|
||||
KeepAlive: time.Duration(keepalivetime) * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: time.Duration(tlstime) * time.Second,
|
||||
MaxIdleConnsPerHost: concurrentTransfers,
|
||||
}
|
||||
|
||||
tr.TLSClientConfig = &tls.Config{}
|
||||
if isCertVerificationDisabledForHost(c, host) {
|
||||
tr.TLSClientConfig.InsecureSkipVerify = true
|
||||
} else {
|
||||
tr.TLSClientConfig.RootCAs = getRootCAsForHost(c, host)
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: tr,
|
||||
}
|
||||
|
||||
c.hostClients[host] = httpClient
|
||||
|
||||
return httpClient
|
||||
}
|
||||
|
||||
func decodeResponse(res *http.Response, obj interface{}) error {
|
||||
ctype := res.Header.Get("Content-Type")
|
||||
if !(lfsMediaTypeRE.MatchString(ctype) || jsonMediaTypeRE.MatchString(ctype)) {
|
||||
|
Loading…
Reference in New Issue
Block a user