diff --git a/lfsapi/client.go b/lfsapi/client.go index 72c604f3..1368934b 100644 --- a/lfsapi/client.go +++ b/lfsapi/client.go @@ -18,11 +18,31 @@ var UserAgent = "git-lfs" const MediaType = "application/vnd.git-lfs+json; charset=utf-8" func (c *Client) NewRequest(method string, e Endpoint, suffix string, body interface{}) (*http.Request, error) { - req, err := http.NewRequest(method, joinURL(e.Url, suffix), nil) + sshRes, err := c.resolveSSHEndpoint(e, method) + if err != nil { + tracerx.Printf("ssh: %s failed, error: %s, message: %s", + e.SshUserAndHost, err.Error(), sshRes.Message, + ) + + if len(sshRes.Message) > 0 { + return nil, errors.Wrap(err, sshRes.Message) + } + return nil, err + } + + prefix := e.Url + if len(sshRes.Href) > 0 { + prefix = sshRes.Href + } + + req, err := http.NewRequest(method, joinURL(prefix, suffix), nil) if err != nil { return req, err } + for key, value := range sshRes.Header { + req.Header.Set(key, value) + } req.Header.Set("Accept", MediaType) if body != nil { diff --git a/lfsapi/ssh.go b/lfsapi/ssh.go new file mode 100644 index 00000000..d59e3e4c --- /dev/null +++ b/lfsapi/ssh.go @@ -0,0 +1,110 @@ +package lfsapi + +import ( + "bytes" + "encoding/json" + "fmt" + "os/exec" + "path/filepath" + "strings" + + "github.com/rubyist/tracerx" +) + +func (c *Client) resolveSSHEndpoint(e Endpoint, method string) (sshAuthResponse, error) { + res := sshAuthResponse{} + if len(e.SshUserAndHost) == 0 { + return res, nil + } + + operation := "upload" + switch method { + case "GET", "HEAD": + operation = "download" + } + + tracerx.Printf("ssh: %s git-lfs-authenticate %s %s", + e.SshUserAndHost, e.SshPath, operation) + + exe, args := sshGetExeAndArgs(c.osEnv, e) + args = append(args, + fmt.Sprintf("git-lfs-authenticate %s %s", e.SshPath, operation)) + + cmd := exec.Command(exe, args...) + + // Save stdout and stderr in separate buffers + var outbuf, errbuf bytes.Buffer + cmd.Stdout = &outbuf + cmd.Stderr = &errbuf + + // Execute command + err := cmd.Start() + if err == nil { + err = cmd.Wait() + } + + // Processing result + if err != nil { + res.Message = strings.TrimSpace(errbuf.String()) + } else { + err = json.Unmarshal(outbuf.Bytes(), &res) + } + + return res, err +} + +type sshAuthResponse struct { + Message string `json:"-"` + Href string `json:"href"` + Header map[string]string `json:"header"` + ExpiresAt string `json:"expires_at"` +} + +// Return the executable name for ssh on this machine and the base args +// Base args includes port settings, user/host, everything pre the command to execute +func sshGetExeAndArgs(osEnv env, e Endpoint) (exe string, baseargs []string) { + isPlink := false + isTortoise := false + + ssh, _ := osEnv.Get("GIT_SSH") + sshCmd, _ := osEnv.Get("GIT_SSH_COMMAND") + cmdArgs := strings.Fields(sshCmd) + if len(cmdArgs) > 0 { + ssh = cmdArgs[0] + cmdArgs = cmdArgs[1:] + } + + if ssh == "" { + ssh = "ssh" + } else { + basessh := filepath.Base(ssh) + // Strip extension for easier comparison + if ext := filepath.Ext(basessh); len(ext) > 0 { + basessh = basessh[:len(basessh)-len(ext)] + } + isPlink = strings.EqualFold(basessh, "plink") + isTortoise = strings.EqualFold(basessh, "tortoiseplink") + } + + args := make([]string, 0, 4+len(cmdArgs)) + if len(cmdArgs) > 0 { + args = append(args, cmdArgs...) + } + + if isTortoise { + // TortoisePlink requires the -batch argument to behave like ssh/plink + args = append(args, "-batch") + } + + if len(e.SshPort) > 0 { + if isPlink || isTortoise { + args = append(args, "-P") + } else { + args = append(args, "-p") + } + args = append(args, e.SshPort) + } + args = append(args, e.SshUserAndHost) + + return ssh, args +} diff --git a/lfsapi/ssh_test.go b/lfsapi/ssh_test.go new file mode 100644 index 00000000..bfe03428 --- /dev/null +++ b/lfsapi/ssh_test.go @@ -0,0 +1,220 @@ +package lfsapi + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSHGetExeAndArgsSsh(t *testing.T) { + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": "", + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, "ssh", exe) + assert.Equal(t, []string{"user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsSshCustomPort(t *testing.T) { + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": "", + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, "ssh", exe) + assert.Equal(t, []string{"-p", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlink(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "plink.exe") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlinkCustomPort(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "plink") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlink(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "tortoiseplink.exe") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlinkCustomPort(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "tortoiseplink") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "", + "GIT_SSH": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsSshCommandPrecedence(t *testing.T) { + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "sshcmd", + "GIT_SSH": "bad", + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, "sshcmd", exe) + assert.Equal(t, []string{"user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsSshCommandArgs(t *testing.T) { + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "sshcmd --args 1", + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, "sshcmd", exe) + assert.Equal(t, []string{"--args", "1", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsSshCommandCustomPort(t *testing.T) { + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": "sshcmd", + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, "sshcmd", exe) + assert.Equal(t, []string{"-p", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlinkCommand(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "plink.exe") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsPlinkCommandCustomPort(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "plink") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-P", "8888", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlinkCommand(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "tortoiseplink.exe") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "user@foo.com"}, args) +} + +func TestSSHGetExeAndArgsTortoisePlinkCommandCustomPort(t *testing.T) { + plink := filepath.Join("Users", "joebloggs", "bin", "tortoiseplink") + + cli, err := NewClient(Env(map[string]string{ + "GIT_SSH_COMMAND": plink, + }), nil) + require.Nil(t, err) + + endpoint := cli.Endpoints.Endpoint("download", "") + endpoint.SshUserAndHost = "user@foo.com" + endpoint.SshPort = "8888" + + exe, args := sshGetExeAndArgs(cli.OSEnv(), endpoint) + assert.Equal(t, plink, exe) + assert.Equal(t, []string{"-batch", "-P", "8888", "user@foo.com"}, args) +}