diff --git a/lfshttp/ssh.go b/lfshttp/ssh.go index 5fe718a3..8fe8f3df 100644 --- a/lfshttp/ssh.go +++ b/lfshttp/ssh.go @@ -81,7 +81,7 @@ func (c *sshAuthClient) Resolve(e Endpoint, method string) (sshAuthResponse, err return res, nil } - exe, args := sshGetLFSExeAndArgs(c.os, e, method) + exe, args := sshGetLFSExeAndArgs(c.os, c.git, e, method) cmd := exec.Command(exe, args...) // Save stdout and stderr in separate buffers @@ -124,8 +124,8 @@ func sshFormatArgs(cmd string, args []string, needShell bool) (string, []string) return "sh", []string{"-c", joined} } -func sshGetLFSExeAndArgs(osEnv config.Environment, e Endpoint, method string) (string, []string) { - exe, args, needShell := sshGetExeAndArgs(osEnv, e) +func sshGetLFSExeAndArgs(osEnv config.Environment, gitEnv config.Environment, e Endpoint, method string) (string, []string) { + exe, args, needShell := sshGetExeAndArgs(osEnv, gitEnv, e) operation := endpointOperation(e, method) args = append(args, fmt.Sprintf("git-lfs-authenticate %s %s", e.SshPath, operation)) exe, args = sshFormatArgs(exe, args, needShell) @@ -133,9 +133,22 @@ func sshGetLFSExeAndArgs(osEnv config.Environment, e Endpoint, method string) (s return exe, args } +// Parse command, and if it looks like a valid command, return the ssh binary +// name, the command to run, and whether we need a shell. If not, return +// existing as the ssh binary name. +func sshParseShellCommand(command string, existing string) (ssh string, cmd string, needShell bool) { + ssh = existing + if cmdArgs := tools.QuotedFields(command); len(cmdArgs) > 0 { + needShell = true + ssh = cmdArgs[0] + cmd = command + } + return +} + // Return the executable name for ssh on this machine and the base args // Base args includes port settings, user/host, everything pre the command to execute -func sshGetExeAndArgs(osEnv config.Environment, e Endpoint) (exe string, baseargs []string, needShell bool) { +func sshGetExeAndArgs(osEnv config.Environment, gitEnv config.Environment, e Endpoint) (exe string, baseargs []string, needShell bool) { var cmd string isPlink := false @@ -143,15 +156,11 @@ func sshGetExeAndArgs(osEnv config.Environment, e Endpoint) (exe string, basearg ssh, _ := osEnv.Get("GIT_SSH") sshCmd, _ := osEnv.Get("GIT_SSH_COMMAND") - cmdArgs := tools.QuotedFields(sshCmd) - if len(cmdArgs) > 0 { - needShell = true - ssh = cmdArgs[0] - cmd = sshCmd - } + ssh, cmd, needShell = sshParseShellCommand(sshCmd, ssh) if ssh == "" { - ssh = defaultSSHCmd + sshCmd, _ := gitEnv.Get("core.sshcommand") + ssh, cmd, needShell = sshParseShellCommand(sshCmd, defaultSSHCmd) } if cmd == "" { diff --git a/lfshttp/ssh_test.go b/lfshttp/ssh_test.go index 18a0ee0a..dd455e0b 100644 --- a/lfshttp/ssh_test.go +++ b/lfshttp/ssh_test.go @@ -220,7 +220,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPath = "user/repo" - exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), endpoint, "GET") + exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint, "GET") assert.Equal(t, "ssh", exe) assert.Equal(t, []string{ "--", @@ -228,7 +228,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) { "git-lfs-authenticate user/repo download", }, args) - exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), endpoint, "HEAD") + exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint, "HEAD") assert.Equal(t, "ssh", exe) assert.Equal(t, []string{ "--", @@ -237,7 +237,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) { }, args) // this is going by endpoint.Operation, implicitly set by Endpoint() on L15. - exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), endpoint, "POST") + exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint, "POST") assert.Equal(t, "ssh", exe) assert.Equal(t, []string{ "--", @@ -246,7 +246,7 @@ func TestSSHGetLFSExeAndArgs(t *testing.T) { }, args) endpoint.Operation = "upload" - exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), endpoint, "POST") + exe, args = sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint, "POST") assert.Equal(t, "ssh", exe) assert.Equal(t, []string{ "--", @@ -265,7 +265,7 @@ func TestSSHGetExeAndArgsSsh(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "ssh", exe) assert.Equal(t, []string{"--", "user@foo.com"}, args) } @@ -281,7 +281,7 @@ func TestSSHGetExeAndArgsSshCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "ssh", exe) assert.Equal(t, []string{"-p", "8888", "--", "user@foo.com"}, args) } @@ -298,7 +298,7 @@ func TestSSHGetExeAndArgsPlink(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, plink, exe) assert.Equal(t, []string{"user@foo.com"}, args) } @@ -316,7 +316,7 @@ func TestSSHGetExeAndArgsPlinkCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, plink, exe) assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) } @@ -333,7 +333,7 @@ func TestSSHGetExeAndArgsTortoisePlink(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, plink, exe) assert.Equal(t, []string{"-batch", "user@foo.com"}, args) } @@ -351,7 +351,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, plink, exe) assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) } @@ -366,7 +366,7 @@ func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", "sshcmd user@foo.com"}, args) } @@ -380,7 +380,7 @@ func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", "sshcmd --args 1 user@foo.com"}, args) } @@ -394,7 +394,7 @@ func TestSSHGetExeAndArgsSshCommandArgsWithMixedQuotes(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' user@foo.com"}, args) } @@ -409,11 +409,55 @@ func TestSSHGetExeAndArgsSshCommandCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", "sshcmd -p 8888 user@foo.com"}, args) } +func TestSSHGetExeAndArgsCoreSshCommand(t *testing.T) { + cli, err := NewClient(NewContext(nil, map[string]string{ + "GIT_SSH_COMMAND": "sshcmd --args 2", + }, map[string]string{ + "core.sshcommand": "sshcmd --args 1", + })) + require.Nil(t, err) + + endpoint := Endpoint{Operation: "download"} + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) + assert.Equal(t, "sh", exe) + assert.Equal(t, []string{"-c", "sshcmd --args 2 user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsCoreSshCommandArgsWithMixedQuotes(t *testing.T) { + cli, err := NewClient(NewContext(nil, nil, map[string]string{ + "core.sshcommand": "sshcmd foo 'bar \"baz\"'", + })) + require.Nil(t, err) + + endpoint := Endpoint{Operation: "download"} + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) + assert.Equal(t, "sh", exe) + assert.Equal(t, []string{"-c", "sshcmd foo 'bar \"baz\"' user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsConfigVersusEnv(t *testing.T) { + cli, err := NewClient(NewContext(nil, nil, map[string]string{ + "core.sshcommand": "sshcmd --args 1", + })) + require.Nil(t, err) + + endpoint := Endpoint{Operation: "download"} + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) + assert.Equal(t, "sh", exe) + assert.Equal(t, []string{"-c", "sshcmd --args 1 user@foo.com"}, args) +} + func TestSSHGetLFSExeAndArgsWithCustomSSH(t *testing.T) { cli, err := NewClient(NewContext(nil, map[string]string{ "GIT_SSH": "not-ssh", @@ -429,7 +473,7 @@ func TestSSHGetLFSExeAndArgsWithCustomSSH(t *testing.T) { assert.Equal(t, "git@host.com", e.SshUserAndHost) assert.Equal(t, "repo", e.SshPath) - exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), e, "GET") + exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), e, "GET") assert.Equal(t, "not-ssh", exe) assert.Equal(t, []string{"-p", "12345", "git@host.com", "git-lfs-authenticate repo download"}, args) } @@ -447,7 +491,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHost(t *testing.T) { assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "repo", e.SshPath) - exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), e, "GET") + exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), e, "GET") assert.Equal(t, "ssh", exe) assert.Equal(t, []string{"--", "-oProxyCommand=gnome-calculator", "git-lfs-authenticate repo download"}, args) } @@ -467,7 +511,7 @@ func TestSSHGetLFSExeAndArgsInvalidOptionsAsHostWithCustomSSH(t *testing.T) { assert.Equal(t, "--oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "repo", e.SshPath) - exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), e, "GET") + exe, args := sshGetLFSExeAndArgs(cli.OSEnv(), cli.GitEnv(), e, "GET") assert.Equal(t, "not-ssh", exe) assert.Equal(t, []string{"oProxyCommand=gnome-calculator", "git-lfs-authenticate repo download"}, args) } @@ -485,7 +529,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsHost(t *testing.T) { assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshUserAndHost) assert.Equal(t, "", e.SshPath) - exe, args, needShell := sshGetExeAndArgs(cli.OSEnv(), e) + exe, args, needShell := sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), e) assert.Equal(t, "ssh", exe) assert.Equal(t, []string{"--", "-oProxyCommand=gnome-calculator"}, args) assert.Equal(t, false, needShell) @@ -504,7 +548,7 @@ func TestSSHGetExeAndArgsInvalidOptionsAsPath(t *testing.T) { assert.Equal(t, "git@git-host.com", e.SshUserAndHost) assert.Equal(t, "-oProxyCommand=gnome-calculator", e.SshPath) - exe, args, needShell := sshGetExeAndArgs(cli.OSEnv(), e) + exe, args, needShell := sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), e) assert.Equal(t, "ssh", exe) assert.Equal(t, []string{"--", "git@git-host.com"}, args) assert.Equal(t, false, needShell) @@ -543,7 +587,7 @@ func TestSSHGetExeAndArgsPlinkCommand(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", plink + " user@foo.com"}, args) } @@ -560,7 +604,7 @@ func TestSSHGetExeAndArgsPlinkCommandCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", plink + " -P 8888 user@foo.com"}, args) } @@ -576,7 +620,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommand(t *testing.T) { endpoint := Endpoint{Operation: "download"} endpoint.SshUserAndHost = "user@foo.com" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", plink + " -batch user@foo.com"}, args) } @@ -593,7 +637,7 @@ func TestSSHGetExeAndArgsTortoisePlinkCommandCustomPort(t *testing.T) { endpoint.SshUserAndHost = "user@foo.com" endpoint.SshPort = "8888" - exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), endpoint)) + exe, args := sshFormatArgs(sshGetExeAndArgs(cli.OSEnv(), cli.GitEnv(), endpoint)) assert.Equal(t, "sh", exe) assert.Equal(t, []string{"-c", plink + " -batch -P 8888 user@foo.com"}, args) } diff --git a/t/cmd/lfstest-customadapter.go b/t/cmd/lfstest-customadapter.go index e0e06dff..c2decdf6 100644 --- a/t/cmd/lfstest-customadapter.go +++ b/t/cmd/lfstest-customadapter.go @@ -117,7 +117,11 @@ func performDownload(apiClient *lfsapi.Client, oid string, size int64, a *action res, err := apiClient.DoWithAuth("origin", req) if err != nil { - sendTransferError(oid, res.StatusCode, err.Error(), writer, errWriter) + statusCode := 6 + if res != nil { + statusCode = res.StatusCode + } + sendTransferError(oid, statusCode, err.Error(), writer, errWriter) return } defer res.Body.Close()