git-lfs/tq/ssh.go

402 lines
11 KiB
Go
Raw Normal View History

package tq
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/git-lfs/git-lfs/errors"
"github.com/git-lfs/git-lfs/lfshttp"
"github.com/git-lfs/git-lfs/ssh"
"github.com/git-lfs/git-lfs/tools"
"github.com/rubyist/tracerx"
)
type SSHBatchClient struct {
maxRetries int
transfer *ssh.SSHTransfer
}
func (a *SSHBatchClient) batchInternal(args []string, batchLines []string) (int, []string, error) {
conn := a.transfer.Connection(0)
conn.Lock()
defer conn.Unlock()
err := conn.SendMessageWithLines("batch", args, batchLines)
if err != nil {
return 0, nil, errors.Wrap(err, "batch request")
}
status, _, lines, err := conn.ReadStatusWithLines()
if err != nil {
return status, nil, errors.Wrap(err, "batch response")
}
return status, lines, err
}
func (a *SSHBatchClient) Batch(remote string, bReq *batchRequest) (*BatchResponse, error) {
bRes := &BatchResponse{TransferAdapterName: "ssh"}
if len(bReq.Objects) == 0 {
return bRes, nil
}
missing := make(map[string]bool)
batchLines := make([]string, 0, len(bReq.Objects))
for _, obj := range bReq.Objects {
missing[obj.Oid] = obj.Missing
batchLines = append(batchLines, fmt.Sprintf("%s %d", obj.Oid, obj.Size))
}
tracerx.Printf("api: batch %d files", len(bReq.Objects))
requestedAt := time.Now()
args := []string{"transfer=ssh"}
if bReq.Ref != nil {
args = append(args, fmt.Sprintf("refname=%s", bReq.Ref.Name))
}
status, lines, err := a.batchInternal(args, batchLines)
if err != nil {
return nil, err
}
if status != 200 {
msg := "no message provided"
if len(lines) > 0 {
msg = lines[0]
}
return nil, fmt.Errorf("batch response: status %d from server (%s)", status, msg)
}
sort.Strings(lines)
for _, line := range lines {
entries := strings.Split(line, " ")
if len(entries) < 3 {
return nil, fmt.Errorf("batch response: malformed response: %q", line)
}
length := len(bRes.Objects)
if length == 0 || bRes.Objects[length-1].Oid != entries[0] {
bRes.Objects = append(bRes.Objects, &Transfer{Actions: make(map[string]*Action)})
}
transfer := bRes.Objects[len(bRes.Objects)-1]
transfer.Oid = entries[0]
transfer.Size, err = strconv.ParseInt(entries[1], 10, 64)
if err != nil {
return nil, fmt.Errorf("batch response: invalid size: %s", entries[1])
}
if entries[2] == "noop" {
continue
}
transfer.Actions[entries[2]] = &Action{}
if len(entries) > 3 {
for _, entry := range entries[3:] {
if strings.HasPrefix(entry, "id=") {
transfer.Actions[entries[2]].Id = entry[3:]
} else if strings.HasPrefix(entry, "token=") {
transfer.Actions[entries[2]].Token = entry[6:]
} else if strings.HasPrefix(entry, "expires-in=") {
transfer.Actions[entries[2]].ExpiresIn, err = strconv.Atoi(entry[11:])
if err != nil {
return nil, fmt.Errorf("batch response: invalid expires-in: %s", entry)
}
} else if strings.HasPrefix(entry, "expires-at=") {
transfer.Actions[entries[2]].ExpiresAt, err = time.Parse(time.RFC3339, entry[11:])
if err != nil {
return nil, fmt.Errorf("batch response: invalid expires-at: %s", entry)
}
}
}
}
}
for _, obj := range bRes.Objects {
obj.Missing = missing[obj.Oid]
for _, a := range obj.Actions {
a.createdAt = requestedAt
}
}
return bRes, nil
}
func (a *SSHBatchClient) MaxRetries() int {
return a.maxRetries
}
func (a *SSHBatchClient) SetMaxRetries(n int) {
a.maxRetries = n
}
type SSHAdapter struct {
*adapterBase
ctx lfshttp.Context
transfer *ssh.SSHTransfer
}
// WorkerStarting is called when a worker goroutine starts to process jobs
// Implementations can run some startup logic here & return some context if needed
func (a *SSHAdapter) WorkerStarting(workerNum int) (interface{}, error) {
a.transfer.SetConnectionCountAtLeast(workerNum + 1)
return a.transfer.Connection(workerNum), nil
}
// WorkerEnding is called when a worker goroutine is shutting down
// Implementations can clean up per-worker resources here, context is as returned from WorkerStarting
func (a *SSHAdapter) WorkerEnding(workerNum int, ctx interface{}) {
}
func (a *SSHAdapter) tempDir() string {
// Shared with the basic download adapter.
d := filepath.Join(a.fs.LFSStorageDir, "incomplete")
if err := tools.MkdirAll(d, a.fs); err != nil {
return os.TempDir()
}
return d
}
// DoTransfer performs a single transfer within a worker. ctx is any context returned from WorkerStarting
func (a *SSHAdapter) DoTransfer(ctx interface{}, t *Transfer, cb ProgressCallback, authOkFunc func()) error {
if authOkFunc != nil {
authOkFunc()
}
conn := ctx.(*ssh.PktlineConnection)
if a.adapterBase.direction == Upload {
return a.upload(t, conn, cb)
} else {
return a.download(t, conn, cb)
}
}
func (a *SSHAdapter) download(t *Transfer, conn *ssh.PktlineConnection, cb ProgressCallback) error {
rel, err := t.Rel("download")
if err != nil {
return err
}
if rel == nil {
return errors.Errorf("No download action for object: %s", t.Oid)
}
// Reserve a temporary filename. We need to make sure nobody operates on the file simultaneously with us.
f, err := tools.TempFile(a.tempDir(), t.Oid, a.fs)
if err != nil {
return err
}
tmpName := f.Name()
defer func() {
if f != nil {
f.Close()
}
os.Remove(tmpName)
}()
return a.doDownload(t, conn, f, cb)
}
// doDownload starts a download. f is expected to be an existing file open in RW mode
func (a *SSHAdapter) doDownload(t *Transfer, conn *ssh.PktlineConnection, f *os.File, cb ProgressCallback) error {
args := a.argumentsForTransfer(t, "download")
conn.Lock()
defer conn.Unlock()
err := conn.SendMessage(fmt.Sprintf("get-object %s", t.Oid), args)
if err != nil {
return err
}
status, args, data, err := conn.ReadStatusWithData()
if err != nil {
return err
}
if status < 200 || status > 299 {
buffer := &bytes.Buffer{}
if data != nil {
io.CopyN(buffer, data, 1024)
io.Copy(ioutil.Discard, data)
}
return errors.NewRetriableError(fmt.Errorf("got status %d when fetching OID %s: %s", status, t.Oid, buffer.String()))
}
var actualSize int64
seenSize := false
for _, arg := range args {
if strings.HasPrefix(arg, "size=") {
if seenSize {
return errors.NewProtocolError("unexpected size argument", nil)
}
actualSize, err = strconv.ParseInt(arg[5:], 10, 64)
if err != nil || actualSize < 0 {
return errors.NewProtocolError(fmt.Sprintf("expected valid size, got %q", arg[5:]), err)
}
seenSize = true
}
}
if !seenSize {
return errors.NewProtocolError("no size argument seen", nil)
}
dlfilename := f.Name()
// Wrap callback to give name context
ccb := func(totalSize int64, readSoFar int64, readSinceLast int) error {
if cb != nil {
return cb(t.Name, totalSize, readSoFar, readSinceLast)
}
return nil
}
hasher := tools.NewHashingReader(data)
written, err := tools.CopyWithCallback(f, hasher, t.Size, ccb)
if err != nil {
return errors.Wrapf(err, "cannot write data to tempfile %q", dlfilename)
}
if actual := hasher.Hash(); actual != t.Oid {
return fmt.Errorf("expected OID %s, got %s after %d bytes written", t.Oid, actual, written)
}
if err := f.Close(); err != nil {
return fmt.Errorf("can't close tempfile %q: %v", dlfilename, err)
}
err = tools.RenameFileCopyPermissions(dlfilename, t.Path)
if _, err2 := os.Stat(t.Path); err2 == nil {
// Target file already exists, possibly was downloaded by other git-lfs process
return nil
}
return err
}
func (a *SSHAdapter) verifyUpload(t *Transfer, conn *ssh.PktlineConnection) error {
args := a.argumentsForTransfer(t, "upload")
conn.Lock()
defer conn.Unlock()
err := conn.SendMessage(fmt.Sprintf("verify-object %s", t.Oid), args)
if err != nil {
return err
}
status, _, lines, err := conn.ReadStatusWithLines()
if err != nil {
return err
}
if status < 200 || status > 299 {
if len(lines) > 0 {
return fmt.Errorf("got status %d when verifying upload OID %s: %s", status, t.Oid, lines[0])
}
return fmt.Errorf("got status %d when verifying upload OID %s", status, t.Oid)
}
return nil
}
func (a *SSHAdapter) doUpload(t *Transfer, conn *ssh.PktlineConnection, f *os.File, cb ProgressCallback) (int, []string, []string, error) {
args := a.argumentsForTransfer(t, "upload")
// Ensure progress callbacks made while uploading
// Wrap callback to give name context
ccb := func(totalSize int64, readSoFar int64, readSinceLast int) error {
if cb != nil {
return cb(t.Name, totalSize, readSoFar, readSinceLast)
}
return nil
}
cbr := tools.NewFileBodyWithCallback(f, t.Size, ccb)
conn.Lock()
defer conn.Unlock()
defer cbr.Close()
err := conn.SendMessageWithData(fmt.Sprintf("put-object %s", t.Oid), args, cbr)
if err != nil {
return 0, nil, nil, err
}
return conn.ReadStatusWithLines()
}
// upload starts an upload.
func (a *SSHAdapter) upload(t *Transfer, conn *ssh.PktlineConnection, cb ProgressCallback) error {
rel, err := t.Rel("upload")
if err != nil {
return err
}
if rel == nil {
return errors.Errorf("No upload action for object: %s", t.Oid)
}
f, err := os.OpenFile(t.Path, os.O_RDONLY, 0644)
if err != nil {
return errors.Wrap(err, "SSH upload")
}
defer f.Close()
status, _, lines, err := a.doUpload(t, conn, f, cb)
if err != nil {
return err
}
if status < 200 || status > 299 {
// A status code of 403 likely means that an authentication token for the
// upload has expired. This can be safely retried.
if status == 403 {
err = errors.New("http: received status 403")
return errors.NewRetriableError(err)
}
if status == 429 {
return errors.NewRetriableError(fmt.Errorf("got status %d when uploading OID %s", status, t.Oid))
}
if len(lines) > 0 {
return fmt.Errorf("got status %d when uploading OID %s: %s", status, t.Oid, lines[0])
}
return fmt.Errorf("got status %d when uploading OID %s", status, t.Oid)
}
return a.verifyUpload(t, conn)
}
func (a *SSHAdapter) argumentsForTransfer(t *Transfer, action string) []string {
args := make([]string, 0, 3)
set, ok := t.Actions[action]
if !ok {
return nil
}
args = append(args, fmt.Sprintf("size=%d", t.Size))
if set.Id != "" {
args = append(args, fmt.Sprintf("id=%s", set.Id))
}
if set.Token != "" {
args = append(args, fmt.Sprintf("token=%s", set.Token))
}
return args
}
// Begin a new batch of uploads or downloads. Call this first, followed by one
// or more Add calls. The passed in callback will receive updates on progress.
func (a *SSHAdapter) Begin(cfg AdapterConfig, cb ProgressCallback) error {
if err := a.adapterBase.Begin(cfg, cb); err != nil {
return err
}
a.ctx = a.adapterBase.apiClient.Context()
a.debugging = a.ctx.OSEnv().Bool("GIT_TRANSFER_TRACE", false)
return nil
}
func (a *SSHAdapter) Trace(format string, args ...interface{}) {
if !a.adapterBase.debugging {
return
}
tracerx.Printf(format, args...)
}
func configureSSHAdapter(m *Manifest) {
m.RegisterNewAdapterFunc("ssh", Upload, func(name string, dir Direction) Adapter {
a := &SSHAdapter{newAdapterBase(m.fs, name, dir, nil), nil, m.sshTransfer}
a.transferImpl = a
return a
})
m.RegisterNewAdapterFunc("ssh", Download, func(name string, dir Direction) Adapter {
a := &SSHAdapter{newAdapterBase(m.fs, name, dir, nil), nil, m.sshTransfer}
a.transferImpl = a
return a
})
}