diff --git a/api/api.go b/api/api.go index be87efa7..5ce48e1a 100644 --- a/api/api.go +++ b/api/api.go @@ -6,10 +6,6 @@ import ( "bytes" "encoding/json" "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" "strconv" "github.com/github/git-lfs/config" @@ -21,6 +17,36 @@ import ( "github.com/rubyist/tracerx" ) +// BatchOrLegacy calls the Batch API and falls back on the Legacy API +// This is for simplicity, legacy route is not most optimal (serial) +// TODO LEGACY API: remove when legacy API removed +func BatchOrLegacy(objects []*ObjectResource, operation string) ([]*ObjectResource, error) { + if !config.Config.BatchTransfer() { + return Legacy(objects, operation) + } + objs, err := Batch(objects, operation) + if err != nil { + if errutil.IsNotImplementedError(err) { + git.Config.SetLocal("", "lfs.batch", "false") + return Legacy(objects, operation) + } + return nil, err + } + return objs, nil +} + +func BatchOrLegacySingle(inobj *ObjectResource, operation string) (*ObjectResource, error) { + objs, err := BatchOrLegacy([]*ObjectResource{inobj}, operation) + if err != nil { + return nil, err + } + if len(objs) > 0 { + return objs[0], nil + } + return nil, fmt.Errorf("Object not found") +} + +// Batch calls the batch API and returns object results func Batch(objects []*ObjectResource, operation string) ([]*ObjectResource, error) { if len(objects) == 0 { return nil, nil @@ -80,61 +106,30 @@ func Batch(objects []*ObjectResource, operation string) ([]*ObjectResource, erro return objs, nil } -// Download will attempt to download the object with the given oid. The batched -// API will be used, but if the server does not implement the batch operations -// it will fall back to the legacy API. -func Download(oid string, size int64) (io.ReadCloser, int64, error) { - if !config.Config.BatchTransfer() { - return DownloadLegacy(oid) - } - - objects := []*ObjectResource{ - &ObjectResource{Oid: oid, Size: size}, - } - - objs, err := Batch(objects, "download") - if err != nil { - if errutil.IsNotImplementedError(err) { - git.Config.SetLocal("", "lfs.batch", "false") - return DownloadLegacy(oid) +// Legacy calls the legacy API serially and returns ObjectResources +// TODO LEGACY API: remove when legacy API removed +func Legacy(objects []*ObjectResource, operation string) ([]*ObjectResource, error) { + retobjs := make([]*ObjectResource, 0, len(objects)) + dl := operation == "download" + var globalErr error + for _, o := range objects { + var ret *ObjectResource + var err error + if dl { + ret, err = DownloadCheck(o.Oid) + } else { + ret, err = UploadCheck(o.Oid, o.Size) } - return nil, 0, err + if err != nil { + // Store for the end, likely only one + globalErr = err + } + retobjs = append(retobjs, ret) } - - if len(objs) != 1 { // Expecting to find one object - return nil, 0, errutil.Error(fmt.Errorf("Object not found: %s", oid)) - } - - return DownloadObject(objs[0]) -} - -// DownloadLegacy attempts to download the object for the given oid using the -// legacy API. -func DownloadLegacy(oid string) (io.ReadCloser, int64, error) { - req, err := NewRequest("GET", oid) - if err != nil { - return nil, 0, errutil.Error(err) - } - - res, obj, err := DoLegacyRequest(req) - if err != nil { - return nil, 0, err - } - httputil.LogTransfer("lfs.download", res) - req, err = obj.NewRequest("download", "GET") - if err != nil { - return nil, 0, errutil.Error(err) - } - - res, err = httputil.DoHttpRequest(req, true) - if err != nil { - return nil, 0, err - } - httputil.LogTransfer("lfs.data.download", res) - - return res.Body, res.ContentLength, nil + return retobjs, globalErr } +// TODO LEGACY API: remove when legacy API removed func DownloadCheck(oid string) (*ObjectResource, error) { req, err := NewRequest("GET", oid) if err != nil { @@ -155,32 +150,12 @@ func DownloadCheck(oid string) (*ObjectResource, error) { return obj, nil } -func DownloadObject(obj *ObjectResource) (io.ReadCloser, int64, error) { - req, err := obj.NewRequest("download", "GET") - if err != nil { - return nil, 0, errutil.Error(err) - } - - res, err := httputil.DoHttpRequest(req, true) - if err != nil { - return nil, 0, errutil.NewRetriableError(err) - } - httputil.LogTransfer("lfs.data.download", res) - - return res.Body, res.ContentLength, nil -} - -func UploadCheck(oidPath string) (*ObjectResource, error) { - oid := filepath.Base(oidPath) - - stat, err := os.Stat(oidPath) - if err != nil { - return nil, errutil.Error(err) - } +// TODO LEGACY API: remove when legacy API removed +func UploadCheck(oid string, size int64) (*ObjectResource, error) { reqObj := &ObjectResource{ Oid: oid, - Size: stat.Size(), + Size: size, } by, err := json.Marshal(reqObj) @@ -204,7 +179,7 @@ func UploadCheck(oidPath string) (*ObjectResource, error) { if err != nil { if errutil.IsAuthError(err) { httputil.SetAuthType(req, res) - return UploadCheck(oidPath) + return UploadCheck(oid, size) } return nil, errutil.NewRetriableError(err) @@ -224,72 +199,3 @@ func UploadCheck(oidPath string) (*ObjectResource, error) { return obj, nil } - -func UploadObject(obj *ObjectResource, reader io.Reader) error { - - req, err := obj.NewRequest("upload", "PUT") - if err != nil { - return errutil.Error(err) - } - - if len(req.Header.Get("Content-Type")) == 0 { - req.Header.Set("Content-Type", "application/octet-stream") - } - - if req.Header.Get("Transfer-Encoding") == "chunked" { - req.TransferEncoding = []string{"chunked"} - } else { - req.Header.Set("Content-Length", strconv.FormatInt(obj.Size, 10)) - } - - req.ContentLength = obj.Size - req.Body = ioutil.NopCloser(reader) - - res, err := httputil.DoHttpRequest(req, true) - if err != nil { - return errutil.NewRetriableError(err) - } - httputil.LogTransfer("lfs.data.upload", res) - - // A status code of 403 likely means that an authentication token for the - // upload has expired. This can be safely retried. - if res.StatusCode == 403 { - return errutil.NewRetriableError(err) - } - - if res.StatusCode > 299 { - return errutil.Errorf(nil, "Invalid status for %s: %d", httputil.TraceHttpReq(req), res.StatusCode) - } - - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - - if _, ok := obj.Rel("verify"); !ok { - return nil - } - - req, err = obj.NewRequest("verify", "POST") - if err != nil { - return errutil.Error(err) - } - - by, err := json.Marshal(obj) - if err != nil { - return errutil.Error(err) - } - - req.Header.Set("Content-Type", MediaType) - req.Header.Set("Content-Length", strconv.Itoa(len(by))) - req.ContentLength = int64(len(by)) - req.Body = ioutil.NopCloser(bytes.NewReader(by)) - res, err = DoRequest(req, true) - if err != nil { - return err - } - - httputil.LogTransfer("lfs.data.verify", res) - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - - return err -} diff --git a/api/download_test.go b/api/download_test.go index 5bebebff..87b9744f 100644 --- a/api/download_test.go +++ b/api/download_test.go @@ -73,55 +73,22 @@ func TestSuccessfulDownload(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != "" { - t.Error("Invalid Accept") - } - - if r.Header.Get("A") != "1" { - t.Error("invalid A") - } - - head := w.Header() - head.Set("Content-Type", "application/octet-stream") - head.Set("Content-Length", "4") - w.WriteHeader(200) - w.Write([]byte("test")) - }) - defer config.Config.ResetConfig() config.Config.SetConfig("lfs.batch", "false") config.Config.SetConfig("lfs.url", server.URL+"/media") - reader, size, err := api.Download("oid", 0) + obj, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: "oid"}, "download") if err != nil { if isDockerConnectionError(err) { return } t.Fatalf("unexpected error: %s", err) } - defer reader.Close() - if size != 4 { - t.Errorf("unexpected size: %d", size) + if obj.Size != 4 { + t.Errorf("unexpected size: %d", obj.Size) } - by, err := ioutil.ReadAll(reader) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if body := string(by); body != "test" { - t.Errorf("unexpected body: %s", body) - } } // nearly identical to TestSuccessfulDownload @@ -212,36 +179,12 @@ func TestSuccessfulDownloadWithRedirects(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != "" { - t.Error("Invalid Accept") - } - - if r.Header.Get("A") != "1" { - t.Error("invalid A") - } - - head := w.Header() - head.Set("Content-Type", "application/octet-stream") - head.Set("Content-Length", "4") - w.WriteHeader(200) - w.Write([]byte("test")) - }) - defer config.Config.ResetConfig() config.Config.SetConfig("lfs.batch", "false") config.Config.SetConfig("lfs.url", server.URL+"/redirect") for _, redirect := range redirectCodes { - reader, size, err := api.Download("oid", 0) + obj, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: "oid"}, "download") if err != nil { if isDockerConnectionError(err) { return @@ -249,19 +192,10 @@ func TestSuccessfulDownloadWithRedirects(t *testing.T) { t.Fatalf("unexpected error for %d status: %s", redirect, err) } - if size != 4 { - t.Errorf("unexpected size for %d status: %d", redirect, size) + if obj.Size != 4 { + t.Errorf("unexpected size for %d status: %d", redirect, obj.Size) } - by, err := ioutil.ReadAll(reader) - reader.Close() - if err != nil { - t.Fatalf("unexpected error for %d status: %s", redirect, err) - } - - if body := string(by); body != "test" { - t.Errorf("unexpected body for %d status: %s", redirect, body) - } } } @@ -323,310 +257,21 @@ func TestSuccessfulDownloadWithAuthorization(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != "" { - t.Error("Invalid Accept") - } - - if r.Header.Get("Authorization") != "custom" { - t.Error("Invalid Authorization") - } - - if r.Header.Get("A") != "1" { - t.Error("invalid A") - } - - head := w.Header() - head.Set("Content-Type", "application/octet-stream") - head.Set("Content-Length", "4") - w.WriteHeader(200) - w.Write([]byte("test")) - }) - defer config.Config.ResetConfig() config.Config.SetConfig("lfs.batch", "false") config.Config.SetConfig("lfs.url", server.URL+"/media") - reader, size, err := api.Download("oid", 0) + obj, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: "oid"}, "download") if err != nil { if isDockerConnectionError(err) { return } t.Fatalf("unexpected error: %s", err) } - defer reader.Close() - if size != 4 { - t.Errorf("unexpected size: %d", size) + if obj.Size != 4 { + t.Errorf("unexpected size: %d", obj.Size) } - by, err := ioutil.ReadAll(reader) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if body := string(by); body != "test" { - t.Errorf("unexpected body: %s", body) - } -} - -// nearly identical to TestSuccessfulDownload -// download is served from a second server -func TestSuccessfulDownloadFromSeparateHost(t *testing.T) { - SetupTestCredentialsFunc() - defer func() { - RestoreCredentialsFunc() - }() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - defer server.Close() - - mux2 := http.NewServeMux() - server2 := httptest.NewServer(mux2) - defer server2.Close() - - tmp := tempdir(t) - defer os.RemoveAll(tmp) - - mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Error("Invalid Accept") - } - - if r.Header.Get("Authorization") != expectedAuth(t, server) { - t.Error("Invalid Authorization") - } - - obj := &api.ObjectResource{ - Oid: "oid", - Size: 4, - Actions: map[string]*api.LinkRelation{ - "download": &api.LinkRelation{ - Href: server2.URL + "/download", - Header: map[string]string{"A": "1"}, - }, - }, - } - - by, err := json.Marshal(obj) - if err != nil { - t.Fatal(err) - } - - head := w.Header() - head.Set("Content-Type", api.MediaType) - head.Set("Content-Length", strconv.Itoa(len(by))) - w.WriteHeader(200) - w.Write(by) - }) - - mux2.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != "" { - t.Error("Invalid Accept") - } - - if r.Header.Get("A") != "1" { - t.Error("invalid A") - } - - head := w.Header() - head.Set("Content-Type", "application/octet-stream") - head.Set("Content-Length", "4") - w.WriteHeader(200) - w.Write([]byte("test")) - }) - - defer config.Config.ResetConfig() - config.Config.SetConfig("lfs.batch", "false") - config.Config.SetConfig("lfs.url", server.URL+"/media") - reader, size, err := api.Download("oid", 0) - if err != nil { - if isDockerConnectionError(err) { - return - } - t.Fatalf("unexpected error: %s", err) - } - defer reader.Close() - - if size != 4 { - t.Errorf("unexpected size: %d", size) - } - - by, err := ioutil.ReadAll(reader) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - - if body := string(by); body != "test" { - t.Errorf("unexpected body: %s", body) - } -} - -// nearly identical to TestSuccessfulDownload -// download is served from a second server -func TestSuccessfulDownloadFromSeparateRedirectedHost(t *testing.T) { - SetupTestCredentialsFunc() - defer func() { - RestoreCredentialsFunc() - }() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - defer server.Close() - - mux2 := http.NewServeMux() - server2 := httptest.NewServer(mux2) - defer server2.Close() - - mux3 := http.NewServeMux() - server3 := httptest.NewServer(mux3) - defer server3.Close() - - tmp := tempdir(t) - defer os.RemoveAll(tmp) - - // all of these should work for GET requests - redirectCodes := []int{301, 302, 303, 307} - redirectIndex := 0 - - mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server 1: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Error("Invalid Accept") - } - - if r.Header.Get("Authorization") != expectedAuth(t, server) { - t.Error("Invalid Authorization") - } - - w.Header().Set("Location", server2.URL+"/media/objects/oid") - w.WriteHeader(redirectCodes[redirectIndex]) - t.Logf("redirect with %d", redirectCodes[redirectIndex]) - redirectIndex += 1 - }) - - mux2.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server 2: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Error("Invalid Accept") - } - - if r.Header.Get("Authorization") != "" { - t.Error("Invalid Authorization") - } - - obj := &api.ObjectResource{ - Oid: "oid", - Size: 4, - Actions: map[string]*api.LinkRelation{ - "download": &api.LinkRelation{ - Href: server3.URL + "/download", - Header: map[string]string{"A": "1"}, - }, - }, - } - - by, err := json.Marshal(obj) - if err != nil { - t.Fatal(err) - } - - head := w.Header() - head.Set("Content-Type", api.MediaType) - head.Set("Content-Length", strconv.Itoa(len(by))) - w.WriteHeader(200) - w.Write(by) - }) - - mux3.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server 3: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != "" { - t.Error("Invalid Accept") - } - - if r.Header.Get("A") != "1" { - t.Error("invalid A") - } - - head := w.Header() - head.Set("Content-Type", "application/octet-stream") - head.Set("Content-Length", "4") - w.WriteHeader(200) - w.Write([]byte("test")) - }) - - defer config.Config.ResetConfig() - config.Config.SetConfig("lfs.batch", "false") - config.Config.SetConfig("lfs.url", server.URL+"/media") - - for _, redirect := range redirectCodes { - reader, size, err := api.Download("oid", 0) - if err != nil { - if isDockerConnectionError(err) { - return - } - t.Fatalf("unexpected error for %d status: %s", redirect, err) - } - - if size != 4 { - t.Errorf("unexpected size for %d status: %d", redirect, size) - } - - by, err := ioutil.ReadAll(reader) - reader.Close() - if err != nil { - t.Fatalf("unexpected error for %d status: %s", redirect, err) - } - - if body := string(by); body != "test" { - t.Errorf("unexpected body for %d status: %s", redirect, body) - } - } } func TestDownloadAPIError(t *testing.T) { @@ -649,7 +294,7 @@ func TestDownloadAPIError(t *testing.T) { defer config.Config.ResetConfig() config.Config.SetConfig("lfs.batch", "false") config.Config.SetConfig("lfs.url", server.URL+"/media") - _, _, err := api.Download("oid", 0) + _, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: "oid"}, "download") if err == nil { t.Fatal("no error?") } @@ -668,84 +313,6 @@ func TestDownloadAPIError(t *testing.T) { } -func TestDownloadStorageError(t *testing.T) { - SetupTestCredentialsFunc() - defer func() { - RestoreCredentialsFunc() - }() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - defer server.Close() - - tmp := tempdir(t) - defer os.RemoveAll(tmp) - - mux.HandleFunc("/media/objects/oid", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - t.Logf("request header: %v", r.Header) - - if r.Method != "GET" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Error("Invalid Accept") - } - - if r.Header.Get("Authorization") != expectedAuth(t, server) { - t.Error("Invalid Authorization") - } - - obj := &api.ObjectResource{ - Oid: "oid", - Size: 4, - Actions: map[string]*api.LinkRelation{ - "download": &api.LinkRelation{ - Href: server.URL + "/download", - Header: map[string]string{"A": "1"}, - }, - }, - } - - by, err := json.Marshal(obj) - if err != nil { - t.Fatal(err) - } - - head := w.Header() - head.Set("Content-Type", api.MediaType) - head.Set("Content-Length", strconv.Itoa(len(by))) - w.WriteHeader(200) - w.Write(by) - }) - - mux.HandleFunc("/download", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - }) - - defer config.Config.ResetConfig() - config.Config.SetConfig("lfs.batch", "false") - config.Config.SetConfig("lfs.url", server.URL+"/media") - _, _, err := api.Download("oid", 0) - if err == nil { - t.Fatal("no error?") - } - - if isDockerConnectionError(err) { - return - } - - if !errutil.IsFatalError(err) { - t.Fatal("should panic") - } - - if err.Error() != fmt.Sprintf(httputil.GetDefaultError(500), server.URL+"/download") { - t.Fatalf("Unexpected error: %s", err.Error()) - } -} - // guards against connection errors that only seem to happen on debian docker // images. func isDockerConnectionError(err error) bool { diff --git a/api/object.go b/api/object.go index c9b7cb12..159e1a0c 100644 --- a/api/object.go +++ b/api/object.go @@ -25,6 +25,7 @@ type ObjectResource struct { Error *ObjectError `json:"error,omitempty"` } +// TODO LEGACY API: remove when legacy API removed func (o *ObjectResource) NewRequest(relation, method string) (*http.Request, error) { rel, ok := o.Rel(relation) if !ok { diff --git a/api/upload_test.go b/api/upload_test.go index ad7d5303..750846df 100644 --- a/api/upload_test.go +++ b/api/upload_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "os" + "path/filepath" "strconv" "testing" @@ -17,7 +18,6 @@ import ( "github.com/github/git-lfs/errutil" "github.com/github/git-lfs/httputil" "github.com/github/git-lfs/lfs" - "github.com/github/git-lfs/progress" "github.com/github/git-lfs/test" ) @@ -38,8 +38,6 @@ func TestExistingUpload(t *testing.T) { defer os.RemoveAll(tmp) postCalled := false - putCalled := false - verifyCalled := false mux.HandleFunc("/media/objects", func(w http.ResponseWriter, r *http.Request) { t.Logf("Server: %s %s", r.Method, r.URL) @@ -103,18 +101,6 @@ func TestExistingUpload(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - putCalled = true - w.WriteHeader(200) - }) - - mux.HandleFunc("/verify", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - verifyCalled = true - w.WriteHeader(200) - }) - defer config.Config.ResetConfig() config.Config.SetConfig("lfs.url", server.URL+"/media") @@ -123,7 +109,9 @@ func TestExistingUpload(t *testing.T) { t.Fatal(err) } - o, err := api.UploadCheck(oidPath) + oid := filepath.Base(oidPath) + stat, _ := os.Stat(oidPath) + o, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: oid, Size: stat.Size()}, "upload") if err != nil { if isDockerConnectionError(err) { return @@ -138,13 +126,6 @@ func TestExistingUpload(t *testing.T) { t.Errorf("POST not called") } - if putCalled { - t.Errorf("PUT not skipped") - } - - if verifyCalled { - t.Errorf("verify not skipped") - } } func TestUploadWithRedirect(t *testing.T) { @@ -254,7 +235,9 @@ func TestUploadWithRedirect(t *testing.T) { t.Fatal(err) } - obj, err := api.UploadCheck(oidPath) + oid := filepath.Base(oidPath) + stat, _ := os.Stat(oidPath) + o, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: oid, Size: stat.Size()}, "upload") if err != nil { if isDockerConnectionError(err) { return @@ -262,7 +245,7 @@ func TestUploadWithRedirect(t *testing.T) { t.Fatal(err) } - if obj != nil { + if o != nil { t.Fatal("Received an object") } } @@ -284,7 +267,6 @@ func TestSuccessfulUploadWithVerify(t *testing.T) { defer os.RemoveAll(tmp) postCalled := false - putCalled := false verifyCalled := false mux.HandleFunc("/media/objects", func(w http.ResponseWriter, r *http.Request) { @@ -349,46 +331,6 @@ func TestSuccessfulUploadWithVerify(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - - if r.Method != "PUT" { - w.WriteHeader(405) - return - } - - if r.Header.Get("A") != "1" { - t.Error("Invalid A") - } - - if r.Header.Get("Content-Type") != "application/octet-stream" { - t.Error("Invalid Content-Type") - } - - if r.Header.Get("Content-Length") != "4" { - t.Error("Invalid Content-Length") - } - - if r.Header.Get("Transfer-Encoding") != "" { - t.Fatal("Transfer-Encoding is set") - } - - by, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Error(err) - } - - t.Logf("request header: %v", r.Header) - t.Logf("request body: %s", string(by)) - - if str := string(by); str != "test" { - t.Errorf("unexpected body: %s", str) - } - - putCalled = true - w.WriteHeader(200) - }) - mux.HandleFunc("/verify", func(w http.ResponseWriter, r *http.Request) { t.Logf("Server: %s %s", r.Method, r.URL) @@ -435,193 +377,25 @@ func TestSuccessfulUploadWithVerify(t *testing.T) { t.Fatal(err) } - // stores callbacks - calls := make([][]int64, 0, 5) - cb := func(total int64, written int64, current int) error { - calls = append(calls, []int64{total, written}) - return nil - } - - obj, err := api.UploadCheck(oidPath) + oid := filepath.Base(oidPath) + stat, _ := os.Stat(oidPath) + o, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: oid, Size: stat.Size()}, "upload") if err != nil { if isDockerConnectionError(err) { return } t.Fatal(err) } - err = uploadObject(obj, cb) - if err != nil { - t.Fatal(err) - } + api.VerifyUpload(o) if !postCalled { t.Errorf("POST not called") } - if !putCalled { - t.Errorf("PUT not called") - } - if !verifyCalled { t.Errorf("verify not called") } - t.Logf("CopyCallback: %v", calls) - - if len(calls) < 1 { - t.Errorf("CopyCallback was not used") - } - - lastCall := calls[len(calls)-1] - if lastCall[0] != 4 || lastCall[1] != 4 { - t.Errorf("Last CopyCallback call should be the total") - } -} - -func TestSuccessfulUploadWithoutVerify(t *testing.T) { - SetupTestCredentialsFunc() - repo := test.NewRepo(t) - repo.Pushd() - defer func() { - repo.Popd() - repo.Cleanup() - RestoreCredentialsFunc() - }() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - tmp := tempdir(t) - defer server.Close() - defer os.RemoveAll(tmp) - - postCalled := false - putCalled := false - - mux.HandleFunc("/media/objects", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - - if r.Method != "POST" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Errorf("Invalid Accept") - } - - if r.Header.Get("Content-Type") != api.MediaType { - t.Errorf("Invalid Content-Type") - } - - buf := &bytes.Buffer{} - tee := io.TeeReader(r.Body, buf) - reqObj := &api.ObjectResource{} - err := json.NewDecoder(tee).Decode(reqObj) - t.Logf("request header: %v", r.Header) - t.Logf("request body: %s", buf.String()) - if err != nil { - t.Fatal(err) - } - - if reqObj.Oid != "988881adc9fc3655077dc2d4d757d480b5ea0e11" { - t.Errorf("invalid oid from request: %s", reqObj.Oid) - } - - if reqObj.Size != 4 { - t.Errorf("invalid size from request: %d", reqObj.Size) - } - - obj := &api.ObjectResource{ - Oid: reqObj.Oid, - Size: reqObj.Size, - Actions: map[string]*api.LinkRelation{ - "upload": &api.LinkRelation{ - Href: server.URL + "/upload", - Header: map[string]string{"A": "1"}, - }, - }, - } - - by, err := json.Marshal(obj) - if err != nil { - t.Fatal(err) - } - - postCalled = true - head := w.Header() - head.Set("Content-Type", api.MediaType) - head.Set("Content-Length", strconv.Itoa(len(by))) - w.WriteHeader(202) - w.Write(by) - }) - - mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - - if r.Method != "PUT" { - w.WriteHeader(405) - return - } - - if a := r.Header.Get("A"); a != "1" { - t.Errorf("Invalid A: %s", a) - } - - if r.Header.Get("Content-Type") != "application/octet-stream" { - t.Error("Invalid Content-Type") - } - - if r.Header.Get("Content-Length") != "4" { - t.Error("Invalid Content-Length") - } - - if r.Header.Get("Transfer-Encoding") != "" { - t.Fatal("Transfer-Encoding is set") - } - - by, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Error(err) - } - - t.Logf("request header: %v", r.Header) - t.Logf("request body: %s", string(by)) - - if str := string(by); str != "test" { - t.Errorf("unexpected body: %s", str) - } - - putCalled = true - w.WriteHeader(200) - }) - - defer config.Config.ResetConfig() - config.Config.SetConfig("lfs.url", server.URL+"/media") - - oidPath, _ := lfs.LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") - if err := ioutil.WriteFile(oidPath, []byte("test"), 0744); err != nil { - t.Fatal(err) - } - - obj, err := api.UploadCheck(oidPath) - if err != nil { - if isDockerConnectionError(err) { - return - } - t.Fatal(err) - } - err = uploadObject(obj, nil) - if err != nil { - t.Fatal(err) - } - - if !postCalled { - t.Errorf("POST not called") - } - - if !putCalled { - t.Errorf("PUT not called") - } } func TestUploadApiError(t *testing.T) { @@ -655,7 +429,9 @@ func TestUploadApiError(t *testing.T) { t.Fatal(err) } - _, err := api.UploadCheck(oidPath) + oid := filepath.Base(oidPath) + stat, _ := os.Stat(oidPath) + _, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: oid, Size: stat.Size()}, "upload") if err == nil { t.Fatal(err) } @@ -677,129 +453,6 @@ func TestUploadApiError(t *testing.T) { } } -func TestUploadStorageError(t *testing.T) { - SetupTestCredentialsFunc() - repo := test.NewRepo(t) - repo.Pushd() - defer func() { - repo.Popd() - repo.Cleanup() - RestoreCredentialsFunc() - }() - - mux := http.NewServeMux() - server := httptest.NewServer(mux) - tmp := tempdir(t) - defer server.Close() - defer os.RemoveAll(tmp) - - postCalled := false - putCalled := false - - mux.HandleFunc("/media/objects", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - - if r.Method != "POST" { - w.WriteHeader(405) - return - } - - if r.Header.Get("Accept") != api.MediaType { - t.Errorf("Invalid Accept") - } - - if r.Header.Get("Content-Type") != api.MediaType { - t.Errorf("Invalid Content-Type") - } - - buf := &bytes.Buffer{} - tee := io.TeeReader(r.Body, buf) - reqObj := &api.ObjectResource{} - err := json.NewDecoder(tee).Decode(reqObj) - t.Logf("request header: %v", r.Header) - t.Logf("request body: %s", buf.String()) - if err != nil { - t.Fatal(err) - } - - if reqObj.Oid != "988881adc9fc3655077dc2d4d757d480b5ea0e11" { - t.Errorf("invalid oid from request: %s", reqObj.Oid) - } - - if reqObj.Size != 4 { - t.Errorf("invalid size from request: %d", reqObj.Size) - } - - obj := &api.ObjectResource{ - Oid: reqObj.Oid, - Size: reqObj.Size, - Actions: map[string]*api.LinkRelation{ - "upload": &api.LinkRelation{ - Href: server.URL + "/upload", - Header: map[string]string{"A": "1"}, - }, - "verify": &api.LinkRelation{ - Href: server.URL + "/verify", - Header: map[string]string{"B": "2"}, - }, - }, - } - - by, err := json.Marshal(obj) - if err != nil { - t.Fatal(err) - } - - postCalled = true - head := w.Header() - head.Set("Content-Type", api.MediaType) - head.Set("Content-Length", strconv.Itoa(len(by))) - w.WriteHeader(202) - w.Write(by) - }) - - mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { - putCalled = true - w.WriteHeader(404) - }) - - defer config.Config.ResetConfig() - config.Config.SetConfig("lfs.url", server.URL+"/media") - - oidPath, _ := lfs.LocalMediaPath("988881adc9fc3655077dc2d4d757d480b5ea0e11") - if err := ioutil.WriteFile(oidPath, []byte("test"), 0744); err != nil { - t.Fatal(err) - } - - obj, err := api.UploadCheck(oidPath) - if err != nil { - if isDockerConnectionError(err) { - return - } - t.Fatal(err) - } - err = uploadObject(obj, nil) - if err == nil { - t.Fatal("Expected an error") - } - - if errutil.IsFatalError(err) { - t.Fatal("should not panic") - } - - if err.Error() != fmt.Sprintf(httputil.GetDefaultError(404), server.URL+"/upload") { - t.Fatalf("Unexpected error: %s", err.Error()) - } - - if !postCalled { - t.Errorf("POST not called") - } - - if !putCalled { - t.Errorf("PUT not called") - } -} - func TestUploadVerifyError(t *testing.T) { SetupTestCredentialsFunc() repo := test.NewRepo(t) @@ -817,7 +470,6 @@ func TestUploadVerifyError(t *testing.T) { defer os.RemoveAll(tmp) postCalled := false - putCalled := false verifyCalled := false mux.HandleFunc("/media/objects", func(w http.ResponseWriter, r *http.Request) { @@ -882,38 +534,6 @@ func TestUploadVerifyError(t *testing.T) { w.Write(by) }) - mux.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { - t.Logf("Server: %s %s", r.Method, r.URL) - - if r.Method != "PUT" { - w.WriteHeader(405) - return - } - - if r.Header.Get("A") != "1" { - t.Error("Invalid A") - } - - if r.Header.Get("Content-Type") != "application/octet-stream" { - t.Error("Invalid Content-Type") - } - - by, err := ioutil.ReadAll(r.Body) - if err != nil { - t.Error(err) - } - - t.Logf("request header: %v", r.Header) - t.Logf("request body: %s", string(by)) - - if str := string(by); str != "test" { - t.Errorf("unexpected body: %s", str) - } - - putCalled = true - w.WriteHeader(200) - }) - mux.HandleFunc("/verify", func(w http.ResponseWriter, r *http.Request) { verifyCalled = true w.WriteHeader(404) @@ -927,16 +547,18 @@ func TestUploadVerifyError(t *testing.T) { t.Fatal(err) } - obj, err := api.UploadCheck(oidPath) + oid := filepath.Base(oidPath) + stat, _ := os.Stat(oidPath) + o, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: oid, Size: stat.Size()}, "upload") if err != nil { if isDockerConnectionError(err) { return } t.Fatal(err) } - err = uploadObject(obj, nil) + err = api.VerifyUpload(o) if err == nil { - t.Fatal("Expected an error") + t.Fatal("verify should fail") } if errutil.IsFatalError(err) { @@ -951,34 +573,8 @@ func TestUploadVerifyError(t *testing.T) { t.Errorf("POST not called") } - if !putCalled { - t.Errorf("PUT not called") - } - if !verifyCalled { t.Errorf("verify not called") } } - -func uploadObject(o *api.ObjectResource, cb progress.CopyCallback) error { - path, err := lfs.LocalMediaPath(o.Oid) - if err != nil { - return errutil.Error(err) - } - - file, err := os.Open(path) - if err != nil { - return errutil.Error(err) - } - defer file.Close() - - reader := &progress.CallbackReader{ - C: cb, - TotalSize: o.Size, - Reader: file, - } - - return api.UploadObject(o, reader) - -} diff --git a/api/verify.go b/api/verify.go new file mode 100644 index 00000000..63f5262a --- /dev/null +++ b/api/verify.go @@ -0,0 +1,46 @@ +package api + +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + "strconv" + + "github.com/github/git-lfs/errutil" + "github.com/github/git-lfs/httputil" +) + +// VerifyUpload calls the "verify" API link relation on obj if it exists +func VerifyUpload(obj *ObjectResource) error { + + // Do we need to do verify? + if _, ok := obj.Rel("verify"); !ok { + return nil + } + + req, err := obj.NewRequest("verify", "POST") + if err != nil { + return errutil.Error(err) + } + + by, err := json.Marshal(obj) + if err != nil { + return errutil.Error(err) + } + + req.Header.Set("Content-Type", MediaType) + req.Header.Set("Content-Length", strconv.Itoa(len(by))) + req.ContentLength = int64(len(by)) + req.Body = ioutil.NopCloser(bytes.NewReader(by)) + res, err := DoRequest(req, true) + if err != nil { + return err + } + + httputil.LogTransfer("lfs.data.verify", res) + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + return err +} diff --git a/commands/command_prune.go b/commands/command_prune.go index fcaa8c13..9f95c2a2 100644 --- a/commands/command_prune.go +++ b/commands/command_prune.go @@ -122,7 +122,7 @@ func prune(verifyRemote, dryRun, verbose bool) { if verifyRemote { config.Config.CurrentRemote = config.Config.FetchPruneConfig().PruneRemoteName // build queue now, no estimates or progress output - verifyQueue = lfs.NewDownloadCheckQueue(0, 0, true) + verifyQueue = lfs.NewDownloadCheckQueue(0, 0) verifiedObjects = lfs.NewStringSetWithCapacity(len(localObjects) / 2) } for _, file := range localObjects { @@ -136,7 +136,7 @@ func prune(verifyRemote, dryRun, verbose bool) { if verifyRemote { tracerx.Printf("VERIFYING: %v", file.Oid) pointer := lfs.NewPointer(file.Oid, file.Size, nil) - verifyQueue.Add(lfs.NewDownloadCheckable(&lfs.WrappedPointer{Pointer: pointer})) + verifyQueue.Add(lfs.NewDownloadable(&lfs.WrappedPointer{Pointer: pointer})) } } } diff --git a/commands/uploader.go b/commands/uploader.go index 2c793354..43f24b1a 100644 --- a/commands/uploader.go +++ b/commands/uploader.go @@ -88,9 +88,9 @@ func (c *uploadContext) checkMissing(missing []*lfs.WrappedPointer, missingSize return } - checkQueue := lfs.NewDownloadCheckQueue(numMissing, missingSize, true) + checkQueue := lfs.NewDownloadCheckQueue(numMissing, missingSize) for _, p := range missing { - checkQueue.Add(lfs.NewDownloadCheckable(p)) + checkQueue.Add(lfs.NewDownloadable(p)) } // this channel is filled with oids for which Check() succeeded & Transfer() was called diff --git a/httputil/request.go b/httputil/request.go index 1d9aa125..25a4c25f 100644 --- a/httputil/request.go +++ b/httputil/request.go @@ -62,7 +62,6 @@ func doHttpRequest(req *http.Request, creds auth.Creds) (*http.Response, error) err = errutil.Error(err) } } else { - // TODO(sinbad) stop handling the response here, separate response processing to api package err = handleResponse(res, creds) } diff --git a/lfs/batcher.go b/lfs/batcher.go index c0ad8c98..e90263e7 100644 --- a/lfs/batcher.go +++ b/lfs/batcher.go @@ -11,16 +11,16 @@ import "sync/atomic" type Batcher struct { exited uint32 batchSize int - input chan Transferable - batchReady chan []Transferable + input chan interface{} + batchReady chan []interface{} } // NewBatcher creates a Batcher with the batchSize. func NewBatcher(batchSize int) *Batcher { b := &Batcher{ batchSize: batchSize, - input: make(chan Transferable, batchSize), - batchReady: make(chan []Transferable), + input: make(chan interface{}, batchSize), + batchReady: make(chan []interface{}), } go b.acceptInput() @@ -29,9 +29,9 @@ func NewBatcher(batchSize int) *Batcher { // Add adds an item to the batcher. Add is safe to call from multiple // goroutines. -func (b *Batcher) Add(t Transferable) { +func (b *Batcher) Add(t interface{}) { if atomic.CompareAndSwapUint32(&b.exited, 1, 0) { - b.input = make(chan Transferable, b.batchSize) + b.input = make(chan interface{}, b.batchSize) go b.acceptInput() } @@ -40,7 +40,7 @@ func (b *Batcher) Add(t Transferable) { // Next will wait for the one of the above batch triggers to occur and return // the accumulated batch. -func (b *Batcher) Next() []Transferable { +func (b *Batcher) Next() []interface{} { return <-b.batchReady } @@ -58,7 +58,7 @@ func (b *Batcher) acceptInput() { exit := false for { - batch := make([]Transferable, 0, b.batchSize) + batch := make([]interface{}, 0, b.batchSize) Loop: for len(batch) < b.batchSize { t, ok := <-b.input diff --git a/lfs/download_queue.go b/lfs/download_queue.go index b9fb8475..ebc43975 100644 --- a/lfs/download_queue.go +++ b/lfs/download_queue.go @@ -2,78 +2,55 @@ package lfs import ( "github.com/github/git-lfs/api" - "github.com/github/git-lfs/errutil" - "github.com/github/git-lfs/progress" + "github.com/github/git-lfs/transfer" ) -// The ability to check that a file can be downloaded -type DownloadCheckable struct { - Pointer *WrappedPointer +type Downloadable struct { + pointer *WrappedPointer object *api.ObjectResource } -func NewDownloadCheckable(p *WrappedPointer) *DownloadCheckable { - return &DownloadCheckable{Pointer: p} -} - -func (d *DownloadCheckable) Check() (*api.ObjectResource, error) { - return api.DownloadCheck(d.Pointer.Oid) -} - -func (d *DownloadCheckable) Transfer(cb progress.CopyCallback) error { - // just report completion of check but don't do anything - cb(d.Size(), d.Size(), int(d.Size())) - return nil -} - -func (d *DownloadCheckable) Object() *api.ObjectResource { +func (d *Downloadable) Object() *api.ObjectResource { return d.object } -func (d *DownloadCheckable) Oid() string { - return d.Pointer.Oid +func (d *Downloadable) Oid() string { + return d.pointer.Oid } -func (d *DownloadCheckable) Size() int64 { - return d.Pointer.Size +func (d *Downloadable) Size() int64 { + return d.pointer.Size } -func (d *DownloadCheckable) Name() string { - return d.Pointer.Name +func (d *Downloadable) Name() string { + return d.pointer.Name } -func (d *DownloadCheckable) SetObject(o *api.ObjectResource) { +func (d *Downloadable) Path() string { + p, _ := LocalMediaPath(d.pointer.Oid) + return p +} + +func (d *Downloadable) SetObject(o *api.ObjectResource) { d.object = o } -// NewDownloadCheckQueue builds a checking queue, allowing `workers` concurrent check operations. -func NewDownloadCheckQueue(files int, size int64, dryRun bool) *TransferQueue { - q := newTransferQueue(files, size, dryRun) - // API operation is still download, but it will only perform the API call (check) - q.transferKind = "download" - return q -} - -// The ability to actually download -type Downloadable struct { - *DownloadCheckable +// TODO remove this legacy method & only support batch +func (d *Downloadable) LegacyCheck() (*api.ObjectResource, error) { + return api.DownloadCheck(d.pointer.Oid) } func NewDownloadable(p *WrappedPointer) *Downloadable { - return &Downloadable{DownloadCheckable: NewDownloadCheckable(p)} + return &Downloadable{pointer: p} } -func (d *Downloadable) Transfer(cb progress.CopyCallback) error { - err := PointerSmudgeObject(d.Pointer.Pointer, d.object, cb) - if err != nil { - return errutil.Error(err) - } - return nil +// NewDownloadCheckQueue builds a checking queue, checks that objects are there but doesn't download +func NewDownloadCheckQueue(files int, size int64) *TransferQueue { + // Always dry run + return newTransferQueue(files, size, true, transfer.NewDownloadAdapter(transfer.BasicAdapterName)) } -// NewDownloadQueue builds a DownloadQueue, allowing `workers` concurrent downloads. +// NewDownloadQueue builds a DownloadQueue, allowing concurrent downloads. func NewDownloadQueue(files int, size int64, dryRun bool) *TransferQueue { - q := newTransferQueue(files, size, dryRun) - q.transferKind = "download" - return q + return newTransferQueue(files, size, dryRun, transfer.NewDownloadAdapter(transfer.BasicAdapterName)) } diff --git a/lfs/pointer_clean.go b/lfs/pointer_clean.go index 11723a6f..ba368a03 100644 --- a/lfs/pointer_clean.go +++ b/lfs/pointer_clean.go @@ -10,6 +10,7 @@ import ( "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" "github.com/github/git-lfs/progress" + "github.com/github/git-lfs/tools" ) type cleanedAsset struct { @@ -82,7 +83,7 @@ func copyToTemp(reader io.Reader, fileSize int64, cb progress.CopyCallback) (oid } multi := io.MultiReader(bytes.NewReader(by), reader) - size, err = CopyWithCallback(writer, multi, fileSize, cb) + size, err = tools.CopyWithCallback(writer, multi, fileSize, cb) if err != nil { return diff --git a/lfs/pointer_smudge.go b/lfs/pointer_smudge.go index 6e0a2da9..b49735cf 100644 --- a/lfs/pointer_smudge.go +++ b/lfs/pointer_smudge.go @@ -1,16 +1,15 @@ package lfs import ( - "crypto/sha256" - "encoding/hex" "fmt" - "hash" "io" - "io/ioutil" "os" "path/filepath" "github.com/cheggaaa/pb" + "github.com/github/git-lfs/tools" + "github.com/github/git-lfs/transfer" + "github.com/github/git-lfs/api" "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" @@ -74,142 +73,39 @@ func PointerSmudge(writer io.Writer, ptr *Pointer, workingfile string, download return nil } -// PointerSmudgeObject uses a Pointer and ObjectResource to download the object to the -// media directory. It does not write the file to the working directory. -func PointerSmudgeObject(ptr *Pointer, obj *api.ObjectResource, cb progress.CopyCallback) error { - mediafile, err := LocalMediaPath(obj.Oid) - if err != nil { - return err - } - - stat, statErr := os.Stat(mediafile) - if statErr == nil && stat != nil { - fileSize := stat.Size() - if fileSize == 0 || fileSize != obj.Size { - tracerx.Printf("Removing %s, size %d is invalid", mediafile, fileSize) - os.RemoveAll(mediafile) - stat = nil - } - } - - if statErr != nil || stat == nil { - err := downloadObject(ptr, obj, mediafile, cb) - - if err != nil { - return errutil.NewSmudgeError(err, obj.Oid, mediafile) - } - } - - return nil -} - -func downloadObject(ptr *Pointer, obj *api.ObjectResource, mediafile string, cb progress.CopyCallback) error { - reader, size, err := api.DownloadObject(obj) - if reader != nil { - defer reader.Close() - } - - if err != nil { - return errutil.Errorf(err, "Error downloading %s", mediafile) - } - - if ptr.Size == 0 { - ptr.Size = size - } - - if err := bufferDownloadedFile(mediafile, reader, ptr.Size, cb); err != nil { - return errutil.Errorf(err, "Error buffering media file: %s", err) - } - - return nil -} - func downloadFile(writer io.Writer, ptr *Pointer, workingfile, mediafile string, cb progress.CopyCallback) error { fmt.Fprintf(os.Stderr, "Downloading %s (%s)\n", workingfile, pb.FormatBytes(ptr.Size)) - reader, size, err := api.Download(filepath.Base(mediafile), ptr.Size) - if reader != nil { - defer reader.Close() - } + obj, err := api.BatchOrLegacySingle(&api.ObjectResource{Oid: ptr.Oid, Size: ptr.Size}, "download") if err != nil { return errutil.Errorf(err, "Error downloading %s: %s", filepath.Base(mediafile), err) } if ptr.Size == 0 { - ptr.Size = size + ptr.Size = obj.Size } - if err := bufferDownloadedFile(mediafile, reader, ptr.Size, cb); err != nil { - return errutil.Errorf(err, "Error buffering media file: %s", err) + adapter := transfer.NewDownloadAdapter(transfer.BasicAdapterName) + var tcb transfer.TransferProgressCallback + if cb != nil { + tcb = func(name string, totalSize, readSoFar int64, readSinceLast int) error { + return cb(totalSize, readSoFar, readSinceLast) + } + } + // Single download + adapterResultChan := make(chan transfer.TransferResult, 1) + adapter.Begin(1, tcb, adapterResultChan) + adapter.Add(transfer.NewTransfer(filepath.Base(workingfile), obj, mediafile)) + adapter.End() + res := <-adapterResultChan + + if res.Error != nil { + return errutil.Errorf(err, "Error buffering media file: %s", res.Error) } return readLocalFile(writer, ptr, mediafile, workingfile, nil) } -// Writes the content of reader to filename atomically by writing to a temp file -// first, and confirming the content SHA-256 is valid. This is basically a copy -// of atomic.WriteFile() at: -// -// https://github.com/natefinch/atomic/blob/a62ce929ffcc871a51e98c6eba7b20321e3ed62d/atomic.go#L12-L17 -// -// filename - Absolute path to a file to write, with the filename a 64 character -// SHA-256 hex signature. -// reader - Any io.Reader -// size - Expected byte size of the content. Used for the progress bar in -// the optional CopyCallback. -// cb - Optional CopyCallback object for providing download progress to -// external Git LFS tools. -func bufferDownloadedFile(filename string, reader io.Reader, size int64, cb progress.CopyCallback) error { - oid := filepath.Base(filename) - f, err := ioutil.TempFile(LocalObjectTempDir(), oid+"-") - if err != nil { - return fmt.Errorf("cannot create temp file: %v", err) - } - - defer func() { - if err != nil { - // Don't leave the temp file lying around on error. - _ = os.Remove(f.Name()) // yes, ignore the error, not much we can do about it. - } - }() - - hasher := newHashingReader(reader) - - // ensure we always close f. Note that this does not conflict with the - // close below, as close is idempotent. - defer f.Close() - name := f.Name() - written, err := CopyWithCallback(f, hasher, size, cb) - if err != nil { - return fmt.Errorf("cannot write data to tempfile %q: %v", name, err) - } - if err := f.Close(); err != nil { - return fmt.Errorf("can't close tempfile %q: %v", name, err) - } - - if actual := hasher.Hash(); actual != oid { - return fmt.Errorf("Expected OID %s, got %s after %d bytes written", oid, actual, written) - } - - // get the file mode from the original file and use that for the replacement - // file, too. - info, err := os.Stat(filename) - if os.IsNotExist(err) { - // no original file - } else if err != nil { - return err - } else { - if err := os.Chmod(name, info.Mode()); err != nil { - return fmt.Errorf("can't set filemode on tempfile %q: %v", name, err) - } - } - - if err := os.Rename(name, filename); err != nil { - return fmt.Errorf("cannot replace %q with tempfile %q: %v", filename, name, err) - } - return nil -} - func readLocalFile(writer io.Writer, ptr *Pointer, mediafile string, workingfile string, cb progress.CopyCallback) error { reader, err := os.Open(mediafile) if err != nil { @@ -286,35 +182,10 @@ func readLocalFile(writer io.Writer, ptr *Pointer, mediafile string, workingfile defer reader.Close() } - _, err = CopyWithCallback(writer, reader, ptr.Size, cb) + _, err = tools.CopyWithCallback(writer, reader, ptr.Size, cb) if err != nil { return errutil.Errorf(err, "Error reading from media file: %s", err) } return nil } - -type hashingReader struct { - reader io.Reader - hasher hash.Hash -} - -func newHashingReader(r io.Reader) *hashingReader { - return &hashingReader{r, sha256.New()} -} - -func (r *hashingReader) Hash() string { - return hex.EncodeToString(r.hasher.Sum(nil)) -} - -func (r *hashingReader) Read(b []byte) (int, error) { - w, err := r.reader.Read(b) - if err == nil || err == io.EOF { - _, e := r.hasher.Write(b[0:w]) - if e != nil && err == nil { - return w, e - } - } - - return w, err -} diff --git a/lfs/transfer_queue.go b/lfs/transfer_queue.go index a4d580d6..1b2bb8b8 100644 --- a/lfs/transfer_queue.go +++ b/lfs/transfer_queue.go @@ -9,6 +9,7 @@ import ( "github.com/github/git-lfs/errutil" "github.com/github/git-lfs/git" "github.com/github/git-lfs/progress" + "github.com/github/git-lfs/transfer" "github.com/rubyist/tracerx" ) @@ -17,45 +18,52 @@ const ( ) type Transferable interface { - Check() (*api.ObjectResource, error) - Transfer(progress.CopyCallback) error - Object() *api.ObjectResource Oid() string Size() int64 Name() string + Path() string + Object() *api.ObjectResource SetObject(*api.ObjectResource) + // Legacy API check - TODO remove this and only support batch + LegacyCheck() (*api.ObjectResource, error) } -// TransferQueue provides a queue that will allow concurrent transfers. +// TransferQueue organises the wider process of uploading and downloading, +// including calling the API, passing the actual transfer request to transfer +// adapters, and dealing with progress, errors and retries type TransferQueue struct { - retrying uint32 - meter *progress.ProgressMeter - workers int // Number of transfer workers to spawn - transferKind string - errors []error - transferables map[string]Transferable - retries []Transferable - batcher *Batcher - apic chan Transferable // Channel for processing individual API requests - transferc chan Transferable // Channel for processing transfers - retriesc chan Transferable // Channel for processing retries - errorc chan error // Channel for processing errors - watchers []chan string - trMutex *sync.Mutex - errorwait sync.WaitGroup - retrywait sync.WaitGroup - wait sync.WaitGroup + adapter transfer.TransferAdapter + adapterInProgress bool + adapterResultChan chan transfer.TransferResult + adapterInitMutex sync.Mutex + dryRun bool + retrying uint32 + meter *progress.ProgressMeter + errors []error + transferables map[string]Transferable + retries []Transferable + batcher *Batcher + apic chan Transferable // Channel for processing individual API requests + retriesc chan Transferable // Channel for processing retries + errorc chan error // Channel for processing errors + watchers []chan string + trMutex *sync.Mutex + errorwait sync.WaitGroup + retrywait sync.WaitGroup + wait sync.WaitGroup // Incremented on Add(), decremented on transfer complete or skip + oldApiWorkers int // Number of non-batch API workers to spawn (deprecated) } -// newTransferQueue builds a TransferQueue, allowing `workers` concurrent transfers. -func newTransferQueue(files int, size int64, dryRun bool) *TransferQueue { +// newTransferQueue builds a TransferQueue, direction and underlying mechanism determined by adapter +func newTransferQueue(files int, size int64, dryRun bool, adapter transfer.TransferAdapter) *TransferQueue { q := &TransferQueue{ + adapter: adapter, + dryRun: dryRun, meter: progress.NewProgressMeter(files, size, dryRun, config.Config.Getenv("GIT_LFS_PROGRESS")), apic: make(chan Transferable, batchSize), - transferc: make(chan Transferable, batchSize), retriesc: make(chan Transferable, batchSize), errorc: make(chan error), - workers: config.Config.ConcurrentTransfers(), + oldApiWorkers: config.Config.ConcurrentTransfers(), transferables: make(map[string]Transferable), trMutex: &sync.Mutex{}, } @@ -83,10 +91,88 @@ func (q *TransferQueue) Add(t Transferable) { q.apic <- t } +func (q *TransferQueue) addToAdapter(t Transferable) { + + tr := transfer.NewTransfer(t.Name(), t.Object(), t.Path()) + + if q.dryRun { + // Don't actually transfer + res := transfer.TransferResult{tr, nil} + q.handleTransferResult(res) + return + } + q.ensureAdapterBegun() + q.adapter.Add(tr) +} + func (q *TransferQueue) Skip(size int64) { q.meter.Skip(size) } +func (q *TransferQueue) transferKind() string { + if q.adapter.Direction() == transfer.Download { + return "download" + } else { + return "upload" + } +} + +func (q *TransferQueue) ensureAdapterBegun() { + q.adapterInitMutex.Lock() + defer q.adapterInitMutex.Unlock() + + if q.adapterInProgress { + return + } + + adapterResultChan := make(chan transfer.TransferResult, 20) + + // Progress callback - receives byte updates + cb := func(name string, total, read int64, current int) error { + q.meter.TransferBytes(q.transferKind(), name, read, total, current) + return nil + } + + tracerx.Printf("tq: starting transfer adapter %q", q.adapter.Name()) + q.adapter.Begin(config.Config.ConcurrentTransfers(), cb, adapterResultChan) + q.adapterInProgress = true + + // Collector for completed transfers + // q.wait.Done() in handleTransferResult is enough to know when this is complete for all transfers + go func() { + for res := range adapterResultChan { + q.handleTransferResult(res) + } + }() + +} + +func (q *TransferQueue) handleTransferResult(res transfer.TransferResult) { + if res.Error != nil { + if q.canRetry(res.Error) { + tracerx.Printf("tq: retrying object %s", res.Transfer.Object.Oid) + q.trMutex.Lock() + t, ok := q.transferables[res.Transfer.Object.Oid] + q.trMutex.Unlock() + if ok { + q.retry(t) + } else { + q.errorc <- res.Error + } + } else { + q.errorc <- res.Error + } + } else { + oid := res.Transfer.Object.Oid + for _, c := range q.watchers { + c <- oid + } + } + q.meter.FinishTransfer(res.Transfer.Name) + q.wait.Done() + +} + // Wait waits for the queue to finish processing all transfers. Once Wait is // called, Add will no longer add transferables to the queue. Any failed // transfers will be automatically retried once. @@ -116,7 +202,10 @@ func (q *TransferQueue) Wait() { atomic.StoreUint32(&q.retrying, 0) close(q.apic) - close(q.transferc) + if q.adapterInProgress { + q.adapter.End() + q.adapterInProgress = false + } close(q.errorc) for _, watcher := range q.watchers { @@ -139,9 +228,10 @@ func (q *TransferQueue) Watch() chan string { // a POST call for each object, feeding the results to the transfer workers. // If configured, the object transfers can still happen concurrently, the // sequential nature here is only for the meta POST calls. +// TODO LEGACY API: remove when legacy API removed func (q *TransferQueue) individualApiRoutine(apiWaiter chan interface{}) { for t := range q.apic { - obj, err := t.Check() + obj, err := t.LegacyCheck() if err != nil { if q.canRetry(err) { q.retry(t) @@ -163,7 +253,7 @@ func (q *TransferQueue) individualApiRoutine(apiWaiter chan interface{}) { if obj != nil { t.SetObject(obj) q.meter.Add(t.Name()) - q.transferc <- t + q.addToAdapter(t) } else { q.Skip(t.Size()) q.wait.Done() @@ -174,13 +264,14 @@ func (q *TransferQueue) individualApiRoutine(apiWaiter chan interface{}) { // legacyFallback is used when a batch request is made to a server that does // not support the batch endpoint. When this happens, the Transferables are // fed from the batcher into apic to be processed individually. -func (q *TransferQueue) legacyFallback(failedBatch []Transferable) { +// TODO LEGACY API: remove when legacy API removed +func (q *TransferQueue) legacyFallback(failedBatch []interface{}) { tracerx.Printf("tq: batch api not implemented, falling back to individual") q.launchIndividualApiRoutines() for _, t := range failedBatch { - q.apic <- t + q.apic <- t.(Transferable) } for { @@ -190,7 +281,7 @@ func (q *TransferQueue) legacyFallback(failedBatch []Transferable) { } for _, t := range batch { - q.apic <- t + q.apic <- t.(Transferable) } } } @@ -210,11 +301,12 @@ func (q *TransferQueue) batchApiRoutine() { tracerx.Printf("tq: sending batch of size %d", len(batch)) transfers := make([]*api.ObjectResource, 0, len(batch)) - for _, t := range batch { + for _, i := range batch { + t := i.(Transferable) transfers = append(transfers, &api.ObjectResource{Oid: t.Oid(), Size: t.Size()}) } - objects, err := api.Batch(transfers, q.transferKind) + objects, err := api.Batch(transfers, q.transferKind()) if err != nil { if errutil.IsNotImplementedError(err) { git.Config.SetLocal("", "lfs.batch", "false") @@ -225,7 +317,7 @@ func (q *TransferQueue) batchApiRoutine() { if q.canRetry(err) { for _, t := range batch { - q.retry(t) + q.retry(t.(Transferable)) } } else { q.errorc <- err @@ -245,7 +337,7 @@ func (q *TransferQueue) batchApiRoutine() { continue } - if _, ok := o.Rel(q.transferKind); ok { + if _, ok := o.Rel(q.transferKind()); ok { // This object needs to be transferred q.trMutex.Lock() transfer, ok := q.transferables[o.Oid] @@ -254,7 +346,7 @@ func (q *TransferQueue) batchApiRoutine() { if ok { transfer.SetObject(o) q.meter.Add(transfer.Name()) - q.transferc <- transfer + q.addToAdapter(transfer) } else { q.Skip(transfer.Size()) q.wait.Done() @@ -282,33 +374,6 @@ func (q *TransferQueue) retryCollector() { q.retrywait.Done() } -func (q *TransferQueue) transferWorker() { - for transfer := range q.transferc { - cb := func(total, read int64, current int) error { - q.meter.TransferBytes(q.transferKind, transfer.Name(), read, total, current) - return nil - } - - if err := transfer.Transfer(cb); err != nil { - if q.canRetry(err) { - tracerx.Printf("tq: retrying object %s", transfer.Oid()) - q.retry(transfer) - } else { - q.errorc <- err - } - } else { - oid := transfer.Oid() - for _, c := range q.watchers { - c <- oid - } - } - - q.meter.FinishTransfer(transfer.Name()) - - q.wait.Done() - } -} - // launchIndividualApiRoutines first launches a single api worker. When it // receives the first successful api request it launches workers - 1 more // workers. This prevents being prompted for credentials multiple times at once @@ -320,7 +385,7 @@ func (q *TransferQueue) launchIndividualApiRoutines() { <-apiWaiter - for i := 0; i < q.workers-1; i++ { + for i := 0; i < q.oldApiWorkers-1; i++ { go q.individualApiRoutine(nil) } }() @@ -333,11 +398,6 @@ func (q *TransferQueue) run() { go q.errorCollector() go q.retryCollector() - tracerx.Printf("tq: starting %d transfer workers", q.workers) - for i := 0; i < q.workers; i++ { - go q.transferWorker() - } - if config.Config.BatchTransfer() { tracerx.Printf("tq: running as batched queue, batch size of %d", batchSize) q.batcher = NewBatcher(batchSize) diff --git a/lfs/upload_queue.go b/lfs/upload_queue.go index 39eb06ef..1df6d6c3 100644 --- a/lfs/upload_queue.go +++ b/lfs/upload_queue.go @@ -8,7 +8,7 @@ import ( "github.com/github/git-lfs/api" "github.com/github/git-lfs/config" "github.com/github/git-lfs/errutil" - "github.com/github/git-lfs/progress" + "github.com/github/git-lfs/transfer" ) // Uploadable describes a file that can be uploaded. @@ -20,6 +20,35 @@ type Uploadable struct { object *api.ObjectResource } +func (u *Uploadable) Object() *api.ObjectResource { + return u.object +} + +func (u *Uploadable) Oid() string { + return u.oid +} + +func (u *Uploadable) Size() int64 { + return u.size +} + +func (u *Uploadable) Name() string { + return u.Filename +} + +func (u *Uploadable) SetObject(o *api.ObjectResource) { + u.object = o +} + +func (u *Uploadable) Path() string { + return u.OidPath +} + +// TODO LEGACY API: remove when legacy API removed +func (u *Uploadable) LegacyCheck() (*api.ObjectResource, error) { + return api.UploadCheck(u.Oid(), u.Size()) +} + // NewUploadable builds the Uploadable from the given information. // "filename" can be empty if a raw object is pushed (see "object-id" flag in push command)/ func NewUploadable(oid, filename string) (*Uploadable, error) { @@ -42,61 +71,9 @@ func NewUploadable(oid, filename string) (*Uploadable, error) { return &Uploadable{oid: oid, OidPath: localMediaPath, Filename: filename, size: fi.Size()}, nil } -func (u *Uploadable) Check() (*api.ObjectResource, error) { - return api.UploadCheck(u.OidPath) -} - -func (u *Uploadable) Transfer(cb progress.CopyCallback) error { - wcb := func(total, read int64, current int) error { - cb(total, read, current) - return nil - } - - path, err := LocalMediaPath(u.object.Oid) - if err != nil { - return errutil.Error(err) - } - - file, err := os.Open(path) - if err != nil { - return errutil.Error(err) - } - defer file.Close() - - reader := &progress.CallbackReader{ - C: wcb, - TotalSize: u.object.Size, - Reader: file, - } - - return api.UploadObject(u.object, reader) -} - -func (u *Uploadable) Object() *api.ObjectResource { - return u.object -} - -func (u *Uploadable) Oid() string { - return u.oid -} - -func (u *Uploadable) Size() int64 { - return u.size -} - -func (u *Uploadable) Name() string { - return u.Filename -} - -func (u *Uploadable) SetObject(o *api.ObjectResource) { - u.object = o -} - // NewUploadQueue builds an UploadQueue, allowing `workers` concurrent uploads. func NewUploadQueue(files int, size int64, dryRun bool) *TransferQueue { - q := newTransferQueue(files, size, dryRun) - q.transferKind = "upload" - return q + return newTransferQueue(files, size, dryRun, transfer.NewUploadAdapter(transfer.BasicAdapterName)) } // ensureFile makes sure that the cleanPath exists before pushing it. If it diff --git a/lfs/util.go b/lfs/util.go index 3e69150e..0f460052 100644 --- a/lfs/util.go +++ b/lfs/util.go @@ -26,25 +26,6 @@ const ( var currentPlatform = PlatformUndetermined -func CopyWithCallback(writer io.Writer, reader io.Reader, totalSize int64, cb progress.CopyCallback) (int64, error) { - if success, _ := CloneFile(writer, reader); success { - if cb != nil { - cb(totalSize, totalSize, 0) - } - return totalSize, nil - } - if cb == nil { - return io.Copy(writer, reader) - } - - cbReader := &progress.CallbackReader{ - C: cb, - TotalSize: totalSize, - Reader: reader, - } - return io.Copy(writer, cbReader) -} - func CopyCallbackFile(event, filename string, index, totalFiles int) (progress.CopyCallback, *os.File, error) { logPath := config.Config.Getenv("GIT_LFS_PROGRESS") if len(logPath) == 0 || len(filename) == 0 || len(event) == 0 { diff --git a/lfs/util_test.go b/lfs/util_test.go index d9777298..b2e776e4 100644 --- a/lfs/util_test.go +++ b/lfs/util_test.go @@ -2,7 +2,6 @@ package lfs import ( "bytes" - "io/ioutil" "strings" "testing" @@ -40,26 +39,6 @@ func TestWriterWithCallback(t *testing.T) { assert.Equal(t, 5, int(calledRead[1])) } -func TestCopyWithCallback(t *testing.T) { - buf := bytes.NewBufferString("BOOYA") - - called := 0 - calledWritten := make([]int64, 0, 2) - - n, err := CopyWithCallback(ioutil.Discard, buf, 5, func(total int64, written int64, current int) error { - called += 1 - calledWritten = append(calledWritten, written) - assert.Equal(t, 5, int(total)) - return nil - }) - assert.Nil(t, err) - assert.Equal(t, 5, int(n)) - - assert.Equal(t, 1, called) - assert.Len(t, calledWritten, 1) - assert.Equal(t, 5, int(calledWritten[0])) -} - type TestIncludeExcludeCase struct { expectedResult bool includes []string diff --git a/tools/filetools.go b/tools/filetools.go index cad65a29..9bae4fbe 100644 --- a/tools/filetools.go +++ b/tools/filetools.go @@ -3,6 +3,7 @@ package tools import ( + "fmt" "os" "path/filepath" "strings" @@ -54,6 +55,26 @@ func ResolveSymlinks(path string) string { return path } +// RenameFileCopyPermissions moves srcfile to destfile, replacing destfile if +// necessary and also copying the permissions of destfile if it already exists +func RenameFileCopyPermissions(srcfile, destfile string) error { + info, err := os.Stat(destfile) + if os.IsNotExist(err) { + // no original file + } else if err != nil { + return err + } else { + if err := os.Chmod(srcfile, info.Mode()); err != nil { + return fmt.Errorf("can't set filemode on file %q: %v", srcfile, err) + } + } + + if err := os.Rename(srcfile, destfile); err != nil { + return fmt.Errorf("cannot replace %q with %q: %v", destfile, srcfile, err) + } + return nil +} + // CleanPaths splits the given `paths` argument by the delimiter argument, and // then "cleans" that path according to the filepath.Clean function (see // https://golang.org/pkg/file/filepath#Clean). diff --git a/tools/iotools.go b/tools/iotools.go index b7f65c14..bb400272 100644 --- a/tools/iotools.go +++ b/tools/iotools.go @@ -1,6 +1,13 @@ package tools -import "io" +import ( + "crypto/sha256" + "encoding/hex" + "hash" + "io" + + "github.com/github/git-lfs/progress" +) type readSeekCloserWrapper struct { readSeeker io.ReadSeeker @@ -23,3 +30,49 @@ func (r *readSeekCloserWrapper) Close() error { func NewReadSeekCloserWrapper(r io.ReadSeeker) io.ReadCloser { return &readSeekCloserWrapper{r} } + +// CopyWithCallback copies reader to writer while performing a progress callback +func CopyWithCallback(writer io.Writer, reader io.Reader, totalSize int64, cb progress.CopyCallback) (int64, error) { + if success, _ := CloneFile(writer, reader); success { + if cb != nil { + cb(totalSize, totalSize, 0) + } + return totalSize, nil + } + if cb == nil { + return io.Copy(writer, reader) + } + + cbReader := &progress.CallbackReader{ + C: cb, + TotalSize: totalSize, + Reader: reader, + } + return io.Copy(writer, cbReader) +} + +// HashingReader wraps a reader and calculates the hash of the data as it is read +type HashingReader struct { + reader io.Reader + hasher hash.Hash +} + +func NewHashingReader(r io.Reader) *HashingReader { + return &HashingReader{r, sha256.New()} +} + +func (r *HashingReader) Hash() string { + return hex.EncodeToString(r.hasher.Sum(nil)) +} + +func (r *HashingReader) Read(b []byte) (int, error) { + w, err := r.reader.Read(b) + if err == nil || err == io.EOF { + _, e := r.hasher.Write(b[0:w]) + if e != nil && err == nil { + return w, e + } + } + + return w, err +} diff --git a/tools/util_generic.go b/tools/util_generic.go new file mode 100644 index 00000000..52943958 --- /dev/null +++ b/tools/util_generic.go @@ -0,0 +1,11 @@ +// +build !linux !cgo + +package tools + +import ( + "io" +) + +func CloneFile(writer io.Writer, reader io.Reader) (bool, error) { + return false, nil +} diff --git a/lfs/util_linux.go b/tools/util_linux.go similarity index 97% rename from lfs/util_linux.go rename to tools/util_linux.go index 776739dc..e43bb1d4 100644 --- a/lfs/util_linux.go +++ b/tools/util_linux.go @@ -1,6 +1,6 @@ // +build linux,cgo -package lfs +package tools /* #include diff --git a/tools/util_test.go b/tools/util_test.go new file mode 100644 index 00000000..a467cb31 --- /dev/null +++ b/tools/util_test.go @@ -0,0 +1,29 @@ +package tools + +import ( + "bytes" + "io/ioutil" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCopyWithCallback(t *testing.T) { + buf := bytes.NewBufferString("BOOYA") + + called := 0 + calledWritten := make([]int64, 0, 2) + + n, err := CopyWithCallback(ioutil.Discard, buf, 5, func(total int64, written int64, current int) error { + called += 1 + calledWritten = append(calledWritten, written) + assert.Equal(t, 5, int(total)) + return nil + }) + assert.Nil(t, err) + assert.Equal(t, 5, int(n)) + + assert.Equal(t, 1, called) + assert.Len(t, calledWritten, 1) + assert.Equal(t, 5, int(calledWritten[0])) +} diff --git a/transfer/basic.go b/transfer/basic.go new file mode 100644 index 00000000..e6a855ec --- /dev/null +++ b/transfer/basic.go @@ -0,0 +1,303 @@ +package transfer + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "sync" + + "github.com/github/git-lfs/api" + "github.com/github/git-lfs/errutil" + "github.com/github/git-lfs/httputil" + "github.com/github/git-lfs/progress" + "github.com/github/git-lfs/tools" + "github.com/rubyist/tracerx" +) + +const ( + BasicAdapterName = "basic" +) + +// Base implementation of basic all-or-nothing HTTP upload / download adapter +type basicAdapter struct { + direction Direction + jobChan chan *Transfer + cb TransferProgressCallback + outChan chan TransferResult + // WaitGroup to sync the completion of all workers + workerWait sync.WaitGroup + // WaitGroup to serialise the first transfer response to perform login if needed + authWait sync.WaitGroup +} + +func (a *basicAdapter) Direction() Direction { + return a.direction +} + +func (a *basicAdapter) Name() string { + return BasicAdapterName +} + +func (a *basicAdapter) Begin(maxConcurrency int, cb TransferProgressCallback, completion chan TransferResult) error { + a.cb = cb + a.outChan = completion + a.jobChan = make(chan *Transfer, 100) + + tracerx.Printf("xfer: adapter %q Begin() with %d workers", a.Name(), maxConcurrency) + + a.workerWait.Add(maxConcurrency) + a.authWait.Add(1) + for i := 0; i < maxConcurrency; i++ { + go a.worker(i) + } + tracerx.Printf("xfer: adapter %q started", a.Name()) + return nil +} + +func (a *basicAdapter) Add(t *Transfer) { + tracerx.Printf("xfer: adapter %q Add() for %q", a.Name(), t.Object.Oid) + a.jobChan <- t +} + +func (a *basicAdapter) End() { + tracerx.Printf("xfer: adapter %q End()", a.Name()) + close(a.jobChan) + // wait for all transfers to complete + a.workerWait.Wait() + if a.outChan != nil { + close(a.outChan) + } + tracerx.Printf("xfer: adapter %q stopped", a.Name()) +} + +func (a *basicAdapter) ClearTempStorage() error { + // Should be empty already but also remove dir + return os.RemoveAll(a.tempDir()) +} + +// worker function, many of these run per adapter +func (a *basicAdapter) worker(workerNum int) { + + tracerx.Printf("xfer: adapter %q worker %d starting", a.Name(), workerNum) + waitForAuth := workerNum > 0 + signalAuthOnResponse := workerNum == 0 + + // First worker is the only one allowed to start immediately + // The rest wait until successful response from 1st worker to + // make sure only 1 login prompt is presented if necessary + // Deliberately outside jobChan processing so we know worker 0 will process 1st item + if waitForAuth { + tracerx.Printf("xfer: adapter %q worker %d waiting for Auth", a.Name(), workerNum) + a.authWait.Wait() + tracerx.Printf("xfer: adapter %q worker %d auth signal received", a.Name(), workerNum) + } + + for t := range a.jobChan { + tracerx.Printf("xfer: adapter %q worker %d processing job for %q", a.Name(), workerNum, t.Object.Oid) + var err error + switch a.Direction() { + case Download: + err = a.download(t, signalAuthOnResponse) + case Upload: + err = a.upload(t, signalAuthOnResponse) + } + + if a.outChan != nil { + res := TransferResult{t, err} + a.outChan <- res + } + + // Only need to signal for auth once + signalAuthOnResponse = false + + tracerx.Printf("xfer: adapter %q worker %d finished job for %q", a.Name(), workerNum, t.Object.Oid) + } + // This will only happen if no jobs were submitted; just wake up all workers to finish + if signalAuthOnResponse { + a.authWait.Done() + } + tracerx.Printf("xfer: adapter %q worker %d stopping", a.Name(), workerNum) + a.workerWait.Done() +} + +func (a *basicAdapter) tempDir() string { + // Must be dedicated to this adapter as deleted by ClearTempStorage + d := filepath.Join(os.TempDir(), "git-lfs-basic-temp") + if err := os.MkdirAll(d, 0755); err != nil { + return os.TempDir() + } + return d +} + +func (a *basicAdapter) download(t *Transfer, signalAuthOnResponse bool) error { + rel, ok := t.Object.Rel("download") + if !ok { + return errors.New("Object not found on the server.") + } + + req, err := httputil.NewHttpRequest("GET", rel.Href, rel.Header) + if err != nil { + return err + } + + res, err := httputil.DoHttpRequest(req, true) + if err != nil { + return errutil.NewRetriableError(err) + } + httputil.LogTransfer("lfs.data.download", res) + defer res.Body.Close() + + // Signal auth OK on success response, before starting download to free up + // other workers immediately + if signalAuthOnResponse { + a.authWait.Done() + } + + // Now do transfer of content + f, err := ioutil.TempFile(a.tempDir(), t.Object.Oid+"-") + if err != nil { + return fmt.Errorf("cannot create temp file: %v", err) + } + + defer func() { + if err != nil { + // Don't leave the temp file lying around on error. + _ = os.Remove(f.Name()) // yes, ignore the error, not much we can do about it. + } + }() + + hasher := tools.NewHashingReader(res.Body) + + // ensure we always close f. Note that this does not conflict with the + // close below, as close is idempotent. + defer f.Close() + tempfilename := f.Name() + // Wrap callback to give name context + ccb := func(totalSize int64, readSoFar int64, readSinceLast int) error { + if a.cb != nil { + return a.cb(t.Name, totalSize, readSoFar, readSinceLast) + } + return nil + } + written, err := tools.CopyWithCallback(f, hasher, res.ContentLength, ccb) + if err != nil { + return fmt.Errorf("cannot write data to tempfile %q: %v", tempfilename, err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("can't close tempfile %q: %v", tempfilename, err) + } + + if actual := hasher.Hash(); actual != t.Object.Oid { + return fmt.Errorf("Expected OID %s, got %s after %d bytes written", t.Object.Oid, actual, written) + } + + return tools.RenameFileCopyPermissions(tempfilename, t.Path) + +} +func (a *basicAdapter) upload(t *Transfer, signalAuthOnResponse bool) error { + rel, ok := t.Object.Rel("upload") + if !ok { + return fmt.Errorf("No upload action for this object.") + } + + req, err := httputil.NewHttpRequest("PUT", rel.Href, rel.Header) + if err != nil { + return err + } + + if len(req.Header.Get("Content-Type")) == 0 { + req.Header.Set("Content-Type", "application/octet-stream") + } + + if req.Header.Get("Transfer-Encoding") == "chunked" { + req.TransferEncoding = []string{"chunked"} + } else { + req.Header.Set("Content-Length", strconv.FormatInt(t.Object.Size, 10)) + } + + req.ContentLength = t.Object.Size + + f, err := os.OpenFile(t.Path, os.O_RDONLY, 0644) + if err != nil { + return errutil.Error(err) + } + defer f.Close() + + // Ensure progress callbacks made while uploading + // Wrap callback to give name context + ccb := func(totalSize int64, readSoFar int64, readSinceLast int) error { + if a.cb != nil { + return a.cb(t.Name, totalSize, readSoFar, readSinceLast) + } + return nil + } + var reader io.Reader + reader = &progress.CallbackReader{ + C: ccb, + TotalSize: t.Object.Size, + Reader: f, + } + + if signalAuthOnResponse { + // Signal auth was ok on first read; this frees up other workers to start + reader = newStartCallbackReader(reader, func(*startCallbackReader) { + a.authWait.Done() + }) + } + + req.Body = ioutil.NopCloser(reader) + + res, err := httputil.DoHttpRequest(req, true) + if err != nil { + return errutil.NewRetriableError(err) + } + httputil.LogTransfer("lfs.data.upload", res) + + // A status code of 403 likely means that an authentication token for the + // upload has expired. This can be safely retried. + if res.StatusCode == 403 { + return errutil.NewRetriableError(err) + } + + if res.StatusCode > 299 { + return errutil.Errorf(nil, "Invalid status for %s: %d", httputil.TraceHttpReq(req), res.StatusCode) + } + + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + return api.VerifyUpload(t.Object) +} + +// startCallbackReader is a reader wrapper which calls a function as soon as the +// first Read() call is made. This callback is only made once +type startCallbackReader struct { + r io.Reader + cb func(*startCallbackReader) + cbDone bool +} + +func (s *startCallbackReader) Read(p []byte) (n int, err error) { + if !s.cbDone && s.cb != nil { + s.cb(s) + s.cbDone = true + } + return s.r.Read(p) +} +func newStartCallbackReader(r io.Reader, cb func(*startCallbackReader)) *startCallbackReader { + return &startCallbackReader{r, cb, false} +} + +func init() { + newfunc := func(name string, dir Direction) TransferAdapter { + return &basicAdapter{ + direction: dir, + } + } + RegisterNewTransferAdapterFunc(BasicAdapterName, Upload, newfunc) + RegisterNewTransferAdapterFunc(BasicAdapterName, Download, newfunc) +} diff --git a/transfer/transfer.go b/transfer/transfer.go new file mode 100644 index 00000000..0eec59a4 --- /dev/null +++ b/transfer/transfer.go @@ -0,0 +1,169 @@ +// Package transfer collects together adapters for uploading and downloading LFS content +// NOTE: Subject to change, do not rely on this package from outside git-lfs source +package transfer + +import ( + "sync" + + "github.com/github/git-lfs/api" +) + +type Direction int + +const ( + Upload = Direction(iota) + Download = Direction(iota) +) + +// NewTransferAdapterFunc creates new instances of TransferAdapter. Code that wishes +// to provide new TransferAdapter instances should pass an implementation of this +// function to RegisterNewTransferAdapterFunc +// name and dir are to provide context if one func implements many instances +type NewTransferAdapterFunc func(name string, dir Direction) TransferAdapter + +var ( + funcMutex sync.Mutex + downloadAdapterFuncs = make(map[string]NewTransferAdapterFunc) + uploadAdapterFuncs = make(map[string]NewTransferAdapterFunc) +) + +type TransferProgressCallback func(name string, totalSize, readSoFar int64, readSinceLast int) error + +// TransferAdapter is implemented by types which can upload and/or download LFS +// file content to a remote store. Each TransferAdapter accepts one or more requests +// which it may schedule and parallelise in whatever way it chooses, clients of +// this interface will receive notifications of progress and completion asynchronously. +// TransferAdapters support transfers in one direction; if an implementation +// provides support for upload and download, it should be instantiated twice, +// advertising support for each direction separately. +// Note that TransferAdapter only implements the actual upload/download of content +// itself; organising the wider process including calling the API to get URLs, +// handling progress reporting and retries is the job of the core TransferQueue. +// This is so that the orchestration remains core & standard but TransferAdapter +// can be changed to physically transfer to different hosts with less code. +type TransferAdapter interface { + // Name returns the name of this adapter, which is the same for all instances + // of this type of adapter + Name() string + // Direction returns whether this instance is an upload or download instance + // TransferAdapter instances can only be one or the other, although the same + // type may be instantiated for each direction + Direction() Direction + // Begin a new batch of uploads or downloads. Call this first, followed by + // one or more Add calls. maxConcurrency controls the number of transfers + // that may be done at once. The passed in callback will receive updates on + // progress, and the completion channel will receive completion notifications + // Either argument may be nil if not required by the client + Begin(maxConcurrency int, cb TransferProgressCallback, completion chan TransferResult) error + // Add queues a download/upload, which will complete asynchronously and + // notify the callbacks given to Begin() + Add(t *Transfer) + // Indicate that all transfers have been scheduled and resources can be released + // once the queued items have completed. + // This call blocks until all items have been processed + End() + // ClearTempStorage clears any temporary files, such as unfinished downloads that + // would otherwise be resumed + ClearTempStorage() error +} + +// General struct for both uploads and downloads +type Transfer struct { + // Name of the file that triggered this transfer + Name string + // Object from API which provides the core data for this transfer + Object *api.ObjectResource + // Path for uploads is the source of data to send, for downloads is the + // location to place the final result + Path string +} + +// NewTransfer creates a new Transfer instance +func NewTransfer(name string, obj *api.ObjectResource, path string) *Transfer { + return &Transfer{name, obj, path} +} + +// Result of a transfer returned through CompletionChannel() +type TransferResult struct { + Transfer *Transfer + // This will be non-nil if there was an error transferring this item + Error error +} + +// GetAdapterNames returns a list of the names of adapters available to be created +func GetAdapterNames(dir Direction) []string { + switch dir { + case Upload: + return GetUploadAdapterNames() + case Download: + return GetDownloadAdapterNames() + } + return nil +} + +// GetDownloadAdapterNames returns a list of the names of download adapters available to be created +func GetDownloadAdapterNames() []string { + funcMutex.Lock() + defer funcMutex.Unlock() + + ret := make([]string, 0, len(downloadAdapterFuncs)) + for n, _ := range downloadAdapterFuncs { + ret = append(ret, n) + } + return ret +} + +// GetUploadAdapterNames returns a list of the names of upload adapters available to be created +func GetUploadAdapterNames() []string { + funcMutex.Lock() + defer funcMutex.Unlock() + + ret := make([]string, 0, len(uploadAdapterFuncs)) + for n, _ := range uploadAdapterFuncs { + ret = append(ret, n) + } + return ret +} + +// RegisterNewTransferAdapterFunc registers a new function for creating upload +// or download adapters. If a function with that name & direction is already +// registered, it is overridden +func RegisterNewTransferAdapterFunc(name string, dir Direction, f NewTransferAdapterFunc) { + funcMutex.Lock() + defer funcMutex.Unlock() + + switch dir { + case Upload: + uploadAdapterFuncs[name] = f + case Download: + downloadAdapterFuncs[name] = f + } +} + +// Create a new adapter by name and direction, or nil if doesn't exist +func NewAdapter(name string, dir Direction) TransferAdapter { + funcMutex.Lock() + defer funcMutex.Unlock() + + switch dir { + case Upload: + if u, ok := uploadAdapterFuncs[name]; ok { + return u(name, dir) + } + case Download: + if d, ok := downloadAdapterFuncs[name]; ok { + return d(name, dir) + } + } + return nil +} + +// Create a new download adapter by name, or nil if doesn't exist +func NewDownloadAdapter(name string) TransferAdapter { + return NewAdapter(name, Download) +} + +// Create a new upload adapter by name, or nil if doesn't exist +func NewUploadAdapter(name string) TransferAdapter { + return NewAdapter(name, Upload) +} diff --git a/transfer/transfer_test.go b/transfer/transfer_test.go new file mode 100644 index 00000000..fde3eb59 --- /dev/null +++ b/transfer/transfer_test.go @@ -0,0 +1,105 @@ +package transfer + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type testAdapter struct { + name string + dir Direction +} + +func (a *testAdapter) Name() string { + return a.name +} +func (a *testAdapter) Direction() Direction { + return a.dir +} +func (a *testAdapter) Begin(maxConcurrency int, cb TransferProgressCallback, completion chan TransferResult) error { + return nil +} +func (a *testAdapter) Add(t *Transfer) { +} +func (a *testAdapter) End() { +} +func (a *testAdapter) ClearTempStorage() error { + return nil +} +func newTestAdapter(name string, dir Direction) TransferAdapter { + return &testAdapter{name, dir} +} +func newRenamedTestAdapter(name string, dir Direction) TransferAdapter { + return &testAdapter{"RENAMED", dir} +} +func resetAdapters() { + uploadAdapterFuncs = make(map[string]NewTransferAdapterFunc) + downloadAdapterFuncs = make(map[string]NewTransferAdapterFunc) +} + +func testBasicAdapterExists(t *testing.T) { + assert := assert.New(t) + + dls := GetDownloadAdapterNames() + if assert.NotNil(dls) { + assert.Equal([]string{"basic"}, dls) + } + uls := GetUploadAdapterNames() + if assert.NotNil(uls) { + assert.Equal([]string{"basic"}, uls) + } + da := NewDownloadAdapter("basic") + if assert.NotNil(da) { + assert.Equal("basic", da.Name()) + assert.Equal(Download, da.Direction()) + } + ua := NewUploadAdapter("basic") + if assert.NotNil(ua) { + assert.Equal("basic", ua.Name()) + assert.Equal(Upload, ua.Direction()) + } +} + +func testAdapterRegAndOverride(t *testing.T) { + assert := assert.New(t) + + assert.Nil(NewDownloadAdapter("test")) + assert.Nil(NewUploadAdapter("test")) + + RegisterNewTransferAdapterFunc("test", Upload, newTestAdapter) + assert.Nil(NewDownloadAdapter("test")) + assert.NotNil(NewUploadAdapter("test")) + + RegisterNewTransferAdapterFunc("test", Download, newTestAdapter) + da := NewDownloadAdapter("test") + if assert.NotNil(da) { + assert.Equal("test", da.Name()) + assert.Equal(Download, da.Direction()) + } + ua := NewUploadAdapter("test") + if assert.NotNil(ua) { + assert.Equal("test", ua.Name()) + assert.Equal(Upload, ua.Direction()) + } + + // Test override + RegisterNewTransferAdapterFunc("test", Upload, newRenamedTestAdapter) + ua = NewUploadAdapter("test") + if assert.NotNil(ua) { + assert.Equal("RENAMED", ua.Name()) + assert.Equal(Upload, ua.Direction()) + } + da = NewDownloadAdapter("test") + if assert.NotNil(da) { + assert.Equal("test", da.Name()) + assert.Equal(Download, da.Direction()) + } + RegisterNewTransferAdapterFunc("test", Download, newRenamedTestAdapter) + da = NewDownloadAdapter("test") + if assert.NotNil(da) { + assert.Equal("RENAMED", da.Name()) + assert.Equal(Download, da.Direction()) + } + +}