diff --git a/src/core/search/base.h b/src/core/search/base.h index 9ff30472f..d81008f0a 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -80,6 +80,9 @@ struct DocumentAccessor { /* Return nullopt if the specified field is not a list of doubles */ virtual std::optional GetNumbers(std::string_view active_field) const = 0; + + /* Same as GetStrings, but also supports boolean values */ + virtual std::optional GetTags(std::string_view active_field) const = 0; }; // Base class for type-specific indices. diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index 9d125a5a8..82805c11f 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -143,7 +143,7 @@ typename BaseStringIndex::Container* BaseStringIndex::GetOrCreate(string_v template bool BaseStringIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) { - auto strings_list = doc.GetStrings(field); + auto strings_list = GetStrings(doc, field); if (!strings_list) { return false; } @@ -159,7 +159,7 @@ bool BaseStringIndex::Add(DocId id, const DocumentAccessor& doc, string_view template void BaseStringIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) { - auto strings_list = doc.GetStrings(field).value(); + auto strings_list = GetStrings(doc, field).value(); absl::flat_hash_set tokens; for (string_view str : strings_list) @@ -188,10 +188,20 @@ template vector BaseStringIndex::GetTerms() const { template struct BaseStringIndex; template struct BaseStringIndex; +std::optional TextIndex::GetStrings(const DocumentAccessor& doc, + std::string_view field) const { + return doc.GetStrings(field); +} + absl::flat_hash_set TextIndex::Tokenize(std::string_view value) const { return TokenizeWords(value, *stopwords_); } +std::optional TagIndex::GetStrings(const DocumentAccessor& doc, + std::string_view field) const { + return doc.GetTags(field); +} + absl::flat_hash_set TagIndex::Tokenize(std::string_view value) const { return NormalizeTags(value, case_sensitive_, separator_); } diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 0058e2043..7cd22ccc4 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -47,9 +47,6 @@ template struct BaseStringIndex : public BaseIndex { bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; - // Used by Add & Remove to tokenize text value - virtual absl::flat_hash_set Tokenize(std::string_view value) const = 0; - // Pointer is valid as long as index is not mutated. Nullptr if not found const Container* Matching(std::string_view str) const; @@ -60,6 +57,15 @@ template struct BaseStringIndex : public BaseIndex { std::vector GetTerms() const; protected: + using StringList = DocumentAccessor::StringList; + + // Used by Add & Remove to get strings from document + virtual std::optional GetStrings(const DocumentAccessor& doc, + std::string_view field) const = 0; + + // Used by Add & Remove to tokenize text value + virtual absl::flat_hash_set Tokenize(std::string_view value) const = 0; + Container* GetOrCreate(std::string_view word); bool case_sensitive_ = false; @@ -75,6 +81,9 @@ struct TextIndex : public BaseStringIndex { : BaseStringIndex(mr, false), stopwords_{stopwords} { } + protected: + std::optional GetStrings(const DocumentAccessor& doc, + std::string_view field) const override; absl::flat_hash_set Tokenize(std::string_view value) const override; private: @@ -88,6 +97,9 @@ struct TagIndex : public BaseStringIndex { : BaseStringIndex(mr, params.case_sensitive), separator_{params.separator} { } + protected: + std::optional GetStrings(const DocumentAccessor& doc, + std::string_view field) const override; absl::flat_hash_set Tokenize(std::string_view value) const override; private: diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index 37752ebdb..886546013 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -52,6 +52,10 @@ struct MockedDocument : public DocumentAccessor { return StringList{string_view{it->second}}; } + std::optional GetTags(string_view field) const override { + return GetStrings(field); + } + std::optional GetVector(string_view field) const override { auto strings_list = GetStrings(field); if (!strings_list) diff --git a/src/core/search/sort_indices.cc b/src/core/search/sort_indices.cc index 2eb2c4aa3..3378754e8 100644 --- a/src/core/search/sort_indices.cc +++ b/src/core/search/sort_indices.cc @@ -10,7 +10,9 @@ #include #include +#include #include +#include namespace dfly::search { @@ -18,11 +20,23 @@ using namespace std; namespace {} // namespace +template bool SimpleValueSortIndex::ParsedSortValue::HasValue() const { + return !std::holds_alternative(value); +} + +template bool SimpleValueSortIndex::ParsedSortValue::IsNullValue() const { + return std::holds_alternative(value); +} + template SimpleValueSortIndex::SimpleValueSortIndex(PMR_NS::memory_resource* mr) : values_{mr} { } template SortableValue SimpleValueSortIndex::Lookup(DocId doc) const { + if (null_values_.contains(doc)) { + return std::monostate{}; + } + DCHECK_LT(doc, values_.size()); if constexpr (is_same_v) { return std::string(values_[doc]); @@ -48,21 +62,30 @@ std::vector SimpleValueSortIndex::Sort(std::vector* ids, template bool SimpleValueSortIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) { auto field_value = Get(doc, field); - if (!field_value) { + if (!field_value.HasValue()) { return false; } - DCHECK_LE(id, values_.size()); // Doc ids grow at most by one + if (field_value.IsNullValue()) { + null_values_.insert(id); + return true; + } + if (id >= values_.size()) values_.resize(id + 1); - values_[id] = field_value.value(); + values_[id] = std::move(std::get(field_value.value)); return true; } template void SimpleValueSortIndex::Remove(DocId id, const DocumentAccessor& doc, std::string_view field) { + if (auto it = null_values_.find(id); it != null_values_.end()) { + null_values_.erase(it); + return; + } + DCHECK_LT(id, values_.size()); values_[id] = T{}; } @@ -74,22 +97,28 @@ template PMR_NS::memory_resource* SimpleValueSortIndex::GetMemRe template struct SimpleValueSortIndex; template struct SimpleValueSortIndex; -std::optional NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) { +SimpleValueSortIndex::ParsedSortValue NumericSortIndex::Get(const DocumentAccessor& doc, + std::string_view field) { auto numbers_list = doc.GetNumbers(field); if (!numbers_list) { - return std::nullopt; + return {}; } - return !numbers_list->empty() ? numbers_list->front() : 0.0; + if (numbers_list->empty()) { + return ParsedSortValue{std::nullopt}; + } + return ParsedSortValue{numbers_list->front()}; } -std::optional StringSortIndex::Get(const DocumentAccessor& doc, - std::string_view field) { - auto strings_list = doc.GetStrings(field); +SimpleValueSortIndex::ParsedSortValue StringSortIndex::Get( + const DocumentAccessor& doc, std::string_view field) { + auto strings_list = doc.GetTags(field); if (!strings_list) { - return std::nullopt; + return {}; } - return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()} - : PMR_NS::string{GetMemRes()}; + if (strings_list->empty()) { + return ParsedSortValue{std::nullopt}; + } + return ParsedSortValue{PMR_NS::string{strings_list->front(), GetMemRes()}}; } } // namespace dfly::search diff --git a/src/core/search/sort_indices.h b/src/core/search/sort_indices.h index bdffc1a0f..b347ea29e 100644 --- a/src/core/search/sort_indices.h +++ b/src/core/search/sort_indices.h @@ -18,7 +18,19 @@ namespace dfly::search { -template struct SimpleValueSortIndex : BaseSortIndex { +template struct SimpleValueSortIndex : public BaseSortIndex { + protected: + struct ParsedSortValue { + bool HasValue() const; + bool IsNullValue() const; + + // std::monostate - no value was found. + // std::nullopt - found value is null. + // T - found value. + std::variant value; + }; + + public: SimpleValueSortIndex(PMR_NS::memory_resource* mr); SortableValue Lookup(DocId doc) const override; @@ -28,25 +40,26 @@ template struct SimpleValueSortIndex : BaseSortIndex { void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override; protected: - virtual std::optional Get(const DocumentAccessor& doc, std::string_view field_value) = 0; + virtual ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field_value) = 0; PMR_NS::memory_resource* GetMemRes() const; private: PMR_NS::vector values_; + absl::flat_hash_set null_values_; }; struct NumericSortIndex : public SimpleValueSortIndex { NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; - std::optional Get(const DocumentAccessor& doc, std::string_view field) override; + ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override; }; // TODO: Map tags to integers for fast sort struct StringSortIndex : public SimpleValueSortIndex { StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; - std::optional Get(const DocumentAccessor& doc, std::string_view field) override; + ParsedSortValue Get(const DocumentAccessor& doc, std::string_view field) override; }; } // namespace dfly::search diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index 1e2c01c83..d82f7b8bf 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -9,6 +9,7 @@ #include "server/search/doc_accessors.h" +#include #include #include @@ -72,10 +73,31 @@ FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, s FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key, const JsonType& json) { + if (json.is_null()) { + return std::monostate{}; + } auto json_as_string = json.to_string(); return ExtractSortableValue(schema, key, json_as_string); } +/* Returns true if json elements were successfully processed. */ +template +bool ProcessJsonElements(const std::vector& json_elements, Callback&& cb) { + auto process = [&cb](const auto& json_range) -> bool { + for (const auto& json : json_range) { + if (!json.is_null() && !cb(json)) { + return false; + } + } + return true; + }; + + if (!json_elements[0].is_array()) { + return process(json_elements); + } + return json_elements.size() == 1 && process(json_elements[0].array_range()); +} + } // namespace SearchDocData BaseAccessor::Serialize(const search::Schema& schema, @@ -127,6 +149,10 @@ std::optional BaseAccessor::GetNumbers( return nums_list; } +std::optional BaseAccessor::GetTags(std::string_view active_field) const { + return GetStrings(active_field); +} + std::optional ListPackAccessor::GetStrings( string_view active_field) const { auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data()); @@ -192,8 +218,17 @@ struct JsonAccessor::JsonPathContainer { variant> val; }; -std::optional JsonAccessor::GetStrings(string_view active_field) const { - auto* path = GetPath(active_field); +std::optional JsonAccessor::GetStrings(std::string_view field) const { + return GetStrings(field, false); +} + +std::optional JsonAccessor::GetTags(std::string_view active_field) const { + return GetStrings(active_field, true); +} + +std::optional JsonAccessor::GetStrings(std::string_view field, + bool accept_boolean_values) const { + auto* path = GetPath(field); if (!path) return search::EmptyAccessResult(); @@ -201,8 +236,18 @@ std::optional JsonAccessor::GetStrings(string_view act if (path_res.empty()) return search::EmptyAccessResult(); + auto is_convertible_to_string = [](bool accept_boolean_values) -> bool (*)(const JsonType& json) { + if (accept_boolean_values) { + return [](const JsonType& json) -> bool { return json.is_string() || json.is_bool(); }; + } else { + return [](const JsonType& json) -> bool { return json.is_string(); }; + } + }(accept_boolean_values); + if (path_res.size() == 1 && !path_res[0].is_array()) { - if (!path_res[0].is_string()) + if (path_res[0].is_null()) + return StringList{}; + if (!is_convertible_to_string(path_res[0])) return std::nullopt; buf_ = path_res[0].as_string(); @@ -213,33 +258,21 @@ std::optional JsonAccessor::GetStrings(string_view act // First, grow buffer and compute string sizes vector sizes; + sizes.reserve(path_res.size()); + + // Returns true if json element is convertiable to string + auto add_json_element_to_buf = [&](const JsonType& json) -> bool { + if (!is_convertible_to_string(json)) + return false; - auto add_json_to_buf = [&](const JsonType& json) { size_t start = buf_.size(); buf_ += json.as_string(); sizes.push_back(buf_.size() - start); + return true; }; - if (!path_res[0].is_array()) { - sizes.reserve(path_res.size()); - for (const auto& element : path_res) { - if (!element.is_string()) - return std::nullopt; - - add_json_to_buf(element); - } - } else { - if (path_res.size() > 1) { - return std::nullopt; - } - - sizes.reserve(path_res[0].size()); - for (const auto& element : path_res[0].array_range()) { - if (!element.is_string()) - return std::nullopt; - - add_json_to_buf(element); - } + if (!ProcessJsonElements(path_res, std::move(add_json_element_to_buf))) { + return std::nullopt; } // Reposition start pointers to the most recent allocation of buf @@ -260,7 +293,7 @@ std::optional JsonAccessor::GetVector(string_view acti return VectorInfo{}; auto res = path->Evaluate(json_); - if (res.empty()) + if (res.empty() || res[0].is_null()) return VectorInfo{}; if (!res[0].is_array()) @@ -290,24 +323,18 @@ std::optional JsonAccessor::GetNumbers(string_view activ return search::EmptyAccessResult(); NumsList nums_list; - if (!path_res[0].is_array()) { - nums_list.reserve(path_res.size()); - for (const auto& element : path_res) { - if (!element.is_number()) - return std::nullopt; - nums_list.push_back(element.as()); - } - } else { - if (path_res.size() > 1) { - return std::nullopt; - } + nums_list.reserve(path_res.size()); - nums_list.reserve(path_res[0].size()); - for (const auto& element : path_res[0].array_range()) { - if (!element.is_number()) - return std::nullopt; - nums_list.push_back(element.as()); - } + // Returns true if json element is convertiable to number + auto add_json_element = [&](const JsonType& json) -> bool { + if (!json.is_number()) + return false; + nums_list.push_back(json.as()); + return true; + }; + + if (!ProcessJsonElements(path_res, std::move(add_json_element))) { + return std::nullopt; } return nums_list; } diff --git a/src/server/search/doc_accessors.h b/src/server/search/doc_accessors.h index f4f0baca6..35f1723fc 100644 --- a/src/server/search/doc_accessors.h +++ b/src/server/search/doc_accessors.h @@ -40,8 +40,9 @@ struct BaseAccessor : public search::DocumentAccessor { virtual SearchDocData SerializeDocument(const search::Schema& schema) const; // Default implementation uses GetStrings - virtual std::optional GetVector(std::string_view active_field) const; - virtual std::optional GetNumbers(std::string_view active_field) const; + virtual std::optional GetVector(std::string_view active_field) const override; + virtual std::optional GetNumbers(std::string_view active_field) const override; + virtual std::optional GetTags(std::string_view active_field) const override; }; // Accessor for hashes stored with listpack @@ -81,6 +82,7 @@ struct JsonAccessor : public BaseAccessor { std::optional GetStrings(std::string_view field) const override; std::optional GetVector(std::string_view field) const override; std::optional GetNumbers(std::string_view active_field) const override; + std::optional GetTags(std::string_view active_field) const override; // The JsonAccessor works with structured types and not plain strings, so an overload is needed SearchDocData Serialize(const search::Schema& schema, @@ -91,6 +93,9 @@ struct JsonAccessor : public BaseAccessor { static void RemoveFieldFromCache(std::string_view field); private: + /* If accept_boolean_values is true, then json boolean values are converted to strings */ + std::optional GetStrings(std::string_view field, bool accept_boolean_values) const; + /// Parses `field` into a JSON path. Caches the results internally. JsonPathContainer* GetPath(std::string_view field) const; diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 6aefe955a..0e1aebbd9 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -1298,9 +1298,15 @@ TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) { Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"}); Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"}); Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"}); - Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"}); - Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"}); - Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"}); + Run({"JSON.SET", "j4", ".", + R"({"data":["first", "second", "third"],"name":"doc_with_strings"})"}); + Run({"JSON.SET", "j5", ".", R"({"name":"no_data"})"}); + Run({"JSON.SET", "j6", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"}); + Run({"JSON.SET", "j7", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"}); + Run({"JSON.SET", "j8", ".", R"({"data":null,"name":"doc_with_null"})"}); + Run({"JSON.SET", "j9", ".", R"({"data":[null, null, null],"name":"doc_with_nulls"})"}); + Run({"JSON.SET", "j10", ".", R"({"data":true,"name":"doc_with_boolean"})"}); + Run({"JSON.SET", "j11", ".", R"({"data":[true, false, true],"name":"doc_with_booleans"})"}); auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"}); EXPECT_EQ(resp, "OK"); @@ -1328,25 +1334,25 @@ TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) { EXPECT_EQ(resp, "OK"); resp = Run({"FT.SEARCH", "i1", "*"}); - EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5")); + EXPECT_THAT(resp, AreDocIds("j1", "j5", "j6", "j8", "j9")); resp = Run({"FT.SEARCH", "i2", "*"}); - EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5")); + EXPECT_THAT(resp, AreDocIds("j1", "j5", "j6", "j8", "j9")); resp = Run({"FT.SEARCH", "i3", "*"}); - EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4")); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9", "j10", "j11")); resp = Run({"FT.SEARCH", "i4", "*"}); - EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4")); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9", "j10", "j11")); resp = Run({"FT.SEARCH", "i5", "*"}); - EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6")); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9")); resp = Run({"FT.SEARCH", "i6", "*"}); - EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6")); + EXPECT_THAT(resp, AreDocIds("j2", "j3", "j4", "j5", "j7", "j8", "j9")); resp = Run({"FT.SEARCH", "i7", "*"}); - EXPECT_THAT(resp, AreDocIds("j4", "j5")); + EXPECT_THAT(resp, AreDocIds("j5", "j6", "j8")); } TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) { @@ -1417,6 +1423,12 @@ TEST_F(SearchFamilyTest, WrongVectorFieldType) { Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"}); Run({"JSON.SET", "j7", ".", R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"}); + Run({"JSON.SET", "j8", ".", R"({"vector_field":null, "name": "doc_with_null"})"}); + Run({"JSON.SET", "j9", ".", R"({"vector_field":[null, null, null], "name": "doc_with_nulls"})"}); + Run({"JSON.SET", "j10", ".", R"({"vector_field":true, "name": "doc_with_boolean"})"}); + Run({"JSON.SET", "j11", ".", + R"({"vector_field":[true, false, true], "name": "doc_with_booleans"})"}); + Run({"JSON.SET", "j12", ".", R"({"vector_field":1, "name": "doc_with_int"})"}); auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field", @@ -1424,7 +1436,7 @@ TEST_F(SearchFamilyTest, WrongVectorFieldType) { EXPECT_EQ(resp, "OK"); resp = Run({"FT.SEARCH", "index", "*"}); - EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4")); + EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4", "j8")); } #ifndef SANITIZERS