diff --git a/hawser/client.go b/hawser/client.go index 42e16abb..5fac15b2 100644 --- a/hawser/client.go +++ b/hawser/client.go @@ -66,12 +66,13 @@ func Download(oidPath string) (io.ReadCloser, int64, *WrappedError) { return nil, 0, wErr } - if ok, wErr := validateMediaHeader(contentType, res.Body); !ok { + ok, headerSize, wErr := validateMediaHeader(contentType, res.Body) + if !ok { setErrorResponseContext(wErr, res) return nil, 0, wErr } - return res.Body, res.ContentLength, nil + return res.Body, res.ContentLength - int64(headerSize), nil } func Upload(oidPath, filename string, cb CopyCallback) *WrappedError { @@ -302,32 +303,35 @@ func callPost(filehash, filename string) (*linkMeta, int, error) { return nil, res.StatusCode, nil } -func validateMediaHeader(contentType string, reader io.Reader) (bool, *WrappedError) { +func validateMediaHeader(contentType string, reader io.Reader) (bool, int, *WrappedError) { mediaType, params, err := mime.ParseMediaType(contentType) + var headerSize int + if err != nil { - return false, Errorf(err, "Invalid Media Type: %s", contentType) + return false, headerSize, Errorf(err, "Invalid Media Type: %s", contentType) } if mediaType == gitMediaType { givenHeader, ok := params["header"] if !ok { - return false, Error(fmt.Errorf("Missing Git Media header in %s", contentType)) + return false, headerSize, Error(fmt.Errorf("Missing Git Media header in %s", contentType)) } fullGivenHeader := "--" + givenHeader + "\n" + headerSize = len(fullGivenHeader) - header := make([]byte, len(fullGivenHeader)) + header := make([]byte, headerSize) _, err = io.ReadAtLeast(reader, header, len(fullGivenHeader)) if err != nil { - return false, Errorf(err, "Error reading response body.") + return false, headerSize, Errorf(err, "Error reading response body.") } if string(header) != fullGivenHeader { - return false, Error(fmt.Errorf("Invalid header: %s expected, got %s", fullGivenHeader, header)) + return false, headerSize, Error(fmt.Errorf("Invalid header: %s expected, got %s", fullGivenHeader, header)) } } - return true, nil + return true, headerSize, nil } func doRequest(req *http.Request, creds Creds) (*http.Response, *WrappedError) { diff --git a/hawser/client_legacy_test.go b/hawser/client_legacy_test.go index 09a96c86..3b846942 100644 --- a/hawser/client_legacy_test.go +++ b/hawser/client_legacy_test.go @@ -9,6 +9,47 @@ import ( "testing" ) +func TestDownloadWithMediaHeader(t *testing.T) { + mux := http.NewServeMux() + server := httptest.NewServer(mux) + tmp := tempdir(t) + defer server.Close() + defer os.RemoveAll(tmp) + + mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + w.WriteHeader(405) + return + } + + head := w.Header() + head.Set("Content-Type", "application/vnd.git-media; header=download-header") + head.Set("Content-Length", "22") + w.WriteHeader(200) + w.Write([]byte("--download-header\ntest")) + }) + + Config.SetConfig("hawser.url", server.URL+"/media") + reader, size, wErr := Download("whatever/oid") + if wErr != nil { + t.Fatalf("unexpected error: %s", wErr) + } + defer reader.Close() + + if size != 4 { + t.Errorf("unexpected size: %d", size) + } + + by, err := ioutil.ReadAll(reader) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if body := string(by); body != "test" { + t.Errorf("unexpected body: %s", body) + } +} + func TestPut(t *testing.T) { mux := http.NewServeMux() server := httptest.NewServer(mux)