mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
fix(search): Support indexing array paths (#2074)
* fix(search): Support indexing array paths Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io> --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
47d92fb010
commit
04cd2ff3f9
8 changed files with 175 additions and 39 deletions
|
@ -5,6 +5,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <absl/container/flat_hash_map.h>
|
||||
#include <absl/container/inlined_vector.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
@ -57,9 +58,11 @@ using ResultScore = std::variant<std::monostate, float, double, WrappedStrPtr>;
|
|||
// Interface for accessing document values with different data structures underneath.
|
||||
struct DocumentAccessor {
|
||||
using VectorInfo = search::OwnedFtVector;
|
||||
using StringList = absl::InlinedVector<std::string_view, 1>;
|
||||
|
||||
virtual ~DocumentAccessor() = default;
|
||||
virtual std::string_view GetString(std::string_view active_field) const = 0;
|
||||
|
||||
virtual StringList GetStrings(std::string_view active_field) const = 0;
|
||||
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
|
||||
};
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include <absl/container/flat_hash_set.h>
|
||||
#include <absl/strings/ascii.h>
|
||||
#include <absl/strings/numbers.h>
|
||||
#include <absl/strings/str_join.h>
|
||||
#include <absl/strings/str_split.h>
|
||||
|
||||
#define UNI_ALGO_DISABLE_NFKC_NFKD
|
||||
|
@ -59,15 +60,19 @@ NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
|
|||
}
|
||||
|
||||
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
double num;
|
||||
if (absl::SimpleAtod(doc->GetString(field), &num))
|
||||
entries_.emplace(num, id);
|
||||
for (auto str : doc->GetStrings(field)) {
|
||||
double num;
|
||||
if (absl::SimpleAtod(str, &num))
|
||||
entries_.emplace(num, id);
|
||||
}
|
||||
}
|
||||
|
||||
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
int64_t num;
|
||||
if (absl::SimpleAtoi(doc->GetString(field), &num))
|
||||
entries_.erase({num, id});
|
||||
for (auto str : doc->GetStrings(field)) {
|
||||
double num;
|
||||
if (absl::SimpleAtod(str, &num))
|
||||
entries_.erase({num, id});
|
||||
}
|
||||
}
|
||||
|
||||
vector<DocId> NumericIndex::Range(double l, double r) const {
|
||||
|
@ -79,6 +84,7 @@ vector<DocId> NumericIndex::Range(double l, double r) const {
|
|||
out.push_back(it->second);
|
||||
|
||||
sort(out.begin(), out.end());
|
||||
out.erase(unique(out.begin(), out.end()), out.end());
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -104,17 +110,27 @@ CompressedSortedSet* BaseStringIndex::GetOrCreate(string_view word) {
|
|||
}
|
||||
|
||||
void BaseStringIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
for (const auto& word : Tokenize(doc->GetString(field)))
|
||||
GetOrCreate(word)->Insert(id);
|
||||
absl::flat_hash_set<std::string> tokens;
|
||||
for (string_view str : doc->GetStrings(field))
|
||||
tokens.merge(Tokenize(str));
|
||||
|
||||
for (string_view token : tokens)
|
||||
GetOrCreate(token)->Insert(id);
|
||||
}
|
||||
|
||||
void BaseStringIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
for (const auto& word : Tokenize(doc->GetString(field))) {
|
||||
if (auto it = entries_.find(word); it != entries_.end()) {
|
||||
it->second.Remove(id);
|
||||
if (it->second.Size() == 0)
|
||||
entries_.erase(it);
|
||||
}
|
||||
absl::flat_hash_set<std::string> tokens;
|
||||
for (string_view str : doc->GetStrings(field))
|
||||
tokens.merge(Tokenize(str));
|
||||
|
||||
for (const auto& token : tokens) {
|
||||
auto it = entries_.find(token);
|
||||
if (it == entries_.end())
|
||||
continue;
|
||||
|
||||
it->second.Remove(id);
|
||||
if (it->second.Size() == 0)
|
||||
entries_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -39,13 +39,13 @@ struct MockedDocument : public DocumentAccessor {
|
|||
MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
|
||||
}
|
||||
|
||||
string_view GetString(string_view field) const override {
|
||||
StringList GetStrings(string_view field) const override {
|
||||
auto it = fields_.find(field);
|
||||
return it != fields_.end() ? string_view{it->second} : "";
|
||||
return {it != fields_.end() ? string_view{it->second} : ""};
|
||||
}
|
||||
|
||||
VectorInfo GetVector(string_view field) const override {
|
||||
return BytesToFtVector(GetString(field));
|
||||
return BytesToFtVector(GetStrings(field).front());
|
||||
}
|
||||
|
||||
string DebugFormat() {
|
||||
|
|
|
@ -56,14 +56,22 @@ template struct SimpleValueSortIndex<double>;
|
|||
template struct SimpleValueSortIndex<PMR_NS::string>;
|
||||
|
||||
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
auto str = doc->GetStrings(field);
|
||||
if (str.empty())
|
||||
return 0;
|
||||
|
||||
double v;
|
||||
if (!absl::SimpleAtod(doc->GetString(field), &v))
|
||||
if (!absl::SimpleAtod(str.front(), &v))
|
||||
return 0;
|
||||
return v;
|
||||
}
|
||||
|
||||
PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
return PMR_NS::string{doc->GetString(field), GetMemRes()};
|
||||
auto str = doc->GetStrings(field);
|
||||
if (str.empty())
|
||||
return "";
|
||||
|
||||
return PMR_NS::string{str.front(), GetMemRes()};
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -55,17 +55,19 @@ SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
|
|||
for (const auto& [fident, fname] : fields) {
|
||||
auto it = schema.fields.find(fident);
|
||||
auto type = it != schema.fields.end() ? it->second.type : search::SchemaField::TEXT;
|
||||
out[fname] = PrintField(type, GetString(fident));
|
||||
out[fname] = PrintField(type, absl::StrJoin(GetStrings(fident), ","));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
string_view ListPackAccessor::GetString(string_view active_field) const {
|
||||
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
|
||||
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const {
|
||||
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
|
||||
return strsv.has_value() ? StringList{*strsv} : StringList{};
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
|
||||
return search::BytesToFtVector(GetString(active_field));
|
||||
auto strlist = GetStrings(active_field);
|
||||
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
|
||||
}
|
||||
|
||||
SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
|
||||
|
@ -86,13 +88,14 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
|
|||
return out;
|
||||
}
|
||||
|
||||
string_view StringMapAccessor::GetString(string_view active_field) const {
|
||||
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const {
|
||||
auto it = hset_->Find(active_field);
|
||||
return it != hset_->end() ? it->second : ""sv;
|
||||
return it != hset_->end() ? StringList{it->second} : StringList{};
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
|
||||
return search::BytesToFtVector(GetString(active_field));
|
||||
auto strlist = GetStrings(active_field);
|
||||
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
|
||||
}
|
||||
|
||||
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
|
||||
|
@ -106,13 +109,35 @@ SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
|
|||
struct JsonAccessor::JsonPathContainer : public jsoncons::jsonpath::jsonpath_expression<JsonType> {
|
||||
};
|
||||
|
||||
string_view JsonAccessor::GetString(string_view active_field) const {
|
||||
auto res = GetPath(active_field)->evaluate(json_);
|
||||
DCHECK(res.is_array());
|
||||
if (res.empty())
|
||||
return "";
|
||||
buf_ = res[0].as_string();
|
||||
return buf_;
|
||||
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const {
|
||||
auto path_res = GetPath(active_field)->evaluate(json_);
|
||||
DCHECK(path_res.is_array()); // json path always returns arrays
|
||||
|
||||
if (path_res.empty())
|
||||
return {};
|
||||
|
||||
if (path_res.size() == 1) {
|
||||
buf_ = path_res[0].as_string();
|
||||
return {buf_};
|
||||
}
|
||||
|
||||
// First, grow buffer and compute string sizes
|
||||
vector<size_t> sizes;
|
||||
for (auto element : path_res.array_range()) {
|
||||
size_t start = buf_.size();
|
||||
buf_ += element.as_string();
|
||||
sizes.push_back(buf_.size() - start);
|
||||
}
|
||||
|
||||
// Reposition start pointers to the most recent allocation of buf
|
||||
StringList out(sizes.size());
|
||||
size_t start = 0;
|
||||
for (size_t i = 0; i < out.size(); i++) {
|
||||
out[i] = string_view{buf_}.substr(start, sizes[i]);
|
||||
start += sizes[i];
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
|
||||
|
@ -156,7 +181,7 @@ SearchDocData JsonAccessor::Serialize(const search::Schema& schema,
|
|||
const SearchParams::FieldReturnList& fields) const {
|
||||
SearchDocData out{};
|
||||
for (const auto& [ident, name] : fields)
|
||||
out[name] = GetString(ident);
|
||||
out[name] = GetPath(ident)->evaluate(json_).to_string();
|
||||
return out;
|
||||
}
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ struct ListPackAccessor : public BaseAccessor {
|
|||
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
|
||||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
|
@ -53,7 +53,7 @@ struct StringMapAccessor : public BaseAccessor {
|
|||
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
|
||||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
|
@ -68,7 +68,7 @@ struct JsonAccessor : public BaseAccessor {
|
|||
explicit JsonAccessor(const JsonType* json) : json_{*json} {
|
||||
}
|
||||
|
||||
std::string_view GetString(std::string_view field) const override;
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
|
|
|
@ -181,7 +181,7 @@ TEST_F(SearchFamilyTest, Json) {
|
|||
EXPECT_THAT(Run({"ft.search", "i1", "@a:small @b:secret"}), kNoResults);
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, AttributesJsonPaths) {
|
||||
TEST_F(SearchFamilyTest, JsonAttributesPaths) {
|
||||
Run({"json.set", "k1", ".", R"( {"nested": {"value": "no"}} )"});
|
||||
Run({"json.set", "k2", ".", R"( {"nested": {"value": "yes"}} )"});
|
||||
Run({"json.set", "k3", ".", R"( {"nested": {"value": "maybe"}} )"});
|
||||
|
@ -193,6 +193,65 @@ TEST_F(SearchFamilyTest, AttributesJsonPaths) {
|
|||
EXPECT_THAT(Run({"ft.search", "i1", "yes"}), AreDocIds("k2"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, JsonArrayValues) {
|
||||
string_view D1 = R"(
|
||||
{
|
||||
"name": "Alex",
|
||||
"plays" : [
|
||||
{"game": "Pacman", "score": 10},
|
||||
{"game": "Tetris", "score": 15}
|
||||
],
|
||||
"areas": ["EU-west", "EU-central"]
|
||||
}
|
||||
)";
|
||||
string_view D2 = R"(
|
||||
{
|
||||
"name": "Bob",
|
||||
"plays" : [
|
||||
{"game": "Pacman", "score": 15},
|
||||
{"game": "Mario", "score": 7}
|
||||
],
|
||||
"areas": "US-central"
|
||||
}
|
||||
)";
|
||||
string_view D3 = R"(
|
||||
{
|
||||
"name": "Caren",
|
||||
"plays" : [
|
||||
{"game": "Mario", "score": 9},
|
||||
{"game": "Doom", "score": 20}
|
||||
],
|
||||
"areas": ["EU-central", "EU-east"]
|
||||
}
|
||||
)";
|
||||
|
||||
Run({"json.set", "k1", ".", D1});
|
||||
Run({"json.set", "k2", ".", D2});
|
||||
Run({"json.set", "k3", ".", D3});
|
||||
|
||||
auto resp = Run({"ft.create", "i1", "on", "json", "schema", "$.name", "text", "$.plays[*].game",
|
||||
"as", "games", "tag", "$.plays[*].score", "as", "scores", "numeric",
|
||||
"$.areas[*]", "as", "areas", "tag"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "*"}), AreDocIds("k1", "k2", "k3"));
|
||||
|
||||
// Find players by games
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Tetris | Mario | Doom}"}),
|
||||
AreDocIds("k1", "k2", "k3"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Pacman}"}), AreDocIds("k1", "k2"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@games:{Mario}"}), AreDocIds("k2", "k3"));
|
||||
|
||||
// Find players by scores
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[15 15]"}), AreDocIds("k1", "k2"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[0 (10]"}), AreDocIds("k2", "k3"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@scores:[(15 20]"}), AreDocIds("k3"));
|
||||
|
||||
// Find platers by areas
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"EU-central\"}"}), AreDocIds("k1", "k3"));
|
||||
EXPECT_THAT(Run({"ft.search", "i1", "@areas:{\"US-central\"}"}), AreDocIds("k2"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, Tags) {
|
||||
Run({"hset", "d:1", "color", "red, green"});
|
||||
Run({"hset", "d:2", "color", "green, blue"});
|
||||
|
|
|
@ -193,6 +193,31 @@ async def test_basic(async_client: aioredis.Redis, index_type):
|
|||
await i1.dropindex()
|
||||
|
||||
|
||||
@dfly_args({"proactor_threads": 4})
|
||||
async def test_big_json(async_client: aioredis.Redis):
|
||||
i1 = async_client.ft("i1")
|
||||
gen_arr = lambda base: {"blob": [base + str(i) for i in range(100)]}
|
||||
|
||||
await async_client.json().set("k1", "$", gen_arr("alex"))
|
||||
await async_client.json().set("k2", "$", gen_arr("bob"))
|
||||
|
||||
await i1.create_index(
|
||||
[TextField(name="$.blob", as_name="items")],
|
||||
definition=IndexDefinition(index_type=IndexType.JSON),
|
||||
)
|
||||
|
||||
res = await i1.search("alex55")
|
||||
assert res.docs[0].id == "k1"
|
||||
|
||||
res = await i1.search("bob77")
|
||||
assert res.docs[0].id == "k2"
|
||||
|
||||
res = await i1.search("alex11 | bob22")
|
||||
assert res.total == 2
|
||||
|
||||
await i1.dropindex()
|
||||
|
||||
|
||||
async def knn_query(idx, query, vector):
|
||||
params = {"vec": np.array(vector, dtype=np.float32).tobytes()}
|
||||
result = await idx.search(query, params)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue