32da88f982
We convert several log messages to use the uppercase acronym "HTTP" instead of the lowercase variant. We also edit one message in lfshttp/client.go to remove an idiosyncratic prefix of the full source code file name, and edit the message slightly for greater clarity.
412 lines
11 KiB
Go
412 lines
11 KiB
Go
package lfshttp
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type redirectTest struct {
|
|
Test string
|
|
}
|
|
|
|
func TestClientRedirect(t *testing.T) {
|
|
var srv3Https, srv3Http string
|
|
|
|
var called1 uint32
|
|
var called2 uint32
|
|
var called3 uint32
|
|
srv3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&called3, 1)
|
|
t.Logf("srv3 req %s %s", r.Method, r.URL.Path)
|
|
assert.Equal(t, "POST", r.Method)
|
|
|
|
switch r.URL.Path {
|
|
case "/upgrade":
|
|
assert.Equal(t, "auth", r.Header.Get("Authorization"))
|
|
assert.Equal(t, "1", r.Header.Get("A"))
|
|
w.Header().Set("Location", srv3Https+"/upgraded")
|
|
w.WriteHeader(301)
|
|
case "/upgraded":
|
|
// Since srv3 listens on both a TLS-enabled socket and a
|
|
// TLS-disabled one, they are two different hosts.
|
|
// Ensure that, even though this is a "secure" upgrade,
|
|
// the authorization header is stripped.
|
|
assert.Equal(t, "", r.Header.Get("Authorization"))
|
|
assert.Equal(t, "1", r.Header.Get("A"))
|
|
|
|
case "/downgrade":
|
|
assert.Equal(t, "auth", r.Header.Get("Authorization"))
|
|
assert.Equal(t, "1", r.Header.Get("A"))
|
|
w.Header().Set("Location", srv3Http+"/404")
|
|
w.WriteHeader(301)
|
|
|
|
default:
|
|
w.WriteHeader(404)
|
|
}
|
|
}))
|
|
|
|
srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&called2, 1)
|
|
t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
|
|
assert.Equal(t, "POST", r.Method)
|
|
|
|
switch r.URL.Path {
|
|
case "/ok":
|
|
assert.Equal(t, "", r.Header.Get("Authorization"))
|
|
assert.Equal(t, "1", r.Header.Get("A"))
|
|
body := &redirectTest{}
|
|
err := json.NewDecoder(r.Body).Decode(body)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "External", body.Test)
|
|
|
|
w.WriteHeader(200)
|
|
default:
|
|
w.WriteHeader(404)
|
|
}
|
|
}))
|
|
|
|
srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&called1, 1)
|
|
t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
|
|
assert.Equal(t, "POST", r.Method)
|
|
|
|
switch r.URL.Path {
|
|
case "/local":
|
|
w.Header().Set("Location", "/ok")
|
|
w.WriteHeader(307)
|
|
case "/external":
|
|
w.Header().Set("Location", srv2.URL+"/ok")
|
|
w.WriteHeader(307)
|
|
case "/ok":
|
|
assert.Equal(t, "auth", r.Header.Get("Authorization"))
|
|
assert.Equal(t, "1", r.Header.Get("A"))
|
|
body := &redirectTest{}
|
|
err := json.NewDecoder(r.Body).Decode(body)
|
|
assert.Nil(t, err)
|
|
assert.Equal(t, "Local", body.Test)
|
|
|
|
w.WriteHeader(200)
|
|
default:
|
|
w.WriteHeader(404)
|
|
}
|
|
}))
|
|
defer srv1.Close()
|
|
defer srv2.Close()
|
|
defer srv3.Close()
|
|
|
|
srv3InsecureListener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.Nil(t, err)
|
|
|
|
go http.Serve(srv3InsecureListener, srv3.Config.Handler)
|
|
defer srv3InsecureListener.Close()
|
|
|
|
srv3Https = srv3.URL
|
|
srv3Http = fmt.Sprintf("http://%s", srv3InsecureListener.Addr().String())
|
|
|
|
c, err := NewClient(NewContext(nil, nil, map[string]string{
|
|
fmt.Sprintf("http.%s.sslverify", srv3Https): "false",
|
|
fmt.Sprintf("http.%s/.sslverify", srv3Https): "false",
|
|
fmt.Sprintf("http.%s.sslverify", srv3Http): "false",
|
|
fmt.Sprintf("http.%s/.sslverify", srv3Http): "false",
|
|
fmt.Sprintf("http.sslverify"): "false",
|
|
}))
|
|
require.Nil(t, err)
|
|
|
|
// local redirect
|
|
req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
|
|
require.Nil(t, err)
|
|
req.Header.Set("Authorization", "auth")
|
|
req.Header.Set("A", "1")
|
|
|
|
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
|
|
|
|
res, err := c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 2, called1)
|
|
assert.EqualValues(t, 0, called2)
|
|
|
|
// external redirect
|
|
req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
|
|
require.Nil(t, err)
|
|
req.Header.Set("Authorization", "auth")
|
|
req.Header.Set("A", "1")
|
|
|
|
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
|
|
|
|
res, err = c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 3, called1)
|
|
assert.EqualValues(t, 1, called2)
|
|
|
|
// http -> https (secure upgrade)
|
|
|
|
req, err = http.NewRequest("POST", srv3Http+"/upgrade", nil)
|
|
require.Nil(t, err)
|
|
req.Header.Set("Authorization", "auth")
|
|
req.Header.Set("A", "1")
|
|
|
|
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "http->https"}))
|
|
|
|
res, err = c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 2, atomic.LoadUint32(&called3))
|
|
|
|
// https -> http (insecure downgrade)
|
|
|
|
req, err = http.NewRequest("POST", srv3Https+"/downgrade", nil)
|
|
require.Nil(t, err)
|
|
req.Header.Set("Authorization", "auth")
|
|
req.Header.Set("A", "1")
|
|
|
|
require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "https->http"}))
|
|
|
|
_, err = c.Do(req)
|
|
assert.EqualError(t, err, "refusing insecure redirect: HTTPS to HTTP")
|
|
}
|
|
|
|
func TestNewClient(t *testing.T) {
|
|
c, err := NewClient(NewContext(nil, nil, map[string]string{
|
|
"lfs.dialtimeout": "151",
|
|
"lfs.keepalive": "152",
|
|
"lfs.tlstimeout": "153",
|
|
"lfs.concurrenttransfers": "154",
|
|
}))
|
|
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 151, c.DialTimeout)
|
|
assert.Equal(t, 152, c.KeepaliveTimeout)
|
|
assert.Equal(t, 153, c.TLSTimeout)
|
|
assert.Equal(t, 154, c.ConcurrentTransfers)
|
|
}
|
|
|
|
func TestNewClientWithGitSSLVerify(t *testing.T) {
|
|
c, err := NewClient(nil)
|
|
assert.Nil(t, err)
|
|
assert.False(t, c.SkipSSLVerify)
|
|
|
|
for _, value := range []string{"true", "1", "t"} {
|
|
c, err = NewClient(NewContext(nil, nil, map[string]string{
|
|
"http.sslverify": value,
|
|
}))
|
|
t.Logf("http.sslverify: %q", value)
|
|
assert.Nil(t, err)
|
|
assert.False(t, c.SkipSSLVerify)
|
|
}
|
|
|
|
for _, value := range []string{"false", "0", "f"} {
|
|
c, err = NewClient(NewContext(nil, nil, map[string]string{
|
|
"http.sslverify": value,
|
|
}))
|
|
t.Logf("http.sslverify: %q", value)
|
|
assert.Nil(t, err)
|
|
assert.True(t, c.SkipSSLVerify)
|
|
}
|
|
}
|
|
|
|
func TestNewClientWithOSSSLVerify(t *testing.T) {
|
|
c, err := NewClient(nil)
|
|
assert.Nil(t, err)
|
|
assert.False(t, c.SkipSSLVerify)
|
|
|
|
for _, value := range []string{"false", "0", "f"} {
|
|
c, err = NewClient(NewContext(nil, map[string]string{
|
|
"GIT_SSL_NO_VERIFY": value,
|
|
}, nil))
|
|
t.Logf("GIT_SSL_NO_VERIFY: %q", value)
|
|
assert.Nil(t, err)
|
|
assert.False(t, c.SkipSSLVerify)
|
|
}
|
|
|
|
for _, value := range []string{"true", "1", "t"} {
|
|
c, err = NewClient(NewContext(nil, map[string]string{
|
|
"GIT_SSL_NO_VERIFY": value,
|
|
}, nil))
|
|
t.Logf("GIT_SSL_NO_VERIFY: %q", value)
|
|
assert.Nil(t, err)
|
|
assert.True(t, c.SkipSSLVerify)
|
|
}
|
|
}
|
|
|
|
func TestNewRequest(t *testing.T) {
|
|
tests := [][]string{
|
|
{"https://example.com", "a", "https://example.com/a"},
|
|
{"https://example.com/", "a", "https://example.com/a"},
|
|
{"https://example.com/a", "b", "https://example.com/a/b"},
|
|
{"https://example.com/a/", "b", "https://example.com/a/b"},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
c, err := NewClient(NewContext(nil, nil, nil))
|
|
require.Nil(t, err)
|
|
|
|
req, err := c.NewRequest("POST", Endpoint{Url: test[0]}, test[1], nil)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, "POST", req.Method)
|
|
assert.Equal(t, test[2], req.URL.String(), fmt.Sprintf("endpoint: %s, suffix: %s, expected: %s", test[0], test[1], test[2]))
|
|
}
|
|
}
|
|
|
|
func TestNewRequestWithBody(t *testing.T) {
|
|
c, err := NewClient(NewContext(nil, nil, nil))
|
|
require.Nil(t, err)
|
|
|
|
body := struct {
|
|
Test string
|
|
}{Test: "test"}
|
|
req, err := c.NewRequest("POST", Endpoint{Url: "https://example.com"}, "body", body)
|
|
require.Nil(t, err)
|
|
|
|
assert.NotNil(t, req.Body)
|
|
assert.Equal(t, "15", req.Header.Get("Content-Length"))
|
|
assert.EqualValues(t, 15, req.ContentLength)
|
|
}
|
|
|
|
func TestMarshalToRequest(t *testing.T) {
|
|
req, err := http.NewRequest("POST", "https://foo/bar", nil)
|
|
require.Nil(t, err)
|
|
|
|
assert.Nil(t, req.Body)
|
|
assert.Equal(t, "", req.Header.Get("Content-Length"))
|
|
assert.EqualValues(t, 0, req.ContentLength)
|
|
|
|
body := struct {
|
|
Test string
|
|
}{Test: "test"}
|
|
require.Nil(t, MarshalToRequest(req, body))
|
|
|
|
assert.NotNil(t, req.Body)
|
|
assert.Equal(t, "15", req.Header.Get("Content-Length"))
|
|
assert.EqualValues(t, 15, req.ContentLength)
|
|
}
|
|
|
|
func TestHttp2(t *testing.T) {
|
|
var calledSrvTLS uint32
|
|
var calledSrv uint32
|
|
|
|
srvTLS := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&calledSrvTLS, 1)
|
|
assert.Equal(t, "GET", r.Method)
|
|
assert.Equal(t, "HTTP/2.0", r.Proto)
|
|
w.WriteHeader(200)
|
|
}))
|
|
srvTLS.TLS = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
|
srvTLS.StartTLS()
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&calledSrv, 1)
|
|
assert.Equal(t, "GET", r.Method)
|
|
assert.Equal(t, "HTTP/1.1", r.Proto)
|
|
w.WriteHeader(200)
|
|
}))
|
|
|
|
defer srvTLS.Close()
|
|
defer srv.Close()
|
|
|
|
c, err := NewClient(NewContext(nil, nil, map[string]string{
|
|
fmt.Sprintf("http.sslverify"): "false",
|
|
}))
|
|
require.Nil(t, err)
|
|
|
|
req, err := http.NewRequest("GET", srvTLS.URL, nil)
|
|
require.Nil(t, err)
|
|
|
|
res, err := c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 1, calledSrvTLS)
|
|
|
|
req, err = http.NewRequest("GET", srv.URL, nil)
|
|
require.Nil(t, err)
|
|
|
|
res, err = c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 1, calledSrv)
|
|
}
|
|
|
|
func TestHttpVersion(t *testing.T) {
|
|
testcases := []struct {
|
|
Proto string
|
|
Setting string
|
|
TLSOk bool
|
|
PlaintextOk bool
|
|
Error string
|
|
}{
|
|
{"HTTP/2.0", "HTTP/2", true, false, "HTTP/2 cannot be used except with TLS"},
|
|
{"HTTP/1.1", "HTTP/1.1", true, true, ""},
|
|
{"HTTP/2.0", "lalala", false, false, `Unknown HTTP version "lalala"`},
|
|
}
|
|
|
|
for _, test := range testcases {
|
|
var calledSrvTLS uint32
|
|
var calledSrv uint32
|
|
|
|
srvTLS := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&calledSrvTLS, 1)
|
|
assert.Equal(t, "GET", r.Method)
|
|
assert.Equal(t, test.Proto, r.Proto)
|
|
w.WriteHeader(200)
|
|
}))
|
|
srvTLS.TLS = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
|
|
srvTLS.StartTLS()
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
atomic.AddUint32(&calledSrv, 1)
|
|
assert.Equal(t, "GET", r.Method)
|
|
assert.Equal(t, "HTTP/1.1", r.Proto)
|
|
w.WriteHeader(200)
|
|
}))
|
|
|
|
defer srvTLS.Close()
|
|
defer srv.Close()
|
|
|
|
c, err := NewClient(NewContext(nil, nil, map[string]string{
|
|
"http.sslverify": "false",
|
|
"http.version": test.Setting,
|
|
}))
|
|
require.Nil(t, err)
|
|
|
|
req, err := http.NewRequest("GET", srvTLS.URL, nil)
|
|
require.Nil(t, err)
|
|
|
|
if test.TLSOk {
|
|
res, err := c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 1, calledSrvTLS)
|
|
} else {
|
|
_, err := c.Do(req)
|
|
require.NotNil(t, err)
|
|
assert.EqualValues(t, err.Error(), test.Error)
|
|
assert.EqualValues(t, 0, calledSrv)
|
|
}
|
|
|
|
req, err = http.NewRequest("GET", srv.URL, nil)
|
|
require.Nil(t, err)
|
|
|
|
if test.PlaintextOk {
|
|
res, err := c.Do(req)
|
|
require.Nil(t, err)
|
|
assert.Equal(t, 200, res.StatusCode)
|
|
assert.EqualValues(t, 1, calledSrv)
|
|
} else {
|
|
_, err := c.Do(req)
|
|
require.NotNil(t, err)
|
|
assert.EqualValues(t, err.Error(), test.Error)
|
|
assert.EqualValues(t, 0, calledSrv)
|
|
}
|
|
}
|
|
}
|