feat(search): STOPWORDS (#3851)

Adds support for STOPWORDS option
This commit is contained in:
Vladislav 2024-10-10 21:58:12 +03:00 committed by GitHub
parent d876bcd5cb
commit e71f083f34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 138 additions and 40 deletions

View file

@ -39,10 +39,13 @@ string ToLower(string_view word) {
}
// Get all words from text as matched by the ICU library
absl::flat_hash_set<std::string> TokenizeWords(std::string_view text) {
absl::flat_hash_set<std::string> TokenizeWords(std::string_view text,
const TextIndex::StopWords& stopwords) {
absl::flat_hash_set<std::string> words;
for (std::string_view word : una::views::word_only::utf8(text))
words.insert(una::cases::to_lowercase_utf8(word));
for (std::string_view word : una::views::word_only::utf8(text)) {
if (std::string word_lc = una::cases::to_lowercase_utf8(word); !stopwords.contains(word_lc))
words.insert(std::move(word_lc));
}
return words;
}
@ -166,7 +169,7 @@ template struct BaseStringIndex<CompressedSortedSet>;
template struct BaseStringIndex<SortedVector>;
absl::flat_hash_set<std::string> TextIndex::Tokenize(std::string_view value) const {
return TokenizeWords(value);
return TokenizeWords(value, *stopwords_);
}
absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) const {

View file

@ -87,10 +87,16 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
// Index for text fields.
// Hashmap based lookup per word.
struct TextIndex : public BaseStringIndex<CompressedSortedSet> {
TextIndex(PMR_NS::memory_resource* mr) : BaseStringIndex(mr, false) {
using StopWords = absl::flat_hash_set<std::string>;
TextIndex(PMR_NS::memory_resource* mr, const StopWords* stopwords)
: BaseStringIndex(mr, false), stopwords_{stopwords} {
}
absl::flat_hash_set<std::string> Tokenize(std::string_view value) const override;
private:
const StopWords* stopwords_;
};
// Index for text fields.

View file

@ -13,6 +13,7 @@
#include <type_traits>
#include <variant>
#include "absl/container/flat_hash_set.h"
#include "base/logging.h"
#include "core/overloaded.h"
#include "core/search/ast_expr.h"
@ -454,8 +455,18 @@ string_view Schema::LookupAlias(string_view alias) const {
return alias;
}
FieldIndices::FieldIndices(Schema schema, PMR_NS::memory_resource* mr)
: schema_{std::move(schema)}, all_ids_{}, indices_{} {
IndicesOptions::IndicesOptions() {
static absl::flat_hash_set<std::string> kDefaultStopwords{
"a", "is", "the", "an", "and", "are", "as", "at", "be", "but", "by",
"for", "if", "in", "into", "it", "no", "not", "of", "on", "or", "such",
"that", "their", "then", "there", "these", "they", "this", "to", "was", "will", "with"};
stopwords = kDefaultStopwords;
}
FieldIndices::FieldIndices(const Schema& schema, const IndicesOptions& options,
PMR_NS::memory_resource* mr)
: schema_{schema}, options_{options} {
CreateIndices(mr);
CreateSortIndices(mr);
}
@ -467,7 +478,7 @@ void FieldIndices::CreateIndices(PMR_NS::memory_resource* mr) {
switch (field_info.type) {
case SchemaField::TEXT:
indices_[field_ident] = make_unique<TextIndex>(mr);
indices_[field_ident] = make_unique<TextIndex>(mr, &options_.stopwords);
break;
case SchemaField::NUMERIC:
indices_[field_ident] = make_unique<NumericIndex>(mr);
@ -546,7 +557,7 @@ BaseSortIndex* FieldIndices::GetSortIndex(string_view field) const {
std::vector<TextIndex*> FieldIndices::GetAllTextIndices() const {
vector<TextIndex*> out;
for (auto& [field_name, field_info] : schema_.fields) {
for (const auto& [field_name, field_info] : schema_.fields) {
if (field_info.type != SchemaField::TEXT || (field_info.flags & SchemaField::NOINDEX) > 0)
continue;
auto* index = dynamic_cast<TextIndex*>(GetIndex(field_name));

View file

@ -5,6 +5,7 @@
#pragma once
#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <functional>
#include <memory>
@ -61,11 +62,20 @@ struct Schema {
std::string_view LookupAlias(std::string_view alias) const;
};
struct IndicesOptions {
IndicesOptions();
explicit IndicesOptions(absl::flat_hash_set<std::string> stopwords)
: stopwords{std::move(stopwords)} {
}
absl::flat_hash_set<std::string> stopwords;
};
// Collection of indices for all fields in schema
class FieldIndices {
public:
// Create indices based on schema
FieldIndices(Schema schema, PMR_NS::memory_resource* mr);
// Create indices based on schema and options. Both must outlive the indices
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);
void Add(DocId doc, DocumentAccessor* access);
void Remove(DocId doc, DocumentAccessor* access);
@ -84,8 +94,8 @@ class FieldIndices {
void CreateIndices(PMR_NS::memory_resource* mr);
void CreateSortIndices(PMR_NS::memory_resource* mr);
private:
Schema schema_;
const Schema& schema_;
const IndicesOptions& options_;
std::vector<DocId> all_ids_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseIndex>> indices_;
absl::flat_hash_map<std::string, std::unique_ptr<BaseSortIndex>> sort_indices_;

View file

@ -66,6 +66,8 @@ struct MockedDocument : public DocumentAccessor {
Map fields_{};
};
IndicesOptions kEmptyOptions{{}};
Schema MakeSimpleSchema(initializer_list<pair<string_view, SchemaField::FieldType>> ilist) {
Schema schema;
for (auto [name, type] : ilist) {
@ -105,7 +107,7 @@ class SearchTest : public ::testing::Test {
bool Check() {
absl::Cleanup cl{[this] { entries_.clear(); }};
FieldIndices index{schema_, PMR_NS::get_default_resource()};
FieldIndices index{schema_, kEmptyOptions, PMR_NS::get_default_resource()};
shuffle(entries_.begin(), entries_.end(), default_random_engine{});
for (DocId i = 0; i < entries_.size(); i++)
@ -372,6 +374,36 @@ TEST_F(SearchTest, IntegerTerms) {
EXPECT_TRUE(Check()) << GetError();
}
TEST_F(SearchTest, StopWords) {
auto schema = MakeSimpleSchema({{"title", SchemaField::TEXT}});
IndicesOptions options{{"some", "words", "are", "left", "out"}};
FieldIndices indices{schema, options, PMR_NS::get_default_resource()};
SearchAlgorithm algo{};
QueryParams params;
vector<string> documents = {"some words left out", //
"some can be found", //
"words are never matched", //
"explicitly found!"};
for (size_t i = 0; i < documents.size(); i++) {
MockedDocument doc{{{"title", documents[i]}}};
indices.Add(i, &doc);
}
// words is a stopword
algo.Init("words", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre());
// some is a stopword
algo.Init("some", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre());
// found is not a stopword
algo.Init("found", &params);
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 3));
}
std::string ToBytes(absl::Span<const float> vec) {
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
}
@ -380,7 +412,7 @@ TEST_F(SearchTest, Errors) {
auto schema = MakeSimpleSchema(
{{"score", SchemaField::NUMERIC}, {"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params = SchemaField::VectorParams{false, 1};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
SearchAlgorithm algo{};
QueryParams params;
@ -404,7 +436,7 @@ class KnnTest : public SearchTest, public testing::WithParamInterface<bool /* hn
TEST_P(KnnTest, Simple1D) {
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params = SchemaField::VectorParams{GetParam(), 1};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
// Place points on a straight line
for (size_t i = 0; i < 100; i++) {
@ -461,7 +493,7 @@ TEST_P(KnnTest, Simple2D) {
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params = SchemaField::VectorParams{GetParam(), 2};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
@ -523,7 +555,7 @@ TEST_P(KnnTest, Cosine) {
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params =
SchemaField::VectorParams{GetParam(), 2, VectorSimilarity::COSINE};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
@ -567,7 +599,7 @@ TEST_P(KnnTest, AddRemove) {
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params =
SchemaField::VectorParams{GetParam(), 1, VectorSimilarity::L2};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
vector<MockedDocument> documents(10);
for (size_t i = 0; i < 10; i++) {
@ -615,7 +647,7 @@ TEST_P(KnnTest, AutoResize) {
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params =
SchemaField::VectorParams{GetParam(), 1, VectorSimilarity::L2, kInitialCapacity};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
for (size_t i = 0; i < 100; i++) {
MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}};
@ -634,7 +666,7 @@ static void BM_VectorSearch(benchmark::State& state) {
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
schema.fields["pos"].special_params = SchemaField::VectorParams{false, ndims};
FieldIndices indices{schema, PMR_NS::get_default_resource()};
FieldIndices indices{schema, kEmptyOptions, PMR_NS::get_default_resource()};
auto random_vec = [ndims]() {
vector<float> coords;

View file

@ -8,6 +8,7 @@
#include <memory>
#include "absl/strings/str_cat.h"
#include "base/logging.h"
#include "core/overloaded.h"
#include "core/search/indices.h"
@ -89,6 +90,11 @@ string DocIndexInfo::BuildRestoreCommand() const {
if (!base_index.prefix.empty())
absl::StrAppend(&out, " PREFIX", " 1 ", base_index.prefix);
// STOPWORDS
absl::StrAppend(&out, " STOPWORDS ", base_index.options.stopwords.size());
for (const auto& sw : base_index.options.stopwords)
absl::StrAppend(&out, " ", sw);
absl::StrAppend(&out, " SCHEMA");
for (const auto& [fident, finfo] : base_index.schema.fields) {
// Store field name, alias and type
@ -170,36 +176,35 @@ bool DocIndex::Matches(string_view key, unsigned obj_code) const {
return obj_code == GetObjCode() && key.rfind(prefix, 0) == 0;
}
ShardDocIndex::ShardDocIndex(shared_ptr<DocIndex> index)
: base_{std::move(index)}, indices_{{}, nullptr}, key_index_{} {
ShardDocIndex::ShardDocIndex(shared_ptr<const DocIndex> index)
: base_{std::move(index)}, key_index_{} {
}
void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr) {
key_index_ = DocKeyIndex{};
indices_ = search::FieldIndices{base_->schema, mr};
indices_.emplace(base_->schema, base_->options, mr);
auto cb = [this](string_view key, BaseAccessor* doc) { indices_.Add(key_index_.Add(key), doc); };
auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); };
TraverseAllMatching(*base_, op_args, cb);
was_built_ = true;
VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix;
}
void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
if (!was_built_)
if (!indices_)
return;
auto accessor = GetAccessor(db_cntx, pv);
indices_.Add(key_index_.Add(key), accessor.get());
indices_->Add(key_index_.Add(key), accessor.get());
}
void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
if (!was_built_)
if (!indices_)
return;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_.Remove(id, accessor.get());
indices_->Remove(id, accessor.get());
}
bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
@ -209,7 +214,7 @@ bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
SearchResult ShardDocIndex::Search(const OpArgs& op_args, const SearchParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.GetDbSlice();
auto search_results = search_algo->Search(&indices_, params.limit_offset + params.limit_total);
auto search_results = search_algo->Search(&*indices_, params.limit_offset + params.limit_total);
if (!search_results.error.empty())
return SearchResult{facade::ErrorReply{std::move(search_results.error)}};
@ -253,7 +258,7 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
const OpArgs& op_args, const AggregateParams& params,
search::SearchAlgorithm* search_algo) const {
auto& db_slice = op_args.GetDbSlice();
auto search_results = search_algo->Search(&indices_);
auto search_results = search_algo->Search(&*indices_);
if (!search_results.error.empty())
return {};
@ -267,7 +272,7 @@ vector<SearchDocData> ShardDocIndex::SearchForAggregator(
continue;
auto accessor = GetAccessor(op_args.db_cntx, (*it)->second);
auto extracted = indices_.ExtractStoredValues(doc);
auto extracted = indices_->ExtractStoredValues(doc);
SearchDocData loaded;
if (params.load_fields.ShouldReturnAllFields()) {
@ -290,7 +295,7 @@ DocIndexInfo ShardDocIndex::GetInfo() const {
}
io::Result<StringVec, ErrorReply> ShardDocIndex::GetTagVals(string_view field) const {
search::BaseIndex* base_index = indices_.GetIndex(field);
search::BaseIndex* base_index = indices_->GetIndex(field);
if (base_index == nullptr) {
return make_unexpected(ErrorReply{"-No such field"});
}
@ -312,8 +317,8 @@ ShardDocIndex* ShardDocIndices::GetIndex(string_view name) {
}
void ShardDocIndices::InitIndex(const OpArgs& op_args, std::string_view name,
shared_ptr<DocIndex> index_ptr) {
auto shard_index = make_unique<ShardDocIndex>(index_ptr);
shared_ptr<const DocIndex> index_ptr) {
auto shard_index = make_unique<ShardDocIndex>(std::move(index_ptr));
auto [it, _] = indices_.emplace(name, std::move(shard_index));
// Don't build while loading, shutting down, etc.

View file

@ -120,6 +120,7 @@ struct DocIndex {
bool Matches(std::string_view key, unsigned obj_code) const;
search::Schema schema;
search::IndicesOptions options{};
std::string prefix{};
DataType type{HASH};
};
@ -156,7 +157,7 @@ class ShardDocIndex {
public:
// Index must be rebuilt at least once after intialization
ShardDocIndex(std::shared_ptr<DocIndex> index);
ShardDocIndex(std::shared_ptr<const DocIndex> index);
// Perform search on all indexed documents and return results.
SearchResult Search(const OpArgs& op_args, const SearchParams& params,
@ -182,9 +183,8 @@ class ShardDocIndex {
void Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr);
private:
bool was_built_ = false;
std::shared_ptr<const DocIndex> base_;
search::FieldIndices indices_;
std::optional<search::FieldIndices> indices_;
DocKeyIndex key_index_;
};
@ -198,7 +198,8 @@ class ShardDocIndices {
// Init index: create shard local state for given index with given name.
// Build if instance is in active state.
void InitIndex(const OpArgs& op_args, std::string_view name, std::shared_ptr<DocIndex> index);
void InitIndex(const OpArgs& op_args, std::string_view name,
std::shared_ptr<const DocIndex> index);
// Drop index, return true if it existed and was dropped
bool DropIndex(std::string_view name);

View file

@ -485,6 +485,14 @@ void SearchFamily::FtCreate(CmdArgList args, ConnectionContext* cntx) {
continue;
}
// STOWORDS count [words...]
if (parser.Check("STOPWORDS")) {
index.options.stopwords.clear();
for (size_t num = parser.Next<size_t>(); num > 0; num--)
index.options.stopwords.emplace(parser.Next());
continue;
}
// SCHEMA
if (parser.Check("SCHEMA")) {
auto schema = ParseSchemaOrReply(index.type, parser.Tail(), cntx);

View file

@ -604,6 +604,22 @@ TEST_F(SearchFamilyTest, TestReturn) {
EXPECT_THAT(resp, MatchEntry("k0", "vec_return", "20"));
}
TEST_F(SearchFamilyTest, TestStopWords) {
Run({"ft.create", "i1", "STOPWORDS", "3", "red", "green", "blue", "SCHEMA", "title", "TEXT"});
Run({"hset", "d:1", "title", "ReD? parrot flies away"});
Run({"hset", "d:2", "title", "GrEEn crocodile eats you"});
Run({"hset", "d:3", "title", "BLUe. Whale surfes the sea"});
EXPECT_THAT(Run({"ft.search", "i1", "red"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "green"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "blue"}), kNoResults);
EXPECT_THAT(Run({"ft.search", "i1", "parrot"}), AreDocIds("d:1"));
EXPECT_THAT(Run({"ft.search", "i1", "crocodile"}), AreDocIds("d:2"));
EXPECT_THAT(Run({"ft.search", "i1", "whale"}), AreDocIds("d:3"));
}
TEST_F(SearchFamilyTest, SimpleUpdates) {
EXPECT_EQ(Run({"ft.create", "i1", "schema", "title", "text", "visits", "numeric"}), "OK");

View file

@ -388,6 +388,7 @@ async def test_index_persistence(df_server):
i1 = client.ft("i1")
await i1.create_index(
fix_schema_naming(IndexType.JSON, SCHEMA_1),
stopwords=["interesting", "stopwords"],
definition=IndexDefinition(index_type=IndexType.JSON, prefix=["blog-"]),
)
@ -470,6 +471,11 @@ async def test_index_persistence(df_server):
"age"
] == "199"
# Check stopwords were loaded
await client.json().set("blog-sw1", ".", {"title": "some stopwords"})
assert (await i1.search("some")).total == 1
assert (await i1.search("stopwords")).total == 0
await i1.dropindex()
await i2.dropindex()