ssh: support concurrent transfers using the pure SSH protocol

When using the pure SSH-based protocol, we can get much higher speeds by
multiplexing multiple connections on the same SSH connection.  If we're
using OpenSSH, let's enable the ControlMaster option unless
lfs.ssh.automultiplex is set to false, and multiplex these shell
operations over one connection.

We prefer XDG_RUNTIME_DIR because it's guaranteed to be private and we
can share many connections over one socket, but if that's not set, let's
default to creating a new temporary directory for the socket.  On
Windows, where the native SSH client doesn't support ControlMaster,
we should fall back to using multiple connections since we use
ControlMaster=auto.

Note that the option exists because users may already be using SSH
multiplexing and we would want to provide a way for them to disable
this, in addition to the case where users have an old or broken OpenSSH
which cannot support this option.

We pass the connection object into each worker and adjust our transfer
code to pass it into each function we invoke.  We also make sure to
properly terminate each connection at the end by reducing our connection
count to 0, which closes the extra (i.e., all) connections.

Co-authored-by: Chris Darroch <chrisd8088@github.com>
This commit is contained in:
brian m. carlson 2021-04-06 14:08:06 +00:00
parent 9ff6739b6b
commit 9c46a38281
No known key found for this signature in database
GPG Key ID: 2D0C9BC12F82B3A1
8 changed files with 100 additions and 48 deletions

@ -76,6 +76,12 @@ be scoped inside the configuration for a remote.
Sets the maximum time, in seconds, for the HTTP client to maintain keepalive
connections. Default: 30 minutes.
* `lfs.ssh.automultiplex`
When using the pure SSH-based protocol, whether to multiplex requests over a
single connection when possible. This option requires the use of OpenSSH or a
compatible SSH client. Default: true.
* `lfs.ssh.retries`
Specifies the number of times Git LFS will attempt to obtain authorization via

@ -79,7 +79,7 @@ func (c *sshAuthClient) Resolve(e Endpoint, method string) (sshAuthResponse, err
return res, nil
}
exe, args := ssh.GetLFSExeAndArgs(c.os, c.git, &e.SSHMetadata, "git-lfs-authenticate", endpointOperation(e, method))
exe, args := ssh.GetLFSExeAndArgs(c.os, c.git, &e.SSHMetadata, "git-lfs-authenticate", endpointOperation(e, method), false)
cmd := subprocess.ExecCommand(exe, args...)
// Save stdout and stderr in separate buffers

@ -33,7 +33,7 @@ func NewSSHTransfer(osEnv config.Environment, gitEnv config.Environment, meta *S
}
func startConnection(id int, osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, operation string) (*PktlineConnection, error) {
exe, args := GetLFSExeAndArgs(osEnv, gitEnv, meta, "git-lfs-transfer", operation)
exe, args := GetLFSExeAndArgs(osEnv, gitEnv, meta, "git-lfs-transfer", operation, true)
cmd := subprocess.ExecCommand(exe, args...)
r, err := cmd.StdoutPipe()
if err != nil {
@ -119,3 +119,7 @@ func (tr *SSHTransfer) setConnectionCount(n int) error {
}
return nil
}
func (tr *SSHTransfer) Shutdown() error {
return tr.SetConnectionCount(0)
}

@ -2,9 +2,12 @@ package ssh
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
"syscall"
"github.com/git-lfs/git-lfs/config"
"github.com/git-lfs/git-lfs/subprocess"
@ -35,8 +38,8 @@ func FormatArgs(cmd string, args []string, needShell bool) (string, []string) {
return subprocess.FormatForShellQuotedArgs(cmd, args)
}
func GetLFSExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, command, operation string) (string, []string) {
exe, args, needShell := GetExeAndArgs(osEnv, gitEnv, meta)
func GetLFSExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, command, operation string, multiplexDesired bool) (string, []string) {
exe, args, needShell := GetExeAndArgs(osEnv, gitEnv, meta, multiplexDesired)
args = append(args, fmt.Sprintf("%s %s %s", command, meta.Path, operation))
exe, args = FormatArgs(exe, args, needShell)
tracerx.Printf("run_command: %s %s", exe, strings.Join(args, " "))
@ -97,9 +100,36 @@ func getVariant(osEnv config.Environment, gitEnv config.Environment, basessh str
return autodetectVariant(osEnv, gitEnv, basessh)
}
// findRuntimeDir returns a path to the runtime directory if one exists and is
// guaranteed to be private.
func findRuntimeDir(osEnv config.Environment) string {
if dir, ok := osEnv.Get("XDG_RUNTIME_DIR"); ok {
return dir
}
return ""
}
func getControlDir(osEnv config.Environment) (string, error) {
dir := findRuntimeDir(osEnv)
if dir == "" {
return ioutil.TempDir("", "sock-*")
}
dir = filepath.Join(dir, "git-lfs")
err := os.Mkdir(dir, 0700)
if err != nil {
// Ideally we would use errors.Is here to check against
// os.ErrExist, but that's not available on Go 1.11.
perr, ok := err.(*os.PathError)
if !ok || perr.Err != syscall.EEXIST {
return ioutil.TempDir("", "sock-*")
}
}
return dir, nil
}
// Return the executable name for ssh on this machine and the base args
// Base args includes port settings, user/host, everything pre the command to execute
func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata) (exe string, baseargs []string, needShell bool) {
func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata, multiplexDesired bool) (exe string, baseargs []string, needShell bool) {
var cmd string
ssh, _ := osEnv.Get("GIT_SSH")
@ -125,6 +155,15 @@ func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SS
args = append(args, "-batch")
}
multiplexEnabled := gitEnv.Bool("lfs.ssh.automultiplex", true)
if variant == variantSSH && multiplexDesired && multiplexEnabled {
controlPath, err := getControlDir(osEnv)
if err != nil {
controlPath = filepath.Join(controlPath, "sock-%C")
args = append(args, "-oControlMaster=auto", fmt.Sprintf("-oControlPath=%s", controlPath))
}
}
if len(meta.Port) > 0 {
if variant == variantPutty || variant == variantTortoise {
args = append(args, "-P")

@ -19,7 +19,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Path = "user/repo"
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, "git-lfs-authenticate", "download")
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, "git-lfs-authenticate", "download", false)
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{
"--",
@ -27,7 +27,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) {
"git-lfs-authenticate user/repo download",
}, args)
exe, args = ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, "git-lfs-authenticate", "upload")
exe, args = ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, "git-lfs-authenticate", "upload", false)
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{
"--",
@ -46,7 +46,7 @@ func TestSSHGetExeAndArgsSsh(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{"--", "user@foo.com"}, args)
}
@ -62,7 +62,7 @@ func TestSSHGetExeAndArgsSshCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{"-p", "8888", "--", "user@foo.com"}, args)
}
@ -79,7 +79,7 @@ func TestSSHGetExeAndArgsPlink(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"user@foo.com"}, args)
}
@ -97,7 +97,7 @@ func TestSSHGetExeAndArgsPlinkCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args)
}
@ -116,7 +116,7 @@ func TestSSHGetExeAndArgsPlinkCustomPortExplicitEnvironment(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args)
}
@ -135,7 +135,7 @@ func TestSSHGetExeAndArgsPlinkCustomPortExplicitEnvironmentPutty(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args)
}
@ -189,7 +189,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args)
}
@ -208,7 +208,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitEnvironment(t *testing.T
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args)
}
@ -229,7 +229,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitConfig(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args)
}
@ -249,7 +249,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitConfigOverride(t *testin
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, plink, exe)
assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args)
}
@ -265,7 +265,7 @@ func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd user@foo.com"}, args)
}
@ -280,7 +280,7 @@ func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd --args 1 user@foo.com"}, args)
}
@ -295,7 +295,7 @@ func TestSSHGetExeAndArgsSshCommandArgsWithMixedQuotes(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' user@foo.com"}, args)
}
@ -310,7 +310,7 @@ func TestSSHGetExeAndArgsSshCommandCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd -p 8888 -- user@foo.com"}, args)
}
@ -326,7 +326,7 @@ func TestSSHGetExeAndArgsCoreSshCommand(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd --args 2 -- user@foo.com"}, args)
}
@ -340,7 +340,7 @@ func TestSSHGetExeAndArgsCoreSshCommandArgsWithMixedQuotes(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' -- user@foo.com"}, args)
}
@ -354,7 +354,7 @@ func TestSSHGetExeAndArgsConfigVersusEnv(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", "sshcmd --args 1 -- user@foo.com"}, args)
}
@ -370,7 +370,7 @@ func TestSSHGetExeAndArgsPlinkCommand(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", plink + " user@foo.com"}, args)
}
@ -387,7 +387,7 @@ func TestSSHGetExeAndArgsPlinkCommandCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", plink + " -P 8888 user@foo.com"}, args)
}
@ -403,7 +403,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommand(t *testing.T) {
meta := ssh.SSHMetadata{}
meta.UserAndHost = "user@foo.com"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", plink + " -batch user@foo.com"}, args)
}
@ -420,7 +420,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommandCustomPort(t *testing.T) {
meta.UserAndHost = "user@foo.com"
meta.Port = "8888"
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta))
exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false))
assert.Equal(t, "sh", exe)
assert.Equal(t, []string{"-c", plink + " -batch -P 8888 user@foo.com"}, args)
}
@ -441,7 +441,7 @@ func TestSSHGetLFSExeAndArgsWithCustomSSH(t *testing.T) {
assert.Equal(t, "git@host.com", e.SSHMetadata.UserAndHost)
assert.Equal(t, "repo", e.SSHMetadata.Path)
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download")
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download", false)
assert.Equal(t, "not-ssh", exe)
assert.Equal(t, []string{"-p", "12345", "git@host.com", "git-lfs-authenticate repo download"}, args)
}
@ -459,7 +459,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHost(t *testing.T) {
assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SSHMetadata.UserAndHost)
assert.Equal(t, "repo", e.SSHMetadata.Path)
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download")
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download", false)
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{"--", "-oProxyCommand=gnome-calculator", "git-lfs-authenticate repo download"}, args)
}
@ -480,7 +480,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHostWithCustomSSH(t *testing.T) {
assert.Equal(t, "--oProxyCommand=gnome-calculator", e.SSHMetadata.UserAndHost)
assert.Equal(t, "repo", e.SSHMetadata.Path)
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download")
exe, args := ssh.GetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, "git-lfs-authenticate", "download", false)
assert.Equal(t, "not-ssh", exe)
assert.Equal(t, []string{"oProxyCommand=gnome-calculator", "git-lfs-authenticate repo download"}, args)
}
@ -498,7 +498,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsHost(t *testing.T) {
assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SSHMetadata.UserAndHost)
assert.Equal(t, "", e.SSHMetadata.Path)
exe, args, needShell := ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata)
exe, args, needShell := ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, false)
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{"--", "-oProxyCommand=gnome-calculator"}, args)
assert.Equal(t, false, needShell)
@ -517,7 +517,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsPath(t *testing.T) {
assert.Equal(t, "git@git-host.com", e.SSHMetadata.UserAndHost)
assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SSHMetadata.Path)
exe, args, needShell := ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata)
exe, args, needShell := ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &e.SSHMetadata, false)
assert.Equal(t, "ssh", exe)
assert.Equal(t, []string{"--", "git@git-host.com"}, args)
assert.Equal(t, false, needShell)

@ -126,7 +126,6 @@ func NewManifest(f *fs.Filesystem, apiClient *lfsapi.Client, operation, remote s
if sshTransfer != nil {
// Multiple concurrent transfers are not yet supported.
m.concurrentTransfers = 1
m.batchClientAdapter = &SSHBatchClient{
maxRetries: m.maxRetries,
transfer: sshTransfer,

@ -141,7 +141,8 @@ type SSHAdapter struct {
// WorkerStarting is called when a worker goroutine starts to process jobs
// Implementations can run some startup logic here & return some context if needed
func (a *SSHAdapter) WorkerStarting(workerNum int) (interface{}, error) {
return nil, nil
a.transfer.SetConnectionCountAtLeast(workerNum + 1)
return a.transfer.Connection(workerNum), nil
}
// WorkerEnding is called when a worker goroutine is shutting down
@ -163,14 +164,16 @@ func (a *SSHAdapter) DoTransfer(ctx interface{}, t *Transfer, cb ProgressCallbac
if authOkFunc != nil {
authOkFunc()
}
conn := ctx.(*ssh.PktlineConnection)
if a.adapterBase.direction == Upload {
return a.upload(t, cb)
return a.upload(t, conn, cb)
} else {
return a.download(t, cb)
return a.download(t, conn, cb)
}
}
func (a *SSHAdapter) download(t *Transfer, cb ProgressCallback) error {
func (a *SSHAdapter) download(t *Transfer, conn *ssh.PktlineConnection, cb ProgressCallback) error {
// Reserve a temporary filename. We need to make sure nobody operates on the file simultaneously with us.
rel, err := t.Rel("download")
if err != nil {
return err
@ -191,12 +194,11 @@ func (a *SSHAdapter) download(t *Transfer, cb ProgressCallback) error {
os.Remove(tmpName)
}()
return a.doDownload(t, f, cb)
return a.doDownload(t, conn, f, cb)
}
// doDownload starts a download. f is expected to be an existing file open in RW mode
func (a *SSHAdapter) doDownload(t *Transfer, f *os.File, cb ProgressCallback) error {
conn := a.transfer.Connection(0)
func (a *SSHAdapter) doDownload(t *Transfer, conn *ssh.PktlineConnection, f *os.File, cb ProgressCallback) error {
args := a.argumentsForTransfer(t, "download")
conn.Lock()
defer conn.Unlock()
@ -265,8 +267,7 @@ func (a *SSHAdapter) doDownload(t *Transfer, f *os.File, cb ProgressCallback) er
return err
}
func (a *SSHAdapter) verifyUpload(t *Transfer) error {
conn := a.transfer.Connection(0)
func (a *SSHAdapter) verifyUpload(t *Transfer, conn *ssh.PktlineConnection) error {
args := a.argumentsForTransfer(t, "upload")
conn.Lock()
defer conn.Unlock()
@ -287,8 +288,7 @@ func (a *SSHAdapter) verifyUpload(t *Transfer) error {
return nil
}
func (a *SSHAdapter) doUpload(t *Transfer, f *os.File, cb ProgressCallback) (int, []string, []string, error) {
conn := a.transfer.Connection(0)
func (a *SSHAdapter) doUpload(t *Transfer, conn *ssh.PktlineConnection, f *os.File, cb ProgressCallback) (int, []string, []string, error) {
args := a.argumentsForTransfer(t, "upload")
// Ensure progress callbacks made while uploading
@ -313,7 +313,7 @@ func (a *SSHAdapter) doUpload(t *Transfer, f *os.File, cb ProgressCallback) (int
}
// upload starts an upload.
func (a *SSHAdapter) upload(t *Transfer, cb ProgressCallback) error {
func (a *SSHAdapter) upload(t *Transfer, conn *ssh.PktlineConnection, cb ProgressCallback) error {
rel, err := t.Rel("upload")
if err != nil {
return err
@ -328,7 +328,7 @@ func (a *SSHAdapter) upload(t *Transfer, cb ProgressCallback) error {
}
defer f.Close()
status, _, lines, err := a.doUpload(t, f, cb)
status, _, lines, err := a.doUpload(t, conn, f, cb)
if err != nil {
return err
}
@ -351,7 +351,7 @@ func (a *SSHAdapter) upload(t *Transfer, cb ProgressCallback) error {
}
return a.verifyUpload(t)
return a.verifyUpload(t, conn)
}
func (a *SSHAdapter) argumentsForTransfer(t *Transfer, action string) []string {

@ -943,6 +943,10 @@ func (q *TransferQueue) Wait() {
q.meter.Flush()
q.errorwait.Wait()
if q.manifest.sshTransfer != nil {
q.manifest.sshTransfer.Shutdown()
}
if q.unsupportedContentType {
for _, line := range contentTypeWarning {
fmt.Fprintf(os.Stderr, "info: %s\n", line)