From 0162e10c53a50a64316ff94c557b513a76e42819 Mon Sep 17 00:00:00 2001 From: Jacky Date: Mon, 5 May 2025 01:36:39 +0000 Subject: [PATCH] fix: pass context to cert and cache --- api/system/install.go | 2 +- internal/cache/cache.go | 5 ++-- internal/cache/index.go | 47 ++++++++++++++++++++++++--------- internal/cert/mutex.go | 36 ++++++++++++++++--------- internal/cert/register.go | 2 ++ internal/kernel/boot.go | 4 ++- internal/nginx_log/log_cache.go | 7 +---- 7 files changed, 68 insertions(+), 35 deletions(-) diff --git a/api/system/install.go b/api/system/install.go index 0c348b84..475e51bf 100644 --- a/api/system/install.go +++ b/api/system/install.go @@ -25,7 +25,7 @@ func init() { } func installLockStatus() bool { - return settings.NodeSettings.SkipInstallation || "" != cSettings.AppSettings.JwtSecret + return settings.NodeSettings.SkipInstallation || cSettings.AppSettings.JwtSecret != "" } // Check if installation time limit (10 minutes) is exceeded diff --git a/internal/cache/cache.go b/internal/cache/cache.go index 7baca34c..8748b0dc 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -1,6 +1,7 @@ package cache import ( + "context" "time" "github.com/dgraph-io/ristretto/v2" @@ -9,7 +10,7 @@ import ( var cache *ristretto.Cache[string, any] -func Init() { +func Init(ctx context.Context) { var err error cache, err = ristretto.NewCache(&ristretto.Config[string, any]{ NumCounters: 1e7, // number of keys to track frequency of (10M). @@ -22,7 +23,7 @@ func Init() { } // Initialize the config scanner - InitScanner() + InitScanner(ctx) } func Set(key string, value interface{}, ttl time.Duration) { diff --git a/internal/cache/index.go b/internal/cache/index.go index 4fb6b2e9..b4a5847c 100644 --- a/internal/cache/index.go +++ b/internal/cache/index.go @@ -1,6 +1,7 @@ package cache import ( + "context" "os" "path/filepath" "regexp" @@ -19,6 +20,7 @@ type ScanCallback func(configPath string, content []byte) error // Scanner is responsible for scanning and watching nginx config files type Scanner struct { + ctx context.Context // Context for the scanner watcher *fsnotify.Watcher // File system watcher scanTicker *time.Ticker // Ticker for periodic scanning initialized bool // Whether the scanner has been initialized @@ -39,24 +41,19 @@ var ( includeRegex = regexp.MustCompile(`include\s+([^;]+);`) // Global callbacks that will be executed during config file scanning - scanCallbacks []ScanCallback + scanCallbacks = make([]ScanCallback, 0) scanCallbacksMutex sync.RWMutex ) -func init() { - // Initialize the callbacks slice - scanCallbacks = make([]ScanCallback, 0) -} - // InitScanner initializes the config scanner -func InitScanner() { +func InitScanner(ctx context.Context) { if nginx.GetConfPath() == "" { logger.Error("Nginx config path is not set") return } s := GetScanner() - err := s.Initialize() + err := s.Initialize(ctx) if err != nil { logger.Error("Failed to initialize config scanner:", err) } @@ -140,7 +137,7 @@ func UnsubscribeScanningStatus(ch chan bool) { } // Initialize sets up the scanner and starts watching for file changes -func (s *Scanner) Initialize() error { +func (s *Scanner) Initialize(ctx context.Context) error { if s.initialized { return nil } @@ -151,6 +148,7 @@ func (s *Scanner) Initialize() error { return err } s.watcher = watcher + s.ctx = ctx // Scan for the first time err = s.ScanAllConfigs() @@ -207,14 +205,26 @@ func (s *Scanner) Initialize() error { // Setup a ticker for periodic scanning (every 5 minutes) s.scanTicker = time.NewTicker(5 * time.Minute) go func() { - for range s.scanTicker.C { - err := s.ScanAllConfigs() - if err != nil { - logger.Error("Periodic config scan failed:", err) + for { + select { + case <-s.ctx.Done(): + return + case <-s.scanTicker.C: + err := s.ScanAllConfigs() + if err != nil { + logger.Error("Periodic config scan failed:", err) + } } } }() + // Start a goroutine to listen for context cancellation + go func() { + <-s.ctx.Done() + logger.Debug("Context cancelled, shutting down scanner") + s.Shutdown() + }() + s.initialized = true return nil } @@ -223,6 +233,8 @@ func (s *Scanner) Initialize() error { func (s *Scanner) watchForChanges() { for { select { + case <-s.ctx.Done(): + return case event, ok := <-s.watcher.Events: if !ok { return @@ -471,3 +483,12 @@ func IsScanningInProgress() bool { defer s.scanMutex.RUnlock() return s.scanning } + +// WithContext sets a context for the scanner that will be used to control its lifecycle +func (s *Scanner) WithContext(ctx context.Context) *Scanner { + // Create a context with cancel if not already done in Initialize + if s.ctx == nil { + s.ctx = ctx + } + return s +} diff --git a/internal/cert/mutex.go b/internal/cert/mutex.go index ecd68e9c..d24d78d5 100644 --- a/internal/cert/mutex.go +++ b/internal/cert/mutex.go @@ -1,6 +1,7 @@ package cert import ( + "context" "sync" ) @@ -24,28 +25,39 @@ var ( processingMutex sync.RWMutex ) -func init() { +func initBroadcastStatus(ctx context.Context) { // Initialize channels and maps statusChan = make(chan bool, 10) // Buffer to prevent blocking subscribers = make(map[chan bool]struct{}) // Start broadcasting goroutine - go broadcastStatus() + go broadcastStatus(ctx) } // broadcastStatus listens for status changes and broadcasts to all subscribers -func broadcastStatus() { - for status := range statusChan { - subscriberMux.RLock() - for ch := range subscribers { - // Non-blocking send to prevent slow subscribers from blocking others - select { - case ch <- status: - default: - // Skip if channel buffer is full +func broadcastStatus(ctx context.Context) { + for { + select { + case <-ctx.Done(): + // Context cancelled, clean up resources and exit + close(statusChan) + return + case status, ok := <-statusChan: + if !ok { + // Channel closed, exit + return } + subscriberMux.RLock() + for ch := range subscribers { + // Non-blocking send to prevent slow subscribers from blocking others + select { + case ch <- status: + default: + // Skip if channel buffer is full + } + } + subscriberMux.RUnlock() } - subscriberMux.RUnlock() } } diff --git a/internal/cert/register.go b/internal/cert/register.go index 1eb61436..692b03a0 100644 --- a/internal/cert/register.go +++ b/internal/cert/register.go @@ -52,6 +52,8 @@ func InitRegister(ctx context.Context) { } logger.Info("ACME Default User registered") + + initBroadcastStatus(ctx) } func GetDefaultACMEUser() (user *model.AcmeUser, err error) { diff --git a/internal/kernel/boot.go b/internal/kernel/boot.go index 0f3c212a..96a994e4 100644 --- a/internal/kernel/boot.go +++ b/internal/kernel/boot.go @@ -38,7 +38,9 @@ func Boot(ctx context.Context) { InitNodeSecret, InitCryptoSecret, validation.Init, - cache.Init, + func() { + cache.Init(ctx) + }, CheckAndCleanupOTAContainers, } diff --git a/internal/nginx_log/log_cache.go b/internal/nginx_log/log_cache.go index c98ea732..4d05dab5 100644 --- a/internal/nginx_log/log_cache.go +++ b/internal/nginx_log/log_cache.go @@ -13,15 +13,10 @@ type NginxLogCache struct { var ( // logCache is the map to store all found log files - logCache map[string]*NginxLogCache + logCache = make(map[string]*NginxLogCache) cacheMutex sync.RWMutex ) -func init() { - // Initialize the cache - logCache = make(map[string]*NginxLogCache) -} - // AddLogPath adds a log path to the log cache func AddLogPath(path, logType, name string) { cacheMutex.Lock()