diff --git a/pkg/utils/yaml_utils/yaml_utils.go b/pkg/utils/yaml_utils/yaml_utils.go index 9d96fa7a7..48f70fff0 100644 --- a/pkg/utils/yaml_utils/yaml_utils.go +++ b/pkg/utils/yaml_utils/yaml_utils.go @@ -153,3 +153,74 @@ func renameYamlKey(node *yaml.Node, path []string, newKey string) (bool, error) return renameYamlKey(valueNode, path[1:], newKey) } + +// Traverses a yaml document, calling the callback function for each node. The +// callback is allowed to modify the node in place, in which case it should +// return true. The function returns the original yaml document if none of the +// callbacks returned true, and the modified document otherwise. +func Walk(yamlBytes []byte, callback func(node *yaml.Node, path string) bool) ([]byte, error) { + // Parse the YAML file. + var node yaml.Node + err := yaml.Unmarshal(yamlBytes, &node) + if err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + + // Empty document: nothing to do. + if len(node.Content) == 0 { + return yamlBytes, nil + } + + body := node.Content[0] + + if didChange, err := walk(body, "", callback); err != nil || !didChange { + return yamlBytes, err + } + + // Convert the updated YAML node back to YAML bytes. + updatedYAMLBytes, err := yaml.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to convert YAML node to bytes: %w", err) + } + + return updatedYAMLBytes, nil +} + +func walk(node *yaml.Node, path string, callback func(*yaml.Node, string) bool) (bool, error) { + didChange := callback(node, path) + switch node.Kind { + case yaml.DocumentNode: + return false, fmt.Errorf("Unexpected document node in the middle of a yaml tree") + case yaml.MappingNode: + for i := 0; i < len(node.Content); i += 2 { + name := node.Content[i].Value + childNode := node.Content[i+1] + var childPath string + if path == "" { + childPath = name + } else { + childPath = fmt.Sprintf("%s.%s", path, name) + } + didChangeChild, err := walk(childNode, childPath, callback) + if err != nil { + return false, err + } + didChange = didChange || didChangeChild + } + case yaml.SequenceNode: + for i := 0; i < len(node.Content); i++ { + childPath := fmt.Sprintf("%s[%d]", path, i) + didChangeChild, err := walk(node.Content[i], childPath, callback) + if err != nil { + return false, err + } + didChange = didChange || didChangeChild + } + case yaml.ScalarNode: + // nothing to do + case yaml.AliasNode: + return false, fmt.Errorf("Alias nodes are not supported") + } + + return didChange, nil +} diff --git a/pkg/utils/yaml_utils/yaml_utils_test.go b/pkg/utils/yaml_utils/yaml_utils_test.go index 7f9dc20f7..0b445a7ab 100644 --- a/pkg/utils/yaml_utils/yaml_utils_test.go +++ b/pkg/utils/yaml_utils/yaml_utils_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "gopkg.in/yaml.v3" ) func TestUpdateYamlValue(t *testing.T) { @@ -199,3 +200,82 @@ func TestRenameYamlKey(t *testing.T) { }) } } + +func TestWalk_paths(t *testing.T) { + tests := []struct { + name string + document string + expectedPaths []string + }{ + { + name: "empty document", + document: "", + expectedPaths: []string{}, + }, + { + name: "scalar", + document: "x: 5", + expectedPaths: []string{"", "x"}, // called with an empty path for the root node + }, + { + name: "nested", + document: "foo:\n x: 5", + expectedPaths: []string{"", "foo", "foo.x"}, + }, + { + name: "array", + document: "foo:\n bar: [3, 7]", + expectedPaths: []string{"", "foo", "foo.bar", "foo.bar[0]", "foo.bar[1]"}, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + paths := []string{} + _, err := Walk([]byte(test.document), func(node *yaml.Node, path string) bool { + paths = append(paths, path) + return true + }) + + assert.NoError(t, err) + assert.Equal(t, test.expectedPaths, paths) + }) + } +} + +func TestWalk_inPlaceChanges(t *testing.T) { + tests := []struct { + name string + in string + callback func(node *yaml.Node, path string) bool + expectedOut string + }{ + { + name: "no change", + in: "x: 5", + callback: func(node *yaml.Node, path string) bool { return false }, + expectedOut: "x: 5", + }, + { + name: "change value", + in: "x: 5\ny: 3", + callback: func(node *yaml.Node, path string) bool { + if path == "x" { + node.Value = "7" + return true + } + return false + }, + expectedOut: "x: 7\ny: 3\n", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result, err := Walk([]byte(test.in), test.callback) + + assert.NoError(t, err) + assert.Equal(t, test.expectedOut, string(result)) + }) + } +}