pkg/cwhub: use explicit context for item install, upgrade (#3067)

This commit is contained in:
mmetc 2024-06-07 17:32:52 +02:00 committed by GitHub
parent cad760e605
commit dd6cf2b844
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 72 additions and 55 deletions

View file

@ -50,7 +50,7 @@ func (cli *cliConfig) restoreHub(ctx context.Context, dirPath string) error {
continue
}
if err = item.Install(false, false); err != nil {
if err = item.Install(ctx, false, false); err != nil {
log.Errorf("Error while installing %s : %s", toinstall, err)
}
}

View file

@ -158,7 +158,7 @@ func (cli *cliHub) upgrade(ctx context.Context, force bool) error {
log.Infof("Upgrading %s", itemType)
for _, item := range items {
didUpdate, err := item.Upgrade(force)
didUpdate, err := item.Upgrade(ctx, force)
if err != nil {
return err
}

View file

@ -83,7 +83,7 @@ func (cli cliItem) install(ctx context.Context, args []string, downloadOnly bool
continue
}
if err := item.Install(force, downloadOnly); err != nil {
if err := item.Install(ctx, force, downloadOnly); err != nil {
if !ignoreError {
return fmt.Errorf("error while installing '%s': %w", item.Name, err)
}
@ -270,7 +270,7 @@ func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all b
updated := 0
for _, item := range items {
didUpdate, err := item.Upgrade(force)
didUpdate, err := item.Upgrade(ctx, force)
if err != nil {
return err
}
@ -301,7 +301,7 @@ func (cli cliItem) upgrade(ctx context.Context, args []string, force bool, all b
return fmt.Errorf("can't find '%s' in %s", itemName, cli.name)
}
didUpdate, err := item.Upgrade(force)
didUpdate, err := item.Upgrade(ctx, force)
if err != nil {
return err
}
@ -376,7 +376,7 @@ func (cli cliItem) inspect(ctx context.Context, args []string, url string, diff
}
if diff {
fmt.Println(cli.whyTainted(hub, item, rev))
fmt.Println(cli.whyTainted(ctx, hub, item, rev))
continue
}
@ -466,7 +466,7 @@ func (cli cliItem) newListCmd() *cobra.Command {
}
// return the diff between the installed version and the latest version
func (cli cliItem) itemDiff(item *cwhub.Item, reverse bool) (string, error) {
func (cli cliItem) itemDiff(ctx context.Context, item *cwhub.Item, reverse bool) (string, error) {
if !item.State.Installed {
return "", fmt.Errorf("'%s' is not installed", item.FQName())
}
@ -477,7 +477,7 @@ func (cli cliItem) itemDiff(item *cwhub.Item, reverse bool) (string, error) {
}
defer os.Remove(dest.Name())
_, remoteURL, err := item.FetchContentTo(dest.Name())
_, remoteURL, err := item.FetchContentTo(ctx, dest.Name())
if err != nil {
return "", err
}
@ -508,7 +508,7 @@ func (cli cliItem) itemDiff(item *cwhub.Item, reverse bool) (string, error) {
return fmt.Sprintf("%s", diff), nil
}
func (cli cliItem) whyTainted(hub *cwhub.Hub, item *cwhub.Item, reverse bool) string {
func (cli cliItem) whyTainted(ctx context.Context, hub *cwhub.Hub, item *cwhub.Item, reverse bool) string {
if !item.State.Installed {
return fmt.Sprintf("# %s is not installed", item.FQName())
}
@ -533,7 +533,7 @@ func (cli cliItem) whyTainted(hub *cwhub.Hub, item *cwhub.Item, reverse bool) st
ret = append(ret, err.Error())
}
diff, err := cli.itemDiff(sub, reverse)
diff, err := cli.itemDiff(ctx, sub, reverse)
if err != nil {
ret = append(ret, err.Error())
}

View file

@ -320,7 +320,7 @@ func runSetupInstallHub(cmd *cobra.Command, args []string) error {
return err
}
return setup.InstallHubItems(hub, input, dryRun)
return setup.InstallHubItems(cmd.Context(), hub, input, dryRun)
}
func runSetupValidate(cmd *cobra.Command, args []string) error {

View file

@ -67,7 +67,8 @@ func testHub(t *testing.T, update bool) *Hub {
require.NoError(t, err)
if update {
err := hub.Update(context.TODO())
ctx := context.Background()
err := hub.Update(ctx)
require.NoError(t, err)
}

View file

@ -21,7 +21,7 @@ type DataSet struct {
}
// downloadDataSet downloads all the data files for an item.
func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error {
func downloadDataSet(ctx context.Context, dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error {
dec := yaml.NewDecoder(reader)
for {
@ -53,8 +53,6 @@ func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *lo
WithShelfLife(7 * 24 * time.Hour)
}
ctx := context.TODO()
downloaded, err := d.Download(ctx, dataS.SourceURL)
if err != nil {
return fmt.Errorf("while getting data: %w", err)

View file

@ -22,7 +22,9 @@ func TestInitHubUpdate(t *testing.T) {
_, err := NewHub(hub.local, remote, nil)
require.NoError(t, err)
err = hub.Update(context.TODO())
ctx := context.Background()
err = hub.Update(ctx)
require.NoError(t, err)
err = hub.Load()
@ -54,7 +56,9 @@ func TestUpdateIndex(t *testing.T) {
hub.local.HubIndexFile = tmpIndex.Name()
err = hub.Update(context.TODO())
ctx := context.Background()
err = hub.Update(ctx)
cstest.RequireErrorContains(t, err, "failed to build hub index request: invalid URL template 'x'")
// bad domain
@ -66,7 +70,7 @@ func TestUpdateIndex(t *testing.T) {
IndexPath: ".index.json",
}
err = hub.Update(context.TODO())
err = hub.Update(ctx)
require.NoError(t, err)
// XXX: this is not failing
// cstest.RequireErrorContains(t, err, "failed http request for hub index: Get")
@ -82,6 +86,6 @@ func TestUpdateIndex(t *testing.T) {
hub.local.HubIndexFile = "/does/not/exist/index.json"
err = hub.Update(context.TODO())
err = hub.Update(ctx)
cstest.RequireErrorContains(t, err, "failed to create temporary download file for /does/not/exist/index.json:")
}

View file

@ -1,6 +1,7 @@
package cwhub
import (
"context"
"fmt"
)
@ -39,7 +40,7 @@ func (i *Item) enable() error {
}
// Install installs the item from the hub, downloading it if needed.
func (i *Item) Install(force bool, downloadOnly bool) error {
func (i *Item) Install(ctx context.Context, force bool, downloadOnly bool) error {
if downloadOnly && i.State.Downloaded && i.State.UpToDate {
i.hub.logger.Infof("%s is already downloaded and up-to-date", i.Name)
@ -48,7 +49,7 @@ func (i *Item) Install(force bool, downloadOnly bool) error {
}
}
downloaded, err := i.downloadLatest(force, true)
downloaded, err := i.downloadLatest(ctx, force, true)
if err != nil {
return err
}

View file

@ -1,6 +1,7 @@
package cwhub
import (
"context"
"os"
"testing"
@ -9,8 +10,10 @@ import (
)
func testInstall(hub *Hub, t *testing.T, item *Item) {
ctx := context.Background()
// Install the parser
_, err := item.downloadLatest(false, false)
_, err := item.downloadLatest(ctx, false, false)
require.NoError(t, err, "failed to download %s", item.Name)
err = hub.localSync()
@ -48,8 +51,10 @@ func testTaint(hub *Hub, t *testing.T, item *Item) {
func testUpdate(hub *Hub, t *testing.T, item *Item) {
assert.False(t, item.State.UpToDate, "%s should not be up-to-date", item.Name)
ctx := context.Background()
// Update it + check status
_, err := item.downloadLatest(true, true)
_, err := item.downloadLatest(ctx, true, true)
require.NoError(t, err, "failed to update %s", item.Name)
// Local sync and check status

View file

@ -16,7 +16,7 @@ import (
)
// Upgrade downloads and applies the last version of the item from the hub.
func (i *Item) Upgrade(force bool) (bool, error) {
func (i *Item) Upgrade(ctx context.Context, force bool) (bool, error) {
if i.State.IsLocal() {
i.hub.logger.Infof("not upgrading %s: local item", i.Name)
return false, nil
@ -33,7 +33,7 @@ func (i *Item) Upgrade(force bool) (bool, error) {
if i.State.UpToDate {
i.hub.logger.Infof("%s: up-to-date", i.Name)
if err := i.DownloadDataIfNeeded(force); err != nil {
if err := i.DownloadDataIfNeeded(ctx, force); err != nil {
return false, fmt.Errorf("%s: download failed: %w", i.Name, err)
}
@ -43,7 +43,7 @@ func (i *Item) Upgrade(force bool) (bool, error) {
}
}
if _, err := i.downloadLatest(force, true); err != nil {
if _, err := i.downloadLatest(ctx, force, true); err != nil {
return false, fmt.Errorf("%s: download failed: %w", i.Name, err)
}
@ -65,7 +65,7 @@ func (i *Item) Upgrade(force bool) (bool, error) {
}
// downloadLatest downloads the latest version of the item to the hub directory.
func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (bool, error) {
func (i *Item) downloadLatest(ctx context.Context, overwrite bool, updateOnly bool) (bool, error) {
i.hub.logger.Debugf("Downloading %s %s", i.Type, i.Name)
for _, sub := range i.SubItems() {
@ -80,14 +80,14 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (bool, error) {
if sub.HasSubItems() {
i.hub.logger.Tracef("collection, recurse")
if _, err := sub.downloadLatest(overwrite, updateOnly); err != nil {
if _, err := sub.downloadLatest(ctx, overwrite, updateOnly); err != nil {
return false, err
}
}
downloaded := sub.State.Downloaded
if _, err := sub.download(overwrite); err != nil {
if _, err := sub.download(ctx, overwrite); err != nil {
return false, err
}
@ -105,11 +105,11 @@ func (i *Item) downloadLatest(overwrite bool, updateOnly bool) (bool, error) {
return false, nil
}
return i.download(overwrite)
return i.download(ctx, overwrite)
}
// FetchContentTo downloads the last version of the item's YAML file to the specified path.
func (i *Item) FetchContentTo(destPath string) (bool, string, error) {
func (i *Item) FetchContentTo(ctx context.Context, destPath string) (bool, string, error) {
url, err := i.hub.remote.urlTo(i.RemotePath)
if err != nil {
return false, "", fmt.Errorf("failed to build request: %w", err)
@ -131,8 +131,6 @@ func (i *Item) FetchContentTo(destPath string) (bool, string, error) {
// TODO: recommend hub update if hash does not match
ctx := context.TODO()
downloaded, err := d.Download(ctx, url)
if err != nil {
return false, "", fmt.Errorf("while downloading %s to %s: %w", i.Name, url, err)
@ -142,7 +140,7 @@ func (i *Item) FetchContentTo(destPath string) (bool, string, error) {
}
// download downloads the item from the hub and writes it to the hub directory.
func (i *Item) download(overwrite bool) (bool, error) {
func (i *Item) download(ctx context.Context, overwrite bool) (bool, error) {
// ensure that target file is within target dir
finalPath, err := i.downloadPath()
if err != nil {
@ -167,7 +165,7 @@ func (i *Item) download(overwrite bool) (bool, error) {
}
}
downloaded, _, err := i.FetchContentTo(finalPath)
downloaded, _, err := i.FetchContentTo(ctx, finalPath)
if err != nil {
return false, fmt.Errorf("while downloading %s: %w", i.Name, err)
}
@ -188,7 +186,7 @@ func (i *Item) download(overwrite bool) (bool, error) {
defer reader.Close()
if err = downloadDataSet(i.hub.local.InstallDataDir, overwrite, reader, i.hub.logger); err != nil {
if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, overwrite, reader, i.hub.logger); err != nil {
return false, fmt.Errorf("while downloading data for %s: %w", i.FileName, err)
}
@ -196,7 +194,7 @@ func (i *Item) download(overwrite bool) (bool, error) {
}
// DownloadDataIfNeeded downloads the data set for the item.
func (i *Item) DownloadDataIfNeeded(force bool) error {
func (i *Item) DownloadDataIfNeeded(ctx context.Context, force bool) error {
itemFilePath, err := i.installPath()
if err != nil {
return err
@ -209,7 +207,7 @@ func (i *Item) DownloadDataIfNeeded(force bool) error {
defer itemFile.Close()
if err = downloadDataSet(i.hub.local.InstallDataDir, force, itemFile, i.hub.logger); err != nil {
if err = downloadDataSet(ctx, i.hub.local.InstallDataDir, force, itemFile, i.hub.logger); err != nil {
return fmt.Errorf("while downloading data for %s: %w", itemFilePath, err)
}

View file

@ -19,7 +19,9 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) {
require.False(t, item.State.Downloaded)
require.False(t, item.State.Installed)
require.NoError(t, item.Install(false, false))
ctx := context.Background()
require.NoError(t, item.Install(ctx, false, false))
require.True(t, item.State.Downloaded)
require.True(t, item.State.Installed)
@ -43,7 +45,7 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) {
hub, err := NewHub(hub.local, remote, nil)
require.NoError(t, err)
err = hub.Update(context.TODO())
err = hub.Update(ctx)
require.NoError(t, err)
err = hub.Load()
@ -58,7 +60,7 @@ func TestUpgradeItemNewScenarioInCollection(t *testing.T) {
require.False(t, item.State.UpToDate)
require.False(t, item.State.Tainted)
didUpdate, err := item.Upgrade(false)
didUpdate, err := item.Upgrade(ctx, false)
require.NoError(t, err)
require.True(t, didUpdate)
assertCollectionDepsInstalled(t, hub, "crowdsecurity/test_collection")
@ -78,7 +80,9 @@ func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) {
require.False(t, item.State.Installed)
require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed)
require.NoError(t, item.Install(false, false))
ctx := context.Background()
require.NoError(t, item.Install(ctx, false, false))
require.True(t, item.State.Downloaded)
require.True(t, item.State.Installed)
@ -110,14 +114,14 @@ func TestUpgradeItemInDisabledScenarioShouldNotBeInstalled(t *testing.T) {
hub, err = NewHub(hub.local, remote, nil)
require.NoError(t, err)
err = hub.Update(context.TODO())
err = hub.Update(ctx)
require.NoError(t, err)
err = hub.Load()
require.NoError(t, err)
item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection")
didUpdate, err := item.Upgrade(false)
didUpdate, err := item.Upgrade(ctx, false)
require.NoError(t, err)
require.False(t, didUpdate)
@ -148,7 +152,9 @@ func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *te
require.False(t, item.State.Installed)
require.False(t, hub.GetItem(SCENARIOS, "crowdsecurity/foobar_scenario").State.Installed)
require.NoError(t, item.Install(false, false))
ctx := context.Background()
require.NoError(t, item.Install(ctx, false, false))
require.True(t, item.State.Downloaded)
require.True(t, item.State.Installed)
@ -185,7 +191,7 @@ func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *te
hub, err = NewHub(hub.local, remote, nil)
require.NoError(t, err)
err = hub.Update(context.TODO())
err = hub.Update(ctx)
require.NoError(t, err)
err = hub.Load()
@ -195,7 +201,7 @@ func TestUpgradeItemNewScenarioIsInstalledWhenReferencedScenarioIsDisabled(t *te
hub = getHubOrFail(t, hub.local, remote)
item = hub.GetItem(COLLECTIONS, "crowdsecurity/test_collection")
didUpdate, err := item.Upgrade(false)
didUpdate, err := item.Upgrade(ctx, false)
require.NoError(t, err)
require.True(t, didUpdate)

View file

@ -1,6 +1,7 @@
package hubtest
import (
"context"
"errors"
"fmt"
"net/url"
@ -219,11 +220,13 @@ func (t *HubTestItem) InstallHub() error {
return err
}
ctx := context.Background()
// install data for parsers if needed
ret := hub.GetItemMap(cwhub.PARSERS)
for parserName, item := range ret {
if item.State.Installed {
if err := item.DownloadDataIfNeeded(true); err != nil {
if err := item.DownloadDataIfNeeded(ctx, true); err != nil {
return fmt.Errorf("unable to download data for parser '%s': %+v", parserName, err)
}
@ -235,7 +238,7 @@ func (t *HubTestItem) InstallHub() error {
ret = hub.GetItemMap(cwhub.SCENARIOS)
for scenarioName, item := range ret {
if item.State.Installed {
if err := item.DownloadDataIfNeeded(true); err != nil {
if err := item.DownloadDataIfNeeded(ctx, true); err != nil {
return fmt.Errorf("unable to download data for parser '%s': %+v", scenarioName, err)
}
@ -247,7 +250,7 @@ func (t *HubTestItem) InstallHub() error {
ret = hub.GetItemMap(cwhub.POSTOVERFLOWS)
for postoverflowName, item := range ret {
if item.State.Installed {
if err := item.DownloadDataIfNeeded(true); err != nil {
if err := item.DownloadDataIfNeeded(ctx, true); err != nil {
return fmt.Errorf("unable to download data for parser '%s': %+v", postoverflowName, err)
}

View file

@ -2,6 +2,7 @@ package setup
import (
"bytes"
"context"
"errors"
"fmt"
"os"
@ -46,7 +47,7 @@ func decodeSetup(input []byte, fancyErrors bool) (Setup, error) {
}
// InstallHubItems installs the objects recommended in a setup file.
func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error {
func InstallHubItems(ctx context.Context, hub *cwhub.Hub, input []byte, dryRun bool) error {
setupEnvelope, err := decodeSetup(input, false)
if err != nil {
return err
@ -74,7 +75,7 @@ func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error {
continue
}
if err := item.Install(forceAction, downloadOnly); err != nil {
if err := item.Install(ctx, forceAction, downloadOnly); err != nil {
return fmt.Errorf("while installing collection %s: %w", item.Name, err)
}
}
@ -93,7 +94,7 @@ func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error {
return fmt.Errorf("parser %s not found", parser)
}
if err := item.Install(forceAction, downloadOnly); err != nil {
if err := item.Install(ctx, forceAction, downloadOnly); err != nil {
return fmt.Errorf("while installing parser %s: %w", item.Name, err)
}
}
@ -112,7 +113,7 @@ func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error {
return fmt.Errorf("scenario %s not found", scenario)
}
if err := item.Install(forceAction, downloadOnly); err != nil {
if err := item.Install(ctx, forceAction, downloadOnly); err != nil {
return fmt.Errorf("while installing scenario %s: %w", item.Name, err)
}
}
@ -131,7 +132,7 @@ func InstallHubItems(hub *cwhub.Hub, input []byte, dryRun bool) error {
return fmt.Errorf("postoverflow %s not found", postoverflow)
}
if err := item.Install(forceAction, downloadOnly); err != nil {
if err := item.Install(ctx, forceAction, downloadOnly); err != nil {
return fmt.Errorf("while installing postoverflow %s: %w", item.Name, err)
}
}