support GET redirects
This commit is contained in:
parent
474ba75e89
commit
6b6a718511
@ -107,6 +107,126 @@ func TestSuccessfulDownload(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// nearly identical to TestSuccessfulDownload
|
||||
// called multiple times to return different 3xx status codes
|
||||
func TestSuccessfulDownloadWithRedirects(t *testing.T) {
|
||||
mux := http.NewServeMux()
|
||||
server := httptest.NewServer(mux)
|
||||
tmp := tempdir(t)
|
||||
defer server.Close()
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
// all of these should work for GET requests
|
||||
redirectCodes := []int{301, 302, 303, 307}
|
||||
redirectIndex := 0
|
||||
|
||||
mux.HandleFunc("/redirect/objects/oid", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("Server: %s %s", r.Method, r.URL)
|
||||
t.Logf("request header: %v", r.Header)
|
||||
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(405)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Location", server.URL+"/media/objects/oid")
|
||||
w.WriteHeader(redirectCodes[redirectIndex])
|
||||
t.Logf("redirect with %d", redirectCodes[redirectIndex])
|
||||
redirectIndex += 1
|
||||
})
|
||||
|
||||
mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("Server: %s %s", r.Method, r.URL)
|
||||
t.Logf("request header: %v", r.Header)
|
||||
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(405)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("Accept") != mediaType {
|
||||
t.Error("Invalid Accept")
|
||||
}
|
||||
|
||||
if r.Header.Get("Authorization") != expectedAuth(t, server) {
|
||||
t.Error("Invalid Authorization")
|
||||
}
|
||||
|
||||
obj := &objectResource{
|
||||
Oid: "oid",
|
||||
Size: 4,
|
||||
Links: map[string]*linkRelation{
|
||||
"download": &linkRelation{
|
||||
Href: server.URL + "/download",
|
||||
Header: map[string]string{"A": "1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
by, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
head := w.Header()
|
||||
head.Set("Content-Type", mediaType)
|
||||
head.Set("Content-Length", strconv.Itoa(len(by)))
|
||||
w.WriteHeader(200)
|
||||
w.Write(by)
|
||||
})
|
||||
|
||||
mux.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("Server: %s %s", r.Method, r.URL)
|
||||
t.Logf("request header: %v", r.Header)
|
||||
|
||||
if r.Method != "GET" {
|
||||
w.WriteHeader(405)
|
||||
return
|
||||
}
|
||||
|
||||
if r.Header.Get("Accept") != "" {
|
||||
t.Error("Invalid Accept")
|
||||
}
|
||||
|
||||
if r.Header.Get("Authorization") != expectedAuth(t, server) {
|
||||
t.Error("Invalid Authorization")
|
||||
}
|
||||
|
||||
if r.Header.Get("A") != "1" {
|
||||
t.Error("invalid A")
|
||||
}
|
||||
|
||||
head := w.Header()
|
||||
head.Set("Content-Type", "application/octet-stream")
|
||||
head.Set("Content-Length", "4")
|
||||
w.WriteHeader(200)
|
||||
w.Write([]byte("test"))
|
||||
})
|
||||
|
||||
Config.SetConfig("lfs.url", server.URL+"/redirect")
|
||||
|
||||
for _, redirect := range redirectCodes {
|
||||
reader, size, wErr := Download("oid")
|
||||
if wErr != nil {
|
||||
t.Fatalf("unexpected error for %d status: %s", redirect, wErr)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
if size != 4 {
|
||||
t.Errorf("unexpected size for %d status: %d", redirect, size)
|
||||
}
|
||||
|
||||
by, err := ioutil.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for %d status: %s", redirect, err)
|
||||
}
|
||||
|
||||
if body := string(by); body != "test" {
|
||||
t.Errorf("unexpected body for %d status: %s", redirect, body)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nearly identical to TestSuccessfulDownload
|
||||
// the api request returns a custom Authorization header
|
||||
func TestSuccessfulDownloadWithAuthorization(t *testing.T) {
|
||||
|
19
lfs/http.go
19
lfs/http.go
@ -2,6 +2,7 @@ package lfs
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/rubyist/tracerx"
|
||||
"io"
|
||||
@ -24,11 +25,27 @@ func (c *Configuration) HttpClient() *http.Client {
|
||||
if sslVerify == "false" || len(os.Getenv("GIT_SSL_NO_VERIFY")) > 0 {
|
||||
tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
}
|
||||
c.httpClient = &http.Client{Transport: tr}
|
||||
c.httpClient = &http.Client{
|
||||
Transport: tr,
|
||||
CheckRedirect: checkRedirect,
|
||||
}
|
||||
}
|
||||
return c.httpClient
|
||||
}
|
||||
|
||||
func checkRedirect(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 3 {
|
||||
return errors.New("stopped after 3 redirects")
|
||||
}
|
||||
|
||||
oldest := via[0]
|
||||
for key, _ := range oldest.Header {
|
||||
req.Header.Set(key, oldest.Header.Get(key))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var tracedTypes = []string{"json", "text", "xml", "html"}
|
||||
|
||||
func traceHttpRequest(c *Configuration, req *http.Request) {
|
||||
|
Loading…
Reference in New Issue
Block a user