diff --git a/hawser/client.go b/hawser/client.go index 5fac15b2..755c19b7 100644 --- a/hawser/client.go +++ b/hawser/client.go @@ -226,7 +226,7 @@ func callExternalPut(filehash, filename string, lm *linkMeta, cb CopyCallback) e req.ContentLength = fileSize tracerx.Printf("external_put: %s %s", filepath.Base(filehash), req.URL) - res, err := http.DefaultClient.Do(req) + res, err := DoHTTP(Config, req) if err != nil { return Error(err) } @@ -248,7 +248,7 @@ func callExternalPut(filehash, filename string, lm *linkMeta, cb CopyCallback) e cbreq.Body = ioutil.NopCloser(bytes.NewBufferString(d)) tracerx.Printf("verify: %s %s", oid, cb.Href) - cbres, err := http.DefaultClient.Do(cbreq) + cbres, err := DoHTTP(Config, cbreq) if err != nil { return Error(err) } @@ -335,7 +335,7 @@ func validateMediaHeader(contentType string, reader io.Reader) (bool, int, *Wrap } func doRequest(req *http.Request, creds Creds) (*http.Response, *WrappedError) { - res, err := HttpClient().Do(req) + res, err := DoHTTP(Config, req) var wErr *WrappedError diff --git a/hawser/client_download_test.go b/hawser/client_download_test.go index 3fe9455e..7d2e682c 100644 --- a/hawser/client_download_test.go +++ b/hawser/client_download_test.go @@ -48,3 +48,55 @@ func TestDownload(t *testing.T) { t.Errorf("unexpected body: %s", body) } } + +func TestDownloadWithRedirect(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + tmp := tempdir(t) + defer server.Close() + defer os.RemoveAll(tmp) + + mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(405) + return + } + + head := w.Header() + head.Set("Location", server.URL+"/media/objects/redirect") + w.WriteHeader(302) + }) + + mux.HandleFunc("/media/objects/redirect", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(405) + return + } + + head := w.Header() + head.Set("Content-Type", "application/octet-stream") + head.Set("Content-Length", "4") + w.WriteHeader(200) + w.Write([]byte("test")) + }) + + Config.SetConfig("hawser.url", server.URL+"/media") + reader, size, wErr := Download("whatever/oid") + if wErr != nil { + t.Fatalf("unexpected error: %s", wErr) + } + defer reader.Close() + + if size != 4 { + t.Errorf("unexpected size: %d", size) + } + + by, err := ioutil.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if body := string(by); body != "test" { + t.Errorf("unexpected body: %s", body) + } +} diff --git a/hawser/config.go b/hawser/config.go index 81859fbf..5a1bfdb3 100644 --- a/hawser/config.go +++ b/hawser/config.go @@ -1,7 +1,6 @@ package hawser import ( - "crypto/tls" "fmt" "github.com/hawser/git-hawser/git" "net/http" @@ -13,9 +12,10 @@ import ( ) type Configuration struct { - gitConfig map[string]string - remotes []string - httpClient *http.Client + gitConfig map[string]string + remotes []string + httpClient *http.Client + redirectingHttpClient *http.Client } var ( @@ -24,25 +24,6 @@ var ( RedirectError = fmt.Errorf("Unexpected redirection") ) -func HttpClient() *http.Client { - return Config.HttpClient() -} - -func (c *Configuration) HttpClient() *http.Client { - if c.httpClient == nil { - tr := &http.Transport{} - sslVerify, _ := c.GitConfig("http.sslverify") - if len(os.Getenv("GIT_SSL_NO_VERIFY")) > 0 || sslVerify == "false" { - tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - c.httpClient = &http.Client{Transport: tr} - c.httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return RedirectError - } - } - return c.httpClient -} - func (c *Configuration) Endpoint() string { if url, ok := c.GitConfig("hawser.url"); ok { return url diff --git a/hawser/http.go b/hawser/http.go new file mode 100644 index 00000000..5967f2f0 --- /dev/null +++ b/hawser/http.go @@ -0,0 +1,46 @@ +package hawser + +import ( + "crypto/tls" + "net/http" + "os" +) + +func DoHTTP(c *Configuration, req *http.Request) (*http.Response, error) { + switch req.Method { + case "GET", "HEAD": + return c.RedirectingHttpClient().Do(req) + default: + return c.HttpClient().Do(req) + } +} + +func (c *Configuration) HttpClient() *http.Client { + if c.httpClient == nil { + c.httpClient = &http.Client{ + Transport: c.RedirectingHttpClient().Transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return RedirectError + }, + } + } + return c.httpClient +} + +func (c *Configuration) RedirectingHttpClient() *http.Client { + if c.redirectingHttpClient == nil { + c.redirectingHttpClient = &http.Client{ + Transport: httpTransportFor(c), + } + } + return c.redirectingHttpClient +} + +func httpTransportFor(c *Configuration) *http.Transport { + tr := &http.Transport{} + sslVerify, _ := c.GitConfig("http.sslverify") + if len(os.Getenv("GIT_SSL_NO_VERIFY")) > 0 || sslVerify == "false" { + tr.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + return tr +}