mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
refact pkg/database: unnecessary pointers (#3611)
* refact pkg/database: unnecessary pointers * lint
This commit is contained in:
parent
73a423034f
commit
31b914512a
7 changed files with 45 additions and 40 deletions
|
@ -11,8 +11,6 @@ import (
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"gopkg.in/tomb.v2"
|
"gopkg.in/tomb.v2"
|
||||||
|
|
||||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
|
||||||
|
|
||||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/args"
|
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/args"
|
||||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
|
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
|
||||||
|
@ -76,17 +74,17 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
|
||||||
|
|
||||||
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
|
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastTimestampStr = ptr.Of("never")
|
lastTimestampStr = "never"
|
||||||
}
|
}
|
||||||
|
|
||||||
// both can and did happen
|
// both can and did happen
|
||||||
if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" {
|
if lastTimestampStr == "" || lastTimestampStr == "0001-01-01T00:00:00Z" {
|
||||||
lastTimestampStr = ptr.Of("never")
|
lastTimestampStr = "never"
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n")
|
fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n")
|
||||||
fmt.Fprintf(out, "Console plan: %s\n", perms.Plan)
|
fmt.Fprintf(out, "Console plan: %s\n", perms.Plan)
|
||||||
fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr)
|
fmt.Fprintf(out, "Last order received: %s\n", lastTimestampStr)
|
||||||
fmt.Fprint(out, "PAPI subscriptions:\n")
|
fmt.Fprint(out, "PAPI subscriptions:\n")
|
||||||
|
|
||||||
for _, sub := range perms.Categories {
|
for _, sub := range perms.Categories {
|
||||||
|
|
|
@ -174,7 +174,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
|
||||||
return &v2Decisions, resp, nil
|
return &v2Decisions, resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) {
|
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp string) ([]*models.Decision, bool, error) {
|
||||||
if blocklist.URL == nil {
|
if blocklist.URL == nil {
|
||||||
return nil, false, errors.New("blocklist URL is nil")
|
return nil, false, errors.New("blocklist URL is nil")
|
||||||
}
|
}
|
||||||
|
@ -188,8 +188,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if lastPullTimestamp != nil {
|
if lastPullTimestamp != "" {
|
||||||
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
|
req.Header.Set("If-Modified-Since", lastPullTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
||||||
|
@ -217,8 +217,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotModified {
|
if resp.StatusCode == http.StatusNotModified {
|
||||||
if lastPullTimestamp != nil {
|
if lastPullTimestamp != "" {
|
||||||
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp)
|
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, lastPullTimestamp)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
|
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
|
||||||
}
|
}
|
||||||
|
|
|
@ -362,7 +362,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
||||||
Remediation: &tremediationBlocklist,
|
Remediation: &tremediationBlocklist,
|
||||||
Name: &tnameBlocklist,
|
Name: &tnameBlocklist,
|
||||||
Duration: &tdurationBlocklist,
|
Duration: &tdurationBlocklist,
|
||||||
}, nil)
|
}, "")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, isModified)
|
assert.True(t, isModified)
|
||||||
|
|
||||||
|
@ -381,7 +381,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
||||||
Remediation: &tremediationBlocklist,
|
Remediation: &tremediationBlocklist,
|
||||||
Name: &tnameBlocklist,
|
Name: &tnameBlocklist,
|
||||||
Duration: &tdurationBlocklist,
|
Duration: &tdurationBlocklist,
|
||||||
}, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT"))
|
}, "Sun, 01 Jan 2023 01:01:01 GMT")
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, isModified)
|
assert.False(t, isModified)
|
||||||
|
@ -392,7 +392,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
||||||
Remediation: &tremediationBlocklist,
|
Remediation: &tremediationBlocklist,
|
||||||
Name: &tnameBlocklist,
|
Name: &tnameBlocklist,
|
||||||
Duration: &tdurationBlocklist,
|
Duration: &tdurationBlocklist,
|
||||||
}, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT"))
|
}, "Mon, 02 Jan 2023 01:01:01 GMT")
|
||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, isModified)
|
assert.True(t, isModified)
|
||||||
|
|
|
@ -243,6 +243,7 @@ func NewAPIC(ctx context.Context, config *csconfig.OnlineApiClientCfg, dbClient
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ret.Authenticate(ctx, config)
|
err = ret.Authenticate(ctx, config)
|
||||||
|
|
||||||
return ret, err
|
return ret, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -260,13 +261,14 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
|
||||||
return "", time.Time{}, false
|
return "", time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if token == nil {
|
if token == "" {
|
||||||
log.Debug("no token found in DB")
|
log.Debug("no token found in DB")
|
||||||
return "", time.Time{}, false
|
return "", time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
parser := new(jwt.Parser)
|
parser := new(jwt.Parser)
|
||||||
tok, _, err := parser.ParseUnverified(*token, jwt.MapClaims{})
|
|
||||||
|
tok, _, err := parser.ParseUnverified(token, jwt.MapClaims{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("error parsing token: %s", err)
|
log.Debugf("error parsing token: %s", err)
|
||||||
return "", time.Time{}, false
|
return "", time.Time{}, false
|
||||||
|
@ -285,12 +287,12 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
|
||||||
}
|
}
|
||||||
|
|
||||||
exp := time.Unix(int64(expFloat), 0)
|
exp := time.Unix(int64(expFloat), 0)
|
||||||
if time.Now().UTC().After(exp.Add(-1*time.Minute)) {
|
if time.Now().UTC().After(exp.Add(-1 * time.Minute)) {
|
||||||
log.Debug("auth token expired")
|
log.Debug("auth token expired")
|
||||||
return "", time.Time{}, false
|
return "", time.Time{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return *token, exp, true
|
return token, exp, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// saveAPICToken stores the given JWT token in the local database under the "apic_token" config item.
|
// saveAPICToken stores the given JWT token in the local database under the "apic_token" config item.
|
||||||
|
@ -310,6 +312,7 @@ func saveAPICToken(ctx context.Context, db *database.Client, token string) error
|
||||||
func (a *apic) Authenticate(ctx context.Context, config *csconfig.OnlineApiClientCfg) error {
|
func (a *apic) Authenticate(ctx context.Context, config *csconfig.OnlineApiClientCfg) error {
|
||||||
if token, exp, valid := loadAPICToken(ctx, a.dbClient); valid {
|
if token, exp, valid := loadAPICToken(ctx, a.dbClient); valid {
|
||||||
log.Debug("using valid token from DB")
|
log.Debug("using valid token from DB")
|
||||||
|
|
||||||
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = token
|
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Token = token
|
||||||
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration = exp
|
a.apiClient.GetClient().Transport.(*apiclient.JWTTransport).Expiration = exp
|
||||||
}
|
}
|
||||||
|
@ -1043,7 +1046,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
|
||||||
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
|
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
lastPullTimestamp *string
|
lastPullTimestamp string
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1060,10 +1063,10 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !hasChanged {
|
if !hasChanged {
|
||||||
if lastPullTimestamp == nil {
|
if lastPullTimestamp == "" {
|
||||||
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
|
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
|
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, lastPullTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -99,7 +99,7 @@ func assertTotalValidDecisionCount(t *testing.T, dbClient *database.Client, coun
|
||||||
assert.Len(t, d, count)
|
assert.Len(t, d, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
func jsonMarshalX(v interface{}) []byte {
|
func jsonMarshalX(v any) []byte {
|
||||||
data, err := json.Marshal(v)
|
data, err := json.Marshal(v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -932,7 +932,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
||||||
blocklistConfigItemName := "blocklist:blocklist1:last_pull"
|
blocklistConfigItemName := "blocklist:blocklist1:last_pull"
|
||||||
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotEmpty(t, *lastPullTimestamp)
|
assert.NotEmpty(t, lastPullTimestamp)
|
||||||
|
|
||||||
// new call should return 304 and should not change lastPullTimestamp
|
// new call should return 304 and should not change lastPullTimestamp
|
||||||
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
|
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
|
||||||
|
@ -944,7 +944,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
|
assert.Equal(t, lastPullTimestamp, secondLastPullTimestamp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
||||||
|
|
|
@ -22,9 +22,10 @@ import (
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const SyncInterval = time.Second * 10
|
const (
|
||||||
|
SyncInterval = time.Second * 10
|
||||||
const PapiPullKey = "papi:last_pull"
|
PapiPullKey = "papi:last_pull"
|
||||||
|
)
|
||||||
|
|
||||||
var operationMap = map[string]func(*Message, *Papi, bool) error{
|
var operationMap = map[string]func(*Message, *Papi, bool) error{
|
||||||
"decision": DecisionCmd,
|
"decision": DecisionCmd,
|
||||||
|
@ -48,7 +49,7 @@ type Source struct {
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Header *Header
|
Header *Header
|
||||||
Data interface{} `json:"data"`
|
Data any `json:"data"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OperationChannels struct {
|
type OperationChannels struct {
|
||||||
|
@ -240,7 +241,7 @@ func (p *Papi) Pull(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// value doesn't exist, it's first time we're pulling
|
// value doesn't exist, it's first time we're pulling
|
||||||
if lastTimestampStr == nil {
|
if lastTimestampStr == "" {
|
||||||
binTime, err := lastTimestamp.MarshalText()
|
binTime, err := lastTimestamp.MarshalText()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to serialize last timestamp: %w", err)
|
return fmt.Errorf("failed to serialize last timestamp: %w", err)
|
||||||
|
@ -252,7 +253,7 @@ func (p *Papi) Pull(ctx context.Context) error {
|
||||||
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
|
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil {
|
if err := lastTimestamp.UnmarshalText([]byte(lastTimestampStr)); err != nil {
|
||||||
return fmt.Errorf("failed to parse last timestamp: %w", err)
|
return fmt.Errorf("failed to parse last timestamp: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,27 +9,30 @@ import (
|
||||||
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
|
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
|
func (c *Client) GetConfigItem(ctx context.Context, key string) (string, error) {
|
||||||
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
|
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
|
||||||
if err != nil && ent.IsNotFound(err) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
switch {
|
||||||
return nil, errors.Wrapf(QueryFail, "select config item: %s", err)
|
case ent.IsNotFound(err):
|
||||||
|
return "", nil
|
||||||
|
case err != nil:
|
||||||
|
return "", errors.Wrapf(QueryFail, "select config item: %s", err)
|
||||||
|
default:
|
||||||
|
return result.Value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return &result.Value, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
|
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
|
||||||
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
|
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
|
||||||
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create
|
|
||||||
|
switch {
|
||||||
|
case ent.IsNotFound(err) || nbUpdated == 0:
|
||||||
|
// not found, create
|
||||||
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
|
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrapf(QueryFail, "insert config item: %s", err)
|
return errors.Wrapf(QueryFail, "insert config item: %s", err)
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
case err != nil:
|
||||||
return errors.Wrapf(QueryFail, "update config item: %s", err)
|
return errors.Wrapf(QueryFail, "update config item: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue