diff --git a/ssh/connection.go b/ssh/connection.go index 445af425..9b499d31 100644 --- a/ssh/connection.go +++ b/ssh/connection.go @@ -1,16 +1,38 @@ package ssh import ( + "sync" + "github.com/git-lfs/git-lfs/config" "github.com/git-lfs/git-lfs/subprocess" "github.com/git-lfs/pktline" ) type SSHTransfer struct { - conn *PktlineConnection + lock *sync.RWMutex + conn []*PktlineConnection + osEnv config.Environment + gitEnv config.Environment + meta *SSHMetadata + operation string } func NewSSHTransfer(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, operation string) (*SSHTransfer, error) { + conn, err := startConnection(osEnv, gitEnv, meta, operation) + if err != nil { + return nil, err + } + return &SSHTransfer{ + lock: &sync.RWMutex{}, + osEnv: osEnv, + gitEnv: gitEnv, + meta: meta, + operation: operation, + conn: []*PktlineConnection{conn}, + }, nil +} + +func startConnection(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, operation string) (*PktlineConnection, error) { exe, args := GetLFSExeAndArgs(osEnv, gitEnv, meta, "git-lfs-transfer", operation) cmd := subprocess.ExecCommand(exe, args...) r, err := cmd.StdoutPipe() @@ -37,14 +59,63 @@ func NewSSHTransfer(osEnv config.Environment, gitEnv config.Environment, meta *S pl: pl, } err = conn.Start() - if err != nil { - return nil, err - } - return &SSHTransfer{ - conn: conn, - }, nil + return conn, err } -func (tr *SSHTransfer) Connection() *PktlineConnection { - return tr.conn +// Connection returns the nth connection (starting from 0) in this transfer +// instance or nil if there is no such item. +func (tr *SSHTransfer) Connection(n int) *PktlineConnection { + tr.lock.RLock() + defer tr.lock.RUnlock() + if n >= len(tr.conn) { + return nil + } + return tr.conn[n] +} + +// ConnectionCount returns the number of connections this object has. +func (tr *SSHTransfer) ConnectionCount() int { + tr.lock.RLock() + defer tr.lock.RUnlock() + return len(tr.conn) +} + +// SetConnectionCount sets the number of connections to the specified number. +func (tr *SSHTransfer) SetConnectionCount(n int) error { + tr.lock.Lock() + defer tr.lock.Unlock() + return tr.setConnectionCount(n) +} + +// SetConnectionCountAtLeast sets the number of connections to be not less than +// the specified number. +func (tr *SSHTransfer) SetConnectionCountAtLeast(n int) error { + tr.lock.Lock() + defer tr.lock.Unlock() + count := len(tr.conn) + if n <= count { + return nil + } + return tr.setConnectionCount(n) +} + +func (tr *SSHTransfer) setConnectionCount(n int) error { + count := len(tr.conn) + if n < count { + for _, item := range tr.conn[n:count] { + if err := item.End(); err != nil { + return err + } + } + tr.conn = tr.conn[0:n] + } else if n > count { + for i := count; i < n; i++ { + conn, err := startConnection(tr.osEnv, tr.gitEnv, tr.meta, tr.operation) + if err != nil { + return err + } + tr.conn = append(tr.conn, conn) + } + } + return nil } diff --git a/tq/ssh.go b/tq/ssh.go index 41b8b0a3..bbf9512d 100644 --- a/tq/ssh.go +++ b/tq/ssh.go @@ -25,7 +25,7 @@ type SSHBatchClient struct { } func (a *SSHBatchClient) batchInternal(args []string, batchLines []string) (int, []string, error) { - conn := a.transfer.Connection() + conn := a.transfer.Connection(0) conn.Lock() defer conn.Unlock() err := conn.SendMessageWithLines("batch", args, batchLines) @@ -196,7 +196,7 @@ func (a *SSHAdapter) download(t *Transfer, cb ProgressCallback) error { // doDownload starts a download. f is expected to be an existing file open in RW mode func (a *SSHAdapter) doDownload(t *Transfer, f *os.File, cb ProgressCallback) error { - conn := a.transfer.Connection() + conn := a.transfer.Connection(0) args := a.argumentsForTransfer(t, "download") conn.Lock() defer conn.Unlock() @@ -266,7 +266,7 @@ func (a *SSHAdapter) doDownload(t *Transfer, f *os.File, cb ProgressCallback) er } func (a *SSHAdapter) verifyUpload(t *Transfer) error { - conn := a.transfer.Connection() + conn := a.transfer.Connection(0) args := a.argumentsForTransfer(t, "upload") conn.Lock() defer conn.Unlock() @@ -288,7 +288,7 @@ func (a *SSHAdapter) verifyUpload(t *Transfer) error { } func (a *SSHAdapter) doUpload(t *Transfer, f *os.File, cb ProgressCallback) (int, []string, []string, error) { - conn := a.transfer.Connection() + conn := a.transfer.Connection(0) args := a.argumentsForTransfer(t, "upload") // Ensure progress callbacks made while uploading