diff --git a/git/filter_process_scanner_test.go b/git/filter_process_scanner_test.go index 6acc5506..50fc87ed 100644 --- a/git/filter_process_scanner_test.go +++ b/git/filter_process_scanner_test.go @@ -35,6 +35,7 @@ func TestFilterProcessScannerRejectsUnrecognizedInitializationMessages(t *testin pl := newPktline(nil, &from) require.Nil(t, pl.writePacketText("git-filter-client-unknown")) + require.Nil(t, pl.writeFlush()) fps := NewFilterProcessScanner(&from, &to) err := fps.Init() @@ -106,8 +107,7 @@ func TestFilterProcessScannerReadsRequestHeadersAndPayload(t *testing.T) { // Multi-line packet require.Nil(t, pl.writePacketText("first")) require.Nil(t, pl.writePacketText("second")) - _, err := from.Write([]byte{0x30, 0x30, 0x30, 0x30}) // flush packet - assert.Nil(t, err) + require.Nil(t, pl.writeFlush()) req, err := readRequest(NewFilterProcessScanner(&from, &to)) @@ -121,13 +121,11 @@ func TestFilterProcessScannerReadsRequestHeadersAndPayload(t *testing.T) { } func TestFilterProcessScannerRejectsInvalidHeaderPackets(t *testing.T) { - var from bytes.Buffer + from := bytes.NewBuffer([]byte{ + 0x30, 0x30, 0x30, 0x34, // 0004 (invalid packet length) + }) - pl := newPktline(nil, &from) - // (Invalid) headers - require.Nil(t, pl.writePacket([]byte{})) - - req, err := readRequest(NewFilterProcessScanner(&from, nil)) + req, err := readRequest(NewFilterProcessScanner(from, nil)) require.NotNil(t, err) assert.Equal(t, "Invalid packet length.", err.Error()) diff --git a/git/pkt_line.go b/git/pkt_line.go index 71fce485..4ca100d4 100644 --- a/git/pkt_line.go +++ b/git/pkt_line.go @@ -68,11 +68,17 @@ func (p *pktline) readPacket() ([]byte, error) { return payload, err } +// readPacketText follows identical semantics to the `readPacket()` function, +// but additionally removes the trailing `\n` LF from the end of the packet, if +// present. func (p *pktline) readPacketText() (string, error) { data, err := p.readPacket() return strings.TrimSuffix(string(data), "\n"), err } +// readPacketList reads as many packets as possible using the `readPacketText` +// function before encountering a flush packet. It returns a slice of all the +// packets it read, or an error if one was encountered. func (p *pktline) readPacketList() ([]string, error) { var list []string for { @@ -91,6 +97,15 @@ func (p *pktline) readPacketList() ([]string, error) { return list, nil } +// writePacket writes the given data in "data" to the underlying data stream +// using Git's `pkt-line` format. +// +// If the data was longer than MaxPacketLength, an error will be returned. If +// there was any error encountered while writing any component of the packet +// (hdr, payload), it will be returned. +// +// NB: writePacket does _not_ flush the underlying buffered writer. See instead: +// `writeFlush()`. func (p *pktline) writePacket(data []byte) error { if len(data) > MaxPacketLength { return errors.New("Packet length exceeds maximal length") @@ -104,13 +119,13 @@ func (p *pktline) writePacket(data []byte) error { return err } - if err := p.w.Flush(); err != nil { - return err - } - return nil } +// writeFlush writes the terminating "flush" packet and then flushes the +// underlying buffered writer. +// +// If any error was encountered along the way, it will be returned immediately func (p *pktline) writeFlush() error { if _, err := p.w.WriteString(fmt.Sprintf("%04x", 0)); err != nil { return err @@ -123,10 +138,16 @@ func (p *pktline) writeFlush() error { return nil } +// writePacketText follows the same semantics as `writePacket`, but appends a +// trailing "\n" LF character to the end of the data. func (p *pktline) writePacketText(data string) error { return p.writePacket([]byte(data + "\n")) } +// writePacketList writes a slice of strings using the semantics of +// and then writes a terminating flush sequence afterwords. +// +// If any error was encountered, it will be returned immediately. func (p *pktline) writePacketList(list []string) error { for _, i := range list { if err := p.writePacketText(i); err != nil { diff --git a/git/pkt_line_reader_test.go b/git/pkt_line_reader_test.go index 64676df1..2c573b32 100644 --- a/git/pkt_line_reader_test.go +++ b/git/pkt_line_reader_test.go @@ -2,28 +2,29 @@ package git import ( "bytes" - "fmt" "io" "io/ioutil" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // writePackets -func writePacket(w io.Writer, datas ...[]byte) { - for _, data := range datas { - io.WriteString(w, fmt.Sprintf("%04x", len(data)+4)) - w.Write(data) +func writePacket(t *testing.T, w io.Writer, datas ...[]byte) { + pl := newPktline(nil, w) + for _, data := range datas { + require.Nil(t, pl.writePacket(data)) } - io.WriteString(w, fmt.Sprintf("%04x", 0)) + + require.Nil(t, pl.writeFlush()) } func TestPktlineReaderReadsSinglePacketsInOneCall(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("asdf")) + writePacket(t, &buf, []byte("asdf")) pr := &pktlineReader{pl: newPktline(&buf, nil)} @@ -36,7 +37,7 @@ func TestPktlineReaderReadsSinglePacketsInOneCall(t *testing.T) { func TestPktlineReaderReadsManyPacketsInOneCall(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("first\n"), []byte("second")) + writePacket(t, &buf, []byte("first\n"), []byte("second")) pr := &pktlineReader{pl: newPktline(&buf, nil)} @@ -49,7 +50,7 @@ func TestPktlineReaderReadsManyPacketsInOneCall(t *testing.T) { func TestPktlineReaderReadsSinglePacketsInMultipleCallsWithUnevenBuffering(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("asdf")) + writePacket(t, &buf, []byte("asdf")) pr := &pktlineReader{pl: newPktline(&buf, nil)} @@ -70,7 +71,7 @@ func TestPktlineReaderReadsSinglePacketsInMultipleCallsWithUnevenBuffering(t *te func TestPktlineReaderReadsManyPacketsInMultipleCallsWithUnevenBuffering(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("first"), []byte("second")) + writePacket(t, &buf, []byte("first"), []byte("second")) pr := &pktlineReader{pl: newPktline(&buf, nil)} @@ -97,7 +98,7 @@ func TestPktlineReaderReadsManyPacketsInMultipleCallsWithUnevenBuffering(t *test func TestPktlineReaderReadsSinglePacketsInMultipleCallsWithEvenBuffering(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("firstother")) + writePacket(t, &buf, []byte("firstother")) pr := &pktlineReader{pl: newPktline(&buf, nil)} @@ -118,7 +119,7 @@ func TestPktlineReaderReadsSinglePacketsInMultipleCallsWithEvenBuffering(t *test func TestPktlineReaderReadsManyPacketsInMultipleCallsWithEvenBuffering(t *testing.T) { var buf bytes.Buffer - writePacket(&buf, []byte("first"), []byte("other")) + writePacket(t, &buf, []byte("first"), []byte("other")) pr := &pktlineReader{pl: newPktline(&buf, nil)} diff --git a/git/pkt_line_test.go b/git/pkt_line_test.go index e84329d1..201baa63 100644 --- a/git/pkt_line_test.go +++ b/git/pkt_line_test.go @@ -169,14 +169,15 @@ func TestPktLineWritesPackets(t *testing.T) { var buf bytes.Buffer rw := newPktline(nil, &buf) - err := rw.writePacket([]byte{ + require.Nil(t, rw.writePacket([]byte{ 0x1, 0x2, 0x3, 0x4, - }) + })) + require.Nil(t, rw.writeFlush()) - assert.Nil(t, err) assert.Equal(t, []byte{ 0x30, 0x30, 0x30, 0x38, // 0008 (hex. length) 0x1, 0x2, 0x3, 0x4, // payload + 0x30, 0x30, 0x30, 0x30, // 0000 (flush packet) }, buf.Bytes()) } @@ -205,12 +206,14 @@ func TestPktLineWritesPacketText(t *testing.T) { var buf bytes.Buffer rw := newPktline(nil, &buf) - err := rw.writePacketText("abcd") - assert.Nil(t, err) + require.Nil(t, rw.writePacketText("abcd")) + require.Nil(t, rw.writeFlush()) + assert.Equal(t, []byte{ 0x30, 0x30, 0x30, 0x39, // 0009 (hex. length) 0x61, 0x62, 0x63, 0x64, 0xa, // "abcd\n" (payload) + 0x30, 0x30, 0x30, 0x30, // 0000 (flush packet) }, buf.Bytes()) }