lfs: implement CatFileBatchScanner.ContentsSha()

This commit is contained in:
Taylor Blau 2017-03-22 15:15:54 -06:00
parent f0ccaf88a2
commit 90c474979c

@ -3,6 +3,7 @@ package lfs
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/sha256"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -64,9 +65,10 @@ type CatFileBatchScanner struct {
w io.Writer w io.Writer
closeFn func() error closeFn func() error
blobSha string blobSha string
pointer *WrappedPointer contentsSha string
err error pointer *WrappedPointer
err error
} }
func NewCatFileBatchScanner() (*CatFileBatchScanner, error) { func NewCatFileBatchScanner() (*CatFileBatchScanner, error) {
@ -99,6 +101,10 @@ func (s *CatFileBatchScanner) BlobSHA() string {
return s.blobSha return s.blobSha
} }
func (s *CatFileBatchScanner) ContentsSha() string {
return s.contentsSha
}
func (s *CatFileBatchScanner) Pointer() *WrappedPointer { func (s *CatFileBatchScanner) Pointer() *WrappedPointer {
return s.pointer return s.pointer
} }
@ -109,6 +115,7 @@ func (s *CatFileBatchScanner) Err() error {
func (s *CatFileBatchScanner) Scan(sha []byte) bool { func (s *CatFileBatchScanner) Scan(sha []byte) bool {
s.pointer, s.err = nil, nil s.pointer, s.err = nil, nil
s.blobSha, s.contentsSha = "", ""
if s.w != nil && sha != nil { if s.w != nil && sha != nil {
if _, err := fmt.Fprintf(s.w, "%s\n", sha); err != nil { if _, err := fmt.Fprintf(s.w, "%s\n", sha); err != nil {
@ -117,8 +124,9 @@ func (s *CatFileBatchScanner) Scan(sha []byte) bool {
} }
} }
b, p, err := s.next() b, c, p, err := s.next()
s.blobSha = b s.blobSha = b
s.contentsSha = c
s.pointer = p s.pointer = p
if err != nil { if err != nil {
@ -138,41 +146,46 @@ func (s *CatFileBatchScanner) Close() error {
return s.closeFn() return s.closeFn()
} }
func (s *CatFileBatchScanner) next() (string, *WrappedPointer, error) { func (s *CatFileBatchScanner) next() (string, string, *WrappedPointer, error) {
l, err := s.r.ReadBytes('\n') l, err := s.r.ReadBytes('\n')
if err != nil { if err != nil {
return "", nil, err return "", "", nil, err
} }
// Line is formatted: // Line is formatted:
// <sha1> <type> <size> // <sha1> <type> <size>
fields := bytes.Fields(l) fields := bytes.Fields(l)
if len(fields) < 3 { if len(fields) < 3 {
return "", nil, errors.Wrap(fmt.Errorf("Invalid: %q", string(l)), "git cat-file --batch") return "", "", nil, errors.Wrap(fmt.Errorf("Invalid: %q", string(l)), "git cat-file --batch")
} }
blobSha := string(fields[0]) blobSha := string(fields[0])
size, _ := strconv.Atoi(string(fields[2])) size, _ := strconv.Atoi(string(fields[2]))
sha := sha256.New()
buf := make([]byte, size) buf := make([]byte, size)
read, err := io.ReadFull(s.r, buf) read, err := io.ReadFull(io.TeeReader(s.r, sha), buf)
if err != nil { if err != nil {
return blobSha, nil, err return blobSha, "", nil, err
} }
if size != read { if size != read {
return blobSha, nil, fmt.Errorf("expected %d bytes, read %d bytes", size, read) return blobSha, "", nil, fmt.Errorf("expected %d bytes, read %d bytes", size, read)
} }
p, err := DecodePointer(bytes.NewBuffer(buf[:read])) p, err := DecodePointer(bytes.NewBuffer(buf[:read]))
var pointer *WrappedPointer var pointer *WrappedPointer
var contentsSha string
if err == nil { if err == nil {
contentsSha = p.Oid
pointer = &WrappedPointer{ pointer = &WrappedPointer{
Sha1: blobSha, Sha1: blobSha,
Pointer: p, Pointer: p,
} }
} else {
contentsSha = fmt.Sprintf("%x", sha.Sum(nil))
} }
_, err = s.r.ReadBytes('\n') // Extra \n inserted by cat-file _, err = s.r.ReadBytes('\n') // Extra \n inserted by cat-file
return blobSha, pointer, err return blobSha, contentsSha, pointer, err
} }