From 31b914512a327ac5a5ecf324d10f248f10915d4b Mon Sep 17 00:00:00 2001 From: mmetc <92726601+mmetc@users.noreply.github.com> Date: Wed, 7 May 2025 11:12:27 +0200 Subject: [PATCH] refact pkg/database: unnecessary pointers (#3611) * refact pkg/database: unnecessary pointers * lint --- cmd/crowdsec-cli/clipapi/papi.go | 10 ++++------ pkg/apiclient/decisions_service.go | 10 +++++----- pkg/apiclient/decisions_service_test.go | 6 +++--- pkg/apiserver/apic.go | 17 ++++++++++------- pkg/apiserver/apic_test.go | 6 +++--- pkg/apiserver/papi.go | 13 +++++++------ pkg/database/config.go | 23 +++++++++++++---------- 7 files changed, 45 insertions(+), 40 deletions(-) diff --git a/cmd/crowdsec-cli/clipapi/papi.go b/cmd/crowdsec-cli/clipapi/papi.go index c2be87f8a..b48685bd1 100644 --- a/cmd/crowdsec-cli/clipapi/papi.go +++ b/cmd/crowdsec-cli/clipapi/papi.go @@ -11,8 +11,6 @@ import ( "github.com/spf13/cobra" "gopkg.in/tomb.v2" - "github.com/crowdsecurity/go-cs-lib/ptr" - "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/args" "github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require" "github.com/crowdsecurity/crowdsec/pkg/apiserver" @@ -76,17 +74,17 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey) if err != nil { - lastTimestampStr = ptr.Of("never") + lastTimestampStr = "never" } // both can and did happen - if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" { - lastTimestampStr = ptr.Of("never") + if lastTimestampStr == "" || lastTimestampStr == "0001-01-01T00:00:00Z" { + lastTimestampStr = "never" } fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n") fmt.Fprintf(out, "Console plan: %s\n", perms.Plan) - fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr) + fmt.Fprintf(out, "Last order received: %s\n", lastTimestampStr) fmt.Fprint(out, "PAPI subscriptions:\n") for _, sub := range perms.Categories { diff --git a/pkg/apiclient/decisions_service.go b/pkg/apiclient/decisions_service.go index c222e2ddb..47a37773e 100644 --- a/pkg/apiclient/decisions_service.go +++ b/pkg/apiclient/decisions_service.go @@ -174,7 +174,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m return &v2Decisions, resp, nil } -func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) { +func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp string) ([]*models.Decision, bool, error) { if blocklist.URL == nil { return nil, false, errors.New("blocklist URL is nil") } @@ -188,8 +188,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl return nil, false, err } - if lastPullTimestamp != nil { - req.Header.Set("If-Modified-Since", *lastPullTimestamp) + if lastPullTimestamp != "" { + req.Header.Set("If-Modified-Since", lastPullTimestamp) } log.Debugf("[URL] %s %s", req.Method, req.URL) @@ -217,8 +217,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl } if resp.StatusCode == http.StatusNotModified { - if lastPullTimestamp != nil { - log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp) + if lastPullTimestamp != "" { + log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, lastPullTimestamp) } else { log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL) } diff --git a/pkg/apiclient/decisions_service_test.go b/pkg/apiclient/decisions_service_test.go index c9e555e92..4bc655876 100644 --- a/pkg/apiclient/decisions_service_test.go +++ b/pkg/apiclient/decisions_service_test.go @@ -362,7 +362,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Remediation: &tremediationBlocklist, Name: &tnameBlocklist, Duration: &tdurationBlocklist, - }, nil) + }, "") require.NoError(t, err) assert.True(t, isModified) @@ -381,7 +381,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Remediation: &tremediationBlocklist, Name: &tnameBlocklist, Duration: &tdurationBlocklist, - }, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT")) + }, "Sun, 01 Jan 2023 01:01:01 GMT") require.NoError(t, err) assert.False(t, isModified) @@ -392,7 +392,7 @@ func TestDecisionsFromBlocklist(t *testing.T) { Remediation: &tremediationBlocklist, Name: &tnameBlocklist, Duration: &tdurationBlocklist, - }, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT")) + }, "Mon, 02 Jan 2023 01:01:01 GMT") require.NoError(t, err) assert.True(t, isModified) diff --git a/pkg/apiserver/apic.go b/pkg/apiserver/apic.go index 8e92dc674..3ab75f5ec 100644 --- a/pkg/apiserver/apic.go +++ b/pkg/apiserver/apic.go @@ -243,6 +243,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient } err = ret.Authenticate(ctx, config) + return ret, err } @@ -260,13 +261,14 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time, return "", time.Time{}, false } - if token == nil { + if token == "" { log.Debug("no token found in DB") return "", time.Time{}, false } parser := new(jwt.Parser) - tok, _, err := parser.ParseUnverified(*token, jwt.MapClaims{}) + + tok, _, err := parser.ParseUnverified(token, jwt.MapClaims{}) if err != nil { log.Debugf("error parsing token: %s", err) return "", time.Time{}, false @@ -285,12 +287,12 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time, } exp := time.Unix(int64(expFloat), 0) - if time.Now().UTC().After(exp.Add(-1*time.Minute)) { + if time.Now().UTC().After(exp.Add(-1 * time.Minute)) { log.Debug("auth token expired") return "", time.Time{}, false } - return *token, exp, true + return token, exp, true } // saveAPICToken stores the given JWT token in the local database under the "apic_token" config item. @@ -310,6 +312,7 @@ func saveAPICToken(ctx context.Context, db *database.Client, token string) error func (a *apic) Authenticate(ctx context.Context, config *csconfig.OnlineApiClientCfg) error { if token, exp, valid := loadAPICToken(ctx, a.dbClient); valid { log.Debug("using valid token from DB") + a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = token a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration = exp } @@ -1043,7 +1046,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name) var ( - lastPullTimestamp *string + lastPullTimestamp string err error ) @@ -1060,10 +1063,10 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient, } if !hasChanged { - if lastPullTimestamp == nil { + if lastPullTimestamp == "" { log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name) } else { - log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp) + log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, lastPullTimestamp) } return nil diff --git a/pkg/apiserver/apic_test.go b/pkg/apiserver/apic_test.go index fc4e29087..a99318755 100644 --- a/pkg/apiserver/apic_test.go +++ b/pkg/apiserver/apic_test.go @@ -99,7 +99,7 @@ func assertTotalValidDecisionCount(t *testing.T, dbClient *database.Client, coun assert.Len(t, d, count) } -func jsonMarshalX(v interface{}) []byte { +func jsonMarshalX(v any) []byte { data, err := json.Marshal(v) if err != nil { panic(err) @@ -932,7 +932,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { blocklistConfigItemName := "blocklist:blocklist1:last_pull" lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) - assert.NotEmpty(t, *lastPullTimestamp) + assert.NotEmpty(t, lastPullTimestamp) // new call should return 304 and should not change lastPullTimestamp httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) { @@ -944,7 +944,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) { require.NoError(t, err) secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName) require.NoError(t, err) - assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp) + assert.Equal(t, lastPullTimestamp, secondLastPullTimestamp) } func TestAPICPullTopBLCacheForceCall(t *testing.T) { diff --git a/pkg/apiserver/papi.go b/pkg/apiserver/papi.go index 442c57295..501f09e8f 100644 --- a/pkg/apiserver/papi.go +++ b/pkg/apiserver/papi.go @@ -22,9 +22,10 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/types" ) -const SyncInterval = time.Second * 10 - -const PapiPullKey = "papi:last_pull" +const ( + SyncInterval = time.Second * 10 + PapiPullKey = "papi:last_pull" +) var operationMap = map[string]func(*Message, *Papi, bool) error{ "decision": DecisionCmd, @@ -48,7 +49,7 @@ type Source struct { type Message struct { Header *Header - Data interface{} `json:"data"` + Data any `json:"data"` } type OperationChannels struct { @@ -240,7 +241,7 @@ func (p *Papi) Pull(ctx context.Context) error { } // value doesn't exist, it's first time we're pulling - if lastTimestampStr == nil { + if lastTimestampStr == "" { binTime, err := lastTimestamp.MarshalText() if err != nil { return fmt.Errorf("failed to serialize last timestamp: %w", err) @@ -252,7 +253,7 @@ func (p *Papi) Pull(ctx context.Context) error { p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime)) } } else { - if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil { + if err := lastTimestamp.UnmarshalText([]byte(lastTimestampStr)); err != nil { return fmt.Errorf("failed to parse last timestamp: %w", err) } } diff --git a/pkg/database/config.go b/pkg/database/config.go index 89ccb1e1b..2f2629650 100644 --- a/pkg/database/config.go +++ b/pkg/database/config.go @@ -9,27 +9,30 @@ import ( "github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem" ) -func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) { +func (c *Client) GetConfigItem(ctx context.Context, key string) (string, error) { result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx) - if err != nil && ent.IsNotFound(err) { - return nil, nil - } - if err != nil { - return nil, errors.Wrapf(QueryFail, "select config item: %s", err) + switch { + case ent.IsNotFound(err): + return "", nil + case err != nil: + return "", errors.Wrapf(QueryFail, "select config item: %s", err) + default: + return result.Value, nil } - - return &result.Value, nil } func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error { nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx) - if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create + + switch { + case ent.IsNotFound(err) || nbUpdated == 0: + // not found, create err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx) if err != nil { return errors.Wrapf(QueryFail, "insert config item: %s", err) } - } else if err != nil { + case err != nil: return errors.Wrapf(QueryFail, "update config item: %s", err) }