git-lfs/transfer/basic.go
2016-05-25 11:34:58 +01:00

261 lines
6.3 KiB
Go

package transfer
import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"sync"
"github.com/github/git-lfs/api"
"github.com/github/git-lfs/config"
"github.com/github/git-lfs/errutil"
"github.com/github/git-lfs/httputil"
"github.com/github/git-lfs/tools"
"github.com/github/git-lfs/progress"
)
// Base implementation of basic all-or-nothing HTTP upload / download adapter
type basicAdapter struct {
direction Direction
jobChan chan *Transfer
cb progress.CopyCallback
outChan chan TransferResult
// WaitGroup to sync the completion of all workers
workerWait sync.WaitGroup
// WaitGroup to serialise the first transfer response to perform login if needed
authWait sync.WaitGroup
}
func newBasicAdapter(d Direction) *basicAdapter {
return &basicAdapter{
direction: d,
}
}
func (a *basicAdapter) Direction() Direction {
return a.direction
}
func (a *basicAdapter) Name() string {
return "basic"
}
func (a *basicAdapter) Begin(cb progress.CopyCallback, completion chan TransferResult) error {
a.cb = cb
a.outChan = completion
a.jobChan = make(chan *Transfer, 100)
numworkers := config.Config.ConcurrentTransfers()
a.workerWait.Add(numworkers)
a.authWait.Add(1)
for i := 0; i < numworkers; i++ {
go a.worker(i)
}
return nil
}
func (a *basicAdapter) Add(t *Transfer) {
a.jobChan <- t
}
func (a *basicAdapter) End() {
a.jobChan.Close()
// wait for all transfers to complete
a.workerWait.Wait()
}
func (a *basicAdapter) ClearTempStorage() error {
// TODO @sinbad
return nil
}
// worker function, many of these run per adapter
func (a *basicAdapter) worker(workerNum int) {
isFirstWorker := workerNum == 0
signalAuthOnResponse := isFirstWorker
for t := range a.jobChan {
if !isFirstWorker {
// First worker is the only one allowed to start immediately
// The rest wait until successful response from 1st worker to
// make sure only 1 login prompt is presented if necessary
a.authWait.Wait()
}
var err error
switch a.Direction() {
case Download:
err = a.download(t, signalAuthOnResponse)
case Upload:
err = a.upload(t, signalAuthOnResponse)
}
res := TransferResult{t, err}
a.outChan <- res
signalAuthOnResponse = false
}
a.workerWait.Done()
}
func (a *basicAdapter) tempDir() string {
return filepath.Join(os.TempDir(), "git-lfs-basic")
}
func (a *basicAdapter) download(t *Transfer, signalAuthOnResponse bool) error {
rel, ok := t.Object.Rel("download")
if !ok {
return errors.New("Object not found on the server.")
}
req, err := httputil.NewHttpRequest("GET", rel.Href, rel.Header)
if err != nil {
return err
}
res, err := httputil.DoHttpRequest(req, true)
if err != nil {
return errutil.NewRetriableError(err)
}
httputil.LogTransfer("lfs.data.download", res)
defer res.Body.Close()
// Signal auth OK on success response, before starting download to free up
// other workers immediately
if signalAuthOnResponse {
a.authWait.Done()
}
// Now do transfer of content
f, err := ioutil.TempFile(a.tempDir(), t.Object.Oid+"-")
if err != nil {
return fmt.Errorf("cannot create temp file: %v", err)
}
defer func() {
if err != nil {
// Don't leave the temp file lying around on error.
_ = os.Remove(f.Name()) // yes, ignore the error, not much we can do about it.
}
}()
hasher := tools.NewHashingReader(res.Body)
// ensure we always close f. Note that this does not conflict with the
// close below, as close is idempotent.
defer f.Close()
tempfilename := f.Name()
written, err := tools.CopyWithCallback(f, hasher, res.ContentLength, a.cb)
if err != nil {
return fmt.Errorf("cannot write data to tempfile %q: %v", tempfilename, err)
}
if err := f.Close(); err != nil {
return fmt.Errorf("can't close tempfile %q: %v", tempfilename, err)
}
if actual := hasher.Hash(); actual != t.Object.Oid {
return fmt.Errorf("Expected OID %s, got %s after %d bytes written", t.Object.Oid, actual, written)
}
return tools.RenameFileCopyPermissions(tempfilename, t.Path)
}
func (a *basicAdapter) upload(t *Transfer, signalAuthOnResponse bool) error {
rel, ok := t.Object.Rel("upload")
if !ok {
return fmt.Errorf("No upload action for this object.")
}
req, err := httputil.NewHttpRequest("PUT", rel.Href, rel.Header)
if err != nil {
return err
}
if len(req.Header.Get("Content-Type")) == 0 {
req.Header.Set("Content-Type", "application/octet-stream")
}
if req.Header.Get("Transfer-Encoding") == "chunked" {
req.TransferEncoding = []string{"chunked"}
} else {
req.Header.Set("Content-Length", strconv.FormatInt(t.Object.Size, 10))
}
req.ContentLength = t.Object.Size
f, err := os.OpenFile(t.Path, os.O_RDONLY, 0644)
if err != nil {
return errutil.Error(err)
}
defer f.Close()
// Ensure progress callbacks made while uploading
var reader io.Reader
reader = &progress.CallbackReader{
C: a.cb,
TotalSize: t.Object.Size,
Reader: f,
}
if signalAuthOnResponse {
// Signal auth was ok on first read; this frees up other workers to start
reader = newStartCallbackReader(reader, func(*startCallbackReader) {
a.authWait.Done()
})
}
req.Body = ioutil.NopCloser(reader)
res, err := httputil.DoHttpRequest(req, true)
if err != nil {
return errutil.NewRetriableError(err)
}
httputil.LogTransfer("lfs.data.upload", res)
// A status code of 403 likely means that an authentication token for the
// upload has expired. This can be safely retried.
if res.StatusCode == 403 {
return errutil.NewRetriableError(err)
}
if res.StatusCode > 299 {
return errutil.Errorf(nil, "Invalid status for %s: %d", httputil.TraceHttpReq(req), res.StatusCode)
}
io.Copy(ioutil.Discard, res.Body)
res.Body.Close()
return api.VerifyUpload(t.Object)
}
// startCallbackReader is a reader wrapper which calls a function as soon as the
// first Read() call is made. This callback is only made once
type startCallbackReader struct {
r io.Reader
cb func(*startCallbackReader)
cbDone bool
}
func (s *startCallbackReader) Read(p []byte) (n int, err error) {
if !s.cbDone && s.cb != nil {
s.cb(s)
s.cbDone = true
}
return s.r.Read(p)
}
func newStartCallbackReader(r io.Reader, cb func(*startCallbackReader)) *startCallbackReader {
return &startCallbackReader{r, cb, false}
}
func init() {
ul := newBasicAdapter(Upload)
RegisterAdapter(ul)
dl := newBasicAdapter(Download)
RegisterAdapter(dl)
}