diff --git a/pkg/gui/gui.go b/pkg/gui/gui.go index 6ab2a916b..1d14c9f50 100644 --- a/pkg/gui/gui.go +++ b/pkg/gui/gui.go @@ -185,7 +185,7 @@ type GuiRepoState struct { // WindowViewNameMap is a mapping of windows to the current view of that window. // Some views move between windows for example the commitFiles view and when cycling through // side windows we need to know which view to give focus to for a given window - WindowViewNameMap map[string]string + WindowViewNameMap *utils.ThreadSafeMap[string, string] // tells us whether we've set up our views for the current repo. We'll need to // do this whenever we switch back and forth between repos to get the views diff --git a/pkg/gui/window.go b/pkg/gui/window.go index 12cd31868..1d69c420c 100644 --- a/pkg/gui/window.go +++ b/pkg/gui/window.go @@ -6,6 +6,7 @@ import ( "github.com/jesseduffield/gocui" "github.com/jesseduffield/lazygit/pkg/gui/context" "github.com/jesseduffield/lazygit/pkg/gui/types" + "github.com/jesseduffield/lazygit/pkg/utils" "github.com/samber/lo" ) @@ -15,18 +16,18 @@ import ( // space. Right now most windows are 1:1 with views, except for commitFiles which // is a view that moves between windows -func (gui *Gui) initialWindowViewNameMap(contextTree *context.ContextTree) map[string]string { - result := map[string]string{} +func (gui *Gui) initialWindowViewNameMap(contextTree *context.ContextTree) *utils.ThreadSafeMap[string, string] { + result := utils.NewThreadSafeMap[string, string]() for _, context := range contextTree.Flatten() { - result[context.GetWindowName()] = context.GetViewName() + result.Set(context.GetWindowName(), context.GetViewName()) } return result } func (gui *Gui) getViewNameForWindow(window string) string { - viewName, ok := gui.State.WindowViewNameMap[window] + viewName, ok := gui.State.WindowViewNameMap.Get(window) if !ok { panic(fmt.Sprintf("Viewname not found for window: %s", window)) } @@ -51,7 +52,7 @@ func (gui *Gui) setWindowContext(c types.Context) { gui.resetWindowContext(c) } - gui.State.WindowViewNameMap[c.GetWindowName()] = c.GetViewName() + gui.State.WindowViewNameMap.Set(c.GetWindowName(), c.GetViewName()) } func (gui *Gui) currentWindow() string { @@ -60,11 +61,15 @@ func (gui *Gui) currentWindow() string { // assumes the context's windowName has been set to the new window if necessary func (gui *Gui) resetWindowContext(c types.Context) { - for windowName, viewName := range gui.State.WindowViewNameMap { + for _, windowName := range gui.State.WindowViewNameMap.Keys() { + viewName, ok := gui.State.WindowViewNameMap.Get(windowName) + if !ok { + continue + } if viewName == c.GetViewName() && windowName != c.GetWindowName() { for _, context := range gui.State.Contexts.Flatten() { if context.GetKey() != c.GetKey() && context.GetWindowName() == windowName { - gui.State.WindowViewNameMap[windowName] = context.GetViewName() + gui.State.WindowViewNameMap.Set(windowName, context.GetViewName()) } } } diff --git a/pkg/utils/thread_safe_map.go b/pkg/utils/thread_safe_map.go new file mode 100644 index 000000000..a70cefc7d --- /dev/null +++ b/pkg/utils/thread_safe_map.go @@ -0,0 +1,90 @@ +package utils + +import "sync" + +type ThreadSafeMap[K comparable, V any] struct { + mutex sync.RWMutex + + innerMap map[K]V +} + +func NewThreadSafeMap[K comparable, V any]() *ThreadSafeMap[K, V] { + return &ThreadSafeMap[K, V]{ + innerMap: make(map[K]V), + } +} + +func (m *ThreadSafeMap[K, V]) Get(key K) (V, bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + + value, ok := m.innerMap[key] + return value, ok +} + +func (m *ThreadSafeMap[K, V]) Set(key K, value V) { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.innerMap[key] = value +} + +func (m *ThreadSafeMap[K, V]) Delete(key K) { + m.mutex.Lock() + defer m.mutex.Unlock() + + delete(m.innerMap, key) +} + +func (m *ThreadSafeMap[K, V]) Keys() []K { + m.mutex.RLock() + defer m.mutex.RUnlock() + + keys := make([]K, 0, len(m.innerMap)) + for key := range m.innerMap { + keys = append(keys, key) + } + + return keys +} + +func (m *ThreadSafeMap[K, V]) Values() []V { + m.mutex.RLock() + defer m.mutex.RUnlock() + + values := make([]V, 0, len(m.innerMap)) + for _, value := range m.innerMap { + values = append(values, value) + } + + return values +} + +func (m *ThreadSafeMap[K, V]) Len() int { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return len(m.innerMap) +} + +func (m *ThreadSafeMap[K, V]) Clear() { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.innerMap = make(map[K]V) +} + +func (m *ThreadSafeMap[K, V]) IsEmpty() bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + return len(m.innerMap) == 0 +} + +func (m *ThreadSafeMap[K, V]) Has(key K) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + _, ok := m.innerMap[key] + return ok +} diff --git a/pkg/utils/thread_safe_map_test.go b/pkg/utils/thread_safe_map_test.go new file mode 100644 index 000000000..9676cfe5f --- /dev/null +++ b/pkg/utils/thread_safe_map_test.go @@ -0,0 +1,59 @@ +package utils + +import ( + "testing" +) + +func TestThreadSafeMap(t *testing.T) { + m := NewThreadSafeMap[int, int]() + + m.Set(1, 1) + m.Set(2, 2) + m.Set(3, 3) + + if m.Len() != 3 { + t.Errorf("Expected length to be 3, got %d", m.Len()) + } + + if !m.Has(1) { + t.Errorf("Expected to have key 1") + } + + if m.Has(4) { + t.Errorf("Expected to not have key 4") + } + + if _, ok := m.Get(1); !ok { + t.Errorf("Expected to have key 1") + } + + if _, ok := m.Get(4); ok { + t.Errorf("Expected to not have key 4") + } + + m.Delete(1) + + if m.Has(1) { + t.Errorf("Expected to not have key 1") + } + + m.Clear() + + if m.Len() != 0 { + t.Errorf("Expected length to be 0, got %d", m.Len()) + } +} + +func TestThreadSafeMapConcurrentReadWrite(t *testing.T) { + m := NewThreadSafeMap[int, int]() + + go func() { + for i := 0; i < 10000; i++ { + m.Set(0, 0) + } + }() + + for i := 0; i < 10000; i++ { + m.Get(0) + } +}