lfsapi: teach lfsapi.Client to retry requests
This commit is contained in:
parent
8fa1985b90
commit
e15e149764
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
@ -122,10 +123,31 @@ func (c *Client) doWithRedirects(cli *http.Client, req *http.Request, via []*htt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res, err := cli.Do(req)
|
||||
var retries int
|
||||
if n, ok := Retries(req); ok {
|
||||
retries = n
|
||||
} else {
|
||||
retries = defaultRequestRetries
|
||||
}
|
||||
|
||||
var res *http.Response
|
||||
|
||||
for i := 0; i < retries; i++ {
|
||||
res, err = cli.Do(req)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if seek, ok := req.Body.(io.Seeker); ok {
|
||||
seek.Seek(0, io.SeekStart)
|
||||
}
|
||||
|
||||
c.traceResponse(req, tracedReq, nil)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.traceResponse(req, tracedReq, nil)
|
||||
return res, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.traceResponse(req, tracedReq, res)
|
||||
|
@ -12,6 +12,10 @@ const (
|
||||
// contextKeyRetries is a context.Context key for storing the desired
|
||||
// number of retries for a given request.
|
||||
contextKeyRetries ckey = "retries"
|
||||
|
||||
// defaultRequestRetries is the default number of retries to perform on
|
||||
// a given HTTP request.
|
||||
defaultRequestRetries = 1
|
||||
)
|
||||
|
||||
// WithRetries stores the desired number of retries "n" on the given
|
||||
|
@ -1,10 +1,14 @@
|
||||
package lfsapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestWithRetries(t *testing.T) {
|
||||
@ -23,3 +27,60 @@ func TestRetriesOnUnannotatedRequest(t *testing.T) {
|
||||
assert.False(t, ok)
|
||||
assert.Equal(t, 0, n)
|
||||
}
|
||||
|
||||
func TestRequestWithRetries(t *testing.T) {
|
||||
type T struct {
|
||||
S string `json:"s"`
|
||||
}
|
||||
|
||||
var hasRaw bool = true
|
||||
var requests uint32
|
||||
var berr error
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload T
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
berr = err
|
||||
}
|
||||
|
||||
assert.Equal(t, "Hello, world!", payload.S)
|
||||
|
||||
if atomic.AddUint32(&requests, 1) < 3 {
|
||||
raw, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
hasRaw = false
|
||||
return
|
||||
}
|
||||
|
||||
conn, _, err := raw.Hijack()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
return
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c, err := NewClient(nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
req, err := http.NewRequest("POST", srv.URL, nil)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, MarshalToRequest(req, &T{"Hello, world!"}))
|
||||
|
||||
if !hasRaw {
|
||||
// Skip tests where the implementation of
|
||||
// net/http/httptest.Server does not provide raw access to the
|
||||
// connection.
|
||||
//
|
||||
// Defer the skip outside of the server, since t.Skip halts the
|
||||
// running goroutine.
|
||||
t.Skip("lfsapi: net/http/httptest.Server does not provide raw access")
|
||||
}
|
||||
|
||||
res, err := c.Do(WithRetries(req, 8))
|
||||
assert.NoError(t, berr)
|
||||
assert.NoError(t, err)
|
||||
require.NotNil(t, res, "lfsapi: expected response")
|
||||
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user