refact: reduce code nesting (acquisition/file, tests) (#3200)

* reduce if nesting

* lint: gocritic (nestingReduce)

* lint
This commit is contained in:
mmetc 2024-09-03 12:25:30 +02:00 committed by GitHub
parent ae5e99ff13
commit 5a50fd06bb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 249 additions and 280 deletions

View file

@ -27,7 +27,7 @@ linters-settings:
nestif:
# lower this after refactoring
min-complexity: 20
min-complexity: 19
nlreturn:
block-size: 5

View file

@ -31,12 +31,12 @@ import (
type configGetter func() *csconfig.Config
type cliConsole struct {
cfg configGetter
cfg configGetter
}
func New(cfg configGetter) *cliConsole {
return &cliConsole{
cfg: cfg,
cfg: cfg,
}
}
@ -88,23 +88,25 @@ func (cli *cliConsole) enroll(key string, name string, overwrite bool, tags []st
}
for _, availableOpt := range csconfig.CONSOLE_CONFIGS {
if opt == availableOpt {
valid = true
enable := true
for _, enabledOpt := range enableOpts {
if opt == enabledOpt {
enable = false
continue
}
}
if enable {
enableOpts = append(enableOpts, opt)
}
break
if opt != availableOpt {
continue
}
valid = true
enable := true
for _, enabledOpt := range enableOpts {
if opt == enabledOpt {
enable = false
continue
}
}
if enable {
enableOpts = append(enableOpts, opt)
}
break
}
if !valid {

View file

@ -426,118 +426,122 @@ func (f *FileSource) monitorNewFiles(out chan types.Event, t *tomb.Tomb) error {
return nil
}
if event.Op&fsnotify.Create == fsnotify.Create {
fi, err := os.Stat(event.Name)
if err != nil {
logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err)
continue
}
if fi.IsDir() {
continue
}
logger.Debugf("Detected new file %s", event.Name)
matched := false
for _, pattern := range f.config.Filenames {
logger.Debugf("Matching %s with %s", pattern, event.Name)
matched, err = filepath.Match(pattern, event.Name)
if err != nil {
logger.Errorf("Could not match pattern : %s", err)
continue
}
if matched {
logger.Debugf("Matched %s with %s", pattern, event.Name)
break
}
}
if !matched {
continue
}
// before opening the file, check if we need to specifically avoid it. (XXX)
skip := false
for _, pattern := range f.exclude_regexps {
if pattern.MatchString(event.Name) {
f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String())
skip = true
break
}
}
if skip {
continue
}
f.tailMapMutex.RLock()
if f.tails[event.Name] {
f.tailMapMutex.RUnlock()
// we already have a tail on it, do not start a new one
logger.Debugf("Already tailing file %s, not creating a new tail", event.Name)
break
}
f.tailMapMutex.RUnlock()
// cf. https://github.com/crowdsecurity/crowdsec/issues/1168
// do not rely on stat, reclose file immediately as it's opened by Tail
fd, err := os.Open(event.Name)
if err != nil {
f.logger.Errorf("unable to read %s : %s", event.Name, err)
continue
}
if err := fd.Close(); err != nil {
f.logger.Errorf("unable to close %s : %s", event.Name, err)
continue
}
pollFile := false
if f.config.PollWithoutInotify != nil {
pollFile = *f.config.PollWithoutInotify
} else {
networkFS, fsType, err := types.IsNetworkFS(event.Name)
if err != nil {
f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err)
}
f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType)
if networkFS {
pollFile = true
}
}
filink, err := os.Lstat(event.Name)
if err != nil {
logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err)
continue
}
if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile {
logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name)
}
//Slightly different parameters for Location, as we want to read the first lines of the newly created file
tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}})
if err != nil {
logger.Errorf("Could not start tailing file %s : %s", event.Name, err)
break
}
f.tailMapMutex.Lock()
f.tails[event.Name] = true
f.tailMapMutex.Unlock()
t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/tailfile")
return f.tailFile(out, t, tail)
})
if event.Op&fsnotify.Create != fsnotify.Create {
continue
}
fi, err := os.Stat(event.Name)
if err != nil {
logger.Errorf("Could not stat() new file %s, ignoring it : %s", event.Name, err)
continue
}
if fi.IsDir() {
continue
}
logger.Debugf("Detected new file %s", event.Name)
matched := false
for _, pattern := range f.config.Filenames {
logger.Debugf("Matching %s with %s", pattern, event.Name)
matched, err = filepath.Match(pattern, event.Name)
if err != nil {
logger.Errorf("Could not match pattern : %s", err)
continue
}
if matched {
logger.Debugf("Matched %s with %s", pattern, event.Name)
break
}
}
if !matched {
continue
}
// before opening the file, check if we need to specifically avoid it. (XXX)
skip := false
for _, pattern := range f.exclude_regexps {
if pattern.MatchString(event.Name) {
f.logger.Infof("file %s matches exclusion pattern %s, skipping", event.Name, pattern.String())
skip = true
break
}
}
if skip {
continue
}
f.tailMapMutex.RLock()
if f.tails[event.Name] {
f.tailMapMutex.RUnlock()
// we already have a tail on it, do not start a new one
logger.Debugf("Already tailing file %s, not creating a new tail", event.Name)
break
}
f.tailMapMutex.RUnlock()
// cf. https://github.com/crowdsecurity/crowdsec/issues/1168
// do not rely on stat, reclose file immediately as it's opened by Tail
fd, err := os.Open(event.Name)
if err != nil {
f.logger.Errorf("unable to read %s : %s", event.Name, err)
continue
}
if err = fd.Close(); err != nil {
f.logger.Errorf("unable to close %s : %s", event.Name, err)
continue
}
pollFile := false
if f.config.PollWithoutInotify != nil {
pollFile = *f.config.PollWithoutInotify
} else {
networkFS, fsType, err := types.IsNetworkFS(event.Name)
if err != nil {
f.logger.Warningf("Could not get fs type for %s : %s", event.Name, err)
}
f.logger.Debugf("fs for %s is network: %t (%s)", event.Name, networkFS, fsType)
if networkFS {
pollFile = true
}
}
filink, err := os.Lstat(event.Name)
if err != nil {
logger.Errorf("Could not lstat() new file %s, ignoring it : %s", event.Name, err)
continue
}
if filink.Mode()&os.ModeSymlink == os.ModeSymlink && !pollFile {
logger.Warnf("File %s is a symlink, but inotify polling is enabled. Crowdsec will not be able to detect rotation. Consider setting poll_without_inotify to true in your configuration", event.Name)
}
// Slightly different parameters for Location, as we want to read the first lines of the newly created file
tail, err := tail.TailFile(event.Name, tail.Config{ReOpen: true, Follow: true, Poll: pollFile, Location: &tail.SeekInfo{Offset: 0, Whence: io.SeekStart}})
if err != nil {
logger.Errorf("Could not start tailing file %s : %s", event.Name, err)
break
}
f.tailMapMutex.Lock()
f.tails[event.Name] = true
f.tailMapMutex.Unlock()
t.Go(func() error {
defer trace.CatchPanic("crowdsec/acquis/tailfile")
return f.tailFile(out, t, tail)
})
case err, ok := <-f.watcher.Errors:
if !ok {
return nil
@ -571,8 +575,9 @@ func (f *FileSource) tailFile(out chan types.Event, t *tomb.Tomb, tail *tail.Tai
return nil
case <-tail.Dying(): // our tailer is dying
err := tail.Err()
errMsg := fmt.Sprintf("file reader of %s died", tail.Filename)
err := tail.Err()
if err != nil {
errMsg = fmt.Sprintf(errMsg+" : %s", err)
}

View file

@ -4,6 +4,10 @@ import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/crowdsecurity/go-cs-lib/cstest"
)
func TestPri(t *testing.T) {
@ -26,28 +30,20 @@ func TestPri(t *testing.T) {
r := &RFC3164{}
r.buf = []byte(test.input)
r.len = len(r.buf)
err := r.parsePRI()
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error %s, got %s", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: %s", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error %s, got no error", test.expectedErr)
} else if r.PRI != test.expected {
t.Errorf("expected %d, got %d", test.expected, r.PRI)
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected, r.PRI)
})
}
}
func TestTimestamp(t *testing.T) {
tests := []struct {
input string
expected string
@ -68,25 +64,19 @@ func TestTimestamp(t *testing.T) {
if test.currentYear {
opts = append(opts, WithCurrentYear())
}
r := NewRFC3164Parser(opts...)
r.buf = []byte(test.input)
r.len = len(r.buf)
err := r.parseTimestamp()
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error %s, got %s", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: %s", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error %s, got no error", test.expectedErr)
} else if r.Timestamp.Format(time.RFC3339) != test.expected {
t.Errorf("expected %s, got %s", test.expected, r.Timestamp.Format(time.RFC3339))
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected, r.Timestamp.Format(time.RFC3339))
})
}
}
@ -121,25 +111,19 @@ func TestHostname(t *testing.T) {
if test.strictHostname {
opts = append(opts, WithStrictHostname())
}
r := NewRFC3164Parser(opts...)
r.buf = []byte(test.input)
r.len = len(r.buf)
err := r.parseHostname()
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error %s, got %s", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: %s", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error %s, got no error", test.expectedErr)
} else if r.Hostname != test.expected {
t.Errorf("expected %s, got %s", test.expected, r.Hostname)
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected, r.Hostname)
})
}
}
@ -164,27 +148,16 @@ func TestTag(t *testing.T) {
r := &RFC3164{}
r.buf = []byte(test.input)
r.len = len(r.buf)
err := r.parseTag()
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error %s, got %s", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: %s", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error %s, got no error", test.expectedErr)
} else {
if r.Tag != test.expected {
t.Errorf("expected %s, got %s", test.expected, r.Tag)
}
if r.PID != test.expectedPID {
t.Errorf("expected %s, got %s", test.expected, r.Message)
}
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected, r.Tag)
assert.Equal(t, test.expectedPID, r.PID)
})
}
}
@ -207,22 +180,15 @@ func TestMessage(t *testing.T) {
r := &RFC3164{}
r.buf = []byte(test.input)
r.len = len(r.buf)
err := r.parseMessage()
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error %s, got %s", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: %s", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error %s, got no error", test.expectedErr)
} else if r.Message != test.expected {
t.Errorf("expected message %s, got %s", test.expected, r.Tag)
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected, r.Message)
})
}
}
@ -236,6 +202,7 @@ func TestParse(t *testing.T) {
Message string
PRI int
}
tests := []struct {
input string
expected expected
@ -326,39 +293,20 @@ func TestParse(t *testing.T) {
for _, test := range tests {
t.Run(test.input, func(t *testing.T) {
r := NewRFC3164Parser(test.opts...)
err := r.Parse([]byte(test.input))
if err != nil {
if test.expectedErr != "" {
if err.Error() != test.expectedErr {
t.Errorf("expected error '%s', got '%s'", test.expectedErr, err)
}
} else {
t.Errorf("unexpected error: '%s'", err)
}
} else {
if test.expectedErr != "" {
t.Errorf("expected error '%s', got no error", test.expectedErr)
} else {
if r.Timestamp != test.expected.Timestamp {
t.Errorf("expected timestamp '%s', got '%s'", test.expected.Timestamp, r.Timestamp)
}
if r.Hostname != test.expected.Hostname {
t.Errorf("expected hostname '%s', got '%s'", test.expected.Hostname, r.Hostname)
}
if r.Tag != test.expected.Tag {
t.Errorf("expected tag '%s', got '%s'", test.expected.Tag, r.Tag)
}
if r.PID != test.expected.PID {
t.Errorf("expected pid '%s', got '%s'", test.expected.PID, r.PID)
}
if r.Message != test.expected.Message {
t.Errorf("expected message '%s', got '%s'", test.expected.Message, r.Message)
}
if r.PRI != test.expected.PRI {
t.Errorf("expected pri '%d', got '%d'", test.expected.PRI, r.PRI)
}
}
cstest.RequireErrorContains(t, err, test.expectedErr)
if test.expectedErr != "" {
return
}
assert.Equal(t, test.expected.Timestamp, r.Timestamp)
assert.Equal(t, test.expected.Hostname, r.Hostname)
assert.Equal(t, test.expected.Tag, r.Tag)
assert.Equal(t, test.expected.PID, r.PID)
assert.Equal(t, test.expected.Message, r.Message)
assert.Equal(t, test.expected.PRI, r.PRI)
})
}
}

View file

@ -26,6 +26,7 @@ type ExprDbgTest struct {
func UpperTwo(params ...any) (any, error) {
s := params[0].(string)
v := params[1].(string)
return strings.ToUpper(s) + strings.ToUpper(v), nil
}
@ -33,6 +34,7 @@ func UpperThree(params ...any) (any, error) {
s := params[0].(string)
v := params[1].(string)
x := params[2].(string)
return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x), nil
}
@ -41,6 +43,7 @@ func UpperN(params ...any) (any, error) {
v := params[1].(string)
x := params[2].(string)
y := params[3].(string)
return strings.ToUpper(s) + strings.ToUpper(v) + strings.ToUpper(x) + strings.ToUpper(y), nil
}
@ -76,9 +79,9 @@ func TestBaseDbg(t *testing.T) {
// use '%#v' to dump in golang syntax
// use regexp to clear empty/default fields:
// [a-z]+: (false|\[\]string\(nil\)|""),
//ConditionResult:(*bool)
// ConditionResult:(*bool)
//Missing multi parametes function
// Missing multi parametes function
tests := []ExprDbgTest{
{
Name: "nil deref",
@ -272,6 +275,7 @@ func TestBaseDbg(t *testing.T) {
}
logger := log.WithField("test", "exprhelpers")
for _, test := range tests {
if test.LogLevel != 0 {
log.SetLevel(test.LogLevel)
@ -308,10 +312,13 @@ func TestBaseDbg(t *testing.T) {
t.Fatalf("test %s : unexpected compile error : %s", test.Name, err)
}
}
if test.Name == "nil deref" {
test.Env["nilvar"] = nil
}
outdbg, ret, err := RunWithDebug(prog, test.Env, logger)
if test.ExpectedFailRuntime {
if err == nil {
t.Fatalf("test %s : expected runtime error", test.Name)
@ -321,25 +328,30 @@ func TestBaseDbg(t *testing.T) {
t.Fatalf("test %s : unexpected runtime error : %s", test.Name, err)
}
}
log.SetLevel(log.DebugLevel)
DisplayExprDebug(prog, outdbg, logger, ret)
if len(outdbg) != len(test.ExpectedOutputs) {
t.Errorf("failed test %s", test.Name)
t.Errorf("%#v", outdbg)
//out, _ := yaml.Marshal(outdbg)
//fmt.Printf("%s", string(out))
// out, _ := yaml.Marshal(outdbg)
// fmt.Printf("%s", string(out))
t.Fatalf("test %s : expected %d outputs, got %d", test.Name, len(test.ExpectedOutputs), len(outdbg))
}
for i, out := range outdbg {
if !reflect.DeepEqual(out, test.ExpectedOutputs[i]) {
spew.Config.DisableMethods = true
t.Errorf("failed test %s", test.Name)
t.Errorf("expected : %#v", test.ExpectedOutputs[i])
t.Errorf("got : %#v", out)
t.Fatalf("%d/%d : mismatch", i, len(outdbg))
if reflect.DeepEqual(out, test.ExpectedOutputs[i]) {
// DisplayExprDebug(prog, outdbg, logger, ret)
continue
}
//DisplayExprDebug(prog, outdbg, logger, ret)
spew.Config.DisableMethods = true
t.Errorf("failed test %s", test.Name)
t.Errorf("expected : %#v", test.ExpectedOutputs[i])
t.Errorf("got : %#v", out)
t.Fatalf("%d/%d : mismatch", i, len(outdbg))
}
}
}

View file

@ -509,37 +509,39 @@ func LoadBucketsState(file string, buckets *Buckets, bucketFactories []BucketFac
found := false
for _, h := range bucketFactories {
if h.Name == v.Name {
log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description)
// check in which mode the bucket was
if v.Mode == types.TIMEMACHINE {
tbucket = NewTimeMachine(h)
} else if v.Mode == types.LIVE {
tbucket = NewLeaky(h)
} else {
log.Errorf("Unknown bucket type : %d", v.Mode)
}
/*Trying to restore queue state*/
tbucket.Queue = v.Queue
/*Trying to set the limiter to the saved values*/
tbucket.Limiter.Load(v.SerializedState)
tbucket.In = make(chan *types.Event)
tbucket.Mapkey = k
tbucket.Signal = make(chan bool, 1)
tbucket.First_ts = v.First_ts
tbucket.Last_ts = v.Last_ts
tbucket.Ovflw_ts = v.Ovflw_ts
tbucket.Total_count = v.Total_count
buckets.Bucket_map.Store(k, tbucket)
h.tomb.Go(func() error {
return LeakRoutine(tbucket)
})
<-tbucket.Signal
found = true
break
if h.Name != v.Name {
continue
}
log.Debugf("found factory %s/%s -> %s", h.Author, h.Name, h.Description)
// check in which mode the bucket was
if v.Mode == types.TIMEMACHINE {
tbucket = NewTimeMachine(h)
} else if v.Mode == types.LIVE {
tbucket = NewLeaky(h)
} else {
log.Errorf("Unknown bucket type : %d", v.Mode)
}
/*Trying to restore queue state*/
tbucket.Queue = v.Queue
/*Trying to set the limiter to the saved values*/
tbucket.Limiter.Load(v.SerializedState)
tbucket.In = make(chan *types.Event)
tbucket.Mapkey = k
tbucket.Signal = make(chan bool, 1)
tbucket.First_ts = v.First_ts
tbucket.Last_ts = v.Last_ts
tbucket.Ovflw_ts = v.Ovflw_ts
tbucket.Total_count = v.Total_count
buckets.Bucket_map.Store(k, tbucket)
h.tomb.Go(func() error {
return LeakRoutine(tbucket)
})
<-tbucket.Signal
found = true
break
}
if !found {