diff --git a/lfsapi/ntlm.go b/lfsapi/ntlm.go index b3b457b8..b4359b5f 100644 --- a/lfsapi/ntlm.go +++ b/lfsapi/ntlm.go @@ -7,9 +7,7 @@ import ( "io/ioutil" "net/http" "net/url" - "strings" - "github.com/ThomsonReutersEikon/go-ntlm/ntlm" "github.com/git-lfs/git-lfs/errors" ) @@ -28,25 +26,8 @@ func (c *Client) doWithNTLM(req *http.Request, credHelper CredentialHelper, cred // If the status is 401 then we need to re-authenticate func (c *Client) ntlmReAuth(req *http.Request, credHelper CredentialHelper, creds Creds, retry bool) (*http.Response, error) { - body, err := rewoundRequestBody(req) - if err != nil { - return nil, err - } - req.Body = body - - chRes, challengeMsg, err := c.ntlmNegotiate(req, ntlmNegotiateMessage) - if err != nil { - return chRes, err - } - - body, err = rewoundRequestBody(req) - if err != nil { - return nil, err - } - req.Body = body - - res, err := c.ntlmChallenge(req, challengeMsg, creds) - if err != nil { + res, err := c.ntlmAuthenticateRequest(req, creds) + if err != nil && !errors.IsAuthError(err) { return res, err } @@ -67,9 +48,8 @@ func (c *Client) ntlmReAuth(req *http.Request, credHelper CredentialHelper, cred return res, nil } -func (c *Client) ntlmNegotiate(req *http.Request, message string) (*http.Response, []byte, error) { - req.Header.Add("Authorization", message) - res, err := c.do(req) +func (c *Client) ntlmSendType1Message(req *http.Request, message []byte) (*http.Response, []byte, error) { + res, err := c.ntlmSendMessage(req, message) if err != nil && !errors.IsAuthError(err) { return res, nil, err } @@ -81,56 +61,20 @@ func (c *Client) ntlmNegotiate(req *http.Request, message string) (*http.Respons return res, by, err } -func (c *Client) ntlmChallenge(req *http.Request, challengeBytes []byte, creds Creds) (*http.Response, error) { - challenge, err := ntlm.ParseChallengeMessage(challengeBytes) - if err != nil { - return nil, err - } - - session, err := c.ntlmClientSession(creds) - if err != nil { - return nil, err - } - - session.ProcessChallengeMessage(challenge) - authenticate, err := session.GenerateAuthenticateMessage() - if err != nil { - return nil, err - } - - authMsg := base64.StdEncoding.EncodeToString(authenticate.Bytes()) - req.Header.Set("Authorization", "NTLM "+authMsg) - return c.do(req) +func (c *Client) ntlmSendType3Message(req *http.Request, authenticate []byte) (*http.Response, error) { + return c.ntlmSendMessage(req, authenticate) } -func (c *Client) ntlmClientSession(creds Creds) (ntlm.ClientSession, error) { - c.ntlmMu.Lock() - defer c.ntlmMu.Unlock() - - splits := strings.Split(creds["username"], "\\") - if len(splits) != 2 { - return nil, fmt.Errorf("Your user name must be of the form DOMAIN\\user. It is currently %s", creds["username"]) - } - - domain := strings.ToUpper(splits[0]) - username := splits[1] - - if c.ntlmSessions == nil { - c.ntlmSessions = make(map[string]ntlm.ClientSession) - } - - if ses, ok := c.ntlmSessions[domain]; ok { - return ses, nil - } - - session, err := ntlm.CreateClientSession(ntlm.Version2, ntlm.ConnectionOrientedMode) +func (c *Client) ntlmSendMessage(req *http.Request, message []byte) (*http.Response, error) { + body, err := rewoundRequestBody(req) if err != nil { return nil, err } + req.Body = body - session.SetUserInfo(username, creds["password"], strings.ToUpper(splits[0])) - c.ntlmSessions[domain] = session - return session, nil + msg := base64.StdEncoding.EncodeToString(message) + req.Header.Set("Authorization", "NTLM "+msg) + return c.do(req) } func parseChallengeResponse(res *http.Response) ([]byte, error) { @@ -162,5 +106,3 @@ func rewoundRequestBody(req *http.Request) (io.ReadCloser, error) { _, err := body.Seek(0, io.SeekStart) return body, err } - -const ntlmNegotiateMessage = "NTLM TlRMTVNTUAABAAAAB7IIogwADAAzAAAACwALACgAAAAKAAAoAAAAD1dJTExISS1NQUlOTk9SVEhBTUVSSUNB" diff --git a/lfsapi/ntlm_auth.go b/lfsapi/ntlm_auth.go new file mode 100644 index 00000000..48786dc4 --- /dev/null +++ b/lfsapi/ntlm_auth.go @@ -0,0 +1,72 @@ +package lfsapi + +import ( + "encoding/base64" + "fmt" + "net/http" + "strings" + + "github.com/ThomsonReutersEikon/go-ntlm/ntlm" +) + +func (c *Client) ntlmAuthenticateRequest(req *http.Request, creds Creds) (*http.Response, error) { + negotiate, err := base64.StdEncoding.DecodeString(ntlmNegotiateMessage) + if err != nil { + return nil, err + } + + chRes, challengeMsg, err := c.ntlmSendType1Message(req, negotiate) + if err != nil { + return chRes, err + } + + challenge, err := ntlm.ParseChallengeMessage(challengeMsg) + if err != nil { + return nil, err + } + + session, err := c.ntlmClientSession(creds) + if err != nil { + return nil, err + } + + session.ProcessChallengeMessage(challenge) + authenticate, err := session.GenerateAuthenticateMessage() + if err != nil { + return nil, err + } + + return c.ntlmSendType3Message(req, authenticate.Bytes()) +} + +func (c *Client) ntlmClientSession(creds Creds) (ntlm.ClientSession, error) { + c.ntlmMu.Lock() + defer c.ntlmMu.Unlock() + + splits := strings.Split(creds["username"], "\\") + if len(splits) != 2 { + return nil, fmt.Errorf("Your user name must be of the form DOMAIN\\user. It is currently %s", creds["username"]) + } + + domain := strings.ToUpper(splits[0]) + username := splits[1] + + if c.ntlmSessions == nil { + c.ntlmSessions = make(map[string]ntlm.ClientSession) + } + + if ses, ok := c.ntlmSessions[domain]; ok { + return ses, nil + } + + session, err := ntlm.CreateClientSession(ntlm.Version2, ntlm.ConnectionOrientedMode) + if err != nil { + return nil, err + } + + session.SetUserInfo(username, creds["password"], strings.ToUpper(splits[0])) + c.ntlmSessions[domain] = session + return session, nil +} + +const ntlmNegotiateMessage = "TlRMTVNTUAABAAAAB7IIogwADAAzAAAACwALACgAAAAKAAAoAAAAD1dJTExISS1NQUlOTk9SVEhBTUVSSUNB" diff --git a/lfsapi/ntlm_test.go b/lfsapi/ntlm_test.go index ab9eba94..89886fdb 100644 --- a/lfsapi/ntlm_test.go +++ b/lfsapi/ntlm_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestNTLMAuth(t *testing.T) { +func TestNtlmAuth(t *testing.T) { session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode) require.Nil(t, err) session.SetUserInfo("ntlmuser", "ntlmpass", "NTLMDOMAIN") @@ -36,12 +36,21 @@ func TestNTLMAuth(t *testing.T) { assert.Equal(t, "ntlm", string(by)) } - switch authHeader { - case "": + switch called { + case 1: w.Header().Set("Www-Authenticate", "ntlm") w.WriteHeader(401) - case ntlmNegotiateMessage: + case 2: assert.True(t, strings.HasPrefix(req.Header.Get("Authorization"), "NTLM ")) + neg := authHeader[5:] // strip "ntlm " prefix + _, err := base64.StdEncoding.DecodeString(neg) + if !assert.Nil(t, err) { + t.Logf("neg base64 error: %+v", err) + w.WriteHeader(500) + return + } + + // ntlm implementation can't currently parse the negotiate message ch, err := session.GenerateChallengeMessage() if !assert.Nil(t, err) { t.Logf("challenge gen error: %+v", err) @@ -62,13 +71,21 @@ func TestNTLMAuth(t *testing.T) { return } - _, err = ntlm.ParseAuthenticateMessage(val, 2) + authMsg, err := ntlm.ParseAuthenticateMessage(val, 2) if !assert.Nil(t, err) { t.Logf("auth parse error: %+v", err) w.WriteHeader(500) return } + + err = session.ProcessAuthenticateMessage(authMsg) + if !assert.Nil(t, err) { + t.Logf("auth process error: %+v", err) + w.WriteHeader(500) + return + } w.WriteHeader(200) + } })) defer srv.Close()