git: split protocol low/high-level

This commit is contained in:
Taylor Blau 2016-10-26 15:39:17 -06:00
parent 3147b5783a
commit 8bc726dc87
2 changed files with 132 additions and 107 deletions

114
git/filter_protocol.go Normal file

@ -0,0 +1,114 @@
package git
import (
"bufio"
"errors"
"fmt"
"io"
"io/ioutil"
"strconv"
"strings"
)
type protocol struct {
r *bufio.Reader
w *bufio.Writer
}
func newProtocolRW(r io.Reader, w io.Writer) *protocol {
return &protocol{
r: bufio.NewReader(r),
w: bufio.NewWriter(w),
}
}
func (p *protocol) readPacket() ([]byte, error) {
pktLenHex, err := ioutil.ReadAll(io.LimitReader(p.r, 4))
if err != nil || len(pktLenHex) != 4 { // TODO check pktLenHex length
return nil, err
}
pktLen, err := strconv.ParseInt(string(pktLenHex), 16, 0)
if err != nil {
return nil, err
}
if pktLen == 0 {
return nil, nil
}
if pktLen <= 4 {
return nil, errors.New("Invalid packet length.")
}
return ioutil.ReadAll(io.LimitReader(p.r, pktLen-4))
}
func (p *protocol) readPacketText() (string, error) {
data, err := p.readPacket()
return strings.TrimSuffix(string(data), "\n"), err
}
func (p *protocol) readPacketList() ([]string, error) {
var list []string
for {
data, err := p.readPacketText()
if err != nil {
return nil, err
}
if len(data) == 0 {
break
}
list = append(list, data)
}
return list, nil
}
func (p *protocol) writePacket(data []byte) error {
if len(data) > MaxPacketLength {
return errors.New("Packet length exceeds maximal length")
}
if _, err := p.w.WriteString(fmt.Sprintf("%04x", len(data)+4)); err != nil {
return err
}
if _, err := p.w.Write(data); err != nil {
return err
}
if err := p.w.Flush(); err != nil {
return err
}
return nil
}
func (p *protocol) writeFlush() error {
if _, err := p.w.WriteString(fmt.Sprintf("%04x", 0)); err != nil {
return err
}
if err := p.w.Flush(); err != nil {
return err
}
return nil
}
func (p *protocol) writePacketText(data string) error {
//TODO: there is probably a more efficient way to do this. worth it?
return p.writePacket([]byte(data + "\n"))
}
func (p *protocol) writePacketList(list []string) error {
for _, i := range list {
if err := p.writePacketText(i); err != nil {
return err
}
}
return p.writeFlush()
}

@ -3,20 +3,16 @@
package git
import (
"bufio"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
"github.com/github/git-lfs/errors"
"github.com/rubyist/tracerx"
)
const (
MaxPacketLenght = 65516
MaxPacketLength = 65516
)
// Private function copied from "github.com/xeipuuv/gojsonschema/utils.go"
@ -31,109 +27,20 @@ func isStringInSlice(s []string, what string) bool {
}
type ObjectScanner struct {
r *bufio.Reader
w *bufio.Writer
p *protocol
}
func NewObjectScanner(r io.Reader, w io.Writer) *ObjectScanner {
return &ObjectScanner{
r: bufio.NewReader(r),
w: bufio.NewWriter(w),
p: newProtocolRW(r, w),
}
}
func (o *ObjectScanner) readPacket() ([]byte, error) {
pktLenHex, err := ioutil.ReadAll(io.LimitReader(o.r, 4))
if err != nil || len(pktLenHex) != 4 { // TODO check pktLenHex length
return nil, err
}
pktLen, err := strconv.ParseInt(string(pktLenHex), 16, 0)
if err != nil {
return nil, err
}
if pktLen == 0 {
return nil, nil
} else if pktLen <= 4 {
return nil, errors.New("Invalid packet length.")
}
return ioutil.ReadAll(io.LimitReader(o.r, pktLen-4))
}
func (o *ObjectScanner) readPacketText() (string, error) {
data, err := o.readPacket()
return strings.TrimSuffix(string(data), "\n"), err
}
func (o *ObjectScanner) readPacketList() ([]string, error) {
var list []string
for {
data, err := o.readPacketText()
if err != nil {
return nil, err
}
if len(data) == 0 {
break
}
list = append(list, data)
}
return list, nil
}
func (o *ObjectScanner) writePacket(data []byte) error {
if len(data) > MaxPacketLenght {
return errors.New("Packet length exceeds maximal length")
}
_, err := o.w.WriteString(fmt.Sprintf("%04x", len(data)+4))
if err != nil {
return err
}
_, err = o.w.Write(data)
if err != nil {
return err
}
err = o.w.Flush()
if err != nil {
return err
}
return nil
}
func (o *ObjectScanner) writeFlush() error {
_, err := o.w.WriteString(fmt.Sprintf("%04x", 0))
if err != nil {
return err
}
err = o.w.Flush()
if err != nil {
return err
}
return nil
}
func (o *ObjectScanner) writePacketText(data string) error {
//TODO: there is probably a more efficient way to do this. worth it?
return o.writePacket([]byte(data + "\n"))
}
func (o *ObjectScanner) writePacketList(list []string) error {
for _, i := range list {
err := o.writePacketText(i)
if err != nil {
return err
}
}
return o.writeFlush()
}
func (o *ObjectScanner) writeStatus(status string) error {
return o.writePacketList([]string{"status=" + status})
}
func (o *ObjectScanner) Init() bool {
tracerx.Printf("Initialize filter")
reqVer := "version=2"
initMsg, err := o.readPacketText()
initMsg, err := o.p.readPacketText()
if err != nil {
fmt.Fprintf(os.Stderr,
"Error: reading filter initialization failed with %s\n", err)
@ -145,7 +52,7 @@ func (o *ObjectScanner) Init() bool {
return false
}
supVers, err := o.readPacketList()
supVers, err := o.p.readPacketList()
if err != nil {
fmt.Fprintf(os.Stderr,
"Error: reading filter versions failed with %s\n", err)
@ -158,7 +65,7 @@ func (o *ObjectScanner) Init() bool {
return false
}
err = o.writePacketList([]string{"git-filter-server", reqVer})
err = o.p.writePacketList([]string{"git-filter-server", reqVer})
if err != nil {
fmt.Fprintf(os.Stderr,
"Error: writing filter initialization failed with %s\n", err)
@ -170,7 +77,7 @@ func (o *ObjectScanner) Init() bool {
func (o *ObjectScanner) NegotiateCapabilities() bool {
reqCaps := []string{"capability=clean", "capability=smudge"}
supCaps, err := o.readPacketList()
supCaps, err := o.p.readPacketList()
if err != nil {
fmt.Fprintf(os.Stderr,
"Error: reading filter capabilities failed with %s\n", err)
@ -185,7 +92,7 @@ func (o *ObjectScanner) NegotiateCapabilities() bool {
}
}
err = o.writePacketList(reqCaps)
err = o.p.writePacketList(reqCaps)
if err != nil {
fmt.Fprintf(os.Stderr,
"Error: writing filter capabilities failed with %s\n", err)
@ -198,7 +105,7 @@ func (o *ObjectScanner) NegotiateCapabilities() bool {
func (o *ObjectScanner) ReadRequest() (map[string]string, []byte, error) {
tracerx.Printf("Process filter command.")
requestList, err := o.readPacketList()
requestList, err := o.p.readPacketList()
if err != nil {
return nil, nil, err
}
@ -211,7 +118,7 @@ func (o *ObjectScanner) ReadRequest() (map[string]string, []byte, error) {
var data []byte
for {
chunk, err := o.readPacket()
chunk, err := o.p.readPacket()
if err != nil {
// TODO: should we check the err of this call, to?!
o.writeStatus("error")
@ -230,12 +137,12 @@ func (o *ObjectScanner) WriteResponse(outputData []byte) error {
for {
chunkSize := len(outputData)
if chunkSize == 0 {
o.writeFlush()
o.p.writeFlush()
break
} else if chunkSize > MaxPacketLenght {
chunkSize = MaxPacketLenght // TODO check packets with the exact size
} else if chunkSize > MaxPacketLength {
chunkSize = MaxPacketLength // TODO check packets with the exact size
}
err := o.writePacket(outputData[:chunkSize])
err := o.p.writePacket(outputData[:chunkSize])
if err != nil {
// TODO: should we check the err of this call, to?!
o.writeStatus("error")
@ -246,3 +153,7 @@ func (o *ObjectScanner) WriteResponse(outputData []byte) error {
o.writeStatus("success")
return nil
}
func (o *ObjectScanner) writeStatus(status string) error {
return o.p.writePacketList([]string{"status=" + status})
}