fix(search): Support indexing array paths (#2074)

* fix(search): Support indexing array paths

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>


---------

Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
Vladislav 2023-10-29 15:14:23 +03:00 committed by GitHub
parent 47d92fb010
commit 04cd2ff3f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 175 additions and 39 deletions

View file

@ -5,6 +5,7 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/container/inlined_vector.h>
#include <cstdint>
#include <memory>
@ -57,9 +58,11 @@ using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;
// Interface for accessing document values with different data structures underneath.
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
using StringList = absl::InlinedVector<std::string_view, 1>;
virtual ~DocumentAccessor() = default;
virtual std::string_view GetString(std::string_view active_field) const = 0;
virtual StringList GetStrings(std::string_view active_field) const = 0;
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
};

View file

@ -7,6 +7,7 @@
#include <absl/container/flat_hash_set.h>
#include <absl/strings/ascii.h>
#include <absl/strings/numbers.h>
#include <absl/strings/str_join.h>
#include <absl/strings/str_split.h>
#define UNI_ALGO_DISABLE_NFKC_NFKD
@ -59,15 +60,19 @@ NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
}
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
double num;
if (absl::SimpleAtod(doc->GetString(field), &num))
entries_.emplace(num, id);
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.emplace(num, id);
}
}
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
int64_t num;
if (absl::SimpleAtoi(doc->GetString(field), &num))
entries_.erase({num, id});
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.erase({num, id});
}
}
vector<DocId> NumericIndex::Range(double l, double r) const {
@ -79,6 +84,7 @@ vector<DocId> NumericIndex::Range(double l, double r) const {
out.push_back(it->second);
sort(out.begin(), out.end());
out.erase(unique(out.begin(), out.end()), out.end());
return out;
}
@ -104,17 +110,27 @@ CompressedSortedSet* BaseStringIndex::GetOrCreate(string_view word) {
}
void BaseStringIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
for (const auto& word : Tokenize(doc->GetString(field)))
GetOrCreate(word)->Insert(id);
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
tokens.merge(Tokenize(str));
for (string_view token : tokens)
GetOrCreate(token)->Insert(id);
}
void BaseStringIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
for (const auto& word : Tokenize(doc->GetString(field))) {
if (auto it = entries_.find(word); it != entries_.end()) {
it->second.Remove(id);
if (it->second.Size() == 0)
entries_.erase(it);
}
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
tokens.merge(Tokenize(str));
for (const auto& token : tokens) {
auto it = entries_.find(token);
if (it == entries_.end())
continue;
it->second.Remove(id);
if (it->second.Size() == 0)
entries_.erase(it);
}
}

View file

@ -39,13 +39,13 @@ struct MockedDocument : public DocumentAccessor {
MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
}
string_view GetString(string_view field) const override {
StringList GetStrings(string_view field) const override {
auto it = fields_.find(field);
return it != fields_.end() ? string_view{it->second} : "";
return {it != fields_.end() ? string_view{it->second} : ""};
}
VectorInfo GetVector(string_view field) const override {
return BytesToFtVector(GetString(field));
return BytesToFtVector(GetStrings(field).front());
}
string DebugFormat() {

View file

@ -56,14 +56,22 @@ template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>;
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
auto str = doc->GetStrings(field);
if (str.empty())
return 0;
double v;
if (!absl::SimpleAtod(doc->GetString(field), &v))
if (!absl::SimpleAtod(str.front(), &v))
return 0;
return v;
}
PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
return PMR_NS::string{doc->GetString(field), GetMemRes()};
auto str = doc->GetStrings(field);
if (str.empty())
return "";
return PMR_NS::string{str.front(), GetMemRes()};
}
} // namespace dfly::search

View file

@ -55,17 +55,19 @@ SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
for (const auto& [fident, fname] : fields) {
auto it = schema.fields.find(fident);
auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT;
out[fname] = PrintField(type, GetString(fident));
out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ","));
}
return out;
}
string_view ListPackAccessor::GetString(string_view active_field) const {
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
return strsv.has_value() ? StringList{*strsv} : StringList{};
}
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}
SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
@ -86,13 +88,14 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
return out;
}
string_view StringMapAccessor::GetString(string_view active_field) const {
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const {
auto it = hset_->Find(active_field);
return it != hset_->end() ? it->second : ""sv;
return it != hset_->end() ? StringList{it->second} : StringList{};
}
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
return search::BytesToFtVector(GetString(active_field));
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
@ -106,13 +109,35 @@ SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
struct JsonAccessor::JsonPathContainer : public jsoncons::jsonpath::jsonpath_expression<JsonType> {
};
string_view JsonAccessor::GetString(string_view active_field) const {
auto res = GetPath(active_field)->evaluate(json_);
DCHECK(res.is_array());
if (res.empty())
return "";
buf_ = res[0].as_string();
return buf_;
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const {
auto path_res = GetPath(active_field)->evaluate(json_);
DCHECK(path_res.is_array()); // json path always returns arrays
if (path_res.empty())
return {};
if (path_res.size() == 1) {
buf_ = path_res[0].as_string();
return {buf_};
}
// First, grow buffer and compute string sizes
vector<size_t> sizes;
for (auto element : path_res.array_range()) {
size_t start = buf_.size();
buf_ += element.as_string();
sizes.push_back(buf_.size() - start);
}
// Reposition start pointers to the most recent allocation of buf
StringList out(sizes.size());
size_t start = 0;
for (size_t i = 0; i < out.size(); i++) {
out[i] = string_view{buf_}.substr(start, sizes[i]);
start += sizes[i];
}
return out;
}
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
@ -156,7 +181,7 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema,
const SearchParams::FieldReturnList& fields) const {
SearchDocData out{};
for (const auto& [ident, name] : fields)
out[name] = GetString(ident);
out[name] = GetPath(ident)->evaluate(json_).to_string();
return out;
}

View file

@ -39,7 +39,7 @@ struct ListPackAccessor : public BaseAccessor {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
}
std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
@ -53,7 +53,7 @@ struct StringMapAccessor : public BaseAccessor {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
}
std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
@ -68,7 +68,7 @@ struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(const JsonType* json) : json_{*json} {
}
std::string_view GetString(std::string_view field) const override;
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;

View file

@ -181,7 +181,7 @@ TEST_F(SearchFamilyTest, Json) {
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
}
TEST_F(SearchFamilyTest, AttributesJsonPaths) {
TEST_F(SearchFamilyTest, JsonAttributesPaths) {
Run({"json.set", "k1", ".", R"( {"nested": {"value": "no"}} )"});
Run({"json.set", "k2", ".", R"( {"nested": {"value": "yes"}} )"});
Run({"json.set", "k3", ".", R"( {"nested": {"value": "maybe"}} )"});
@ -193,6 +193,65 @@ TEST_F(SearchFamilyTest, AttributesJsonPaths) {
EXPECT_THAT(Run({"ft.search", "i1", "yes"}), AreDocIds("k2"));
}
TEST_F(SearchFamilyTest, JsonArrayValues) {
string_view D1 = R"(
{
"name": "Alex",
"plays" : [
{"game": "Pacman", "score": 10},
{"game": "Tetris", "score": 15}
],
"areas": ["EU-west", "EU-central"]
}
)";
string_view D2 = R"(
{
"name": "Bob",
"plays" : [
{"game": "Pacman", "score": 15},
{"game": "Mario", "score": 7}
],
"areas": "US-central"
}
)";
string_view D3 = R"(
{
"name": "Caren",
"plays" : [
{"game": "Mario", "score": 9},
{"game": "Doom", "score": 20}
],
"areas": ["EU-central", "EU-east"]
}
)";
Run({"json.set", "k1", ".", D1});
Run({"json.set", "k2", ".", D2});
Run({"json.set", "k3", ".", D3});
auto resp = Run({"ft.create", "i1", "on", "json", "schema", "$.name", "text", "$.plays[*].game",
"as", "games", "tag", "$.plays[*].score", "as", "scores", "numeric",
"$.areas[*]", "as", "areas", "tag"});
EXPECT_EQ(resp, "OK");
EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("k1", "k2", "k3"));
// Find players by games
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Tetris | Mario | Doom}"}),
AreDocIds("k1", "k2", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Pacman}"}), AreDocIds("k1", "k2"));
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Mario}"}), AreDocIds("k2", "k3"));
// Find players by scores
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[15 15]"}), AreDocIds("k1", "k2"));
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[0 (10]"}), AreDocIds("k2", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[(15 20]"}), AreDocIds("k3"));
// Find platers by areas
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"EU-central\"}"}), AreDocIds("k1", "k3"));
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"US-central\"}"}), AreDocIds("k2"));
}
TEST_F(SearchFamilyTest, Tags) {
Run({"hset", "d:1", "color", "red, green"});
Run({"hset", "d:2", "color", "green, blue"});

View file

@ -193,6 +193,31 @@ async def test_basic(async_client: aioredis.Redis, index_type):
await i1.dropindex()
@dfly_args({"proactor_threads": 4})
async def test_big_json(async_client: aioredis.Redis):
i1 = async_client.ft("i1")
gen_arr = lambda base: {"blob": [base + str(i) for i in range(100)]}
await async_client.json().set("k1", "$", gen_arr("alex"))
await async_client.json().set("k2", "$", gen_arr("bob"))
await i1.create_index(
[TextField(name="$.blob", as_name="items")],
definition=IndexDefinition(index_type=IndexType.JSON),
)
res = await i1.search("alex55")
assert res.docs[0].id == "k1"
res = await i1.search("bob77")
assert res.docs[0].id == "k2"
res = await i1.search("alex11 | bob22")
assert res.total == 2
await i1.dropindex()
async def knn_query(idx, query, vector):
params = {"vec": np.array(vector, dtype=np.float32).tobytes()}
result = await idx.search(query, params)