diff --git a/lfs/gitscanner_catfilebatch.go b/lfs/gitscanner_catfilebatch.go index f7b0c2e8..8e486186 100644 --- a/lfs/gitscanner_catfilebatch.go +++ b/lfs/gitscanner_catfilebatch.go @@ -162,24 +162,35 @@ func (s *CatFileBatchScanner) next() (string, string, *WrappedPointer, error) { blobSha := string(fields[0]) size, _ := strconv.Atoi(string(fields[2])) sha := sha256.New() - buf := make([]byte, size) - read, err := io.ReadFull(io.TeeReader(s.r, sha), buf) + + var buf *bytes.Buffer + var to io.Writer = sha + if size <= blobSizeCutoff { + buf = bytes.NewBuffer(make([]byte, 0, size)) + to = io.MultiWriter(to, buf) + } + + read, err := io.CopyN(to, s.r, int64(size)) if err != nil { return blobSha, "", nil, err } - if size != read { + if int64(size) != read { return blobSha, "", nil, fmt.Errorf("expected %d bytes, read %d bytes", size, read) } - p, err := DecodePointer(bytes.NewBuffer(buf[:read])) var pointer *WrappedPointer var contentsSha string - if err == nil { - contentsSha = p.Oid - pointer = &WrappedPointer{ - Sha1: blobSha, - Pointer: p, + + if size <= blobSizeCutoff { + if p, err := DecodePointer(bytes.NewReader(buf.Bytes())); err != nil { + contentsSha = fmt.Sprintf("%x", sha.Sum(nil)) + } else { + pointer = &WrappedPointer{ + Sha1: blobSha, + Pointer: p, + } + contentsSha = p.Oid } } else { contentsSha = fmt.Sprintf("%x", sha.Sum(nil)) diff --git a/lfs/gitscanner_catfilebatchscanner_test.go b/lfs/gitscanner_catfilebatchscanner_test.go index 1c0c24f3..59b4226f 100644 --- a/lfs/gitscanner_catfilebatchscanner_test.go +++ b/lfs/gitscanner_catfilebatchscanner_test.go @@ -3,12 +3,14 @@ package lfs import ( "bufio" "bytes" + "crypto/sha256" "fmt" "io" "math/rand" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCatFileBatchScannerWithValidOutput(t *testing.T) { @@ -55,6 +57,28 @@ func TestCatFileBatchScannerWithValidOutput(t *testing.T) { assert.Nil(t, scanner.Pointer()) } +func TestCatFileBatchScannerWithLargeBlobs(t *testing.T) { + buf := bytes.NewBuffer(make([]byte, 0, 1025)) + sha := sha256.New() + rng := rand.New(rand.NewSource(0)) + + _, err := io.CopyN(io.MultiWriter(sha, buf), rng, 1025) + require.Nil(t, err) + + fake := bytes.NewBuffer(nil) + writeFakeBuffer(t, fake, buf.Bytes(), buf.Len()) + + scanner := &CatFileBatchScanner{r: bufio.NewReader(fake)} + + require.True(t, scanner.Scan(nil)) + assert.Nil(t, scanner.Pointer()) + assert.Equal(t, fmt.Sprintf("%x", sha.Sum(nil)), scanner.ContentsSha()) + + assert.False(t, scanner.Scan(nil)) + assert.Nil(t, scanner.Err()) + assert.Nil(t, scanner.Pointer()) +} + func assertNextPointer(t *testing.T, scanner *CatFileBatchScanner, oid string) { assert.True(t, scanner.Scan(nil)) assert.Nil(t, scanner.Err())