diff --git a/git/odb/object_writer.go b/git/odb/object_writer.go new file mode 100644 index 00000000..d0909417 --- /dev/null +++ b/git/odb/object_writer.go @@ -0,0 +1,123 @@ +package odb + +import ( + "compress/zlib" + "crypto/sha1" + "fmt" + "hash" + "io" + "sync/atomic" +) + +// ObjectWriter provides an implementation of io.Writer that compresses and +// writes data given to it, and keeps track of the SHA1 hash of the data as it +// is written. +type ObjectWriter struct { + // w is the underling writer that this ObjectWriter is writing to. + w io.Writer + // sum is the in-progress hash calculation. + sum hash.Hash + + // wroteHeader is a uint32 managed by the sync/atomic package. It is 1 + // if the header was written, and 0 otherwise. + wroteHeader uint32 + + // closeFn supplies an optional function that, when called, frees an + // resources (open files, memory, etc) held by this instance of the + // *ObjectWriter. + // + // closeFn returns any error encountered when closing/freeing resources + // held. + // + // It is allowed to be nil. + closeFn func() error +} + +var _ io.WriteCloser = (*ObjectWriter)(nil) + +// nopCloser provides a no-op implementation of the io.WriteCloser interface by +// taking an io.Writer and wrapping it with a Close() method that returns nil. +type nopCloser struct { + // Writer is an embedded io.Writer that receives the Write() method + // call. + io.Writer +} + +var _ io.WriteCloser = (*nopCloser)(nil) + +// Close implements the io.Closer interface by returning nil. +func (n *nopCloser) Close() error { + return nil +} + +// NewObjectWriter returns a new *ObjectWriter instance that drains incoming +// writes into the io.Writer given, "w". +func NewObjectWriter(w io.Writer) *ObjectWriter { + return NewObjectWriteCloser(&nopCloser{w}) +} + +// NewObjectWriter returns a new *ObjectWriter instance that drains incoming +// writes into the io.Writer given, "w". +// +// Upon closing, it calls the given Close() function of the io.WriteCloser. +func NewObjectWriteCloser(w io.WriteCloser) *ObjectWriter { + zw := zlib.NewWriter(w) + sum := sha1.New() + + return &ObjectWriter{ + w: io.MultiWriter(zw, sum), + sum: sum, + + closeFn: func() error { + if err := zw.Close(); err != nil { + return err + } + if err := w.Close(); err != nil { + return err + } + return nil + }, + } +} + +// WriteHeader writes object header information and returns the number of +// uncompressed bytes written, or any error that was encountered along the way. +// +// WriteHeader MUST be called only once, or a panic() will occur. +func (w *ObjectWriter) WriteHeader(typ ObjectType, len int64) (n int, err error) { + if !atomic.CompareAndSwapUint32(&w.wroteHeader, 0, 1) { + panic("git/odb: cannot write headers more than once") + } + return fmt.Fprintf(w, "%s %d\x00", typ, len) +} + +// Write writes the given buffer "p" of uncompressed bytes into the underlying +// data-stream, returning the number of uncompressed bytes written, along with +// any error encountered along the way. +// +// A call to WriteHeaders MUST occur before calling Write, or a panic() will +// occur. +func (w *ObjectWriter) Write(p []byte) (n int, err error) { + if atomic.LoadUint32(&w.wroteHeader) != 1 { + panic("git/odb: cannot write data without header") + } + return w.w.Write(p) +} + +// Sha returns the in-progress SHA1 of the compressed object contents. +func (w *ObjectWriter) Sha() []byte { + return w.sum.Sum(nil) +} + +// Close closes the ObjectWriter and frees any resources held by it, including +// flushing the zlib-compressed content to the underling writer. It must be +// called before discarding of the Writer instance. +// +// If any error occurred while calling close, it will be returned immediately, +// otherwise nil. +func (w *ObjectWriter) Close() error { + if w.closeFn == nil { + return nil + } + return w.closeFn() +} diff --git a/git/odb/object_writer_test.go b/git/odb/object_writer_test.go new file mode 100644 index 00000000..561a7736 --- /dev/null +++ b/git/odb/object_writer_test.go @@ -0,0 +1,120 @@ +package odb + +import ( + "bytes" + "compress/zlib" + "encoding/hex" + "errors" + "io" + "io/ioutil" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestObjectWriterWritesHeaders(t *testing.T) { + var buf bytes.Buffer + + w := NewObjectWriter(&buf) + + n, err := w.WriteHeader(BlobObjectType, 1) + assert.Equal(t, 7, n) + assert.Nil(t, err) + + assert.Nil(t, w.Close()) + + r, err := zlib.NewReader(&buf) + assert.Nil(t, err) + + all, err := ioutil.ReadAll(r) + assert.Nil(t, err) + assert.Equal(t, []byte("blob 1\x00"), all) + + assert.Nil(t, r.Close()) +} + +func TestObjectWriterWritesData(t *testing.T) { + var buf bytes.Buffer + + w := NewObjectWriter(&buf) + w.WriteHeader(BlobObjectType, 1) + + n, err := w.Write([]byte{0x31}) + assert.Equal(t, 1, n) + assert.Nil(t, err) + + assert.Nil(t, w.Close()) + + r, err := zlib.NewReader(&buf) + assert.Nil(t, err) + + all, err := ioutil.ReadAll(r) + assert.Nil(t, err) + assert.Equal(t, []byte("blob 1\x001"), all) + + assert.Nil(t, r.Close()) +} + +func TestObjectWriterPanicsOnWritesWithoutHeader(t *testing.T) { + defer func() { + err := recover() + + assert.NotNil(t, err) + assert.Equal(t, "git/odb: cannot write data without header", err) + }() + + w := NewObjectWriter(new(bytes.Buffer)) + w.Write(nil) +} + +func TestObjectWriterPanicsOnMultipleHeaderWrites(t *testing.T) { + defer func() { + err := recover() + + assert.NotNil(t, err) + assert.Equal(t, "git/odb: cannot write headers more than once", err) + }() + + w := NewObjectWriter(new(bytes.Buffer)) + w.WriteHeader(BlobObjectType, 1) + w.WriteHeader(TreeObjectType, 2) +} + +func TestObjectWriterKeepsTrackOfHash(t *testing.T) { + w := NewObjectWriter(new(bytes.Buffer)) + n, err := w.WriteHeader(BlobObjectType, 1) + + assert.Nil(t, err) + assert.Equal(t, 7, n) + + assert.Equal(t, "bb6ca78b66403a67c6281df142de5ef472186283", hex.EncodeToString(w.Sha())) +} + +type WriteCloserFn struct { + io.Writer + closeFn func() error +} + +var _ io.WriteCloser = (*WriteCloserFn)(nil) + +func (r *WriteCloserFn) Close() error { return r.closeFn() } + +func TestObjectWriterCallsClose(t *testing.T) { + var calls uint32 + + expected := errors.New("close error") + + w := NewObjectWriteCloser(&WriteCloserFn{ + Writer: new(bytes.Buffer), + closeFn: func() error { + atomic.AddUint32(&calls, 1) + return expected + }, + }) + + got := w.Close() + + assert.EqualValues(t, 1, calls) + assert.Equal(t, expected, got) +}