From e71f083f34b31f1149e0d1fc6fbec13f9dfe68a5 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 10 Oct 2024 21:58:12 +0300 Subject: [PATCH] feat(search): STOPWORDS (#3851) Adds support for STOPWORDS option --- src/core/search/indices.cc | 11 +++--- src/core/search/indices.h | 8 ++++- src/core/search/search.cc | 19 +++++++--- src/core/search/search.h | 18 +++++++--- src/core/search/search_test.cc | 48 ++++++++++++++++++++----- src/server/search/doc_index.cc | 35 ++++++++++-------- src/server/search/doc_index.h | 9 ++--- src/server/search/search_family.cc | 8 +++++ src/server/search/search_family_test.cc | 16 +++++++++ tests/dragonfly/search_test.py | 6 ++++ 10 files changed, 138 insertions(+), 40 deletions(-) diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index a6d490e0c..01121d552 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -39,10 +39,13 @@ string ToLower(string_view word) { } // Get all words from text as matched by the ICU library -absl::flat_hash_set TokenizeWords(std::string_view text) { +absl::flat_hash_set TokenizeWords(std::string_view text, + const TextIndex::StopWords& stopwords) { absl::flat_hash_set words; - for (std::string_view word : una::views::word_only::utf8(text)) - words.insert(una::cases::to_lowercase_utf8(word)); + for (std::string_view word : una::views::word_only::utf8(text)) { + if (std::string word_lc = una::cases::to_lowercase_utf8(word); !stopwords.contains(word_lc)) + words.insert(std::move(word_lc)); + } return words; } @@ -166,7 +169,7 @@ template struct BaseStringIndex; template struct BaseStringIndex; absl::flat_hash_set TextIndex::Tokenize(std::string_view value) const { - return TokenizeWords(value); + return TokenizeWords(value, *stopwords_); } absl::flat_hash_set TagIndex::Tokenize(std::string_view value) const { diff --git a/src/core/search/indices.h b/src/core/search/indices.h index e8b070e61..84bedd8eb 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -87,10 +87,16 @@ template struct BaseStringIndex : public BaseIndex { // Index for text fields. // Hashmap based lookup per word. struct TextIndex : public BaseStringIndex { - TextIndex(PMR_NS::memory_resource* mr) : BaseStringIndex(mr, false) { + using StopWords = absl::flat_hash_set; + + TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords) + : BaseStringIndex(mr, false), stopwords_{stopwords} { } absl::flat_hash_set Tokenize(std::string_view value) const override; + + private: + const StopWords* stopwords_; }; // Index for text fields. diff --git a/src/core/search/search.cc b/src/core/search/search.cc index a6a68dd36..2ce14b97f 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -13,6 +13,7 @@ #include #include +#include "absl/container/flat_hash_set.h" #include "base/logging.h" #include "core/overloaded.h" #include "core/search/ast_expr.h" @@ -454,8 +455,18 @@ string_view Schema::LookupAlias(string_view alias) const { return alias; } -FieldIndices::FieldIndices(Schema schema, PMR_NS::memory_resource* mr) - : schema_{std::move(schema)}, all_ids_{}, indices_{} { +IndicesOptions::IndicesOptions() { + static absl::flat_hash_set kDefaultStopwords{ + "a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by", + "for", "if", "in", "into", "it", "no", "not", "of", "on", "or", "such", + "that", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"}; + + stopwords = kDefaultStopwords; +} + +FieldIndices::FieldIndices(const Schema& schema, const IndicesOptions& options, + PMR_NS::memory_resource* mr) + : schema_{schema}, options_{options} { CreateIndices(mr); CreateSortIndices(mr); } @@ -467,7 +478,7 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) { switch (field_info.type) { case SchemaField::TEXT: - indices_[field_ident] = make_unique(mr); + indices_[field_ident] = make_unique(mr, &options_.stopwords); break; case SchemaField::NUMERIC: indices_[field_ident] = make_unique(mr); @@ -546,7 +557,7 @@ BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const { std::vector FieldIndices::GetAllTextIndices() const { vector out; - for (auto& [field_name, field_info] : schema_.fields) { + for (const auto& [field_name, field_info] : schema_.fields) { if (field_info.type != SchemaField::TEXT || (field_info.flags & SchemaField::NOINDEX) > 0) continue; auto* index = dynamic_cast(GetIndex(field_name)); diff --git a/src/core/search/search.h b/src/core/search/search.h index 5e8b14c94..d52e60648 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -61,11 +62,20 @@ struct Schema { std::string_view LookupAlias(std::string_view alias) const; }; +struct IndicesOptions { + IndicesOptions(); + explicit IndicesOptions(absl::flat_hash_set stopwords) + : stopwords{std::move(stopwords)} { + } + + absl::flat_hash_set stopwords; +}; + // Collection of indices for all fields in schema class FieldIndices { public: - // Create indices based on schema - FieldIndices(Schema schema, PMR_NS::memory_resource* mr); + // Create indices based on schema and options. Both must outlive the indices + FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr); void Add(DocId doc, DocumentAccessor* access); void Remove(DocId doc, DocumentAccessor* access); @@ -84,8 +94,8 @@ class FieldIndices { void CreateIndices(PMR_NS::memory_resource* mr); void CreateSortIndices(PMR_NS::memory_resource* mr); - private: - Schema schema_; + const Schema& schema_; + const IndicesOptions& options_; std::vector all_ids_; absl::flat_hash_map> indices_; absl::flat_hash_map> sort_indices_; diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index 6e8fe9472..a5bc5495a 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -66,6 +66,8 @@ struct MockedDocument : public DocumentAccessor { Map fields_{}; }; +IndicesOptions kEmptyOptions{{}}; + Schema MakeSimpleSchema(initializer_list> ilist) { Schema schema; for (auto [name, type] : ilist) { @@ -105,7 +107,7 @@ class SearchTest : public ::testing::Test { bool Check() { absl::Cleanup cl{[this] { entries_.clear(); }}; - FieldIndices index{schema_, PMR_NS::get_default_resource()}; + FieldIndices index{schema_, kEmptyOptions, PMR_NS::get_default_resource()}; shuffle(entries_.begin(), entries_.end(), default_random_engine{}); for (DocId i = 0; i < entries_.size(); i++) @@ -372,6 +374,36 @@ TEST_F(SearchTest, IntegerTerms) { EXPECT_TRUE(Check()) << GetError(); } +TEST_F(SearchTest, StopWords) { + auto schema = MakeSimpleSchema({{"title", SchemaField::TEXT}}); + IndicesOptions options{{"some", "words", "are", "left", "out"}}; + + FieldIndices indices{schema, options, PMR_NS::get_default_resource()}; + SearchAlgorithm algo{}; + QueryParams params; + + vector documents = {"some words left out", // + "some can be found", // + "words are never matched", // + "explicitly found!"}; + for (size_t i = 0; i < documents.size(); i++) { + MockedDocument doc{{{"title", documents[i]}}}; + indices.Add(i, &doc); + } + + // words is a stopword + algo.Init("words", ¶ms); + EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre()); + + // some is a stopword + algo.Init("some", ¶ms); + EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre()); + + // found is not a stopword + algo.Init("found", ¶ms); + EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 3)); +} + std::string ToBytes(absl::Span vec) { return string{reinterpret_cast(vec.data()), sizeof(float) * vec.size()}; } @@ -380,7 +412,7 @@ TEST_F(SearchTest, Errors) { auto schema = MakeSimpleSchema( {{"score", SchemaField::NUMERIC}, {"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1}; - FieldIndices indices{schema, PMR_NS::get_default_resource()}; + FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()}; SearchAlgorithm algo{}; QueryParams params; @@ -404,7 +436,7 @@ class KnnTest : public SearchTest, public testing::WithParamInterface documents(10); for (size_t i = 0; i < 10; i++) { @@ -615,7 +647,7 @@ TEST_P(KnnTest, AutoResize) { auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{GetParam(), 1, VectorSimilarity::L2, kInitialCapacity}; - FieldIndices indices{schema, PMR_NS::get_default_resource()}; + FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()}; for (size_t i = 0; i < 100; i++) { MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}}; @@ -634,7 +666,7 @@ static void BM_VectorSearch(benchmark::State& state) { auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); schema.fields["pos"].special_params = SchemaField::VectorParams{false, ndims}; - FieldIndices indices{schema, PMR_NS::get_default_resource()}; + FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()}; auto random_vec = [ndims]() { vector coords; diff --git a/src/server/search/doc_index.cc b/src/server/search/doc_index.cc index 5a5b00145..65067bda4 100644 --- a/src/server/search/doc_index.cc +++ b/src/server/search/doc_index.cc @@ -8,6 +8,7 @@ #include +#include "absl/strings/str_cat.h" #include "base/logging.h" #include "core/overloaded.h" #include "core/search/indices.h" @@ -89,6 +90,11 @@ string DocIndexInfo::BuildRestoreCommand() const { if (!base_index.prefix.empty()) absl::StrAppend(&out, " PREFIX", " 1 ", base_index.prefix); + // STOPWORDS + absl::StrAppend(&out, " STOPWORDS ", base_index.options.stopwords.size()); + for (const auto& sw : base_index.options.stopwords) + absl::StrAppend(&out, " ", sw); + absl::StrAppend(&out, " SCHEMA"); for (const auto& [fident, finfo] : base_index.schema.fields) { // Store field name, alias and type @@ -170,36 +176,35 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const { return obj_code == GetObjCode() && key.rfind(prefix, 0) == 0; } -ShardDocIndex::ShardDocIndex(shared_ptr index) - : base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} { +ShardDocIndex::ShardDocIndex(shared_ptr index) + : base_{std::move(index)}, key_index_{} { } void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) { key_index_ = DocKeyIndex{}; - indices_ = search::FieldIndices{base_->schema, mr}; + indices_.emplace(base_->schema, base_->options, mr); - auto cb = [this](string_view key, BaseAccessor* doc) { indices_.Add(key_index_.Add(key), doc); }; + auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); }; TraverseAllMatching(*base_, op_args, cb); - was_built_ = true; VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix; } void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { - if (!was_built_) + if (!indices_) return; auto accessor = GetAccessor(db_cntx, pv); - indices_.Add(key_index_.Add(key), accessor.get()); + indices_->Add(key_index_.Add(key), accessor.get()); } void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { - if (!was_built_) + if (!indices_) return; auto accessor = GetAccessor(db_cntx, pv); DocId id = key_index_.Remove(key); - indices_.Remove(id, accessor.get()); + indices_->Remove(id, accessor.get()); } bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { @@ -209,7 +214,7 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params, search::SearchAlgorithm* search_algo) const { auto& db_slice = op_args.GetDbSlice(); - auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total); + auto search_results = search_algo->Search(&*indices_, params.limit_offset + params.limit_total); if (!search_results.error.empty()) return SearchResult{facade::ErrorReply{std::move(search_results.error)}}; @@ -253,7 +258,7 @@ vector ShardDocIndex::SearchForAggregator( const OpArgs& op_args, const AggregateParams& params, search::SearchAlgorithm* search_algo) const { auto& db_slice = op_args.GetDbSlice(); - auto search_results = search_algo->Search(&indices_); + auto search_results = search_algo->Search(&*indices_); if (!search_results.error.empty()) return {}; @@ -267,7 +272,7 @@ vector ShardDocIndex::SearchForAggregator( continue; auto accessor = GetAccessor(op_args.db_cntx, (*it)->second); - auto extracted = indices_.ExtractStoredValues(doc); + auto extracted = indices_->ExtractStoredValues(doc); SearchDocData loaded; if (params.load_fields.ShouldReturnAllFields()) { @@ -290,7 +295,7 @@ DocIndexInfo ShardDocIndex::GetInfo() const { } io::Result ShardDocIndex::GetTagVals(string_view field) const { - search::BaseIndex* base_index = indices_.GetIndex(field); + search::BaseIndex* base_index = indices_->GetIndex(field); if (base_index == nullptr) { return make_unexpected(ErrorReply{"-No such field"}); } @@ -312,8 +317,8 @@ ShardDocIndex* ShardDocIndices::GetIndex(string_view name) { } void ShardDocIndices::InitIndex(const OpArgs& op_args, std::string_view name, - shared_ptr index_ptr) { - auto shard_index = make_unique(index_ptr); + shared_ptr index_ptr) { + auto shard_index = make_unique(std::move(index_ptr)); auto [it, _] = indices_.emplace(name, std::move(shard_index)); // Don't build while loading, shutting down, etc. diff --git a/src/server/search/doc_index.h b/src/server/search/doc_index.h index 6c96ca66d..9e0b268da 100644 --- a/src/server/search/doc_index.h +++ b/src/server/search/doc_index.h @@ -120,6 +120,7 @@ struct DocIndex { bool Matches(std::string_view key, unsigned obj_code) const; search::Schema schema; + search::IndicesOptions options{}; std::string prefix{}; DataType type{HASH}; }; @@ -156,7 +157,7 @@ class ShardDocIndex { public: // Index must be rebuilt at least once after intialization - ShardDocIndex(std::shared_ptr index); + ShardDocIndex(std::shared_ptr index); // Perform search on all indexed documents and return results. SearchResult Search(const OpArgs& op_args, const SearchParams& params, @@ -182,9 +183,8 @@ class ShardDocIndex { void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr); private: - bool was_built_ = false; std::shared_ptr base_; - search::FieldIndices indices_; + std::optional indices_; DocKeyIndex key_index_; }; @@ -198,7 +198,8 @@ class ShardDocIndices { // Init index: create shard local state for given index with given name. // Build if instance is in active state. - void InitIndex(const OpArgs& op_args, std::string_view name, std::shared_ptr index); + void InitIndex(const OpArgs& op_args, std::string_view name, + std::shared_ptr index); // Drop index, return true if it existed and was dropped bool DropIndex(std::string_view name); diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 0fc601b3d..2a646c687 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -485,6 +485,14 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) { continue; } + // STOWORDS count [words...] + if (parser.Check("STOPWORDS")) { + index.options.stopwords.clear(); + for (size_t num = parser.Next(); num > 0; num--) + index.options.stopwords.emplace(parser.Next()); + continue; + } + // SCHEMA if (parser.Check("SCHEMA")) { auto schema = ParseSchemaOrReply(index.type, parser.Tail(), cntx); diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 08e63bb26..87ac9d4e8 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -604,6 +604,22 @@ TEST_F(SearchFamilyTest, TestReturn) { EXPECT_THAT(resp, MatchEntry("k0", "vec_return", "20")); } +TEST_F(SearchFamilyTest, TestStopWords) { + Run({"ft.create", "i1", "STOPWORDS", "3", "red", "green", "blue", "SCHEMA", "title", "TEXT"}); + + Run({"hset", "d:1", "title", "ReD? parrot flies away"}); + Run({"hset", "d:2", "title", "GrEEn crocodile eats you"}); + Run({"hset", "d:3", "title", "BLUe. Whale surfes the sea"}); + + EXPECT_THAT(Run({"ft.search", "i1", "red"}), kNoResults); + EXPECT_THAT(Run({"ft.search", "i1", "green"}), kNoResults); + EXPECT_THAT(Run({"ft.search", "i1", "blue"}), kNoResults); + + EXPECT_THAT(Run({"ft.search", "i1", "parrot"}), AreDocIds("d:1")); + EXPECT_THAT(Run({"ft.search", "i1", "crocodile"}), AreDocIds("d:2")); + EXPECT_THAT(Run({"ft.search", "i1", "whale"}), AreDocIds("d:3")); +} + TEST_F(SearchFamilyTest, SimpleUpdates) { EXPECT_EQ(Run({"ft.create", "i1", "schema", "title", "text", "visits", "numeric"}), "OK"); diff --git a/tests/dragonfly/search_test.py b/tests/dragonfly/search_test.py index e5f1609a9..4e686164d 100644 --- a/tests/dragonfly/search_test.py +++ b/tests/dragonfly/search_test.py @@ -388,6 +388,7 @@ async def test_index_persistence(df_server): i1 = client.ft("i1") await i1.create_index( fix_schema_naming(IndexType.JSON, SCHEMA_1), + stopwords=["interesting", "stopwords"], definition=IndexDefinition(index_type=IndexType.JSON, prefix=["blog-"]), ) @@ -470,6 +471,11 @@ async def test_index_persistence(df_server): "age" ] == "199" + # Check stopwords were loaded + await client.json().set("blog-sw1", ".", {"title": "some stopwords"}) + assert (await i1.search("some")).total == 1 + assert (await i1.search("stopwords")).total == 0 + await i1.dropindex() await i2.dropindex()