fix: pass context to cert and cache

This commit is contained in:
Jacky 2025-05-05 01:36:39 +00:00
parent 910a2efbd9
commit 0162e10c53
No known key found for this signature in database
GPG key ID: 215C21B10DF38B4D
7 changed files with 68 additions and 35 deletions

View file

@ -25,7 +25,7 @@ func init() {
} }
func installLockStatus() bool { func installLockStatus() bool {
return settings.NodeSettings.SkipInstallation || "" != cSettings.AppSettings.JwtSecret return settings.NodeSettings.SkipInstallation || cSettings.AppSettings.JwtSecret != ""
} }
// Check if installation time limit (10 minutes) is exceeded // Check if installation time limit (10 minutes) is exceeded

View file

@ -1,6 +1,7 @@
package cache package cache
import ( import (
"context"
"time" "time"
"github.com/dgraph-io/ristretto/v2" "github.com/dgraph-io/ristretto/v2"
@ -9,7 +10,7 @@ import (
var cache *ristretto.Cache[string, any] var cache *ristretto.Cache[string, any]
func Init() { func Init(ctx context.Context) {
var err error var err error
cache, err = ristretto.NewCache(&ristretto.Config[string, any]{ cache, err = ristretto.NewCache(&ristretto.Config[string, any]{
NumCounters: 1e7, // number of keys to track frequency of (10M). NumCounters: 1e7, // number of keys to track frequency of (10M).
@ -22,7 +23,7 @@ func Init() {
} }
// Initialize the config scanner // Initialize the config scanner
InitScanner() InitScanner(ctx)
} }
func Set(key string, value interface{}, ttl time.Duration) { func Set(key string, value interface{}, ttl time.Duration) {

View file

@ -1,6 +1,7 @@
package cache package cache
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
@ -19,6 +20,7 @@ type ScanCallback func(configPath string, content []byte) error
// Scanner is responsible for scanning and watching nginx config files // Scanner is responsible for scanning and watching nginx config files
type Scanner struct { type Scanner struct {
ctx context.Context // Context for the scanner
watcher *fsnotify.Watcher // File system watcher watcher *fsnotify.Watcher // File system watcher
scanTicker *time.Ticker // Ticker for periodic scanning scanTicker *time.Ticker // Ticker for periodic scanning
initialized bool // Whether the scanner has been initialized initialized bool // Whether the scanner has been initialized
@ -39,24 +41,19 @@ var (
includeRegex = regexp.MustCompile(`include\s+([^;]+);`) includeRegex = regexp.MustCompile(`include\s+([^;]+);`)
// Global callbacks that will be executed during config file scanning // Global callbacks that will be executed during config file scanning
scanCallbacks []ScanCallback scanCallbacks = make([]ScanCallback, 0)
scanCallbacksMutex sync.RWMutex scanCallbacksMutex sync.RWMutex
) )
func init() {
// Initialize the callbacks slice
scanCallbacks = make([]ScanCallback, 0)
}
// InitScanner initializes the config scanner // InitScanner initializes the config scanner
func InitScanner() { func InitScanner(ctx context.Context) {
if nginx.GetConfPath() == "" { if nginx.GetConfPath() == "" {
logger.Error("Nginx config path is not set") logger.Error("Nginx config path is not set")
return return
} }
s := GetScanner() s := GetScanner()
err := s.Initialize() err := s.Initialize(ctx)
if err != nil { if err != nil {
logger.Error("Failed to initialize config scanner:", err) logger.Error("Failed to initialize config scanner:", err)
} }
@ -140,7 +137,7 @@ func UnsubscribeScanningStatus(ch chan bool) {
} }
// Initialize sets up the scanner and starts watching for file changes // Initialize sets up the scanner and starts watching for file changes
func (s *Scanner) Initialize() error { func (s *Scanner) Initialize(ctx context.Context) error {
if s.initialized { if s.initialized {
return nil return nil
} }
@ -151,6 +148,7 @@ func (s *Scanner) Initialize() error {
return err return err
} }
s.watcher = watcher s.watcher = watcher
s.ctx = ctx
// Scan for the first time // Scan for the first time
err = s.ScanAllConfigs() err = s.ScanAllConfigs()
@ -207,14 +205,26 @@ func (s *Scanner) Initialize() error {
// Setup a ticker for periodic scanning (every 5 minutes) // Setup a ticker for periodic scanning (every 5 minutes)
s.scanTicker = time.NewTicker(5 * time.Minute) s.scanTicker = time.NewTicker(5 * time.Minute)
go func() { go func() {
for range s.scanTicker.C { for {
err := s.ScanAllConfigs() select {
if err != nil { case <-s.ctx.Done():
logger.Error("Periodic config scan failed:", err) return
case <-s.scanTicker.C:
err := s.ScanAllConfigs()
if err != nil {
logger.Error("Periodic config scan failed:", err)
}
} }
} }
}() }()
// Start a goroutine to listen for context cancellation
go func() {
<-s.ctx.Done()
logger.Debug("Context cancelled, shutting down scanner")
s.Shutdown()
}()
s.initialized = true s.initialized = true
return nil return nil
} }
@ -223,6 +233,8 @@ func (s *Scanner) Initialize() error {
func (s *Scanner) watchForChanges() { func (s *Scanner) watchForChanges() {
for { for {
select { select {
case <-s.ctx.Done():
return
case event, ok := <-s.watcher.Events: case event, ok := <-s.watcher.Events:
if !ok { if !ok {
return return
@ -471,3 +483,12 @@ func IsScanningInProgress() bool {
defer s.scanMutex.RUnlock() defer s.scanMutex.RUnlock()
return s.scanning return s.scanning
} }
// WithContext sets a context for the scanner that will be used to control its lifecycle
func (s *Scanner) WithContext(ctx context.Context) *Scanner {
// Create a context with cancel if not already done in Initialize
if s.ctx == nil {
s.ctx = ctx
}
return s
}

View file

@ -1,6 +1,7 @@
package cert package cert
import ( import (
"context"
"sync" "sync"
) )
@ -24,28 +25,39 @@ var (
processingMutex sync.RWMutex processingMutex sync.RWMutex
) )
func init() { func initBroadcastStatus(ctx context.Context) {
// Initialize channels and maps // Initialize channels and maps
statusChan = make(chan bool, 10) // Buffer to prevent blocking statusChan = make(chan bool, 10) // Buffer to prevent blocking
subscribers = make(map[chan bool]struct{}) subscribers = make(map[chan bool]struct{})
// Start broadcasting goroutine // Start broadcasting goroutine
go broadcastStatus() go broadcastStatus(ctx)
} }
// broadcastStatus listens for status changes and broadcasts to all subscribers // broadcastStatus listens for status changes and broadcasts to all subscribers
func broadcastStatus() { func broadcastStatus(ctx context.Context) {
for status := range statusChan { for {
subscriberMux.RLock() select {
for ch := range subscribers { case <-ctx.Done():
// Non-blocking send to prevent slow subscribers from blocking others // Context cancelled, clean up resources and exit
select { close(statusChan)
case ch <- status: return
default: case status, ok := <-statusChan:
// Skip if channel buffer is full if !ok {
// Channel closed, exit
return
} }
subscriberMux.RLock()
for ch := range subscribers {
// Non-blocking send to prevent slow subscribers from blocking others
select {
case ch <- status:
default:
// Skip if channel buffer is full
}
}
subscriberMux.RUnlock()
} }
subscriberMux.RUnlock()
} }
} }

View file

@ -52,6 +52,8 @@ func InitRegister(ctx context.Context) {
} }
logger.Info("ACME Default User registered") logger.Info("ACME Default User registered")
initBroadcastStatus(ctx)
} }
func GetDefaultACMEUser() (user *model.AcmeUser, err error) { func GetDefaultACMEUser() (user *model.AcmeUser, err error) {

View file

@ -38,7 +38,9 @@ func Boot(ctx context.Context) {
InitNodeSecret, InitNodeSecret,
InitCryptoSecret, InitCryptoSecret,
validation.Init, validation.Init,
cache.Init, func() {
cache.Init(ctx)
},
CheckAndCleanupOTAContainers, CheckAndCleanupOTAContainers,
} }

View file

@ -13,15 +13,10 @@ type NginxLogCache struct {
var ( var (
// logCache is the map to store all found log files // logCache is the map to store all found log files
logCache map[string]*NginxLogCache logCache = make(map[string]*NginxLogCache)
cacheMutex sync.RWMutex cacheMutex sync.RWMutex
) )
func init() {
// Initialize the cache
logCache = make(map[string]*NginxLogCache)
}
// AddLogPath adds a log path to the log cache // AddLogPath adds a log path to the log cache
func AddLogPath(path, logType, name string) { func AddLogPath(path, logType, name string) {
cacheMutex.Lock() cacheMutex.Lock()