refact pkg/database: unnecessary pointers (#3611)

* refact pkg/database: unnecessary pointers

* lint
This commit is contained in:
mmetc 2025-05-07 11:12:27 +02:00 committed by GitHub
parent 73a423034f
commit 31b914512a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 45 additions and 40 deletions

View file

@ -11,8 +11,6 @@ import (
"github.com/spf13/cobra"
"gopkg.in/tomb.v2"
"github.com/crowdsecurity/go-cs-lib/ptr"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/args"
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
@ -76,17 +74,17 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
if err != nil {
lastTimestampStr = ptr.Of("never")
lastTimestampStr = "never"
}
// both can and did happen
if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" {
lastTimestampStr = ptr.Of("never")
if lastTimestampStr == "" || lastTimestampStr == "0001-01-01T00:00:00Z" {
lastTimestampStr = "never"
}
fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n")
fmt.Fprintf(out, "Console plan: %s\n", perms.Plan)
fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr)
fmt.Fprintf(out, "Last order received: %s\n", lastTimestampStr)
fmt.Fprint(out, "PAPI subscriptions:\n")
for _, sub := range perms.Categories {

View file

@ -174,7 +174,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
return &v2Decisions, resp, nil
}
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) {
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp string) ([]*models.Decision, bool, error) {
if blocklist.URL == nil {
return nil, false, errors.New("blocklist URL is nil")
}
@ -188,8 +188,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
return nil, false, err
}
if lastPullTimestamp != nil {
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
if lastPullTimestamp != "" {
req.Header.Set("If-Modified-Since", lastPullTimestamp)
}
log.Debugf("[URL] %s %s", req.Method, req.URL)
@ -217,8 +217,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
}
if resp.StatusCode == http.StatusNotModified {
if lastPullTimestamp != nil {
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp)
if lastPullTimestamp != "" {
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, lastPullTimestamp)
} else {
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
}

View file

@ -362,7 +362,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, nil)
}, "")
require.NoError(t, err)
assert.True(t, isModified)
@ -381,7 +381,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT"))
}, "Sun, 01 Jan 2023 01:01:01 GMT")
require.NoError(t, err)
assert.False(t, isModified)
@ -392,7 +392,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
Remediation: &tremediationBlocklist,
Name: &tnameBlocklist,
Duration: &tdurationBlocklist,
}, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT"))
}, "Mon, 02 Jan 2023 01:01:01 GMT")
require.NoError(t, err)
assert.True(t, isModified)

View file

@ -243,6 +243,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
}
err = ret.Authenticate(ctx, config)
return ret, err
}
@ -260,13 +261,14 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
return "", time.Time{}, false
}
if token == nil {
if token == "" {
log.Debug("no token found in DB")
return "", time.Time{}, false
}
parser := new(jwt.Parser)
tok, _, err := parser.ParseUnverified(*token, jwt.MapClaims{})
tok, _, err := parser.ParseUnverified(token, jwt.MapClaims{})
if err != nil {
log.Debugf("error parsing token: %s", err)
return "", time.Time{}, false
@ -285,12 +287,12 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
}
exp := time.Unix(int64(expFloat), 0)
if time.Now().UTC().After(exp.Add(-1*time.Minute)) {
if time.Now().UTC().After(exp.Add(-1 * time.Minute)) {
log.Debug("auth token expired")
return "", time.Time{}, false
}
return *token, exp, true
return token, exp, true
}
// saveAPICToken stores the given JWT token in the local database under the "apic_token" config item.
@ -310,6 +312,7 @@ func saveAPICToken(ctx context.Context, db *database.Client, token string) error
func (a *apic) Authenticate(ctx context.Context, config *csconfig.OnlineApiClientCfg) error {
if token, exp, valid := loadAPICToken(ctx, a.dbClient); valid {
log.Debug("using valid token from DB")
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = token
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration = exp
}
@ -1043,7 +1046,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
var (
lastPullTimestamp *string
lastPullTimestamp string
err error
)
@ -1060,10 +1063,10 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
}
if !hasChanged {
if lastPullTimestamp == nil {
if lastPullTimestamp == "" {
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
} else {
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, lastPullTimestamp)
}
return nil

View file

@ -99,7 +99,7 @@ func assertTotalValidDecisionCount(t *testing.T, dbClient *database.Client, coun
assert.Len(t, d, count)
}
func jsonMarshalX(v interface{}) []byte {
func jsonMarshalX(v any) []byte {
data, err := json.Marshal(v)
if err != nil {
panic(err)
@ -932,7 +932,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
blocklistConfigItemName := "blocklist:blocklist1:last_pull"
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.NotEmpty(t, *lastPullTimestamp)
assert.NotEmpty(t, lastPullTimestamp)
// new call should return 304 and should not change lastPullTimestamp
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
@ -944,7 +944,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
require.NoError(t, err)
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
require.NoError(t, err)
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
assert.Equal(t, lastPullTimestamp, secondLastPullTimestamp)
}
func TestAPICPullTopBLCacheForceCall(t *testing.T) {

View file

@ -22,9 +22,10 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/types"
)
const SyncInterval = time.Second * 10
const PapiPullKey = "papi:last_pull"
const (
SyncInterval = time.Second * 10
PapiPullKey = "papi:last_pull"
)
var operationMap = map[string]func(*Message, *Papi, bool) error{
"decision": DecisionCmd,
@ -48,7 +49,7 @@ type Source struct {
type Message struct {
Header *Header
Data interface{} `json:"data"`
Data any `json:"data"`
}
type OperationChannels struct {
@ -240,7 +241,7 @@ func (p *Papi) Pull(ctx context.Context) error {
}
// value doesn't exist, it's first time we're pulling
if lastTimestampStr == nil {
if lastTimestampStr == "" {
binTime, err := lastTimestamp.MarshalText()
if err != nil {
return fmt.Errorf("failed to serialize last timestamp: %w", err)
@ -252,7 +253,7 @@ func (p *Papi) Pull(ctx context.Context) error {
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
}
} else {
if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil {
if err := lastTimestamp.UnmarshalText([]byte(lastTimestampStr)); err != nil {
return fmt.Errorf("failed to parse last timestamp: %w", err)
}
}

View file

@ -9,27 +9,30 @@ import (
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
)
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
func (c *Client) GetConfigItem(ctx context.Context, key string) (string, error) {
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
if err != nil && ent.IsNotFound(err) {
return nil, nil
}
if err != nil {
return nil, errors.Wrapf(QueryFail, "select config item: %s", err)
switch {
case ent.IsNotFound(err):
return "", nil
case err != nil:
return "", errors.Wrapf(QueryFail, "select config item: %s", err)
default:
return result.Value, nil
}
return &result.Value, nil
}
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create
switch {
case ent.IsNotFound(err) || nbUpdated == 0:
// not found, create
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
if err != nil {
return errors.Wrapf(QueryFail, "insert config item: %s", err)
}
} else if err != nil {
case err != nil:
return errors.Wrapf(QueryFail, "update config item: %s", err)
}