From 04cd2ff3f95c14049ee39c4a1add138785d04d2e Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sun, 29 Oct 2023 15:14:23 +0300 Subject: [PATCH] fix(search): Support indexing array paths (#2074) * fix(search): Support indexing array paths Signed-off-by: Vladislav Oleshko --------- Signed-off-by: Vladislav Oleshko --- src/core/search/base.h | 5 +- src/core/search/indices.cc | 44 ++++++++++++------ src/core/search/search_test.cc | 6 +-- src/core/search/sort_indices.cc | 12 ++++- src/server/search/doc_accessors.cc | 55 ++++++++++++++++------ src/server/search/doc_accessors.h | 6 +-- src/server/search/search_family_test.cc | 61 ++++++++++++++++++++++++- tests/dragonfly/search_test.py | 25 ++++++++++ 8 files changed, 175 insertions(+), 39 deletions(-) diff --git a/src/core/search/base.h b/src/core/search/base.h index 7662e79ff..89fe9b74c 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -57,9 +58,11 @@ using ResultScore = std::variant; // Interface for accessing document values with different data structures underneath. struct DocumentAccessor { using VectorInfo = search::OwnedFtVector; + using StringList = absl::InlinedVector; virtual ~DocumentAccessor() = default; - virtual std::string_view GetString(std::string_view active_field) const = 0; + + virtual StringList GetStrings(std::string_view active_field) const = 0; virtual VectorInfo GetVector(std::string_view active_field) const = 0; }; diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index 67a77ee91..2c435612c 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #define UNI_ALGO_DISABLE_NFKC_NFKD @@ -59,15 +60,19 @@ NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} { } void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { - double num; - if (absl::SimpleAtod(doc->GetString(field), &num)) - entries_.emplace(num, id); + for (auto str : doc->GetStrings(field)) { + double num; + if (absl::SimpleAtod(str, &num)) + entries_.emplace(num, id); + } } void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { - int64_t num; - if (absl::SimpleAtoi(doc->GetString(field), &num)) - entries_.erase({num, id}); + for (auto str : doc->GetStrings(field)) { + double num; + if (absl::SimpleAtod(str, &num)) + entries_.erase({num, id}); + } } vector NumericIndex::Range(double l, double r) const { @@ -79,6 +84,7 @@ vector NumericIndex::Range(double l, double r) const { out.push_back(it->second); sort(out.begin(), out.end()); + out.erase(unique(out.begin(), out.end()), out.end()); return out; } @@ -104,17 +110,27 @@ CompressedSortedSet* BaseStringIndex::GetOrCreate(string_view word) { } void BaseStringIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { - for (const auto& word : Tokenize(doc->GetString(field))) - GetOrCreate(word)->Insert(id); + absl::flat_hash_set tokens; + for (string_view str : doc->GetStrings(field)) + tokens.merge(Tokenize(str)); + + for (string_view token : tokens) + GetOrCreate(token)->Insert(id); } void BaseStringIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { - for (const auto& word : Tokenize(doc->GetString(field))) { - if (auto it = entries_.find(word); it != entries_.end()) { - it->second.Remove(id); - if (it->second.Size() == 0) - entries_.erase(it); - } + absl::flat_hash_set tokens; + for (string_view str : doc->GetStrings(field)) + tokens.merge(Tokenize(str)); + + for (const auto& token : tokens) { + auto it = entries_.find(token); + if (it == entries_.end()) + continue; + + it->second.Remove(id); + if (it->second.Size() == 0) + entries_.erase(it); } } diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index c0b384ae2..b74ee149f 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -39,13 +39,13 @@ struct MockedDocument : public DocumentAccessor { MockedDocument(std::string test_field) : fields_{{"field", test_field}} { } - string_view GetString(string_view field) const override { + StringList GetStrings(string_view field) const override { auto it = fields_.find(field); - return it != fields_.end() ? string_view{it->second} : ""; + return {it != fields_.end() ? string_view{it->second} : ""}; } VectorInfo GetVector(string_view field) const override { - return BytesToFtVector(GetString(field)); + return BytesToFtVector(GetStrings(field).front()); } string DebugFormat() { diff --git a/src/core/search/sort_indices.cc b/src/core/search/sort_indices.cc index 771cc5d1e..8f4cc0a6f 100644 --- a/src/core/search/sort_indices.cc +++ b/src/core/search/sort_indices.cc @@ -56,14 +56,22 @@ template struct SimpleValueSortIndex; template struct SimpleValueSortIndex; double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { + auto str = doc->GetStrings(field); + if (str.empty()) + return 0; + double v; - if (!absl::SimpleAtod(doc->GetString(field), &v)) + if (!absl::SimpleAtod(str.front(), &v)) return 0; return v; } PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { - return PMR_NS::string{doc->GetString(field), GetMemRes()}; + auto str = doc->GetStrings(field); + if (str.empty()) + return ""; + + return PMR_NS::string{str.front(), GetMemRes()}; } } // namespace dfly::search diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index 0bd3f0561..bdd6314f8 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -55,17 +55,19 @@ SearchDocData BaseAccessor::Serialize(const search::Schema& schema, for (const auto& [fident, fname] : fields) { auto it = schema.fields.find(fident); auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT; - out[fname] = PrintField(type, GetString(fident)); + out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ",")); } return out; } -string_view ListPackAccessor::GetString(string_view active_field) const { - return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv); +BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const { + auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data()); + return strsv.has_value() ? StringList{*strsv} : StringList{}; } BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const { - return search::BytesToFtVector(GetString(active_field)); + auto strlist = GetStrings(active_field); + return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front()); } SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { @@ -86,13 +88,14 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { return out; } -string_view StringMapAccessor::GetString(string_view active_field) const { +BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const { auto it = hset_->Find(active_field); - return it != hset_->end() ? it->second : ""sv; + return it != hset_->end() ? StringList{it->second} : StringList{}; } BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const { - return search::BytesToFtVector(GetString(active_field)); + auto strlist = GetStrings(active_field); + return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front()); } SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const { @@ -106,13 +109,35 @@ SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const { struct JsonAccessor::JsonPathContainer : public jsoncons::jsonpath::jsonpath_expression { }; -string_view JsonAccessor::GetString(string_view active_field) const { - auto res = GetPath(active_field)->evaluate(json_); - DCHECK(res.is_array()); - if (res.empty()) - return ""; - buf_ = res[0].as_string(); - return buf_; +BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const { + auto path_res = GetPath(active_field)->evaluate(json_); + DCHECK(path_res.is_array()); // json path always returns arrays + + if (path_res.empty()) + return {}; + + if (path_res.size() == 1) { + buf_ = path_res[0].as_string(); + return {buf_}; + } + + // First, grow buffer and compute string sizes + vector sizes; + for (auto element : path_res.array_range()) { + size_t start = buf_.size(); + buf_ += element.as_string(); + sizes.push_back(buf_.size() - start); + } + + // Reposition start pointers to the most recent allocation of buf + StringList out(sizes.size()); + size_t start = 0; + for (size_t i = 0; i < out.size(); i++) { + out[i] = string_view{buf_}.substr(start, sizes[i]); + start += sizes[i]; + } + + return out; } BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const { @@ -156,7 +181,7 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema, const SearchParams::FieldReturnList& fields) const { SearchDocData out{}; for (const auto& [ident, name] : fields) - out[name] = GetString(ident); + out[name] = GetPath(ident)->evaluate(json_).to_string(); return out; } diff --git a/src/server/search/doc_accessors.h b/src/server/search/doc_accessors.h index aed91d489..37f8bf682 100644 --- a/src/server/search/doc_accessors.h +++ b/src/server/search/doc_accessors.h @@ -39,7 +39,7 @@ struct ListPackAccessor : public BaseAccessor { explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} { } - std::string_view GetString(std::string_view field) const override; + StringList GetStrings(std::string_view field) const override; VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; @@ -53,7 +53,7 @@ struct StringMapAccessor : public BaseAccessor { explicit StringMapAccessor(StringMap* hset) : hset_{hset} { } - std::string_view GetString(std::string_view field) const override; + StringList GetStrings(std::string_view field) const override; VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; @@ -68,7 +68,7 @@ struct JsonAccessor : public BaseAccessor { explicit JsonAccessor(const JsonType* json) : json_{*json} { } - std::string_view GetString(std::string_view field) const override; + StringList GetStrings(std::string_view field) const override; VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index d857d2eea..b4f6a2afb 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -181,7 +181,7 @@ TEST_F(SearchFamilyTest, Json) { EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults); } -TEST_F(SearchFamilyTest, AttributesJsonPaths) { +TEST_F(SearchFamilyTest, JsonAttributesPaths) { Run({"json.set", "k1", ".", R"( {"nested": {"value": "no"}} )"}); Run({"json.set", "k2", ".", R"( {"nested": {"value": "yes"}} )"}); Run({"json.set", "k3", ".", R"( {"nested": {"value": "maybe"}} )"}); @@ -193,6 +193,65 @@ TEST_F(SearchFamilyTest, AttributesJsonPaths) { EXPECT_THAT(Run({"ft.search", "i1", "yes"}), AreDocIds("k2")); } +TEST_F(SearchFamilyTest, JsonArrayValues) { + string_view D1 = R"( +{ + "name": "Alex", + "plays" : [ + {"game": "Pacman", "score": 10}, + {"game": "Tetris", "score": 15} + ], + "areas": ["EU-west", "EU-central"] +} +)"; + string_view D2 = R"( +{ + "name": "Bob", + "plays" : [ + {"game": "Pacman", "score": 15}, + {"game": "Mario", "score": 7} + ], + "areas": "US-central" +} +)"; + string_view D3 = R"( +{ + "name": "Caren", + "plays" : [ + {"game": "Mario", "score": 9}, + {"game": "Doom", "score": 20} + ], + "areas": ["EU-central", "EU-east"] +} +)"; + + Run({"json.set", "k1", ".", D1}); + Run({"json.set", "k2", ".", D2}); + Run({"json.set", "k3", ".", D3}); + + auto resp = Run({"ft.create", "i1", "on", "json", "schema", "$.name", "text", "$.plays[*].game", + "as", "games", "tag", "$.plays[*].score", "as", "scores", "numeric", + "$.areas[*]", "as", "areas", "tag"}); + EXPECT_EQ(resp, "OK"); + + EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("k1", "k2", "k3")); + + // Find players by games + EXPECT_THAT(Run({"ft.search", "i1", "@games:{Tetris | Mario | Doom}"}), + AreDocIds("k1", "k2", "k3")); + EXPECT_THAT(Run({"ft.search", "i1", "@games:{Pacman}"}), AreDocIds("k1", "k2")); + EXPECT_THAT(Run({"ft.search", "i1", "@games:{Mario}"}), AreDocIds("k2", "k3")); + + // Find players by scores + EXPECT_THAT(Run({"ft.search", "i1", "@scores:[15 15]"}), AreDocIds("k1", "k2")); + EXPECT_THAT(Run({"ft.search", "i1", "@scores:[0 (10]"}), AreDocIds("k2", "k3")); + EXPECT_THAT(Run({"ft.search", "i1", "@scores:[(15 20]"}), AreDocIds("k3")); + + // Find platers by areas + EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"EU-central\"}"}), AreDocIds("k1", "k3")); + EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"US-central\"}"}), AreDocIds("k2")); +} + TEST_F(SearchFamilyTest, Tags) { Run({"hset", "d:1", "color", "red, green"}); Run({"hset", "d:2", "color", "green, blue"}); diff --git a/tests/dragonfly/search_test.py b/tests/dragonfly/search_test.py index 361c190f2..c0d9b1f10 100644 --- a/tests/dragonfly/search_test.py +++ b/tests/dragonfly/search_test.py @@ -193,6 +193,31 @@ async def test_basic(async_client: aioredis.Redis, index_type): await i1.dropindex() +@dfly_args({"proactor_threads": 4}) +async def test_big_json(async_client: aioredis.Redis): + i1 = async_client.ft("i1") + gen_arr = lambda base: {"blob": [base + str(i) for i in range(100)]} + + await async_client.json().set("k1", "$", gen_arr("alex")) + await async_client.json().set("k2", "$", gen_arr("bob")) + + await i1.create_index( + [TextField(name="$.blob", as_name="items")], + definition=IndexDefinition(index_type=IndexType.JSON), + ) + + res = await i1.search("alex55") + assert res.docs[0].id == "k1" + + res = await i1.search("bob77") + assert res.docs[0].id == "k2" + + res = await i1.search("alex11 | bob22") + assert res.total == 2 + + await i1.dropindex() + + async def knn_query(idx, query, vector): params = {"vec": np.array(vector, dtype=np.float32).tobytes()} result = await idx.search(query, params)