diff --git a/pkg/tasks/tasks.go b/pkg/tasks/tasks.go index 4d9fe67e3..4a987039c 100644 --- a/pkg/tasks/tasks.go +++ b/pkg/tasks/tasks.go @@ -19,18 +19,14 @@ const THROTTLE_TIME = time.Millisecond * 30 // we use this to check if the system is under stress right now. Hopefully this makes sense on other machines const COMMAND_START_THRESHOLD = time.Millisecond * 10 -type Task struct { - stop chan struct{} - stopped bool - stopMutex sync.Mutex - notifyStopped chan struct{} - Log *logrus.Entry - f func(chan struct{}) error -} - type ViewBufferManager struct { - writer io.Writer - currentTask *Task + // this blocks until the task has been properly stopped + stopCurrentTask func() + + // this is what we write the output of the task to. It's typically a view + writer io.Writer + + // this is for when we wait to get waitingMutex sync.Mutex taskIDMutex sync.Mutex Log *logrus.Entry @@ -80,8 +76,15 @@ func (m *ViewBufferManager) ReadLines(n int) { }) } +// note: onDone may be called twice func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), prefix string, linesToRead int, onDone func()) func(chan struct{}) error { return func(stop chan struct{}) error { + var once sync.Once + var onDoneWrapper func() + if onDone != nil { + onDoneWrapper = func() { once.Do(onDone) } + } + if m.throttle { m.Log.Info("throttling task") time.Sleep(THROTTLE_TIME) @@ -110,8 +113,9 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref } } - if onDone != nil { - onDone() + // for pty's we need to call onDone here so that cmd.Wait() doesn't block forever + if onDoneWrapper != nil { + onDoneWrapper() } }) @@ -122,34 +126,42 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref done := make(chan struct{}) + scanner := bufio.NewScanner(r) + scanner.Split(bufio.ScanLines) + + loaded := false + go utils.Safe(func() { - scanner := bufio.NewScanner(r) - scanner.Split(bufio.ScanLines) - - loaded := false - - go utils.Safe(func() { - ticker := time.NewTicker(time.Millisecond * 200) - defer ticker.Stop() - select { - case <-ticker.C: - loadingMutex.Lock() - if !loaded { - m.beforeStart() - _, _ = m.writer.Write([]byte("loading...")) - m.refreshView() - } - loadingMutex.Unlock() - case <-stop: - return + ticker := time.NewTicker(time.Millisecond * 200) + defer ticker.Stop() + select { + case <-stop: + return + case <-ticker.C: + loadingMutex.Lock() + if !loaded { + m.beforeStart() + _, _ = m.writer.Write([]byte("loading...")) + m.refreshView() } - }) + loadingMutex.Unlock() + } + }) + go utils.Safe(func() { outer: for { select { + case <-stop: + break outer case linesToRead := <-m.readLines: for i := 0; i < linesToRead; i++ { + select { + case <-stop: + break outer + default: + } + ok := scanner.Scan() loadingMutex.Lock() if !loaded { @@ -161,11 +173,6 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref } loadingMutex.Unlock() - select { - case <-stop: - break outer - default: - } if !ok { // if we're here then there's nothing left to scan from the source // so we're at the EOF and can flush the stale content @@ -175,8 +182,6 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref _, _ = m.writer.Write(append(scanner.Bytes(), '\n')) } m.refreshView() - case <-stop: - break outer } } @@ -189,8 +194,9 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref } } - if onDone != nil { - onDone() + // calling onDoneWrapper here again in case the program ended on its own accord + if onDoneWrapper != nil { + onDoneWrapper() } close(done) @@ -206,14 +212,14 @@ func (m *ViewBufferManager) NewCmdTask(start func() (*exec.Cmd, io.Reader), pref // Close closes the task manager, killing whatever task may currently be running func (t *ViewBufferManager) Close() { - if t.currentTask == nil { + if t.stopCurrentTask == nil { return } c := make(chan struct{}) go utils.Safe(func() { - t.currentTask.Stop() + t.stopCurrentTask() c <- struct{}{} }) @@ -249,19 +255,20 @@ func (m *ViewBufferManager) NewTask(f func(stop chan struct{}) error, key string return } + if m.stopCurrentTask != nil { + m.stopCurrentTask() + } + stop := make(chan struct{}) notifyStopped := make(chan struct{}) - if m.currentTask != nil { - m.currentTask.Stop() + var once sync.Once + onStop := func() { + close(stop) + <-notifyStopped } - m.currentTask = &Task{ - stop: stop, - notifyStopped: notifyStopped, - Log: m.Log, - f: f, - } + m.stopCurrentTask = func() { once.Do(onStop) } go utils.Safe(func() { if err := f(stop); err != nil { @@ -274,14 +281,3 @@ func (m *ViewBufferManager) NewTask(f func(stop chan struct{}) error, key string return nil } - -func (t *Task) Stop() { - t.stopMutex.Lock() - defer t.stopMutex.Unlock() - if t.stopped { - return - } - close(t.stop) - <-t.notifyStopped - t.stopped = true -} diff --git a/pkg/tasks/tasks_test.go b/pkg/tasks/tasks_test.go new file mode 100644 index 000000000..d580c95f5 --- /dev/null +++ b/pkg/tasks/tasks_test.go @@ -0,0 +1,136 @@ +package tasks + +import ( + "bytes" + "io" + "os/exec" + "sync" + "testing" + "time" + + "github.com/jesseduffield/lazygit/pkg/secureexec" + "github.com/jesseduffield/lazygit/pkg/utils" +) + +func getCounter() (func(), func() int) { + counter := 0 + return func() { counter++ }, func() int { return counter } +} + +func TestNewCmdTaskInstantStop(t *testing.T) { + writer := bytes.NewBuffer(nil) + beforeStart, getBeforeStartCallCount := getCounter() + refreshView, getRefreshViewCallCount := getCounter() + onEndOfInput, getOnEndOfInputCallCount := getCounter() + onNewKey, getOnNewKeyCallCount := getCounter() + onDone, getOnDoneCallCount := getCounter() + + manager := NewViewBufferManager( + utils.NewDummyLog(), + writer, + beforeStart, + refreshView, + onEndOfInput, + onNewKey, + ) + + stop := make(chan struct{}) + reader := bytes.NewBufferString("test") + start := func() (*exec.Cmd, io.Reader) { + // not actually starting this because it's not necessary + cmd := secureexec.Command("blah blah") + + close(stop) + + return cmd, reader + } + + fn := manager.NewCmdTask(start, "prefix\n", 20, onDone) + + _ = fn(stop) + + callCountExpectations := []struct { + expected int + actual int + name string + }{ + {0, getBeforeStartCallCount(), "beforeStart"}, + {1, getRefreshViewCallCount(), "refreshView"}, + {0, getOnEndOfInputCallCount(), "onEndOfInput"}, + {0, getOnNewKeyCallCount(), "onNewKey"}, + {1, getOnDoneCallCount(), "onDone"}, + } + for _, expectation := range callCountExpectations { + if expectation.actual != expectation.expected { + t.Errorf("expected %s to be called %d times, got %d", expectation.name, expectation.expected, expectation.actual) + } + } + + expectedContent := "" + actualContent := writer.String() + if actualContent != expectedContent { + t.Errorf("expected writer to receive the following content: \n%s\n. But instead it recevied: %s", expectedContent, actualContent) + } +} + +func TestNewCmdTask(t *testing.T) { + writer := bytes.NewBuffer(nil) + beforeStart, getBeforeStartCallCount := getCounter() + refreshView, getRefreshViewCallCount := getCounter() + onEndOfInput, getOnEndOfInputCallCount := getCounter() + onNewKey, getOnNewKeyCallCount := getCounter() + onDone, getOnDoneCallCount := getCounter() + + manager := NewViewBufferManager( + utils.NewDummyLog(), + writer, + beforeStart, + refreshView, + onEndOfInput, + onNewKey, + ) + + stop := make(chan struct{}) + reader := bytes.NewBufferString("test") + start := func() (*exec.Cmd, io.Reader) { + // not actually starting this because it's not necessary + cmd := secureexec.Command("blah blah") + + return cmd, reader + } + + fn := manager.NewCmdTask(start, "prefix\n", 20, onDone) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + time.Sleep(100 * time.Millisecond) + close(stop) + wg.Done() + }() + _ = fn(stop) + + wg.Wait() + + callCountExpectations := []struct { + expected int + actual int + name string + }{ + {1, getBeforeStartCallCount(), "beforeStart"}, + {1, getRefreshViewCallCount(), "refreshView"}, + {1, getOnEndOfInputCallCount(), "onEndOfInput"}, + {0, getOnNewKeyCallCount(), "onNewKey"}, + {1, getOnDoneCallCount(), "onDone"}, + } + for _, expectation := range callCountExpectations { + if expectation.actual != expectation.expected { + t.Errorf("expected %s to be called %d times, got %d", expectation.name, expectation.expected, expectation.actual) + } + } + + expectedContent := "prefix\ntest\n" + actualContent := writer.String() + if actualContent != expectedContent { + t.Errorf("expected writer to receive the following content: \n%s\n. But instead it recevied: %s", expectedContent, actualContent) + } +}