mirror of
https://github.com/crowdsecurity/crowdsec.git
synced 2025-05-10 20:05:55 +02:00
refact pkg/database: unnecessary pointers
This commit is contained in:
parent
e6b85b641c
commit
4ba8c03774
7 changed files with 38 additions and 39 deletions
|
@ -11,8 +11,6 @@ import (
|
|||
"github.com/spf13/cobra"
|
||||
"gopkg.in/tomb.v2"
|
||||
|
||||
"github.com/crowdsecurity/go-cs-lib/ptr"
|
||||
|
||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/args"
|
||||
"github.com/crowdsecurity/crowdsec/cmd/crowdsec-cli/require"
|
||||
"github.com/crowdsecurity/crowdsec/pkg/apiserver"
|
||||
|
@ -76,17 +74,17 @@ func (cli *cliPapi) Status(ctx context.Context, out io.Writer, db *database.Clie
|
|||
|
||||
lastTimestampStr, err := db.GetConfigItem(ctx, apiserver.PapiPullKey)
|
||||
if err != nil {
|
||||
lastTimestampStr = ptr.Of("never")
|
||||
lastTimestampStr = "never"
|
||||
}
|
||||
|
||||
// both can and did happen
|
||||
if lastTimestampStr == nil || *lastTimestampStr == "0001-01-01T00:00:00Z" {
|
||||
lastTimestampStr = ptr.Of("never")
|
||||
if lastTimestampStr == "" || lastTimestampStr == "0001-01-01T00:00:00Z" {
|
||||
lastTimestampStr = "never"
|
||||
}
|
||||
|
||||
fmt.Fprint(out, "You can successfully interact with Polling API (PAPI)\n")
|
||||
fmt.Fprintf(out, "Console plan: %s\n", perms.Plan)
|
||||
fmt.Fprintf(out, "Last order received: %s\n", *lastTimestampStr)
|
||||
fmt.Fprintf(out, "Last order received: %s\n", lastTimestampStr)
|
||||
fmt.Fprint(out, "PAPI subscriptions:\n")
|
||||
|
||||
for _, sub := range perms.Categories {
|
||||
|
|
|
@ -174,7 +174,7 @@ func (s *DecisionsService) FetchV3Decisions(ctx context.Context, url string) (*m
|
|||
return &v2Decisions, resp, nil
|
||||
}
|
||||
|
||||
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp *string) ([]*models.Decision, bool, error) {
|
||||
func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blocklist *modelscapi.BlocklistLink, lastPullTimestamp string) ([]*models.Decision, bool, error) {
|
||||
if blocklist.URL == nil {
|
||||
return nil, false, errors.New("blocklist URL is nil")
|
||||
}
|
||||
|
@ -188,8 +188,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
|||
return nil, false, err
|
||||
}
|
||||
|
||||
if lastPullTimestamp != nil {
|
||||
req.Header.Set("If-Modified-Since", *lastPullTimestamp)
|
||||
if lastPullTimestamp != "" {
|
||||
req.Header.Set("If-Modified-Since", lastPullTimestamp)
|
||||
}
|
||||
|
||||
log.Debugf("[URL] %s %s", req.Method, req.URL)
|
||||
|
@ -217,8 +217,8 @@ func (s *DecisionsService) GetDecisionsFromBlocklist(ctx context.Context, blockl
|
|||
}
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
if lastPullTimestamp != nil {
|
||||
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, *lastPullTimestamp)
|
||||
if lastPullTimestamp != "" {
|
||||
log.Debugf("Blocklist %s has not been modified since %s", *blocklist.URL, lastPullTimestamp)
|
||||
} else {
|
||||
log.Debugf("Blocklist %s has not been modified (decisions about to expire)", *blocklist.URL)
|
||||
}
|
||||
|
|
|
@ -362,7 +362,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
|||
Remediation: &tremediationBlocklist,
|
||||
Name: &tnameBlocklist,
|
||||
Duration: &tdurationBlocklist,
|
||||
}, nil)
|
||||
}, "")
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isModified)
|
||||
|
||||
|
@ -381,7 +381,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
|||
Remediation: &tremediationBlocklist,
|
||||
Name: &tnameBlocklist,
|
||||
Duration: &tdurationBlocklist,
|
||||
}, ptr.Of("Sun, 01 Jan 2023 01:01:01 GMT"))
|
||||
}, "Sun, 01 Jan 2023 01:01:01 GMT")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, isModified)
|
||||
|
@ -392,7 +392,7 @@ func TestDecisionsFromBlocklist(t *testing.T) {
|
|||
Remediation: &tremediationBlocklist,
|
||||
Name: &tnameBlocklist,
|
||||
Duration: &tdurationBlocklist,
|
||||
}, ptr.Of("Mon, 02 Jan 2023 01:01:01 GMT"))
|
||||
}, "Mon, 02 Jan 2023 01:01:01 GMT")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, isModified)
|
||||
|
|
|
@ -260,13 +260,13 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
|
|||
return "", time.Time{}, false
|
||||
}
|
||||
|
||||
if token == nil {
|
||||
if token == "" {
|
||||
log.Debug("no token found in DB")
|
||||
return "", time.Time{}, false
|
||||
}
|
||||
|
||||
parser := new(jwt.Parser)
|
||||
tok, _, err := parser.ParseUnverified(*token, jwt.MapClaims{})
|
||||
tok, _, err := parser.ParseUnverified(token, jwt.MapClaims{})
|
||||
if err != nil {
|
||||
log.Debugf("error parsing token: %s", err)
|
||||
return "", time.Time{}, false
|
||||
|
@ -290,7 +290,7 @@ func loadAPICToken(ctx context.Context, db *database.Client) (string, time.Time,
|
|||
return "", time.Time{}, false
|
||||
}
|
||||
|
||||
return *token, exp, true
|
||||
return token, exp, true
|
||||
}
|
||||
|
||||
// saveAPICToken stores the given JWT token in the local database under the "apic_token" config item.
|
||||
|
@ -1043,7 +1043,7 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
|
|||
blocklistConfigItemName := fmt.Sprintf("blocklist:%s:last_pull", *blocklist.Name)
|
||||
|
||||
var (
|
||||
lastPullTimestamp *string
|
||||
lastPullTimestamp string
|
||||
err error
|
||||
)
|
||||
|
||||
|
@ -1060,10 +1060,10 @@ func (a *apic) updateBlocklist(ctx context.Context, client *apiclient.ApiClient,
|
|||
}
|
||||
|
||||
if !hasChanged {
|
||||
if lastPullTimestamp == nil {
|
||||
if lastPullTimestamp == "" {
|
||||
log.Infof("blocklist %s hasn't been modified or there was an error reading it, skipping", *blocklist.Name)
|
||||
} else {
|
||||
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, *lastPullTimestamp)
|
||||
log.Infof("blocklist %s hasn't been modified since %s, skipping", *blocklist.Name, lastPullTimestamp)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -932,7 +932,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
|||
blocklistConfigItemName := "blocklist:blocklist1:last_pull"
|
||||
lastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, *lastPullTimestamp)
|
||||
assert.NotEmpty(t, lastPullTimestamp)
|
||||
|
||||
// new call should return 304 and should not change lastPullTimestamp
|
||||
httpmock.RegisterResponder("GET", "http://api.crowdsec.net/blocklist1", func(req *http.Request) (*http.Response, error) {
|
||||
|
@ -944,7 +944,7 @@ func TestAPICPullTopBLCacheFirstCall(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
secondLastPullTimestamp, err := api.dbClient.GetConfigItem(ctx, blocklistConfigItemName)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *lastPullTimestamp, *secondLastPullTimestamp)
|
||||
assert.Equal(t, lastPullTimestamp, secondLastPullTimestamp)
|
||||
}
|
||||
|
||||
func TestAPICPullTopBLCacheForceCall(t *testing.T) {
|
||||
|
|
|
@ -22,9 +22,10 @@ import (
|
|||
"github.com/crowdsecurity/crowdsec/pkg/types"
|
||||
)
|
||||
|
||||
const SyncInterval = time.Second * 10
|
||||
|
||||
const PapiPullKey = "papi:last_pull"
|
||||
const (
|
||||
SyncInterval = time.Second * 10
|
||||
PapiPullKey = "papi:last_pull"
|
||||
)
|
||||
|
||||
var operationMap = map[string]func(*Message, *Papi, bool) error{
|
||||
"decision": DecisionCmd,
|
||||
|
@ -240,7 +241,7 @@ func (p *Papi) Pull(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// value doesn't exist, it's first time we're pulling
|
||||
if lastTimestampStr == nil {
|
||||
if lastTimestampStr == "" {
|
||||
binTime, err := lastTimestamp.MarshalText()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize last timestamp: %w", err)
|
||||
|
@ -252,7 +253,7 @@ func (p *Papi) Pull(ctx context.Context) error {
|
|||
p.Logger.Debugf("config item '%s' set in database with value '%s'", PapiPullKey, string(binTime))
|
||||
}
|
||||
} else {
|
||||
if err := lastTimestamp.UnmarshalText([]byte(*lastTimestampStr)); err != nil {
|
||||
if err := lastTimestamp.UnmarshalText([]byte(lastTimestampStr)); err != nil {
|
||||
return fmt.Errorf("failed to parse last timestamp: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,29 +9,29 @@ import (
|
|||
"github.com/crowdsecurity/crowdsec/pkg/database/ent/configitem"
|
||||
)
|
||||
|
||||
func (c *Client) GetConfigItem(ctx context.Context, key string) (*string, error) {
|
||||
func (c *Client) GetConfigItem(ctx context.Context, key string) (string, error) {
|
||||
result, err := c.Ent.ConfigItem.Query().Where(configitem.NameEQ(key)).First(ctx)
|
||||
if err != nil && ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
switch {
|
||||
case ent.IsNotFound(err):
|
||||
return "", nil
|
||||
case err != nil:
|
||||
return "", errors.Wrapf(QueryFail, "select config item: %s", err)
|
||||
default:
|
||||
return result.Value, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(QueryFail, "select config item: %s", err)
|
||||
}
|
||||
|
||||
return &result.Value, nil
|
||||
}
|
||||
|
||||
func (c *Client) SetConfigItem(ctx context.Context, key string, value string) error {
|
||||
nbUpdated, err := c.Ent.ConfigItem.Update().SetValue(value).Where(configitem.NameEQ(key)).Save(ctx)
|
||||
if (err != nil && ent.IsNotFound(err)) || nbUpdated == 0 { // not found, create
|
||||
switch {
|
||||
case ent.IsNotFound(err) || nbUpdated == 0:
|
||||
// not found, create
|
||||
err := c.Ent.ConfigItem.Create().SetName(key).SetValue(value).Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrapf(QueryFail, "insert config item: %s", err)
|
||||
}
|
||||
} else if err != nil {
|
||||
case err != nil:
|
||||
return errors.Wrapf(QueryFail, "update config item: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue