diff --git a/commands/command_push.go b/commands/command_push.go index 5699a997..670742f9 100644 --- a/commands/command_push.go +++ b/commands/command_push.go @@ -23,25 +23,42 @@ var ( ) func pushCommand(cmd *cobra.Command, args []string) { - refsData, err := ioutil.ReadAll(os.Stdin) - if err != nil { - Panic(err, "Error reading refs on stdin") - } + var left, right string - if len(refsData) == 0 { - return - } + if dryRun { + if len(args) != 2 { + Print("Usage: git media push --dry-run ") + return + } + + ref, err := gitmedia.CurrentRef() + if err != nil { + Panic(err, "Error getting current ref") + } + left = ref + right = fmt.Sprintf("^%s/%s", args[0], args[1]) + } else { + refsData, err := ioutil.ReadAll(os.Stdin) + if err != nil { + Panic(err, "Error reading refs on stdin") + } + + if len(refsData) == 0 { + return + } + + left, right = decodeRefs(string(refsData)) + if left == deleteBranch { + return + } - left, right := decodeRefs(string(refsData)) - if left == deleteBranch { - return } links := linksFromRefs(left, right) for i, link := range links { if dryRun { - fmt.Println("push", link.Oid, link.Name) + Print("push %s", link.Name) continue } if wErr := pushAsset(link.Oid, link.Name, i+1, len(links)); wErr != nil { diff --git a/gitmedia/gitmedia.go b/gitmedia/gitmedia.go index 0b5bfb33..447391e0 100644 --- a/gitmedia/gitmedia.go +++ b/gitmedia/gitmedia.go @@ -1,11 +1,13 @@ package gitmedia import ( + "errors" "fmt" "github.com/github/git-media/git" "io/ioutil" "os" "path/filepath" + "regexp" "runtime" "strings" ) @@ -83,6 +85,33 @@ func InRepo() bool { return LocalWorkingDir != "" } +var shaMatcher = regexp.MustCompile(`^[0-9a-f]{40}`) + +func CurrentRef() (string, error) { + head, err := ioutil.ReadFile(filepath.Join(LocalGitDir, "HEAD")) + if err != nil { + return "", err + } + + if shaMatcher.Match(head) { + return strings.TrimSpace(string(head)), nil + } + + headString := string(head) + parts := strings.Split(headString, " ") + if len(parts) != 2 { + return "", errors.New("Unable to parse HEAD") + } + + refFile := strings.TrimSpace(parts[1]) + sha, err := ioutil.ReadFile(filepath.Join(LocalGitDir, refFile)) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(sha)), nil +} + func init() { var err error LocalWorkingDir, LocalGitDir, err = resolveGitDir()