fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands (#4070)

* fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands

fixes  #3986
---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-11-10 15:56:25 +01:00 committed by GitHub
parent f745f3133d
commit 503bb4ed33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 686 additions and 219 deletions

View file

@ -4,6 +4,8 @@
#include "core/search/base.h" #include "core/search/base.h"
#include <absl/strings/numbers.h>
namespace dfly::search { namespace dfly::search {
std::string_view QueryParams::operator[](std::string_view name) const { std::string_view QueryParams::operator[](std::string_view name) const {
@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const {
return std::string_view{ptr.get(), std::strlen(ptr.get())}; return std::string_view{ptr.get(), std::strlen(ptr.get())};
} }
std::optional<double> ParseNumericField(std::string_view value) {
double value_as_double;
if (absl::SimpleAtod(value, &value_as_double))
return value_as_double;
return std::nullopt;
}
} // namespace dfly::search } // namespace dfly::search

View file

@ -68,11 +68,18 @@ using SortableValue = std::variant<std::monostate, double, std::string>;
struct DocumentAccessor { struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector; using VectorInfo = search::OwnedFtVector;
using StringList = absl::InlinedVector<std::string_view, 1>; using StringList = absl::InlinedVector<std::string_view, 1>;
using NumsList = absl::InlinedVector<double, 1>;
virtual ~DocumentAccessor() = default; virtual ~DocumentAccessor() = default;
virtual StringList GetStrings(std::string_view active_field) const = 0; /* Returns nullopt if the specified field is not a list of strings */
virtual VectorInfo GetVector(std::string_view active_field) const = 0; virtual std::optional<StringList> GetStrings(std::string_view active_field) const = 0;
/* Returns nullopt if the specified field is not a vector */
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const = 0;
/* Return nullopt if the specified field is not a list of doubles */
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
}; };
// Base class for type-specific indices. // Base class for type-specific indices.
@ -81,8 +88,10 @@ struct DocumentAccessor {
// query functions. All results for all index types should be sorted. // query functions. All results for all index types should be sorted.
struct BaseIndex { struct BaseIndex {
virtual ~BaseIndex() = default; virtual ~BaseIndex() = default;
virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0; // Returns true if the document was added / indexed
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
}; };
// Base class for type-specific sorting indices. // Base class for type-specific sorting indices.
@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex {
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0; virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
}; };
/* Used for converting field values to double. Returns std::nullopt if the conversion fails */
std::optional<double> ParseNumericField(std::string_view value);
/* Temporary method to create an empty std::optional<InlinedVector> in DocumentAccessor::GetString
and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation
details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To
suppress this false warning, we temporarily disable it around this block of code using GCC
diagnostic directives. */
template <typename InlinedVector> std::optional<InlinedVector> EmptyAccessResult() {
// GCC 13.1 throws spurious warnings around this code.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
return InlinedVector{};
#pragma GCC diagnostic pop
}
} // namespace dfly::search } // namespace dfly::search

View file

@ -71,19 +71,22 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} { NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
} }
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) {
for (auto str : doc->GetStrings(field)) { auto numbers = doc.GetNumbers(field);
double num; if (!numbers) {
if (absl::SimpleAtod(str, &num)) return false;
entries_.emplace(num, id);
} }
for (auto num : numbers.value()) {
entries_.emplace(num, id);
}
return true;
} }
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
for (auto str : doc->GetStrings(field)) { auto numbers = doc.GetNumbers(field).value();
double num; for (auto num : numbers) {
if (absl::SimpleAtod(str, &num)) entries_.erase({num, id});
entries_.erase({num, id});
} }
} }
@ -139,19 +142,27 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
} }
template <typename C> template <typename C>
void BaseStringIndex<C>::Add(DocId id, DocumentAccessor* doc, string_view field) { bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field);
if (!strings_list) {
return false;
}
absl::flat_hash_set<std::string> tokens; absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field)) for (string_view str : strings_list.value())
tokens.merge(Tokenize(str)); tokens.merge(Tokenize(str));
for (string_view token : tokens) for (string_view token : tokens)
GetOrCreate(token)->Insert(id); GetOrCreate(token)->Insert(id);
return true;
} }
template <typename C> template <typename C>
void BaseStringIndex<C>::Remove(DocId id, DocumentAccessor* doc, string_view field) { void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field).value();
absl::flat_hash_set<std::string> tokens; absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field)) for (string_view str : strings_list)
tokens.merge(Tokenize(str)); tokens.merge(Tokenize(str));
for (const auto& token : tokens) { for (const auto& token : tokens) {
@ -192,6 +203,20 @@ std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_}; return {dim_, sim_};
} }
bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto vector = doc.GetVector(field);
if (!vector)
return false;
auto& [ptr, size] = vector.value();
if (ptr && size != dim_) {
return false;
}
AddVector(id, ptr);
return true;
}
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params, FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
PMR_NS::memory_resource* mr) PMR_NS::memory_resource* mr)
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} { : BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
@ -199,19 +224,18 @@ FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
entries_.reserve(params.capacity * params.dim); entries_.reserve(params.capacity * params.dim);
} }
void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
DCHECK_LE(id * dim_, entries_.size()); DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size()) if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_); entries_.resize((id + 1) * dim_);
// TODO: Let get vector write to buf itself // TODO: Let get vector write to buf itself
auto [ptr, size] = doc->GetVector(field); if (vector) {
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
if (size == dim_) }
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
} }
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
// noop // noop
} }
@ -229,7 +253,7 @@ struct HnswlibAdapter {
100 /* seed*/} { 100 /* seed*/} {
} }
void Add(float* data, DocId id) { void Add(const float* data, DocId id) {
if (world_.cur_element_count + 1 >= world_.max_elements_) if (world_.cur_element_count + 1 >= world_.max_elements_)
world_.resizeIndex(world_.cur_element_count * 2); world_.resizeIndex(world_.cur_element_count * 2);
world_.addPoint(data, id); world_.addPoint(data, id);
@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS
HnswVectorIndex::~HnswVectorIndex() { HnswVectorIndex::~HnswVectorIndex() {
} }
void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
auto [ptr, size] = doc->GetVector(field); if (vector) {
if (size == dim_) adapter_->Add(vector.get(), id);
adapter_->Add(ptr.get(), id); }
} }
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k, std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
@ -314,7 +338,7 @@ std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t
return adapter_->Knn(target, k, ef, allowed); return adapter_->Knn(target, k, ef, allowed);
} }
void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
adapter_->Remove(id); adapter_->Remove(id);
} }

View file

@ -28,8 +28,8 @@ namespace dfly::search {
struct NumericIndex : public BaseIndex { struct NumericIndex : public BaseIndex {
explicit NumericIndex(PMR_NS::memory_resource* mr); explicit NumericIndex(PMR_NS::memory_resource* mr);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
std::vector<DocId> Range(double l, double r) const; std::vector<DocId> Range(double l, double r) const;
@ -44,8 +44,8 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive); BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
// Used by Add & Remove to tokenize text value // Used by Add & Remove to tokenize text value
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0; virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
@ -53,7 +53,7 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
// Pointer is valid as long as index is not mutated. Nullptr if not found // Pointer is valid as long as index is not mutated. Nullptr if not found
const Container* Matching(std::string_view str) const; const Container* Matching(std::string_view str) const;
// Iterate over all Machting on prefix. // Iterate over all Matching on prefix.
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const; void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
// Returns all the terms that appear as keys in the reverse index. // Returns all the terms that appear as keys in the reverse index.
@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
struct BaseVectorIndex : public BaseIndex { struct BaseVectorIndex : public BaseIndex {
std::pair<size_t /*dim*/, VectorSimilarity> Info() const; std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;
protected: protected:
BaseVectorIndex(size_t dim, VectorSimilarity sim); BaseVectorIndex(size_t dim, VectorSimilarity sim);
using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;
size_t dim_; size_t dim_;
VectorSimilarity sim_; VectorSimilarity sim_;
}; };
@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex {
struct FlatVectorIndex : public BaseVectorIndex { struct FlatVectorIndex : public BaseVectorIndex {
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
const float* Get(DocId doc) const; const float* Get(DocId doc) const;
protected:
void AddVector(DocId id, const VectorPtr& vector) override;
private: private:
PMR_NS::vector<float> entries_; PMR_NS::vector<float> entries_;
}; };
@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex {
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr); HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
~HnswVectorIndex(); ~HnswVectorIndex();
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const; std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef, std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
const std::vector<DocId>& allowed) const; const std::vector<DocId>& allowed) const;
protected:
void AddVector(DocId id, const VectorPtr& vector) override;
private: private:
std::unique_ptr<HnswlibAdapter> adapter_; std::unique_ptr<HnswlibAdapter> adapter_;
}; };

View file

@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
} }
} }
void FieldIndices::Add(DocId doc, DocumentAccessor* access) { bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) {
for (auto& [field, index] : indices_) bool was_added = true;
index->Add(doc, access, field);
for (auto& [field, sort_index] : sort_indices_) std::vector<std::pair<std::string_view, BaseIndex*>> successfully_added_indices;
sort_index->Add(doc, access, field); successfully_added_indices.reserve(indices_.size() + sort_indices_.size());
auto try_add = [&](const auto& indices_container) {
for (auto& [field, index] : indices_container) {
if (index->Add(doc, access, field)) {
successfully_added_indices.emplace_back(field, index.get());
} else {
was_added = false;
break;
}
}
};
try_add(indices_);
if (was_added) {
try_add(sort_indices_);
}
if (!was_added) {
for (auto& [field, index] : successfully_added_indices) {
index->Remove(doc, access, field);
}
return false;
}
all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc); all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc);
return true;
} }
void FieldIndices::Remove(DocId doc, DocumentAccessor* access) { void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) {
for (auto& [field, index] : indices_) for (auto& [field, index] : indices_)
index->Remove(doc, access, field); index->Remove(doc, access, field);
for (auto& [field, sort_index] : sort_indices_) for (auto& [field, sort_index] : sort_indices_)
sort_index->Remove(doc, access, field); sort_index->Remove(doc, access, field);
auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc); auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc);
CHECK(it != all_ids_.end() && *it == doc); DCHECK(it != all_ids_.end() && *it == doc);
all_ids_.erase(it); all_ids_.erase(it);
} }

View file

@ -77,8 +77,9 @@ class FieldIndices {
// Create indices based on schema and options. Both must outlive the indices // Create indices based on schema and options. Both must outlive the indices
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr); FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);
void Add(DocId doc, DocumentAccessor* access); // Returns true if document was added
void Remove(DocId doc, DocumentAccessor* access); bool Add(DocId doc, const DocumentAccessor& access);
void Remove(DocId doc, const DocumentAccessor& access);
BaseIndex* GetIndex(std::string_view field) const; BaseIndex* GetIndex(std::string_view field) const;
BaseSortIndex* GetSortIndex(std::string_view field) const; BaseSortIndex* GetSortIndex(std::string_view field) const;

View file

@ -44,13 +44,36 @@ struct MockedDocument : public DocumentAccessor {
MockedDocument(std::string test_field) : fields_{{"field", test_field}} { MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
} }
StringList GetStrings(string_view field) const override { std::optional<StringList> GetStrings(string_view field) const override {
auto it = fields_.find(field); auto it = fields_.find(field);
return {it != fields_.end() ? string_view{it->second} : ""}; if (it == fields_.end()) {
return EmptyAccessResult<StringList>();
}
return StringList{string_view{it->second}};
} }
VectorInfo GetVector(string_view field) const override { std::optional<VectorInfo> GetVector(string_view field) const override {
return BytesToFtVector(GetStrings(field).front()); auto strings_list = GetStrings(field);
if (!strings_list)
return std::nullopt;
return !strings_list->empty() ? BytesToFtVectorSafe(strings_list->front()) : VectorInfo{};
}
std::optional<NumsList> GetNumbers(std::string_view field) const override {
auto strings_list = GetStrings(field);
if (!strings_list)
return std::nullopt;
NumsList nums_list;
nums_list.reserve(strings_list->size());
for (auto str : strings_list.value()) {
auto num = ParseNumericField(str);
if (!num) {
return std::nullopt;
}
nums_list.push_back(num.value());
}
return nums_list;
} }
string DebugFormat() { string DebugFormat() {
@ -121,7 +144,7 @@ class SearchTest : public ::testing::Test {
shuffle(entries_.begin(), entries_.end(), default_random_engine{}); shuffle(entries_.begin(), entries_.end(), default_random_engine{});
for (DocId i = 0; i < entries_.size(); i++) for (DocId i = 0; i < entries_.size(); i++)
index.Add(i, &entries_[i].first); index.Add(i, entries_[i].first);
SearchAlgorithm search_algo{}; SearchAlgorithm search_algo{};
if (!search_algo.Init(query_, &params_)) { if (!search_algo.Init(query_, &params_)) {
@ -430,7 +453,7 @@ TEST_F(SearchTest, StopWords) {
"explicitly found!"}; "explicitly found!"};
for (size_t i = 0; i < documents.size(); i++) { for (size_t i = 0; i < documents.size(); i++) {
MockedDocument doc{{{"title", documents[i]}}}; MockedDocument doc{{{"title", documents[i]}}};
indices.Add(i, &doc); indices.Add(i, doc);
} }
// words is a stopword // words is a stopword
@ -484,7 +507,7 @@ TEST_P(KnnTest, Simple1D) {
for (size_t i = 0; i < 100; i++) { for (size_t i = 0; i < 100; i++) {
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}}; Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
MockedDocument doc{values}; MockedDocument doc{values};
indices.Add(i, &doc); indices.Add(i, doc);
} }
SearchAlgorithm algo{}; SearchAlgorithm algo{};
@ -540,7 +563,7 @@ TEST_P(KnnTest, Simple2D) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
MockedDocument doc{Map{{"pos", coords}}}; MockedDocument doc{Map{{"pos", coords}}};
indices.Add(i, &doc); indices.Add(i, doc);
} }
SearchAlgorithm algo{}; SearchAlgorithm algo{};
@ -602,7 +625,7 @@ TEST_P(KnnTest, Cosine) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
MockedDocument doc{Map{{"pos", coords}}}; MockedDocument doc{Map{{"pos", coords}}};
indices.Add(i, &doc); indices.Add(i, doc);
} }
SearchAlgorithm algo{}; SearchAlgorithm algo{};
@ -646,7 +669,7 @@ TEST_P(KnnTest, AddRemove) {
vector<MockedDocument> documents(10); vector<MockedDocument> documents(10);
for (size_t i = 0; i < 10; i++) { for (size_t i = 0; i < 10; i++) {
documents[i] = Map{{"pos", ToBytes({float(i)})}}; documents[i] = Map{{"pos", ToBytes({float(i)})}};
indices.Add(i, &documents[i]); indices.Add(i, documents[i]);
} }
SearchAlgorithm algo{}; SearchAlgorithm algo{};
@ -661,7 +684,7 @@ TEST_P(KnnTest, AddRemove) {
// delete leftmost 5 // delete leftmost 5
for (size_t i = 0; i < 5; i++) for (size_t i = 0; i < 5; i++)
indices.Remove(i, &documents[i]); indices.Remove(i, documents[i]);
// search leftmost 5 again // search leftmost 5 again
{ {
@ -672,7 +695,7 @@ TEST_P(KnnTest, AddRemove) {
// add removed elements // add removed elements
for (size_t i = 0; i < 5; i++) for (size_t i = 0; i < 5; i++)
indices.Add(i, &documents[i]); indices.Add(i, documents[i]);
// repeat first search // repeat first search
{ {
@ -693,7 +716,7 @@ TEST_P(KnnTest, AutoResize) {
for (size_t i = 0; i < 100; i++) { for (size_t i = 0; i < 100; i++) {
MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}}; MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}};
indices.Add(i, &doc); indices.Add(i, doc);
} }
EXPECT_EQ(indices.GetAllDocs().size(), 100); EXPECT_EQ(indices.GetAllDocs().size(), 100);
@ -720,7 +743,7 @@ static void BM_VectorSearch(benchmark::State& state) {
for (size_t i = 0; i < nvecs; i++) { for (size_t i = 0; i < nvecs; i++) {
auto rv = random_vec(); auto rv = random_vec();
MockedDocument doc{Map{{"pos", ToBytes(rv)}}}; MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
indices.Add(i, &doc); indices.Add(i, doc);
} }
SearchAlgorithm algo{}; SearchAlgorithm algo{};

View file

@ -46,15 +46,23 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
} }
template <typename T> template <typename T>
void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_view field) { bool SimpleValueSortIndex<T>::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto field_value = Get(doc, field);
if (!field_value) {
return false;
}
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
if (id >= values_.size()) if (id >= values_.size())
values_.resize(id + 1); values_.resize(id + 1);
values_[id] = Get(id, doc, field);
values_[id] = field_value.value();
return true;
} }
template <typename T> template <typename T>
void SimpleValueSortIndex<T>::Remove(DocId id, DocumentAccessor* doc, std::string_view field) { void SimpleValueSortIndex<T>::Remove(DocId id, const DocumentAccessor& doc,
std::string_view field) {
DCHECK_LT(id, values_.size()); DCHECK_LT(id, values_.size());
values_[id] = T{}; values_[id] = T{};
} }
@ -66,23 +74,22 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
template struct SimpleValueSortIndex<double>; template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>; template struct SimpleValueSortIndex<PMR_NS::string>;
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { std::optional<double> NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) {
auto str = doc->GetStrings(field); auto numbers_list = doc.GetNumbers(field);
if (str.empty()) if (!numbers_list) {
return 0; return std::nullopt;
}
double v; return !numbers_list->empty() ? numbers_list->front() : 0.0;
if (!absl::SimpleAtod(str.front(), &v))
return 0;
return v;
} }
PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) { std::optional<PMR_NS::string> StringSortIndex::Get(const DocumentAccessor& doc,
auto str = doc->GetStrings(field); std::string_view field) {
if (str.empty()) auto strings_list = doc.GetStrings(field);
return ""; if (!strings_list) {
return std::nullopt;
return PMR_NS::string{str.front(), GetMemRes()}; }
return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()}
: PMR_NS::string{GetMemRes()};
} }
} // namespace dfly::search } // namespace dfly::search

View file

@ -24,11 +24,11 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
SortableValue Lookup(DocId doc) const override; SortableValue Lookup(DocId doc) const override;
std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override; std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override;
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
protected: protected:
virtual T Get(DocId id, DocumentAccessor* doc, std::string_view field) = 0; virtual std::optional<T> Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
PMR_NS::memory_resource* GetMemRes() const; PMR_NS::memory_resource* GetMemRes() const;
@ -39,14 +39,14 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
struct NumericSortIndex : public SimpleValueSortIndex<double> { struct NumericSortIndex : public SimpleValueSortIndex<double> {
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
double Get(DocId id, DocumentAccessor* doc, std::string_view field) override; std::optional<double> Get(const DocumentAccessor& doc, std::string_view field) override;
}; };
// TODO: Map tags to integers for fast sort // TODO: Map tags to integers for fast sort
struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> { struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> {
StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {}; StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
PMR_NS::string Get(DocId id, DocumentAccessor* doc, std::string_view field) override; std::optional<PMR_NS::string> Get(const DocumentAccessor& doc, std::string_view field) override;
}; };
} // namespace dfly::search } // namespace dfly::search

View file

@ -39,18 +39,28 @@ __attribute__((optimize("fast-math"))) float CosineDistance(const float* u, cons
return 0.0f; return 0.0f;
} }
} // namespace OwnedFtVector ConvertToFtVector(string_view value) {
OwnedFtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
// Value cannot be casted directly as it might be not aligned as a float (4 bytes). // Value cannot be casted directly as it might be not aligned as a float (4 bytes).
// Misaligned memory access is UB. // Misaligned memory access is UB.
size_t size = value.size() / sizeof(float); size_t size = value.size() / sizeof(float);
auto out = make_unique<float[]>(size); auto out = make_unique<float[]>(size);
memcpy(out.get(), value.data(), size * sizeof(float)); memcpy(out.get(), value.data(), size * sizeof(float));
return {std::move(out), size}; return OwnedFtVector{std::move(out), size};
}
} // namespace
OwnedFtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
return ConvertToFtVector(value);
}
std::optional<OwnedFtVector> BytesToFtVectorSafe(string_view value) {
if (value.size() % sizeof(float)) {
return std::nullopt;
}
return ConvertToFtVector(value);
} }
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) { float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {

View file

@ -10,6 +10,10 @@ namespace dfly::search {
OwnedFtVector BytesToFtVector(std::string_view value); OwnedFtVector BytesToFtVector(std::string_view value);
// Returns std::nullopt if value can not be converted to the vector
// TODO: Remove unsafe version
std::optional<OwnedFtVector> BytesToFtVectorSafe(std::string_view value);
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim); float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
} // namespace dfly::search } // namespace dfly::search

View file

@ -38,43 +38,44 @@ string_view SdsToSafeSv(sds str) {
return str != nullptr ? string_view{str, sdslen(str)} : ""sv; return str != nullptr ? string_view{str, sdslen(str)} : ""sv;
} }
search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) { using FieldValue = std::optional<search::SortableValue>;
FieldValue ToSortableValue(search::SchemaField::FieldType type, string_view value) {
if (value.empty()) {
return std::nullopt;
}
if (type == search::SchemaField::NUMERIC) { if (type == search::SchemaField::NUMERIC) {
double value_as_double = 0; auto value_as_double = search::ParseNumericField(value);
if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double if (!value_as_double) { // temporary convert to double
LOG(DFATAL) << "Failed to convert " << value << " to double"; LOG(DFATAL) << "Failed to convert " << value << " to double";
return std::nullopt;
} }
return value_as_double; return value_as_double.value();
} }
if (type == search::SchemaField::VECTOR) { if (type == search::SchemaField::VECTOR) {
auto [ptr, size] = search::BytesToFtVector(value); auto opt_vector = search::BytesToFtVectorSafe(value);
if (!opt_vector) {
LOG(DFATAL) << "Failed to convert " << value << " to vector";
return std::nullopt;
}
auto& [ptr, size] = opt_vector.value();
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]"); return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
} }
return string{value}; return string{value};
} }
search::SortableValue JsonToSortableValue(const search::SchemaField::FieldType type, FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, string_view value) {
const JsonType& json) {
if (type == search::SchemaField::NUMERIC) {
return json.as_double();
}
return json.to_string();
}
search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key,
string_view value) {
auto it = schema.fields.find(key); auto it = schema.fields.find(key);
if (it == schema.fields.end()) if (it == schema.fields.end())
return FieldToSortableValue(search::SchemaField::TEXT, value); return ToSortableValue(search::SchemaField::TEXT, value);
return FieldToSortableValue(it->second.type, value); return ToSortableValue(it->second.type, value);
} }
search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key, FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
const JsonType& json) { const JsonType& json) {
auto it = schema.fields.find(key); auto json_as_string = json.to_string();
if (it == schema.fields.end()) return ExtractSortableValue(schema, key, json_as_string);
return JsonToSortableValue(search::SchemaField::TEXT, json);
return JsonToSortableValue(it->second.type, json);
} }
} // namespace } // namespace
@ -83,7 +84,11 @@ SearchDocData BaseAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const { const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{}; SearchDocData out{};
for (const auto& [fident, fname] : fields) { for (const auto& [fident, fname] : fields) {
out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ",")); auto field_value =
ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ","));
if (field_value) {
out[fname] = std::move(field_value).value();
}
} }
return out; return out;
} }
@ -92,14 +97,39 @@ SearchDocData BaseAccessor::SerializeDocument(const search::Schema& schema) cons
return Serialize(schema); return Serialize(schema);
} }
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const { std::optional<BaseAccessor::VectorInfo> BaseAccessor::GetVector(
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data()); std::string_view active_field) const {
return strsv.has_value() ? StringList{*strsv} : StringList{}; auto strings_list = GetStrings(active_field);
if (strings_list) {
return !strings_list->empty() ? search::BytesToFtVectorSafe(strings_list->front())
: VectorInfo{};
}
return std::nullopt;
} }
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const { std::optional<BaseAccessor::NumsList> BaseAccessor::GetNumbers(
auto strlist = GetStrings(active_field); std::string_view active_field) const {
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front()); auto strings_list = GetStrings(active_field);
if (!strings_list) {
return std::nullopt;
}
NumsList nums_list;
nums_list.reserve(strings_list->size());
for (auto str : strings_list.value()) {
auto num = search::ParseNumericField(str);
if (!num) {
return std::nullopt;
}
nums_list.push_back(num.value());
}
return nums_list;
}
std::optional<BaseAccessor::StringList> ListPackAccessor::GetStrings(
string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
return strsv.has_value() ? StringList{*strsv} : StringList{};
} }
SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const { SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
@ -114,27 +144,29 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
string_view v = container_utils::LpGetView(fptr, intbuf_[1].data()); string_view v = container_utils::LpGetView(fptr, intbuf_[1].data());
fptr = lpNext(lp_, fptr); fptr = lpNext(lp_, fptr);
out[k] = ExtractSortableValue(schema, k, v); auto field_value = ExtractSortableValue(schema, k, v);
if (field_value) {
out[k] = std::move(field_value).value();
}
} }
return out; return out;
} }
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const { std::optional<BaseAccessor::StringList> StringMapAccessor::GetStrings(
string_view active_field) const {
auto it = hset_->Find(active_field); auto it = hset_->Find(active_field);
return it != hset_->end() ? StringList{SdsToSafeSv(it->second)} : StringList{}; return it != hset_->end() ? StringList{SdsToSafeSv(it->second)} : StringList{};
} }
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const { SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
SearchDocData out{}; SearchDocData out{};
for (const auto& [kptr, vptr] : *hset_) for (const auto& [kptr, vptr] : *hset_) {
out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr)); auto field_value = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
if (field_value) {
out[SdsToSafeSv(kptr)] = std::move(field_value).value();
}
}
return out; return out;
} }
@ -159,27 +191,54 @@ struct JsonAccessor::JsonPathContainer {
variant<json::Path, jsoncons::jsonpath::jsonpath_expression<JsonType>> val; variant<json::Path, jsoncons::jsonpath::jsonpath_expression<JsonType>> val;
}; };
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const { std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view active_field) const {
auto* path = GetPath(active_field); auto* path = GetPath(active_field);
if (!path) if (!path)
return {}; return search::EmptyAccessResult<StringList>();
auto path_res = path->Evaluate(json_); auto path_res = path->Evaluate(json_);
if (path_res.empty()) if (path_res.empty())
return {}; return search::EmptyAccessResult<StringList>();
if (path_res.size() == 1 && !path_res[0].is_array()) {
if (!path_res[0].is_string())
return std::nullopt;
if (path_res.size() == 1) {
buf_ = path_res[0].as_string(); buf_ = path_res[0].as_string();
return {buf_}; return StringList{buf_};
} }
buf_.clear(); buf_.clear();
// First, grow buffer and compute string sizes // First, grow buffer and compute string sizes
vector<size_t> sizes; vector<size_t> sizes;
for (const auto& element : path_res) {
auto add_json_to_buf = [&](const JsonType& json) {
size_t start = buf_.size(); size_t start = buf_.size();
buf_ += element.as_string(); buf_ += json.as_string();
sizes.push_back(buf_.size() - start); sizes.push_back(buf_.size() - start);
};
if (!path_res[0].is_array()) {
sizes.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
sizes.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
} }
// Reposition start pointers to the most recent allocation of buf // Reposition start pointers to the most recent allocation of buf
@ -194,23 +253,62 @@ BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) cons
return out; return out;
} }
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const { std::optional<BaseAccessor::VectorInfo> JsonAccessor::GetVector(string_view active_field) const {
auto* path = GetPath(active_field); auto* path = GetPath(active_field);
if (!path) if (!path)
return {}; return VectorInfo{};
auto res = path->Evaluate(json_); auto res = path->Evaluate(json_);
if (res.empty()) if (res.empty())
return {nullptr, 0}; return VectorInfo{};
if (!res[0].is_array())
return std::nullopt;
size_t size = res[0].size(); size_t size = res[0].size();
auto ptr = make_unique<float[]>(size); auto ptr = make_unique<float[]>(size);
size_t i = 0; size_t i = 0;
for (const auto& v : res[0].array_range()) for (const auto& v : res[0].array_range()) {
if (!v.is_number()) {
return std::nullopt;
}
ptr[i++] = v.as<float>(); ptr[i++] = v.as<float>();
}
return {std::move(ptr), size}; return BaseAccessor::VectorInfo{std::move(ptr), size};
}
std::optional<BaseAccessor::NumsList> JsonAccessor::GetNumbers(string_view active_field) const {
auto* path = GetPath(active_field);
if (!path)
return search::EmptyAccessResult<NumsList>();
auto path_res = path->Evaluate(json_);
if (path_res.empty())
return search::EmptyAccessResult<NumsList>();
NumsList nums_list;
if (!path_res[0].is_array()) {
nums_list.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
nums_list.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
}
return nums_list;
} }
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const { JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
@ -259,8 +357,12 @@ SearchDocData JsonAccessor::Serialize(
SearchDocData out{}; SearchDocData out{};
for (const auto& [ident, name] : fields) { for (const auto& [ident, name] : fields) {
if (auto* path = GetPath(ident); path) { if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty()) if (auto res = path->Evaluate(json_); !res.empty()) {
out[name] = ExtractSortableValueFromJson(schema, ident, res[0]); auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]);
if (field_value) {
out[name] = std::move(field_value).value();
}
}
} }
} }
return out; return out;

View file

@ -12,6 +12,7 @@
#include "core/json/json_object.h" #include "core/json/json_object.h"
#include "core/search/search.h" #include "core/search/search.h"
#include "core/search/vector_utils.h"
#include "server/common.h" #include "server/common.h"
#include "server/search/doc_index.h" #include "server/search/doc_index.h"
#include "server/table.h" #include "server/table.h"
@ -37,6 +38,10 @@ struct BaseAccessor : public search::DocumentAccessor {
indexed field indexed field
*/ */
virtual SearchDocData SerializeDocument(const search::Schema& schema) const; virtual SearchDocData SerializeDocument(const search::Schema& schema) const;
// Default implementation uses GetStrings
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const;
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const;
}; };
// Accessor for hashes stored with listpack // Accessor for hashes stored with listpack
@ -46,8 +51,7 @@ struct ListPackAccessor : public BaseAccessor {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} { explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
} }
StringList GetStrings(std::string_view field) const override; std::optional<StringList> GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override; SearchDocData Serialize(const search::Schema& schema) const override;
private: private:
@ -60,8 +64,7 @@ struct StringMapAccessor : public BaseAccessor {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} { explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
} }
StringList GetStrings(std::string_view field) const override; std::optional<StringList> GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override; SearchDocData Serialize(const search::Schema& schema) const override;
private: private:
@ -75,8 +78,9 @@ struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(const JsonType* json) : json_{*json} { explicit JsonAccessor(const JsonType* json) : json_{*json} {
} }
StringList GetStrings(std::string_view field) const override; std::optional<StringList> GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override; std::optional<VectorInfo> GetVector(std::string_view field) const override;
std::optional<NumsList> GetNumbers(std::string_view active_field) const override;
// The JsonAccessor works with structured types and not plain strings, so an overload is needed // The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema, SearchDocData Serialize(const search::Schema& schema,

View file

@ -41,7 +41,7 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
return; return;
auto accessor = GetAccessor(op_args.db_cntx, pv); auto accessor = GetAccessor(op_args.db_cntx, pv);
f(key, accessor.get()); f(key, *accessor);
}; };
PrimeTable::Cursor cursor; PrimeTable::Cursor cursor;
@ -146,12 +146,14 @@ ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Add(string_view key) {
return id; return id;
} }
ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Remove(string_view key) { std::optional<ShardDocIndex::DocId> ShardDocIndex::DocKeyIndex::Remove(string_view key) {
DCHECK_GT(ids_.count(key), 0u); auto it = ids_.extract(key);
if (!it) {
return std::nullopt;
}
DocId id = ids_.find(key)->second; const DocId id = it.mapped();
keys_[id] = ""; keys_[id] = "";
ids_.erase(key);
free_ids_.push_back(id); free_ids_.push_back(id);
return id; return id;
@ -184,7 +186,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
key_index_ = DocKeyIndex{}; key_index_ = DocKeyIndex{};
indices_.emplace(base_->schema, base_->options, mr); indices_.emplace(base_->schema, base_->options, mr);
auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); }; auto cb = [this](string_view key, const BaseAccessor& doc) {
DocId id = key_index_.Add(key);
if (!indices_->Add(id, doc)) {
key_index_.Remove(key);
}
};
TraverseAllMatching(*base_, op_args, cb); TraverseAllMatching(*base_, op_args, cb);
VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix; VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix;
@ -195,7 +203,10 @@ void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const Prim
return; return;
auto accessor = GetAccessor(db_cntx, pv); auto accessor = GetAccessor(db_cntx, pv);
indices_->Add(key_index_.Add(key), accessor.get()); DocId id = key_index_.Add(key);
if (!indices_->Add(id, *accessor)) {
key_index_.Remove(key);
}
} }
void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) { void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
@ -203,8 +214,10 @@ void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const P
return; return;
auto accessor = GetAccessor(db_cntx, pv); auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key); auto id = key_index_.Remove(key);
indices_->Remove(id, accessor.get()); if (id) {
indices_->Remove(id.value(), *accessor);
}
} }
bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const { bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {

View file

@ -133,7 +133,7 @@ class ShardDocIndex {
// DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface. // DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface.
struct DocKeyIndex { struct DocKeyIndex {
DocId Add(std::string_view key); DocId Add(std::string_view key);
DocId Remove(std::string_view key); std::optional<DocId> Remove(std::string_view key);
std::string_view Get(DocId id) const; std::string_view Get(DocId id) const;
size_t Size() const; size_t Size() const;

View file

@ -90,6 +90,13 @@ template <typename... Args> auto IsArray(Args... args) {
template <typename... Args> auto IsUnordArray(Args... args) { template <typename... Args> auto IsUnordArray(Args... args) {
return RespArray(UnorderedElementsAre(std::forward<Args>(args)...)); return RespArray(UnorderedElementsAre(std::forward<Args>(args)...));
} }
template <typename Expected, size_t... Is>
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
const Expected& expected, std::index_sequence<Is...>) {
std::initializer_list<int>{
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
0)...};
}
MATCHER_P(IsMapMatcher, expected, "") { MATCHER_P(IsMapMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) { if (arg.type != RespExpr::ARRAY) {
@ -97,73 +104,29 @@ MATCHER_P(IsMapMatcher, expected, "") {
return false; return false;
} }
constexpr size_t expected_size = std::tuple_size<decltype(expected)>::value;
constexpr size_t exprected_pairs_number = expected_size / 2;
auto result = arg.GetVec(); auto result = arg.GetVec();
if (result.size() != expected.size()) { if (result.size() != expected_size) {
*result_listener << "Wrong resp array size: " << result.size(); *result_listener << "Wrong resp array size: " << result.size();
return false; return false;
} }
using KeyValueArray = std::vector<std::pair<std::string, std::string>>; std::vector<std::pair<std::string, RespExpr>> received_pairs;
KeyValueArray received_pairs;
for (size_t i = 0; i < result.size(); i += 2) { for (size_t i = 0; i < result.size(); i += 2) {
received_pairs.emplace_back(result[i].GetString(), result[i + 1].GetString()); received_pairs.emplace_back(result[i].GetString(), result[i + 1]);
} }
KeyValueArray expected_pairs; std::vector<Matcher<std::pair<std::string, RespExpr>>> kv_matchers;
for (size_t i = 0; i < expected.size(); i += 2) { BuildKvMatchers(kv_matchers, expected, std::make_index_sequence<exprected_pairs_number>{});
expected_pairs.emplace_back(expected[i], expected[i + 1]);
}
// Custom unordered comparison return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs,
std::sort(received_pairs.begin(), received_pairs.end());
std::sort(expected_pairs.begin(), expected_pairs.end());
return received_pairs == expected_pairs;
}
template <typename... Matchers> auto IsMap(Matchers... matchers) {
return IsMapMatcher(std::vector<std::string>{std::forward<Matchers>(matchers)...});
}
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) {
*result_listener << "Wrong response type: " << arg.type;
return false;
}
auto result = arg.GetVec();
size_t expected_size = std::tuple_size<decltype(expected)>::value;
if (result.size() != expected_size + 1) {
*result_listener << "Wrong resp array size: " << result.size();
return false;
}
if (result[0].GetInt() != expected_size) {
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
return false;
}
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
// Create a vector of matchers from the tuple
std::vector<Matcher<RespExpr>> matchers;
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
result_listener); result_listener);
} }
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) { template <typename... Args> auto IsMap(Args... args) {
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...)); return IsMapMatcher(std::make_tuple(args...));
}
template <typename Expected, size_t... Is>
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
const Expected& expected, std::index_sequence<Is...>) {
std::initializer_list<int>{
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
0)...};
} }
MATCHER_P(IsMapWithSizeMatcher, expected, "") { MATCHER_P(IsMapWithSizeMatcher, expected, "") {
@ -201,6 +164,38 @@ template <typename... Args> auto IsMapWithSize(Args... args) {
return IsMapWithSizeMatcher(std::make_tuple(args...)); return IsMapWithSizeMatcher(std::make_tuple(args...));
} }
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) {
*result_listener << "Wrong response type: " << arg.type;
return false;
}
auto result = arg.GetVec();
size_t expected_size = std::tuple_size<decltype(expected)>::value;
if (result.size() != expected_size + 1) {
*result_listener << "Wrong resp array size: " << result.size();
return false;
}
if (result[0].GetInt() != expected_size) {
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
return false;
}
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
// Create a vector of matchers from the tuple
std::vector<Matcher<RespExpr>> matchers;
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
result_listener);
}
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) {
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...));
}
TEST_F(SearchFamilyTest, CreateDropListIndex) { TEST_F(SearchFamilyTest, CreateDropListIndex) {
EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK"); EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK");
EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK"); EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK");
@ -649,7 +644,7 @@ TEST_F(SearchFamilyTest, TestReturn) {
// Check non-existing field // Check non-existing field
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"}); resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"});
EXPECT_THAT(resp, MatchEntry("k0", "nothere", "")); EXPECT_THAT(resp, MatchEntry("k0"));
// Checl implcit __vector_score is provided // Checl implcit __vector_score is provided
float score = 20; float score = 20;
@ -1194,8 +1189,8 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
IsMap("foo_total", "10", "word", "item1"))); IsMap("foo_total", "10", "word", "item1")));
// Test JSON // Test JSON
Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":"10","text":"first key"})"}); Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":10,"text":"first key"})"});
Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":"20","text":"second key"})"}); Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":20,"text":"second key"})"});
resp = Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.word", "AS", "word", "TAG", "$.foo", resp = Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.word", "AS", "word", "TAG", "$.foo",
"AS", "foo", "NUMERIC", "$.text", "AS", "text", "TEXT"}); "AS", "foo", "NUMERIC", "$.text", "AS", "text", "TEXT"});
@ -1214,4 +1209,220 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
} }
#endif #endif
TEST_F(SearchFamilyTest, WrongFieldTypeJson) {
// Test simple
Run({"JSON.SET", "j1", ".", R"({"value":"one"})"});
Run({"JSON.SET", "j2", ".", R"({"value":1})"});
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC",
"SORTABLE"}),
"OK");
auto resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j2"));
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "$.value"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("$.value", "1")));
// Test with two fields. One is loading
Run({"JSON.SET", "j3", ".", R"({"value":"two","another_value":1})"});
Run({"JSON.SET", "j4", ".", R"({"value":2,"another_value":2})"});
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC"}),
"OK");
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "$.another_value"});
EXPECT_THAT(
resp, IsMapWithSize("j2", IsMap("$", R"({"value":1})"), "j4",
IsMap("$", R"({"another_value":2,"value":2})", "$.another_value", "2")));
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "$.value", "$.another_value", "GROUPBY", "2",
"$.value", "$.another_value", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(
IsMap("$.value", "1", "$.another_value", ArgType(RespExpr::NIL), "count", "1"),
IsMap("$.value", "2", "$.another_value", "2", "count", "1")));
// Test multiple field values
Run({"JSON.SET", "j5", ".", R"({"arr":[{"id":1},{"id":"two"}]})"});
Run({"JSON.SET", "j6", ".", R"({"arr":[{"id":1},{"id":2}]})"});
Run({"JSON.SET", "j7", ".", R"({"arr":[]})"});
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
resp = Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC",
"SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
}
TEST_F(SearchFamilyTest, WrongFieldTypeHash) {
// Test simple
Run({"HSET", "h1", "value", "one"});
Run({"HSET", "h2", "value", "1"});
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "value", "NUMERIC", "SORTABLE"}), "OK");
auto resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1")));
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "@value"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("value", "1")));
// Test with two fields. One is loading
Run({"HSET", "h3", "value", "two", "another_value", "1"});
Run({"HSET", "h4", "value", "2", "another_value", "2"});
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "value", "NUMERIC"}), "OK");
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "@another_value"});
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1"), "h4",
IsMap("value", "2", "another_value", "2")));
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "@value", "@another_value", "GROUPBY", "2",
"@value", "@another_value", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("value", "1", "another_value", ArgType(RespExpr::NIL), "count", "1"),
IsMap("value", "2", "another_value", "2", "count", "1")));
}
TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) {
Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"});
Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"});
Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"});
Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"});
Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"});
Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"});
auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run(
{"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG"});
EXPECT_EQ(resp, "OK");
resp =
Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i5", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT"});
EXPECT_EQ(resp, "OK");
resp =
Run({"FT.CREATE", "i6", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i7", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "VECTOR", "FLAT",
"6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i2", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
resp = Run({"FT.SEARCH", "i5", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
resp = Run({"FT.SEARCH", "i6", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
resp = Run({"FT.SEARCH", "i7", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j5"));
}
TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) {
Run({"HSET", "j1", "data", "1", "name", "doc_with_int"});
Run({"HSET", "j2", "data", "1", "name", "doc_with_int_as_string"});
Run({"HSET", "j3", "data", "string", "name", "doc_with_string"});
Run({"HSET", "j4", "name", "no_data"});
Run({"HSET", "j5", "data", "5,4,3", "name", "doc_with_fake_vector"});
Run({"HSET", "j6", "data", "[5,4,3]", "name", "doc_with_fake_vector_as_string"});
// Vector [1, 2, 3]
std::string vector = std::string("\x3f\x80\x00\x00\x40\x00\x00\x00\x40\x40\x00\x00", 12);
Run({"HSET", "j7", "data", vector, "name", "doc_with_vector [1, 2, 3]"});
auto resp = Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "data", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "data", "NUMERIC", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i3", "ON", "HASH", "SCHEMA", "data", "TAG"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i4", "ON", "HASH", "SCHEMA", "data", "TAG", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i5", "ON", "HASH", "SCHEMA", "data", "TEXT"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i6", "ON", "HASH", "SCHEMA", "data", "TEXT", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i7", "ON", "HASH", "SCHEMA", "data", "VECTOR", "FLAT", "6", "TYPE",
"FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
resp = Run({"FT.SEARCH", "i2", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i5", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
resp = Run({"FT.SEARCH", "i6", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
resp = Run({"FT.SEARCH", "i7", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j7"));
}
TEST_F(SearchFamilyTest, WrongVectorFieldType) {
Run({"JSON.SET", "j1", ".",
R"({"vector_field": [0.1, 0.2, 0.3], "name": "doc_with_correct_dim"})"});
Run({"JSON.SET", "j2", ".", R"({"vector_field": [0.1, 0.2], "name": "doc_with_small_dim"})"});
Run({"JSON.SET", "j3", ".",
R"({"vector_field": [0.1, 0.2, 0.3, 0.4], "name": "doc_with_large_dim"})"});
Run({"JSON.SET", "j4", ".", R"({"vector_field": [1, 2, 3], "name": "doc_with_int_values"})"});
Run({"JSON.SET", "j5", ".",
R"({"vector_field":"not_vector", "name":"doc_with_incorrect_field_type"})"});
Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"});
Run({"JSON.SET", "j7", ".",
R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"});
auto resp =
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field",
"VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "index", "*"});
EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4"));
}
} // namespace dfly } // namespace dfly