diff --git a/ssh/ssh.go b/ssh/ssh.go index 1f2234f1..46308c2e 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -12,6 +12,15 @@ import ( "github.com/rubyist/tracerx" ) +type sshVariant string + +const ( + variantSSH = sshVariant("ssh") + variantSimple = sshVariant("simple") + variantPutty = sshVariant("putty") + variantTortoise = sshVariant("tortoiseplink") +) + type SSHMetadata struct { UserAndHost string Port string @@ -47,14 +56,52 @@ func parseShellCommand(command string, existing string) (ssh string, cmd string, return } +func findVariant(variant string) (bool, sshVariant) { + switch variant { + case "ssh", "simple", "putty", "tortoiseplink": + return false, sshVariant(variant) + case "plink": + return false, variantPutty + case "auto": + return true, "" + default: + return false, variantSSH + } +} + +func autodetectVariant(osEnv config.Environment, gitEnv config.Environment, basessh string) sshVariant { + if basessh != defaultSSHCmd { + // Strip extension for easier comparison + if ext := filepath.Ext(basessh); len(ext) > 0 { + basessh = basessh[:len(basessh)-len(ext)] + } + if strings.EqualFold(basessh, "plink") { + return variantPutty + } + if strings.EqualFold(basessh, "tortoiseplink") { + return variantTortoise + } + } + return "ssh" +} + +func getVariant(osEnv config.Environment, gitEnv config.Environment, basessh string) sshVariant { + variant, ok := osEnv.Get("GIT_SSH_VARIANT") + if !ok { + variant, ok = gitEnv.Get("ssh.variant") + } + autodetect, val := findVariant(variant) + if ok && !autodetect { + return val + } + return autodetectVariant(osEnv, gitEnv, basessh) +} + // Return the executable name for ssh on this machine and the base args // Base args includes port settings, user/host, everything pre the command to execute func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SSHMetadata) (exe string, baseargs []string, needShell bool) { var cmd string - isPlink := false - isTortoise := false - ssh, _ := osEnv.Get("GIT_SSH") sshCmd, _ := osEnv.Get("GIT_SSH_COMMAND") ssh, cmd, needShell = parseShellCommand(sshCmd, ssh) @@ -69,25 +116,17 @@ func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SS } basessh := filepath.Base(ssh) - - if basessh != defaultSSHCmd { - // Strip extension for easier comparison - if ext := filepath.Ext(basessh); len(ext) > 0 { - basessh = basessh[:len(basessh)-len(ext)] - } - isPlink = strings.EqualFold(basessh, "plink") - isTortoise = strings.EqualFold(basessh, "tortoiseplink") - } + variant := getVariant(osEnv, gitEnv, basessh) args := make([]string, 0, 7) - if isTortoise { + if variant == variantTortoise { // TortoisePlink requires the -batch argument to behave like ssh/plink args = append(args, "-batch") } if len(meta.Port) > 0 { - if isPlink || isTortoise { + if variant == variantPutty || variant == variantTortoise { args = append(args, "-P") } else { args = append(args, "-p") @@ -95,10 +134,10 @@ func GetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, meta *SS args = append(args, meta.Port) } - if sep, ok := sshSeparators[basessh]; ok { + if variant == variantSSH { // inserts a separator between cli -options and host/cmd commands // example: $ ssh -p 12345 -- user@host.com git-lfs-authenticate ... - args = append(args, sep, meta.UserAndHost) + args = append(args, "--", meta.UserAndHost) } else { // no prefix supported, strip leading - off host to prevent cmd like: // $ git config lfs.url ssh://-proxycmd=whatever @@ -116,8 +155,4 @@ const defaultSSHCmd = "ssh" var ( sshOptPrefixRE = regexp.MustCompile(`\A\-+`) - sshSeparators = map[string]string{ - "ssh": "--", - "lfs-ssh-echo": "--", // used in lfs integration tests only - } ) diff --git a/ssh/ssh_test.go b/ssh/ssh_test.go index 45fc1336..a025957e 100644 --- a/ssh/ssh_test.go +++ b/ssh/ssh_test.go @@ -102,6 +102,63 @@ func TestSSHGetExeAndArgsPlinkCustomPort(t *testing.T) { assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) } +func TestSSHGetExeAndArgsPlinkCustomPortExplicitEnvironment(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + "GIT_SSH_VARIANT": "plink", + }, nil)) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlinkCustomPortExplicitEnvironmentPutty(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + "GIT_SSH_VARIANT": "putty", + }, nil)) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlinkCustomPortExplicitEnvironmentSsh(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + "GIT_SSH_VARIANT": "ssh", + }, nil)) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-p", "8888", "--", "user@foo.com"}, args) +} + func TestSSHGetExeAndArgsTortoisePlink(t *testing.T) { plink := filepath.Join("Users", "joebloggs", "bin", "tortoiseplink.exe") @@ -114,7 +171,7 @@ func TestSSHGetExeAndArgsTortoisePlink(t *testing.T) { meta := ssh.SSHMetadata{} meta.UserAndHost = "user@foo.com" - exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta, false)) assert.Equal(t, plink, exe) assert.Equal(t, []string{"-batch", "user@foo.com"}, args) } @@ -137,10 +194,71 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPort(t *testing.T) { assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) } +func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitEnvironment(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + "GIT_SSH_VARIANT": "tortoiseplink", + }, nil)) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitConfig(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + "GIT_SSH_VARIANT": "tortoiseplink", + }, map[string]string{ + "ssh.variant": "tortoiseplink", + })) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlinkCustomPortExplicitConfigOverride(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "ssh") + + cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + }, map[string]string{ + "ssh.variant": "putty", + })) + require.Nil(t, err) + + meta := ssh.SSHMetadata{} + meta.UserAndHost = "user@foo.com" + meta.Port = "8888" + + exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) +} + func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) { cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ "GIT_SSH_COMMAND": "sshcmd", "GIT_SSH": "bad", + "GIT_SSH_VARIANT": "simple", }, nil)) require.Nil(t, err) @@ -155,6 +273,7 @@ func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) { func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) { cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ "GIT_SSH_COMMAND": "sshcmd --args 1", + "GIT_SSH_VARIANT": "simple", }, nil)) require.Nil(t, err) @@ -169,6 +288,7 @@ func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) { func TestSSHGetExeAndArgsSshCommandArgsWithMixedQuotes(t *testing.T) { cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ "GIT_SSH_COMMAND": "sshcmd foo 'bar \"baz\"'", + "GIT_SSH_VARIANT": "simple", }, nil)) require.Nil(t, err) @@ -192,7 +312,7 @@ func TestSSHGetExeAndArgsSshCommandCustomPort(t *testing.T) { exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) assert.Equal(t, "sh", exe) - assert.Equal(t, []string{"-c", "sshcmd -p 8888 user@foo.com"}, args) + assert.Equal(t, []string{"-c", "sshcmd -p 8888 -- user@foo.com"}, args) } func TestSSHGetExeAndArgsCoreSshCommand(t *testing.T) { @@ -208,7 +328,7 @@ func TestSSHGetExeAndArgsCoreSshCommand(t *testing.T) { exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) assert.Equal(t, "sh", exe) - assert.Equal(t, []string{"-c", "sshcmd --args 2 user@foo.com"}, args) + assert.Equal(t, []string{"-c", "sshcmd --args 2 -- user@foo.com"}, args) } func TestSSHGetExeAndArgsCoreSshCommandArgsWithMixedQuotes(t *testing.T) { @@ -222,7 +342,7 @@ func TestSSHGetExeAndArgsCoreSshCommandArgsWithMixedQuotes(t *testing.T) { exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) assert.Equal(t, "sh", exe) - assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' user@foo.com"}, args) + assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' -- user@foo.com"}, args) } func TestSSHGetExeAndArgsConfigVersusEnv(t *testing.T) { @@ -236,7 +356,7 @@ func TestSSHGetExeAndArgsConfigVersusEnv(t *testing.T) { exe, args := ssh.FormatArgs(ssh.GetExeAndArgs(cli.OSEnv(), cli.GitEnv(), &meta)) assert.Equal(t, "sh", exe) - assert.Equal(t, []string{"-c", "sshcmd --args 1 user@foo.com"}, args) + assert.Equal(t, []string{"-c", "sshcmd --args 1 -- user@foo.com"}, args) } func TestSSHGetExeAndArgsPlinkCommand(t *testing.T) { @@ -307,7 +427,8 @@ func TestSSHGetExeAndArgsTortoisePlinkCommandCustomPort(t *testing.T) { func TestSSHGetLFSExeAndArgsWithCustomSSH(t *testing.T) { cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ - "GIT_SSH": "not-ssh", + "GIT_SSH": "not-ssh", + "GIT_SSH_VARIANT": "simple", }, nil)) require.Nil(t, err) @@ -345,7 +466,8 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHost(t *testing.T) { func TestSSHGetLFSExeAndArgsInvalidOptionsAsHostWithCustomSSH(t *testing.T) { cli, err := lfshttp.NewClient(lfshttp.NewContext(nil, map[string]string{ - "GIT_SSH": "not-ssh", + "GIT_SSH": "not-ssh", + "GIT_SSH_VARIANT": "simple", }, nil)) require.Nil(t, err)