diff --git a/git/filter_protocol.go b/git/filter_protocol.go new file mode 100644 index 00000000..33956004 --- /dev/null +++ b/git/filter_protocol.go @@ -0,0 +1,114 @@ +package git + +import ( + "bufio" + "errors" + "fmt" + "io" + "io/ioutil" + "strconv" + "strings" +) + +type protocol struct { + r *bufio.Reader + w *bufio.Writer +} + +func newProtocolRW(r io.Reader, w io.Writer) *protocol { + return &protocol{ + r: bufio.NewReader(r), + w: bufio.NewWriter(w), + } +} + +func (p *protocol) readPacket() ([]byte, error) { + pktLenHex, err := ioutil.ReadAll(io.LimitReader(p.r, 4)) + if err != nil || len(pktLenHex) != 4 { // TODO check pktLenHex length + return nil, err + } + + pktLen, err := strconv.ParseInt(string(pktLenHex), 16, 0) + if err != nil { + return nil, err + } + + if pktLen == 0 { + return nil, nil + } + if pktLen <= 4 { + return nil, errors.New("Invalid packet length.") + } + + return ioutil.ReadAll(io.LimitReader(p.r, pktLen-4)) +} + +func (p *protocol) readPacketText() (string, error) { + data, err := p.readPacket() + return strings.TrimSuffix(string(data), "\n"), err +} + +func (p *protocol) readPacketList() ([]string, error) { + var list []string + for { + data, err := p.readPacketText() + if err != nil { + return nil, err + } + + if len(data) == 0 { + break + } + + list = append(list, data) + } + + return list, nil +} + +func (p *protocol) writePacket(data []byte) error { + if len(data) > MaxPacketLength { + return errors.New("Packet length exceeds maximal length") + } + + if _, err := p.w.WriteString(fmt.Sprintf("%04x", len(data)+4)); err != nil { + return err + } + + if _, err := p.w.Write(data); err != nil { + return err + } + + if err := p.w.Flush(); err != nil { + return err + } + + return nil +} + +func (p *protocol) writeFlush() error { + if _, err := p.w.WriteString(fmt.Sprintf("%04x", 0)); err != nil { + return err + } + + if err := p.w.Flush(); err != nil { + return err + } + + return nil +} + +func (p *protocol) writePacketText(data string) error { + //TODO: there is probably a more efficient way to do this. worth it? + return p.writePacket([]byte(data + "\n")) +} + +func (p *protocol) writePacketList(list []string) error { + for _, i := range list { + if err := p.writePacketText(i); err != nil { + return err + } + } + + return p.writeFlush() +} diff --git a/git/git_filter_protocol.go b/git/git_filter_protocol.go index 4f9a8efb..cb720995 100644 --- a/git/git_filter_protocol.go +++ b/git/git_filter_protocol.go @@ -3,20 +3,16 @@ package git import ( - "bufio" "fmt" "io" - "io/ioutil" "os" - "strconv" "strings" - "github.com/github/git-lfs/errors" "github.com/rubyist/tracerx" ) const ( - MaxPacketLenght = 65516 + MaxPacketLength = 65516 ) // Private function copied from "github.com/xeipuuv/gojsonschema/utils.go" @@ -31,109 +27,20 @@ func isStringInSlice(s []string, what string) bool { } type ObjectScanner struct { - r *bufio.Reader - w *bufio.Writer + p *protocol } func NewObjectScanner(r io.Reader, w io.Writer) *ObjectScanner { return &ObjectScanner{ - r: bufio.NewReader(r), - w: bufio.NewWriter(w), + p: newProtocolRW(r, w), } } -func (o *ObjectScanner) readPacket() ([]byte, error) { - pktLenHex, err := ioutil.ReadAll(io.LimitReader(o.r, 4)) - if err != nil || len(pktLenHex) != 4 { // TODO check pktLenHex length - return nil, err - } - pktLen, err := strconv.ParseInt(string(pktLenHex), 16, 0) - if err != nil { - return nil, err - } - if pktLen == 0 { - return nil, nil - } else if pktLen <= 4 { - return nil, errors.New("Invalid packet length.") - } - return ioutil.ReadAll(io.LimitReader(o.r, pktLen-4)) -} - -func (o *ObjectScanner) readPacketText() (string, error) { - data, err := o.readPacket() - return strings.TrimSuffix(string(data), "\n"), err -} - -func (o *ObjectScanner) readPacketList() ([]string, error) { - var list []string - for { - data, err := o.readPacketText() - if err != nil { - return nil, err - } - if len(data) == 0 { - break - } - list = append(list, data) - } - return list, nil -} - -func (o *ObjectScanner) writePacket(data []byte) error { - if len(data) > MaxPacketLenght { - return errors.New("Packet length exceeds maximal length") - } - _, err := o.w.WriteString(fmt.Sprintf("%04x", len(data)+4)) - if err != nil { - return err - } - _, err = o.w.Write(data) - if err != nil { - return err - } - err = o.w.Flush() - if err != nil { - return err - } - return nil -} - -func (o *ObjectScanner) writeFlush() error { - _, err := o.w.WriteString(fmt.Sprintf("%04x", 0)) - if err != nil { - return err - } - err = o.w.Flush() - if err != nil { - return err - } - return nil -} - -func (o *ObjectScanner) writePacketText(data string) error { - //TODO: there is probably a more efficient way to do this. worth it? - return o.writePacket([]byte(data + "\n")) -} - -func (o *ObjectScanner) writePacketList(list []string) error { - for _, i := range list { - err := o.writePacketText(i) - if err != nil { - return err - } - } - return o.writeFlush() -} - -func (o *ObjectScanner) writeStatus(status string) error { - return o.writePacketList([]string{"status=" + status}) -} - func (o *ObjectScanner) Init() bool { tracerx.Printf("Initialize filter") reqVer := "version=2" - initMsg, err := o.readPacketText() + initMsg, err := o.p.readPacketText() if err != nil { fmt.Fprintf(os.Stderr, "Error: reading filter initialization failed with %s\n", err) @@ -145,7 +52,7 @@ func (o *ObjectScanner) Init() bool { return false } - supVers, err := o.readPacketList() + supVers, err := o.p.readPacketList() if err != nil { fmt.Fprintf(os.Stderr, "Error: reading filter versions failed with %s\n", err) @@ -158,7 +65,7 @@ func (o *ObjectScanner) Init() bool { return false } - err = o.writePacketList([]string{"git-filter-server", reqVer}) + err = o.p.writePacketList([]string{"git-filter-server", reqVer}) if err != nil { fmt.Fprintf(os.Stderr, "Error: writing filter initialization failed with %s\n", err) @@ -170,7 +77,7 @@ func (o *ObjectScanner) Init() bool { func (o *ObjectScanner) NegotiateCapabilities() bool { reqCaps := []string{"capability=clean", "capability=smudge"} - supCaps, err := o.readPacketList() + supCaps, err := o.p.readPacketList() if err != nil { fmt.Fprintf(os.Stderr, "Error: reading filter capabilities failed with %s\n", err) @@ -185,7 +92,7 @@ func (o *ObjectScanner) NegotiateCapabilities() bool { } } - err = o.writePacketList(reqCaps) + err = o.p.writePacketList(reqCaps) if err != nil { fmt.Fprintf(os.Stderr, "Error: writing filter capabilities failed with %s\n", err) @@ -198,7 +105,7 @@ func (o *ObjectScanner) NegotiateCapabilities() bool { func (o *ObjectScanner) ReadRequest() (map[string]string, []byte, error) { tracerx.Printf("Process filter command.") - requestList, err := o.readPacketList() + requestList, err := o.p.readPacketList() if err != nil { return nil, nil, err } @@ -211,7 +118,7 @@ func (o *ObjectScanner) ReadRequest() (map[string]string, []byte, error) { var data []byte for { - chunk, err := o.readPacket() + chunk, err := o.p.readPacket() if err != nil { // TODO: should we check the err of this call, to?! o.writeStatus("error") @@ -230,12 +137,12 @@ func (o *ObjectScanner) WriteResponse(outputData []byte) error { for { chunkSize := len(outputData) if chunkSize == 0 { - o.writeFlush() + o.p.writeFlush() break - } else if chunkSize > MaxPacketLenght { - chunkSize = MaxPacketLenght // TODO check packets with the exact size + } else if chunkSize > MaxPacketLength { + chunkSize = MaxPacketLength // TODO check packets with the exact size } - err := o.writePacket(outputData[:chunkSize]) + err := o.p.writePacket(outputData[:chunkSize]) if err != nil { // TODO: should we check the err of this call, to?! o.writeStatus("error") @@ -246,3 +153,7 @@ func (o *ObjectScanner) WriteResponse(outputData []byte) error { o.writeStatus("success") return nil } + +func (o *ObjectScanner) writeStatus(status string) error { + return o.p.writePacketList([]string{"status=" + status}) +}