diff --git a/lfs/client.go b/lfs/client.go index 4afca852..a134860b 100644 --- a/lfs/client.go +++ b/lfs/client.go @@ -127,14 +127,18 @@ func Download(oid string) (io.ReadCloser, int64, error) { } LogTransfer("lfs.data.download", res) - buf, _ := ioutil.ReadAll(res.Body) - body := myReader{bytes.NewBuffer(buf)} - - //We most close the body to ensure the http connection is kept alive - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - - return body, res.ContentLength, nil + if(Config.NtlmAccess()){ + buf, _ := ioutil.ReadAll(res.Body) + body := myReader{bytes.NewBuffer(buf)} + + //We must close the body to ensure the http connection is kept alive + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + return body, res.ContentLength, nil + } + + return res.Body, res.ContentLength, nil } type byteCloser struct { @@ -175,15 +179,19 @@ func DownloadObject(obj *objectResource) (io.ReadCloser, int64, error) { return nil, 0, newRetriableError(err) } LogTransfer("lfs.data.download", res) + + if(Config.NtlmAccess()){ + buf, _ := ioutil.ReadAll(res.Body) + body := myReader{bytes.NewBuffer(buf)} + + //We must close the body to ensure the http connection is kept alive + io.Copy(ioutil.Discard, res.Body) + res.Body.Close() + + return body, res.ContentLength, nil + } - buf, _ := ioutil.ReadAll(res.Body) - body := myReader{bytes.NewBuffer(buf)} - - //We most close the body to ensure the http connection is kept alive - io.Copy(ioutil.Discard, res.Body) - res.Body.Close() - - return body, res.ContentLength, nil + return res.Body, res.ContentLength, nil } func (b *byteCloser) Close() error { @@ -229,13 +237,11 @@ func Batch(objects []*objectResource, operation string) ([]*objectResource, erro } if IsAuthError(err){ - if strings.ToLower(res.Header["Www-Authenticate"][0][0:4]) == "ntlm" { - Config.SetAccess("ntlm") - tracerx.Printf("api: response indicates ntlm, submitting with ntlm auth") + if isNtlmRequest(res) { + toggleAuthType("ntlm", "api: response indicates ntlm, submitting with %s auth") return Batch(objects, operation) } else { - Config.SetAccess("basic") - tracerx.Printf("api: batch not authorized, submitting with auth") + toggleAuthType("basic", "api: batch not authorized, submitting with %s auth") return Batch(objects, operation) } } @@ -293,13 +299,11 @@ func UploadCheck(oidPath string) (*objectResource, error) { if err != nil { if IsAuthError(err) { - if strings.ToLower(res.Header["Www-Authenticate"][0][0:4]) == "ntlm" { - Config.SetAccess("ntlm") - tracerx.Printf("api: response indicates ntlm, submitting with ntlm auth") + if isNtlmRequest(res) { + toggleAuthType("ntlm", "api: response indicates ntlm, submitting with %s auth") return UploadCheck(oidPath) } else{ - Config.SetAccess("basic") - tracerx.Printf("api: upload check not authorized, submitting with auth") + toggleAuthType("basic", "api: batch not authorized, submitting with %s auth") return UploadCheck(oidPath) } } @@ -494,9 +498,8 @@ func doHttpRequest(req *http.Request, creds Creds) (*http.Response, error) { } if err != nil { - if IsAuthError(err) && res.Header["Www-Authenticate"][0][0:4] == "ntlm" { - Config.SetAccess("ntlm") - tracerx.Printf("api: response indicates ntlm, submitting with ntlm auth") + if IsAuthError(err) && isNtlmRequest(res) { + toggleAuthType("ntlm", "api: response indicates ntlm, submitting with %s auth") doHttpRequest(req, creds) } else { err = Error(err) @@ -745,18 +748,29 @@ func setRequestAuthFromUrl(req *http.Request, u *url.URL) bool { return false } +func isNtlmRequest(res *http.Response)(bool){ + + header := res.Header.Get("Www-Authenticate") + return strings.HasPrefix(header, "ntlm") +} + +func toggleAuthType(authType string, message string){ + Config.SetAccess(authType) + tracerx.Printf(message, authType) +} + func setRequestAuth(req *http.Request, user, pass string) { if(Config.NtlmAccess()){ - //no-op. The NTLM manager will handle auth headers - } else { - if len(user) == 0 && len(pass) == 0 { - return - } - - token := fmt.Sprintf("%s:%s", user, pass) - auth := "Basic " + base64.URLEncoding.EncodeToString([]byte(token)) - req.Header.Set("Authorization", auth) + return } + + if len(user) == 0 && len(pass) == 0 { + return + } + + token := fmt.Sprintf("%s:%s", user, pass) + auth := "Basic " + base64.URLEncoding.EncodeToString([]byte(token)) + req.Header.Set("Authorization", auth) } func setErrorResponseContext(err error, res *http.Response) { diff --git a/lfs/config.go b/lfs/config.go index 9e9b2427..ad6e2eca 100644 --- a/lfs/config.go +++ b/lfs/config.go @@ -144,7 +144,7 @@ func (c *Configuration) BatchTransfer() bool { } func (c *Configuration) NtlmAccess() bool { - return c.Access() == "ntlm" + return c.Access() == "none" } // PrivateAccess will retrieve the access value and return true if @@ -152,7 +152,7 @@ func (c *Configuration) NtlmAccess() bool { // access, the http requests for the batch api will fetch the credentials // before running, otherwise the request will run without credentials. func (c *Configuration) PrivateAccess() bool { - return c.Access() != "none" && c.Access() != "ntlm" + return c.Access() != "none" } // Access returns the access auth type. diff --git a/lfs/ntlm.go b/lfs/ntlm.go index 7ceab522..dd7aed3d 100644 --- a/lfs/ntlm.go +++ b/lfs/ntlm.go @@ -3,6 +3,7 @@ package lfs import ( "bytes" "encoding/base64" + "errors" "github.com/ThomsonReutersEikon/go-ntlm/ntlm" "io" "io/ioutil" @@ -10,35 +11,71 @@ import ( "strings" ) -func (c *Configuration) NTLMSession(creds Creds) ntlm.ClientSession { +func (c *Configuration) ntlmClientSession(creds Creds) (ntlm.ClientSession, error) { if c.ntlmSession != nil { - return c.ntlmSession + return c.ntlmSession, nil } splits := strings.Split(creds["username"], "\\") - var session, _ = ntlm.CreateClientSession(ntlm.Version2, ntlm.ConnectionOrientedMode) + + if(len(splits) != 2){ + return nil, errors.New("Your user name must be of the form DOMAIN\\user.") + } + + var session, err = ntlm.CreateClientSession(ntlm.Version2, ntlm.ConnectionOrientedMode) + + if(err != nil){ + return nil, err + } + session.SetUserInfo(splits[1], creds["password"], strings.ToUpper(splits[0])) c.ntlmSession = session - return session + return session, nil } func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { - handReq := cloneRequest(request) - res, nil := InitHandShake(handReq) + handReq, err := cloneRequest(request) + if(err != nil){ + return nil, err + } + + + res, err := InitHandShake(handReq) + + if(err != nil && res == nil){ + return res, err + } //If the status is 401 then we need to re-authenticate, otherwise it was successful if res.StatusCode == 401 { - creds, _ := getCredsForNTLM(request) + creds, err := getCredsForNTLM(request) + if(err != nil){ + return nil, err + } - negotiateReq := cloneRequest(request) - challengeMessage := negotiate(negotiateReq, getNegotiateMessage()) + negotiateReq, err := cloneRequest(request) + if(err != nil){ + return nil, err + } - challengeReq := cloneRequest(request) - res, _ := challenge(challengeReq, challengeMessage, creds) + challengeMessage, err := negotiate(negotiateReq, getNegotiateMessage()) + if(err != nil){ + return nil, err + } + + challengeReq, err := cloneRequest(request) + if(err != nil){ + return nil, err + } + + res, err := challenge(challengeReq, challengeMessage, creds) + if(err != nil){ + return res, err + } //If the status is 401 then we need to re-authenticate if res.StatusCode == 401 && retry == true { @@ -49,91 +86,93 @@ func DoNTLMRequest(request *http.Request, retry bool) (*http.Response, error) { return res, nil } - return res, nil + return res, err } func InitHandShake(request *http.Request) (*http.Response, error){ - var response, err = Config.HttpClient().Do(request) - - if err != nil { - return nil, Error(err) - } - - return response, nil + return Config.HttpClient().Do(request) } -func negotiate(request *http.Request, message string) []byte{ +func negotiate(request *http.Request, message string) ([]byte, error){ request.Header.Add("Authorization", message) - var response, err = Config.HttpClient().Do(request) + var res, err = Config.HttpClient().Do(request) + defer io.Copy(ioutil.Discard, res.Body) + defer res.Body.Close() if err != nil{ - panic(err.Error()) + return nil, err } - ret := parseChallengeResponse(response) - - //Always close negotiate to keep the connection alive - //We never return the response from negotiate so we - //can't trust decodeApiResponse to close it - io.Copy(ioutil.Discard, response.Body) - response.Body.Close() - - return ret; + return parseChallengeResponse(res) } func challenge(request *http.Request, challengeBytes []byte, creds Creds) (*http.Response, error){ challenge, err := ntlm.ParseChallengeMessage(challengeBytes) if err != nil { - return nil, Error(err) + return nil, err } - Config.NTLMSession(creds).ProcessChallengeMessage(challenge) - authenticate, err := Config.NTLMSession(creds).GenerateAuthenticateMessage() + session, err := Config.ntlmClientSession(creds) + if err != nil { + return nil, err + } + + session.ProcessChallengeMessage(challenge) + authenticate, err := session.GenerateAuthenticateMessage() if err != nil { - return nil, Error(err) + return nil, err } authenticateMessage := concatS("NTLM ", base64.StdEncoding.EncodeToString(authenticate.Bytes())) - request.Header.Add("Authorization", authenticateMessage) - response, err := Config.HttpClient().Do(request) - - return response, nil + return Config.HttpClient().Do(request) } -func parseChallengeResponse(response *http.Response) []byte{ - if headers, ok := response.Header["Www-Authenticate"]; ok{ +func parseChallengeResponse(response *http.Response) ([]byte, error){ + header := response.Header.Get("Www-Authenticate") - //parse out the "NTLM " at the beginning of the response - challenge := headers[0][5:] - val, err := base64.StdEncoding.DecodeString(challenge) - - if err != nil{ - panic(err.Error()) - } - return []byte(val) + //parse out the "NTLM " at the beginning of the response + challenge := header[5:] + val, err := base64.StdEncoding.DecodeString(challenge) + + if err != nil{ + return nil, err } - - panic("www-Authenticate header is not present") + return []byte(val), nil } -func cloneRequest(request *http.Request) *http.Request { +func cloneRequest(request *http.Request) (*http.Request, error) { var rdr1, rdr2 myReader var clonedReq *http.Request + var err error if request.Body != nil { //If we have a body (POST/PUT etc.) //We need to do some magic to copy the request without closing the body stream - buf, _ := ioutil.ReadAll(request.Body) + buf, err := ioutil.ReadAll(request.Body) + + if(err != nil){ + return nil, err + } + rdr1 = myReader{bytes.NewBuffer(buf)} rdr2 = myReader{bytes.NewBuffer(buf)} request.Body = rdr2 // OK since rdr2 implements the io.ReadCloser interface - clonedReq, _ = http.NewRequest(request.Method, request.URL.String(), rdr1) - }else{ - clonedReq, _ = http.NewRequest(request.Method, request.URL.String(), nil) + clonedReq, err = http.NewRequest(request.Method, request.URL.String(), rdr1) + + if(err != nil){ + return nil, err + } + + } else { + clonedReq, err = http.NewRequest(request.Method, request.URL.String(), nil) + + if(err != nil){ + return nil, err + } } for k, v := range request.Header { @@ -142,7 +181,7 @@ func cloneRequest(request *http.Request) *http.Request { clonedReq.ContentLength = request.ContentLength - return clonedReq + return clonedReq, nil } func getNegotiateMessage() string{