diff --git a/tasklog/percentage_task.go b/tasklog/percentage_task.go index 4e4d6a04..4b8d8f3a 100644 --- a/tasklog/percentage_task.go +++ b/tasklog/percentage_task.go @@ -78,6 +78,15 @@ func (c *PercentageTask) Entry(update string) { } } +// Complete notes that the task is completed by setting the number of +// completed elements to the total number of elements, and if necessary +// closing the Updates channel, which yields the logger to the next Task. +func (c *PercentageTask) Complete() { + if count := atomic.SwapUint64(&c.n, c.total); count < c.total { + close(c.ch) + } +} + // Updates implements Task.Updates and returns a channel which is written to // when the state of this task changes, and closed when the task is completed. func (c *PercentageTask) Updates() <-chan *Update { diff --git a/tasklog/percentage_task_test.go b/tasklog/percentage_task_test.go index 3159cd57..0d562dd0 100644 --- a/tasklog/percentage_task_test.go +++ b/tasklog/percentage_task_test.go @@ -50,6 +50,37 @@ func TestPercentageTaskCallsDoneWhenComplete(t *testing.T) { if _, ok := <-task.Updates(); ok { t.Fatalf("expected channel to be closed") } + + defer func() { + if err := recover(); err != nil { + t.Fatal("tasklog: expected *PercentageTask.Complete() to not panic") + } + }() + + task.Complete() +} + +func TestPercentageTaskCompleteClosesUpdates(t *testing.T) { + task := NewPercentageTask("example", 10) + + select { + case v, ok := <-task.Updates(): + if ok { + assert.Equal(t, "example: 0% (0/10)", v.S) + } else { + t.Fatal("expected channel to be open") + } + default: + } + + assert.EqualValues(t, 7, task.Count(7)) + assert.Equal(t, "example: 70% (7/10)", (<-task.Updates()).S) + + task.Complete() + + if _, ok := <-task.Updates(); ok { + t.Fatalf("expected channel to be closed") + } } func TestPercentageTaskIsThrottled(t *testing.T) {