fix(search_family): Support boolean and nullable types in indexes (#4314)

* fix(search_family): Support boolean and nullable types in indexes

fixes dragonflydb#4107, dragonflydb#4129

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

* refactor: address comments

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>

---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-12-24 10:52:39 +04:00 committed by GitHub
parent 01f24da2b6
commit 3c7e31240f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 191 additions and 76 deletions

View file

@ -80,6 +80,9 @@ struct DocumentAccessor {
/* Return nullopt if the specified field is not a list of doubles */
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
/* Same as GetStrings, but also supports boolean values */
virtual std::optional<StringList> GetTags(std::string_view active_field) const = 0;
};
// Base class for type-specific indices.

View file

@ -143,7 +143,7 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
template <typename C>
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field);
auto strings_list = GetStrings(doc, field);
if (!strings_list) {
return false;
}
@ -159,7 +159,7 @@ bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view
template <typename C>
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field).value();
auto strings_list = GetStrings(doc, field).value();
absl::flat_hash_set<std::string> tokens;
for (string_view str : strings_list)
@ -188,10 +188,20 @@ template <typename C> vector<string> BaseStringIndex<C>::GetTerms() const {
template struct BaseStringIndex<CompressedSortedSet>;
template struct BaseStringIndex<SortedVector>;
std::optional<DocumentAccessor::StringList> TextIndex::GetStrings(const DocumentAccessor& doc,
std::string_view field) const {
return doc.GetStrings(field);
}
absl::flat_hash_set<std::string> TextIndex::Tokenize(std::string_view value) const {
return TokenizeWords(value, *stopwords_);
}
std::optional<DocumentAccessor::StringList> TagIndex::GetStrings(const DocumentAccessor& doc,
std::string_view field) const {
return doc.GetTags(field);
}
absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) const {
return NormalizeTags(value, case_sensitive_, separator_);
}

View file

@ -47,9 +47,6 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
// Used by Add & Remove to tokenize text value
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
// Pointer is valid as long as index is not mutated. Nullptr if not found
const Container* Matching(std::string_view str) const;
@ -60,6 +57,15 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
std::vector<std::string> GetTerms() const;
protected:
using StringList = DocumentAccessor::StringList;
// Used by Add & Remove to get strings from document
virtual std::optional<StringList> GetStrings(const DocumentAccessor& doc,
std::string_view field) const = 0;
// Used by Add & Remove to tokenize text value
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
Container* GetOrCreate(std::string_view word);
bool case_sensitive_ = false;
@ -75,6 +81,9 @@ struct TextIndex : public BaseStringIndex<CompressedSortedSet> {
: BaseStringIndex(mr, false), stopwords_{stopwords} {
}
protected:
std::optional<StringList> GetStrings(const DocumentAccessor& doc,
std::string_view field) const override;
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
private:
@ -88,6 +97,9 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
: BaseStringIndex(mr, params.case_sensitive), separator_{params.separator} {
}
protected:
std::optional<StringList> GetStrings(const DocumentAccessor& doc,
std::string_view field) const override;
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
private:

View file

@ -52,6 +52,10 @@ struct MockedDocument : public DocumentAccessor {
return StringList{string_view{it->second}};
}
std::optional<StringList> GetTags(string_view field) const override {
return GetStrings(field);
}
std::optional<VectorInfo> GetVector(string_view field) const override {
auto strings_list = GetStrings(field);
if (!strings_list)

View file

@ -10,7 +10,9 @@
#include <absl/strings/str_split.h>
#include <algorithm>
#include <optional>
#include <type_traits>
#include <variant>
namespace dfly::search {
@ -18,11 +20,23 @@ using namespace std;
namespace {} // namespace
template <typename T> bool SimpleValueSortIndex<T>::ParsedSortValue::HasValue() const {
return !std::holds_alternative<std::monostate>(value);
}
template <typename T> bool SimpleValueSortIndex<T>::ParsedSortValue::IsNullValue() const {
return std::holds_alternative<std::nullopt_t>(value);
}
template <typename T>
SimpleValueSortIndex<T>::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} {
}
template <typename T> SortableValue SimpleValueSortIndex<T>::Lookup(DocId doc) const {
if (null_values_.contains(doc)) {
return std::monostate{};
}
DCHECK_LT(doc, values_.size());
if constexpr (is_same_v<T, PMR_NS::string>) {
return std::string(values_[doc]);
@ -48,21 +62,30 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
template <typename T>
bool SimpleValueSortIndex<T>::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto field_value = Get(doc, field);
if (!field_value) {
if (!field_value.HasValue()) {
return false;
}
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
if (field_value.IsNullValue()) {
null_values_.insert(id);
return true;
}
if (id >= values_.size())
values_.resize(id + 1);
values_[id] = field_value.value();
values_[id] = std::move(std::get<T>(field_value.value));
return true;
}
template <typename T>
void SimpleValueSortIndex<T>::Remove(DocId id, const DocumentAccessor& doc,
std::string_view field) {
if (auto it = null_values_.find(id); it != null_values_.end()) {
null_values_.erase(it);
return;
}
DCHECK_LT(id, values_.size());
values_[id] = T{};
}
@ -74,22 +97,28 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>;
std::optional<double> NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) {
SimpleValueSortIndex<double>::ParsedSortValue NumericSortIndex::Get(const DocumentAccessor& doc,
std::string_view field) {
auto numbers_list = doc.GetNumbers(field);
if (!numbers_list) {
return std::nullopt;
return {};
}
return !numbers_list->empty() ? numbers_list->front() : 0.0;
if (numbers_list->empty()) {
return ParsedSortValue{std::nullopt};
}
return ParsedSortValue{numbers_list->front()};
}
std::optional<PMR_NS::string> StringSortIndex::Get(const DocumentAccessor& doc,
std::string_view field) {
auto strings_list = doc.GetStrings(field);
SimpleValueSortIndex<PMR_NS::string>::ParsedSortValue StringSortIndex::Get(
const DocumentAccessor& doc, std::string_view field) {
auto strings_list = doc.GetTags(field);
if (!strings_list) {
return std::nullopt;
return {};
}
return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()}
: PMR_NS::string{GetMemRes()};
if (strings_list->empty()) {
return ParsedSortValue{std::nullopt};
}
return ParsedSortValue{PMR_NS::string{strings_list->front(), GetMemRes()}};
}
} // namespace dfly::search

View file

@ -18,7 +18,19 @@
namespace dfly::search {
template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
template <typename T> struct SimpleValueSortIndex : public BaseSortIndex {
protected:
struct ParsedSortValue {
bool HasValue() const;
bool IsNullValue() const;
// std::monostate - no value was found.
// std::nullopt - found value is null.
// T - found value.
std::variant<std::monostate, std::nullopt_t, T> value;
};
public:
SimpleValueSortIndex(PMR_NS::memory_resource* mr);
SortableValue Lookup(DocId doc) const override;
@ -28,25 +40,26 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
protected:
virtual std::optional<T> Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
virtual ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
PMR_NS::memory_resource* GetMemRes() const;
private:
PMR_NS::vector<T> values_;
absl::flat_hash_set<DocId> null_values_;
};
struct NumericSortIndex : public SimpleValueSortIndex<double> {
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
std::optional<double> Get(const DocumentAccessor& doc, std::string_view field) override;
ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override;
};
// TODO: Map tags to integers for fast sort
struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> {
StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
std::optional<PMR_NS::string> Get(const DocumentAccessor& doc, std::string_view field) override;
ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override;
};
} // namespace dfly::search

View file

@ -9,6 +9,7 @@
#include "server/search/doc_accessors.h"
#include <absl/functional/any_invocable.h>
#include <absl/strings/str_cat.h>
#include <absl/strings/str_join.h>
@ -72,10 +73,31 @@ FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, s
FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
const JsonType& json) {
if (json.is_null()) {
return std::monostate{};
}
auto json_as_string = json.to_string();
return ExtractSortableValue(schema, key, json_as_string);
}
/* Returns true if json elements were successfully processed. */
template <typename Callback>
bool ProcessJsonElements(const std::vector<JsonType>& json_elements, Callback&& cb) {
auto process = [&cb](const auto& json_range) -> bool {
for (const auto& json : json_range) {
if (!json.is_null() && !cb(json)) {
return false;
}
}
return true;
};
if (!json_elements[0].is_array()) {
return process(json_elements);
}
return json_elements.size() == 1 && process(json_elements[0].array_range());
}
} // namespace
SearchDocData BaseAccessor::Serialize(const search::Schema& schema,
@ -127,6 +149,10 @@ std::optional<BaseAccessor::NumsList> BaseAccessor::GetNumbers(
return nums_list;
}
std::optional<BaseAccessor::StringList> BaseAccessor::GetTags(std::string_view active_field) const {
return GetStrings(active_field);
}
std::optional<BaseAccessor::StringList> ListPackAccessor::GetStrings(
string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
@ -192,8 +218,17 @@ struct JsonAccessor::JsonPathContainer {
variant<json::Path, jsoncons::jsonpath::jsonpath_expression<JsonType>> val;
};
std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view active_field) const {
auto* path = GetPath(active_field);
std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(std::string_view field) const {
return GetStrings(field, false);
}
std::optional<BaseAccessor::StringList> JsonAccessor::GetTags(std::string_view active_field) const {
return GetStrings(active_field, true);
}
std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(std::string_view field,
bool accept_boolean_values) const {
auto* path = GetPath(field);
if (!path)
return search::EmptyAccessResult<StringList>();
@ -201,8 +236,18 @@ std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view act
if (path_res.empty())
return search::EmptyAccessResult<StringList>();
auto is_convertible_to_string = [](bool accept_boolean_values) -> bool (*)(const JsonType& json) {
if (accept_boolean_values) {
return [](const JsonType& json) -> bool { return json.is_string() || json.is_bool(); };
} else {
return [](const JsonType& json) -> bool { return json.is_string(); };
}
}(accept_boolean_values);
if (path_res.size() == 1 && !path_res[0].is_array()) {
if (!path_res[0].is_string())
if (path_res[0].is_null())
return StringList{};
if (!is_convertible_to_string(path_res[0]))
return std::nullopt;
buf_ = path_res[0].as_string();
@ -213,33 +258,21 @@ std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view act
// First, grow buffer and compute string sizes
vector<size_t> sizes;
sizes.reserve(path_res.size());
// Returns true if json element is convertiable to string
auto add_json_element_to_buf = [&](const JsonType& json) -> bool {
if (!is_convertible_to_string(json))
return false;
auto add_json_to_buf = [&](const JsonType& json) {
size_t start = buf_.size();
buf_ += json.as_string();
sizes.push_back(buf_.size() - start);
return true;
};
if (!path_res[0].is_array()) {
sizes.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
sizes.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
if (!ProcessJsonElements(path_res, std::move(add_json_element_to_buf))) {
return std::nullopt;
}
// Reposition start pointers to the most recent allocation of buf
@ -260,7 +293,7 @@ std::optional<BaseAccessor::VectorInfo> JsonAccessor::GetVector(string_view acti
return VectorInfo{};
auto res = path->Evaluate(json_);
if (res.empty())
if (res.empty() || res[0].is_null())
return VectorInfo{};
if (!res[0].is_array())
@ -290,24 +323,18 @@ std::optional<BaseAccessor::NumsList> JsonAccessor::GetNumbers(string_view activ
return search::EmptyAccessResult<NumsList>();
NumsList nums_list;
if (!path_res[0].is_array()) {
nums_list.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
nums_list.reserve(path_res.size());
nums_list.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
// Returns true if json element is convertiable to number
auto add_json_element = [&](const JsonType& json) -> bool {
if (!json.is_number())
return false;
nums_list.push_back(json.as<double>());
return true;
};
if (!ProcessJsonElements(path_res, std::move(add_json_element))) {
return std::nullopt;
}
return nums_list;
}

View file

@ -40,8 +40,9 @@ struct BaseAccessor : public search::DocumentAccessor {
virtual SearchDocData SerializeDocument(const search::Schema& schema) const;
// Default implementation uses GetStrings
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const;
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const;
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const override;
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const override;
virtual std::optional<StringList> GetTags(std::string_view active_field) const override;
};
// Accessor for hashes stored with listpack
@ -81,6 +82,7 @@ struct JsonAccessor : public BaseAccessor {
std::optional<StringList> GetStrings(std::string_view field) const override;
std::optional<VectorInfo> GetVector(std::string_view field) const override;
std::optional<NumsList> GetNumbers(std::string_view active_field) const override;
std::optional<StringList> GetTags(std::string_view active_field) const override;
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema,
@ -91,6 +93,9 @@ struct JsonAccessor : public BaseAccessor {
static void RemoveFieldFromCache(std::string_view field);
private:
/* If accept_boolean_values is true, then json boolean values are converted to strings */
std::optional<StringList> GetStrings(std::string_view field, bool accept_boolean_values) const;
/// Parses `field` into a JSON path. Caches the results internally.
JsonPathContainer* GetPath(std::string_view field) const;

View file

@ -1298,9 +1298,15 @@ TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) {
Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"});
Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"});
Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"});
Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"});
Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"});
Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"});
Run({"JSON.SET", "j4", ".",
R"({"data":["first", "second", "third"],"name":"doc_with_strings"})"});
Run({"JSON.SET", "j5", ".", R"({"name":"no_data"})"});
Run({"JSON.SET", "j6", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"});
Run({"JSON.SET", "j7", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"});
Run({"JSON.SET", "j8", ".", R"({"data":null,"name":"doc_with_null"})"});
Run({"JSON.SET", "j9", ".", R"({"data":[null, null, null],"name":"doc_with_nulls"})"});
Run({"JSON.SET", "j10", ".", R"({"data":true,"name":"doc_with_boolean"})"});
Run({"JSON.SET", "j11", ".", R"({"data":[true, false, true],"name":"doc_with_booleans"})"});
auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"});
EXPECT_EQ(resp, "OK");
@ -1328,25 +1334,25 @@ TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) {
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
EXPECT_THAT(resp, AreDocIds("j1", "j5", "j6", "j8", "j9"));
resp = Run({"FT.SEARCH", "i2", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
EXPECT_THAT(resp, AreDocIds("j1", "j5", "j6", "j8", "j9"));
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9", "j10", "j11"));
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9", "j10", "j11"));
resp = Run({"FT.SEARCH", "i5", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9"));
resp = Run({"FT.SEARCH", "i6", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9"));
resp = Run({"FT.SEARCH", "i7", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j5"));
EXPECT_THAT(resp, AreDocIds("j5", "j6", "j8"));
}
TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) {
@ -1417,6 +1423,12 @@ TEST_F(SearchFamilyTest, WrongVectorFieldType) {
Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"});
Run({"JSON.SET", "j7", ".",
R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"});
Run({"JSON.SET", "j8", ".", R"({"vector_field":null, "name": "doc_with_null"})"});
Run({"JSON.SET", "j9", ".", R"({"vector_field":[null, null, null], "name": "doc_with_nulls"})"});
Run({"JSON.SET", "j10", ".", R"({"vector_field":true, "name": "doc_with_boolean"})"});
Run({"JSON.SET", "j11", ".",
R"({"vector_field":[true, false, true], "name": "doc_with_booleans"})"});
Run({"JSON.SET", "j12", ".", R"({"vector_field":1, "name": "doc_with_int"})"});
auto resp =
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field",
@ -1424,7 +1436,7 @@ TEST_F(SearchFamilyTest, WrongVectorFieldType) {
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "index", "*"});
EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4"));
EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4", "j8"));
}
#ifndef SANITIZERS