fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands (#4070)

* fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands

fixes  #3986
---------

Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
Stepan Bagritsevich 2024-11-10 15:56:25 +01:00 committed by GitHub
parent f745f3133d
commit 503bb4ed33
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 686 additions and 219 deletions

View file

@ -4,6 +4,8 @@
#include "core/search/base.h"
#include <absl/strings/numbers.h>
namespace dfly::search {
std::string_view QueryParams::operator[](std::string_view name) const {
@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const {
return std::string_view{ptr.get(), std::strlen(ptr.get())};
}
std::optional<double> ParseNumericField(std::string_view value) {
double value_as_double;
if (absl::SimpleAtod(value, &value_as_double))
return value_as_double;
return std::nullopt;
}
} // namespace dfly::search

View file

@ -68,11 +68,18 @@ using SortableValue = std::variant<std::monostate, double, std::string>;
struct DocumentAccessor {
using VectorInfo = search::OwnedFtVector;
using StringList = absl::InlinedVector<std::string_view, 1>;
using NumsList = absl::InlinedVector<double, 1>;
virtual ~DocumentAccessor() = default;
virtual StringList GetStrings(std::string_view active_field) const = 0;
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
/* Returns nullopt if the specified field is not a list of strings */
virtual std::optional<StringList> GetStrings(std::string_view active_field) const = 0;
/* Returns nullopt if the specified field is not a vector */
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const = 0;
/* Return nullopt if the specified field is not a list of doubles */
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
};
// Base class for type-specific indices.
@ -81,8 +88,10 @@ struct DocumentAccessor {
// query functions. All results for all index types should be sorted.
struct BaseIndex {
virtual ~BaseIndex() = default;
virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
// Returns true if the document was added / indexed
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
};
// Base class for type-specific sorting indices.
@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex {
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
};
/* Used for converting field values to double. Returns std::nullopt if the conversion fails */
std::optional<double> ParseNumericField(std::string_view value);
/* Temporary method to create an empty std::optional<InlinedVector> in DocumentAccessor::GetString
and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation
details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To
suppress this false warning, we temporarily disable it around this block of code using GCC
diagnostic directives. */
template <typename InlinedVector> std::optional<InlinedVector> EmptyAccessResult() {
// GCC 13.1 throws spurious warnings around this code.
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
return InlinedVector{};
#pragma GCC diagnostic pop
}
} // namespace dfly::search

View file

@ -71,19 +71,22 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
}
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.emplace(num, id);
bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto numbers = doc.GetNumbers(field);
if (!numbers) {
return false;
}
for (auto num : numbers.value()) {
entries_.emplace(num, id);
}
return true;
}
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
for (auto str : doc->GetStrings(field)) {
double num;
if (absl::SimpleAtod(str, &num))
entries_.erase({num, id});
void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto numbers = doc.GetNumbers(field).value();
for (auto num : numbers) {
entries_.erase({num, id});
}
}
@ -139,19 +142,27 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
}
template <typename C>
void BaseStringIndex<C>::Add(DocId id, DocumentAccessor* doc, string_view field) {
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field);
if (!strings_list) {
return false;
}
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
for (string_view str : strings_list.value())
tokens.merge(Tokenize(str));
for (string_view token : tokens)
GetOrCreate(token)->Insert(id);
return true;
}
template <typename C>
void BaseStringIndex<C>::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
auto strings_list = doc.GetStrings(field).value();
absl::flat_hash_set<std::string> tokens;
for (string_view str : doc->GetStrings(field))
for (string_view str : strings_list)
tokens.merge(Tokenize(str));
for (const auto& token : tokens) {
@ -192,6 +203,20 @@ std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
return {dim_, sim_};
}
bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto vector = doc.GetVector(field);
if (!vector)
return false;
auto& [ptr, size] = vector.value();
if (ptr && size != dim_) {
return false;
}
AddVector(id, ptr);
return true;
}
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
PMR_NS::memory_resource* mr)
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
@ -199,19 +224,18 @@ FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
entries_.reserve(params.capacity * params.dim);
}
void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
DCHECK_LE(id * dim_, entries_.size());
if (id * dim_ == entries_.size())
entries_.resize((id + 1) * dim_);
// TODO: Let get vector write to buf itself
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
if (vector) {
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
}
}
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
// noop
}
@ -229,7 +253,7 @@ struct HnswlibAdapter {
100 /* seed*/} {
}
void Add(float* data, DocId id) {
void Add(const float* data, DocId id) {
if (world_.cur_element_count + 1 >= world_.max_elements_)
world_.resizeIndex(world_.cur_element_count * 2);
world_.addPoint(data, id);
@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS
HnswVectorIndex::~HnswVectorIndex() {
}
void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
auto [ptr, size] = doc->GetVector(field);
if (size == dim_)
adapter_->Add(ptr.get(), id);
void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
if (vector) {
adapter_->Add(vector.get(), id);
}
}
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
@ -314,7 +338,7 @@ std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t
return adapter_->Knn(target, k, ef, allowed);
}
void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
adapter_->Remove(id);
}

View file

@ -28,8 +28,8 @@ namespace dfly::search {
struct NumericIndex : public BaseIndex {
explicit NumericIndex(PMR_NS::memory_resource* mr);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
std::vector<DocId> Range(double l, double r) const;
@ -44,8 +44,8 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
// Used by Add & Remove to tokenize text value
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
@ -53,7 +53,7 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
// Pointer is valid as long as index is not mutated. Nullptr if not found
const Container* Matching(std::string_view str) const;
// Iterate over all Machting on prefix.
// Iterate over all Matching on prefix.
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
// Returns all the terms that appear as keys in the reverse index.
@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
struct BaseVectorIndex : public BaseIndex {
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;
protected:
BaseVectorIndex(size_t dim, VectorSimilarity sim);
using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;
size_t dim_;
VectorSimilarity sim_;
};
@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex {
struct FlatVectorIndex : public BaseVectorIndex {
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
const float* Get(DocId doc) const;
protected:
void AddVector(DocId id, const VectorPtr& vector) override;
private:
PMR_NS::vector<float> entries_;
};
@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex {
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
~HnswVectorIndex();
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
const std::vector<DocId>& allowed) const;
protected:
void AddVector(DocId id, const VectorPtr& vector) override;
private:
std::unique_ptr<HnswlibAdapter> adapter_;
};

View file

@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
}
}
void FieldIndices::Add(DocId doc, DocumentAccessor* access) {
for (auto& [field, index] : indices_)
index->Add(doc, access, field);
for (auto& [field, sort_index] : sort_indices_)
sort_index->Add(doc, access, field);
bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) {
bool was_added = true;
std::vector<std::pair<std::string_view, BaseIndex*>> successfully_added_indices;
successfully_added_indices.reserve(indices_.size() + sort_indices_.size());
auto try_add = [&](const auto& indices_container) {
for (auto& [field, index] : indices_container) {
if (index->Add(doc, access, field)) {
successfully_added_indices.emplace_back(field, index.get());
} else {
was_added = false;
break;
}
}
};
try_add(indices_);
if (was_added) {
try_add(sort_indices_);
}
if (!was_added) {
for (auto& [field, index] : successfully_added_indices) {
index->Remove(doc, access, field);
}
return false;
}
all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc);
return true;
}
void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) {
for (auto& [field, index] : indices_)
index->Remove(doc, access, field);
for (auto& [field, sort_index] : sort_indices_)
sort_index->Remove(doc, access, field);
auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc);
CHECK(it != all_ids_.end() && *it == doc);
DCHECK(it != all_ids_.end() && *it == doc);
all_ids_.erase(it);
}

View file

@ -77,8 +77,9 @@ class FieldIndices {
// Create indices based on schema and options. Both must outlive the indices
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);
void Add(DocId doc, DocumentAccessor* access);
void Remove(DocId doc, DocumentAccessor* access);
// Returns true if document was added
bool Add(DocId doc, const DocumentAccessor& access);
void Remove(DocId doc, const DocumentAccessor& access);
BaseIndex* GetIndex(std::string_view field) const;
BaseSortIndex* GetSortIndex(std::string_view field) const;

View file

@ -44,13 +44,36 @@ struct MockedDocument : public DocumentAccessor {
MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
}
StringList GetStrings(string_view field) const override {
std::optional<StringList> GetStrings(string_view field) const override {
auto it = fields_.find(field);
return {it != fields_.end() ? string_view{it->second} : ""};
if (it == fields_.end()) {
return EmptyAccessResult<StringList>();
}
return StringList{string_view{it->second}};
}
VectorInfo GetVector(string_view field) const override {
return BytesToFtVector(GetStrings(field).front());
std::optional<VectorInfo> GetVector(string_view field) const override {
auto strings_list = GetStrings(field);
if (!strings_list)
return std::nullopt;
return !strings_list->empty() ? BytesToFtVectorSafe(strings_list->front()) : VectorInfo{};
}
std::optional<NumsList> GetNumbers(std::string_view field) const override {
auto strings_list = GetStrings(field);
if (!strings_list)
return std::nullopt;
NumsList nums_list;
nums_list.reserve(strings_list->size());
for (auto str : strings_list.value()) {
auto num = ParseNumericField(str);
if (!num) {
return std::nullopt;
}
nums_list.push_back(num.value());
}
return nums_list;
}
string DebugFormat() {
@ -121,7 +144,7 @@ class SearchTest : public ::testing::Test {
shuffle(entries_.begin(), entries_.end(), default_random_engine{});
for (DocId i = 0; i < entries_.size(); i++)
index.Add(i, &entries_[i].first);
index.Add(i, entries_[i].first);
SearchAlgorithm search_algo{};
if (!search_algo.Init(query_, &params_)) {
@ -430,7 +453,7 @@ TEST_F(SearchTest, StopWords) {
"explicitly found!"};
for (size_t i = 0; i < documents.size(); i++) {
MockedDocument doc{{{"title", documents[i]}}};
indices.Add(i, &doc);
indices.Add(i, doc);
}
// words is a stopword
@ -484,7 +507,7 @@ TEST_P(KnnTest, Simple1D) {
for (size_t i = 0; i < 100; i++) {
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
MockedDocument doc{values};
indices.Add(i, &doc);
indices.Add(i, doc);
}
SearchAlgorithm algo{};
@ -540,7 +563,7 @@ TEST_P(KnnTest, Simple2D) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
MockedDocument doc{Map{{"pos", coords}}};
indices.Add(i, &doc);
indices.Add(i, doc);
}
SearchAlgorithm algo{};
@ -602,7 +625,7 @@ TEST_P(KnnTest, Cosine) {
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
MockedDocument doc{Map{{"pos", coords}}};
indices.Add(i, &doc);
indices.Add(i, doc);
}
SearchAlgorithm algo{};
@ -646,7 +669,7 @@ TEST_P(KnnTest, AddRemove) {
vector<MockedDocument> documents(10);
for (size_t i = 0; i < 10; i++) {
documents[i] = Map{{"pos", ToBytes({float(i)})}};
indices.Add(i, &documents[i]);
indices.Add(i, documents[i]);
}
SearchAlgorithm algo{};
@ -661,7 +684,7 @@ TEST_P(KnnTest, AddRemove) {
// delete leftmost 5
for (size_t i = 0; i < 5; i++)
indices.Remove(i, &documents[i]);
indices.Remove(i, documents[i]);
// search leftmost 5 again
{
@ -672,7 +695,7 @@ TEST_P(KnnTest, AddRemove) {
// add removed elements
for (size_t i = 0; i < 5; i++)
indices.Add(i, &documents[i]);
indices.Add(i, documents[i]);
// repeat first search
{
@ -693,7 +716,7 @@ TEST_P(KnnTest, AutoResize) {
for (size_t i = 0; i < 100; i++) {
MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}};
indices.Add(i, &doc);
indices.Add(i, doc);
}
EXPECT_EQ(indices.GetAllDocs().size(), 100);
@ -720,7 +743,7 @@ static void BM_VectorSearch(benchmark::State& state) {
for (size_t i = 0; i < nvecs; i++) {
auto rv = random_vec();
MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
indices.Add(i, &doc);
indices.Add(i, doc);
}
SearchAlgorithm algo{};

View file

@ -46,15 +46,23 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
}
template <typename T>
void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_view field) {
bool SimpleValueSortIndex<T>::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
auto field_value = Get(doc, field);
if (!field_value) {
return false;
}
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
if (id >= values_.size())
values_.resize(id + 1);
values_[id] = Get(id, doc, field);
values_[id] = field_value.value();
return true;
}
template <typename T>
void SimpleValueSortIndex<T>::Remove(DocId id, DocumentAccessor* doc, std::string_view field) {
void SimpleValueSortIndex<T>::Remove(DocId id, const DocumentAccessor& doc,
std::string_view field) {
DCHECK_LT(id, values_.size());
values_[id] = T{};
}
@ -66,23 +74,22 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
template struct SimpleValueSortIndex<double>;
template struct SimpleValueSortIndex<PMR_NS::string>;
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
auto str = doc->GetStrings(field);
if (str.empty())
return 0;
double v;
if (!absl::SimpleAtod(str.front(), &v))
return 0;
return v;
std::optional<double> NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) {
auto numbers_list = doc.GetNumbers(field);
if (!numbers_list) {
return std::nullopt;
}
return !numbers_list->empty() ? numbers_list->front() : 0.0;
}
PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
auto str = doc->GetStrings(field);
if (str.empty())
return "";
return PMR_NS::string{str.front(), GetMemRes()};
std::optional<PMR_NS::string> StringSortIndex::Get(const DocumentAccessor& doc,
std::string_view field) {
auto strings_list = doc.GetStrings(field);
if (!strings_list) {
return std::nullopt;
}
return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()}
: PMR_NS::string{GetMemRes()};
}
} // namespace dfly::search

View file

@ -24,11 +24,11 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
SortableValue Lookup(DocId doc) const override;
std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override;
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
protected:
virtual T Get(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
virtual std::optional<T> Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
PMR_NS::memory_resource* GetMemRes() const;
@ -39,14 +39,14 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
struct NumericSortIndex : public SimpleValueSortIndex<double> {
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
double Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
std::optional<double> Get(const DocumentAccessor& doc, std::string_view field) override;
};
// TODO: Map tags to integers for fast sort
struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> {
StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
PMR_NS::string Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
std::optional<PMR_NS::string> Get(const DocumentAccessor& doc, std::string_view field) override;
};
} // namespace dfly::search

View file

@ -39,18 +39,28 @@ __attribute__((optimize("fast-math"))) float CosineDistance(const float* u, cons
return 0.0f;
}
} // namespace
OwnedFtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
OwnedFtVector ConvertToFtVector(string_view value) {
// Value cannot be casted directly as it might be not aligned as a float (4 bytes).
// Misaligned memory access is UB.
size_t size = value.size() / sizeof(float);
auto out = make_unique<float[]>(size);
memcpy(out.get(), value.data(), size * sizeof(float));
return {std::move(out), size};
return OwnedFtVector{std::move(out), size};
}
} // namespace
OwnedFtVector BytesToFtVector(string_view value) {
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
return ConvertToFtVector(value);
}
std::optional<OwnedFtVector> BytesToFtVectorSafe(string_view value) {
if (value.size() % sizeof(float)) {
return std::nullopt;
}
return ConvertToFtVector(value);
}
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {

View file

@ -10,6 +10,10 @@ namespace dfly::search {
OwnedFtVector BytesToFtVector(std::string_view value);
// Returns std::nullopt if value can not be converted to the vector
// TODO: Remove unsafe version
std::optional<OwnedFtVector> BytesToFtVectorSafe(std::string_view value);
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
} // namespace dfly::search

View file

@ -38,43 +38,44 @@ string_view SdsToSafeSv(sds str) {
return str != nullptr ? string_view{str, sdslen(str)} : ""sv;
}
search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) {
using FieldValue = std::optional<search::SortableValue>;
FieldValue ToSortableValue(search::SchemaField::FieldType type, string_view value) {
if (value.empty()) {
return std::nullopt;
}
if (type == search::SchemaField::NUMERIC) {
double value_as_double = 0;
if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double
auto value_as_double = search::ParseNumericField(value);
if (!value_as_double) { // temporary convert to double
LOG(DFATAL) << "Failed to convert " << value << " to double";
return std::nullopt;
}
return value_as_double;
return value_as_double.value();
}
if (type == search::SchemaField::VECTOR) {
auto [ptr, size] = search::BytesToFtVector(value);
auto opt_vector = search::BytesToFtVectorSafe(value);
if (!opt_vector) {
LOG(DFATAL) << "Failed to convert " << value << " to vector";
return std::nullopt;
}
auto& [ptr, size] = opt_vector.value();
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
}
return string{value};
}
search::SortableValue JsonToSortableValue(const search::SchemaField::FieldType type,
const JsonType& json) {
if (type == search::SchemaField::NUMERIC) {
return json.as_double();
}
return json.to_string();
}
search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key,
string_view value) {
FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, string_view value) {
auto it = schema.fields.find(key);
if (it == schema.fields.end())
return FieldToSortableValue(search::SchemaField::TEXT, value);
return FieldToSortableValue(it->second.type, value);
return ToSortableValue(search::SchemaField::TEXT, value);
return ToSortableValue(it->second.type, value);
}
search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
const JsonType& json) {
auto it = schema.fields.find(key);
if (it == schema.fields.end())
return JsonToSortableValue(search::SchemaField::TEXT, json);
return JsonToSortableValue(it->second.type, json);
FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
const JsonType& json) {
auto json_as_string = json.to_string();
return ExtractSortableValue(schema, key, json_as_string);
}
} // namespace
@ -83,7 +84,11 @@ SearchDocData BaseAccessor::Serialize(
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
SearchDocData out{};
for (const auto& [fident, fname] : fields) {
out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ","));
auto field_value =
ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ","));
if (field_value) {
out[fname] = std::move(field_value).value();
}
}
return out;
}
@ -92,14 +97,39 @@ SearchDocData BaseAccessor::SerializeDocument(const search::Schema& schema) cons
return Serialize(schema);
}
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
return strsv.has_value() ? StringList{*strsv} : StringList{};
std::optional<BaseAccessor::VectorInfo> BaseAccessor::GetVector(
std::string_view active_field) const {
auto strings_list = GetStrings(active_field);
if (strings_list) {
return !strings_list->empty() ? search::BytesToFtVectorSafe(strings_list->front())
: VectorInfo{};
}
return std::nullopt;
}
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
std::optional<BaseAccessor::NumsList> BaseAccessor::GetNumbers(
std::string_view active_field) const {
auto strings_list = GetStrings(active_field);
if (!strings_list) {
return std::nullopt;
}
NumsList nums_list;
nums_list.reserve(strings_list->size());
for (auto str : strings_list.value()) {
auto num = search::ParseNumericField(str);
if (!num) {
return std::nullopt;
}
nums_list.push_back(num.value());
}
return nums_list;
}
std::optional<BaseAccessor::StringList> ListPackAccessor::GetStrings(
string_view active_field) const {
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
return strsv.has_value() ? StringList{*strsv} : StringList{};
}
SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
@ -114,27 +144,29 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
string_view v = container_utils::LpGetView(fptr, intbuf_[1].data());
fptr = lpNext(lp_, fptr);
out[k] = ExtractSortableValue(schema, k, v);
auto field_value = ExtractSortableValue(schema, k, v);
if (field_value) {
out[k] = std::move(field_value).value();
}
}
return out;
}
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const {
std::optional<BaseAccessor::StringList> StringMapAccessor::GetStrings(
string_view active_field) const {
auto it = hset_->Find(active_field);
return it != hset_->end() ? StringList{SdsToSafeSv(it->second)} : StringList{};
}
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
auto strlist = GetStrings(active_field);
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
}
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
SearchDocData out{};
for (const auto& [kptr, vptr] : *hset_)
out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
for (const auto& [kptr, vptr] : *hset_) {
auto field_value = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
if (field_value) {
out[SdsToSafeSv(kptr)] = std::move(field_value).value();
}
}
return out;
}
@ -159,27 +191,54 @@ struct JsonAccessor::JsonPathContainer {
variant<json::Path, jsoncons::jsonpath::jsonpath_expression<JsonType>> val;
};
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const {
std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view active_field) const {
auto* path = GetPath(active_field);
if (!path)
return {};
return search::EmptyAccessResult<StringList>();
auto path_res = path->Evaluate(json_);
if (path_res.empty())
return {};
return search::EmptyAccessResult<StringList>();
if (path_res.size() == 1 && !path_res[0].is_array()) {
if (!path_res[0].is_string())
return std::nullopt;
if (path_res.size() == 1) {
buf_ = path_res[0].as_string();
return {buf_};
return StringList{buf_};
}
buf_.clear();
// First, grow buffer and compute string sizes
vector<size_t> sizes;
for (const auto& element : path_res) {
auto add_json_to_buf = [&](const JsonType& json) {
size_t start = buf_.size();
buf_ += element.as_string();
buf_ += json.as_string();
sizes.push_back(buf_.size() - start);
};
if (!path_res[0].is_array()) {
sizes.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
sizes.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_string())
return std::nullopt;
add_json_to_buf(element);
}
}
// Reposition start pointers to the most recent allocation of buf
@ -194,23 +253,62 @@ BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) cons
return out;
}
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
std::optional<BaseAccessor::VectorInfo> JsonAccessor::GetVector(string_view active_field) const {
auto* path = GetPath(active_field);
if (!path)
return {};
return VectorInfo{};
auto res = path->Evaluate(json_);
if (res.empty())
return {nullptr, 0};
return VectorInfo{};
if (!res[0].is_array())
return std::nullopt;
size_t size = res[0].size();
auto ptr = make_unique<float[]>(size);
size_t i = 0;
for (const auto& v : res[0].array_range())
for (const auto& v : res[0].array_range()) {
if (!v.is_number()) {
return std::nullopt;
}
ptr[i++] = v.as<float>();
}
return {std::move(ptr), size};
return BaseAccessor::VectorInfo{std::move(ptr), size};
}
std::optional<BaseAccessor::NumsList> JsonAccessor::GetNumbers(string_view active_field) const {
auto* path = GetPath(active_field);
if (!path)
return search::EmptyAccessResult<NumsList>();
auto path_res = path->Evaluate(json_);
if (path_res.empty())
return search::EmptyAccessResult<NumsList>();
NumsList nums_list;
if (!path_res[0].is_array()) {
nums_list.reserve(path_res.size());
for (const auto& element : path_res) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
} else {
if (path_res.size() > 1) {
return std::nullopt;
}
nums_list.reserve(path_res[0].size());
for (const auto& element : path_res[0].array_range()) {
if (!element.is_number())
return std::nullopt;
nums_list.push_back(element.as<double>());
}
}
return nums_list;
}
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
@ -259,8 +357,12 @@ SearchDocData JsonAccessor::Serialize(
SearchDocData out{};
for (const auto& [ident, name] : fields) {
if (auto* path = GetPath(ident); path) {
if (auto res = path->Evaluate(json_); !res.empty())
out[name] = ExtractSortableValueFromJson(schema, ident, res[0]);
if (auto res = path->Evaluate(json_); !res.empty()) {
auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]);
if (field_value) {
out[name] = std::move(field_value).value();
}
}
}
}
return out;

View file

@ -12,6 +12,7 @@
#include "core/json/json_object.h"
#include "core/search/search.h"
#include "core/search/vector_utils.h"
#include "server/common.h"
#include "server/search/doc_index.h"
#include "server/table.h"
@ -37,6 +38,10 @@ struct BaseAccessor : public search::DocumentAccessor {
indexed field
*/
virtual SearchDocData SerializeDocument(const search::Schema& schema) const;
// Default implementation uses GetStrings
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const;
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const;
};
// Accessor for hashes stored with listpack
@ -46,8 +51,7 @@ struct ListPackAccessor : public BaseAccessor {
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
}
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
std::optional<StringList> GetStrings(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
private:
@ -60,8 +64,7 @@ struct StringMapAccessor : public BaseAccessor {
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
}
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
std::optional<StringList> GetStrings(std::string_view field) const override;
SearchDocData Serialize(const search::Schema& schema) const override;
private:
@ -75,8 +78,9 @@ struct JsonAccessor : public BaseAccessor {
explicit JsonAccessor(const JsonType* json) : json_{*json} {
}
StringList GetStrings(std::string_view field) const override;
VectorInfo GetVector(std::string_view field) const override;
std::optional<StringList> GetStrings(std::string_view field) const override;
std::optional<VectorInfo> GetVector(std::string_view field) const override;
std::optional<NumsList> GetNumbers(std::string_view active_field) const override;
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
SearchDocData Serialize(const search::Schema& schema,

View file

@ -41,7 +41,7 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
return;
auto accessor = GetAccessor(op_args.db_cntx, pv);
f(key, accessor.get());
f(key, *accessor);
};
PrimeTable::Cursor cursor;
@ -146,12 +146,14 @@ ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Add(string_view key) {
return id;
}
ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Remove(string_view key) {
DCHECK_GT(ids_.count(key), 0u);
std::optional<ShardDocIndex::DocId> ShardDocIndex::DocKeyIndex::Remove(string_view key) {
auto it = ids_.extract(key);
if (!it) {
return std::nullopt;
}
DocId id = ids_.find(key)->second;
const DocId id = it.mapped();
keys_[id] = "";
ids_.erase(key);
free_ids_.push_back(id);
return id;
@ -184,7 +186,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
key_index_ = DocKeyIndex{};
indices_.emplace(base_->schema, base_->options, mr);
auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); };
auto cb = [this](string_view key, const BaseAccessor& doc) {
DocId id = key_index_.Add(key);
if (!indices_->Add(id, doc)) {
key_index_.Remove(key);
}
};
TraverseAllMatching(*base_, op_args, cb);
VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix;
@ -195,7 +203,10 @@ void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const Prim
return;
auto accessor = GetAccessor(db_cntx, pv);
indices_->Add(key_index_.Add(key), accessor.get());
DocId id = key_index_.Add(key);
if (!indices_->Add(id, *accessor)) {
key_index_.Remove(key);
}
}
void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
@ -203,8 +214,10 @@ void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const P
return;
auto accessor = GetAccessor(db_cntx, pv);
DocId id = key_index_.Remove(key);
indices_->Remove(id, accessor.get());
auto id = key_index_.Remove(key);
if (id) {
indices_->Remove(id.value(), *accessor);
}
}
bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {

View file

@ -133,7 +133,7 @@ class ShardDocIndex {
// DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface.
struct DocKeyIndex {
DocId Add(std::string_view key);
DocId Remove(std::string_view key);
std::optional<DocId> Remove(std::string_view key);
std::string_view Get(DocId id) const;
size_t Size() const;

View file

@ -90,6 +90,13 @@ template <typename... Args> auto IsArray(Args... args) {
template <typename... Args> auto IsUnordArray(Args... args) {
return RespArray(UnorderedElementsAre(std::forward<Args>(args)...));
}
template <typename Expected, size_t... Is>
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
const Expected& expected, std::index_sequence<Is...>) {
std::initializer_list<int>{
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
0)...};
}
MATCHER_P(IsMapMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) {
@ -97,73 +104,29 @@ MATCHER_P(IsMapMatcher, expected, "") {
return false;
}
constexpr size_t expected_size = std::tuple_size<decltype(expected)>::value;
constexpr size_t exprected_pairs_number = expected_size / 2;
auto result = arg.GetVec();
if (result.size() != expected.size()) {
if (result.size() != expected_size) {
*result_listener << "Wrong resp array size: " << result.size();
return false;
}
using KeyValueArray = std::vector<std::pair<std::string, std::string>>;
KeyValueArray received_pairs;
std::vector<std::pair<std::string, RespExpr>> received_pairs;
for (size_t i = 0; i < result.size(); i += 2) {
received_pairs.emplace_back(result[i].GetString(), result[i + 1].GetString());
received_pairs.emplace_back(result[i].GetString(), result[i + 1]);
}
KeyValueArray expected_pairs;
for (size_t i = 0; i < expected.size(); i += 2) {
expected_pairs.emplace_back(expected[i], expected[i + 1]);
}
std::vector<Matcher<std::pair<std::string, RespExpr>>> kv_matchers;
BuildKvMatchers(kv_matchers, expected, std::make_index_sequence<exprected_pairs_number>{});
// Custom unordered comparison
std::sort(received_pairs.begin(), received_pairs.end());
std::sort(expected_pairs.begin(), expected_pairs.end());
return received_pairs == expected_pairs;
}
template <typename... Matchers> auto IsMap(Matchers... matchers) {
return IsMapMatcher(std::vector<std::string>{std::forward<Matchers>(matchers)...});
}
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) {
*result_listener << "Wrong response type: " << arg.type;
return false;
}
auto result = arg.GetVec();
size_t expected_size = std::tuple_size<decltype(expected)>::value;
if (result.size() != expected_size + 1) {
*result_listener << "Wrong resp array size: " << result.size();
return false;
}
if (result[0].GetInt() != expected_size) {
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
return false;
}
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
// Create a vector of matchers from the tuple
std::vector<Matcher<RespExpr>> matchers;
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs,
result_listener);
}
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) {
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...));
}
template <typename Expected, size_t... Is>
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
const Expected& expected, std::index_sequence<Is...>) {
std::initializer_list<int>{
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
0)...};
template <typename... Args> auto IsMap(Args... args) {
return IsMapMatcher(std::make_tuple(args...));
}
MATCHER_P(IsMapWithSizeMatcher, expected, "") {
@ -201,6 +164,38 @@ template <typename... Args> auto IsMapWithSize(Args... args) {
return IsMapWithSizeMatcher(std::make_tuple(args...));
}
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
if (arg.type != RespExpr::ARRAY) {
*result_listener << "Wrong response type: " << arg.type;
return false;
}
auto result = arg.GetVec();
size_t expected_size = std::tuple_size<decltype(expected)>::value;
if (result.size() != expected_size + 1) {
*result_listener << "Wrong resp array size: " << result.size();
return false;
}
if (result[0].GetInt() != expected_size) {
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
return false;
}
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
// Create a vector of matchers from the tuple
std::vector<Matcher<RespExpr>> matchers;
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
result_listener);
}
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) {
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...));
}
TEST_F(SearchFamilyTest, CreateDropListIndex) {
EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK");
EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK");
@ -649,7 +644,7 @@ TEST_F(SearchFamilyTest, TestReturn) {
// Check non-existing field
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"});
EXPECT_THAT(resp, MatchEntry("k0", "nothere", ""));
EXPECT_THAT(resp, MatchEntry("k0"));
// Checl implcit __vector_score is provided
float score = 20;
@ -1194,8 +1189,8 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
IsMap("foo_total", "10", "word", "item1")));
// Test JSON
Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":"10","text":"first key"})"});
Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":"20","text":"second key"})"});
Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":10,"text":"first key"})"});
Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":20,"text":"second key"})"});
resp = Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.word", "AS", "word", "TAG", "$.foo",
"AS", "foo", "NUMERIC", "$.text", "AS", "text", "TEXT"});
@ -1214,4 +1209,220 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
}
#endif
TEST_F(SearchFamilyTest, WrongFieldTypeJson) {
// Test simple
Run({"JSON.SET", "j1", ".", R"({"value":"one"})"});
Run({"JSON.SET", "j2", ".", R"({"value":1})"});
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC",
"SORTABLE"}),
"OK");
auto resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j2"));
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "$.value"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("$.value", "1")));
// Test with two fields. One is loading
Run({"JSON.SET", "j3", ".", R"({"value":"two","another_value":1})"});
Run({"JSON.SET", "j4", ".", R"({"value":2,"another_value":2})"});
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC"}),
"OK");
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "$.another_value"});
EXPECT_THAT(
resp, IsMapWithSize("j2", IsMap("$", R"({"value":1})"), "j4",
IsMap("$", R"({"another_value":2,"value":2})", "$.another_value", "2")));
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "$.value", "$.another_value", "GROUPBY", "2",
"$.value", "$.another_value", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(
IsMap("$.value", "1", "$.another_value", ArgType(RespExpr::NIL), "count", "1"),
IsMap("$.value", "2", "$.another_value", "2", "count", "1")));
// Test multiple field values
Run({"JSON.SET", "j5", ".", R"({"arr":[{"id":1},{"id":"two"}]})"});
Run({"JSON.SET", "j6", ".", R"({"arr":[{"id":1},{"id":2}]})"});
Run({"JSON.SET", "j7", ".", R"({"arr":[]})"});
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
resp = Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC",
"SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
}
TEST_F(SearchFamilyTest, WrongFieldTypeHash) {
// Test simple
Run({"HSET", "h1", "value", "one"});
Run({"HSET", "h2", "value", "1"});
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "value", "NUMERIC", "SORTABLE"}), "OK");
auto resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1")));
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "@value"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("value", "1")));
// Test with two fields. One is loading
Run({"HSET", "h3", "value", "two", "another_value", "1"});
Run({"HSET", "h4", "value", "2", "another_value", "2"});
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "value", "NUMERIC"}), "OK");
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "@another_value"});
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1"), "h4",
IsMap("value", "2", "another_value", "2")));
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "@value", "@another_value", "GROUPBY", "2",
"@value", "@another_value", "REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("value", "1", "another_value", ArgType(RespExpr::NIL), "count", "1"),
IsMap("value", "2", "another_value", "2", "count", "1")));
}
TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) {
Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"});
Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"});
Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"});
Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"});
Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"});
Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"});
auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run(
{"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG"});
EXPECT_EQ(resp, "OK");
resp =
Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i5", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT"});
EXPECT_EQ(resp, "OK");
resp =
Run({"FT.CREATE", "i6", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i7", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "VECTOR", "FLAT",
"6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i2", "*"});
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
resp = Run({"FT.SEARCH", "i5", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
resp = Run({"FT.SEARCH", "i6", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
resp = Run({"FT.SEARCH", "i7", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j5"));
}
TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) {
Run({"HSET", "j1", "data", "1", "name", "doc_with_int"});
Run({"HSET", "j2", "data", "1", "name", "doc_with_int_as_string"});
Run({"HSET", "j3", "data", "string", "name", "doc_with_string"});
Run({"HSET", "j4", "name", "no_data"});
Run({"HSET", "j5", "data", "5,4,3", "name", "doc_with_fake_vector"});
Run({"HSET", "j6", "data", "[5,4,3]", "name", "doc_with_fake_vector_as_string"});
// Vector [1, 2, 3]
std::string vector = std::string("\x3f\x80\x00\x00\x40\x00\x00\x00\x40\x40\x00\x00", 12);
Run({"HSET", "j7", "data", vector, "name", "doc_with_vector [1, 2, 3]"});
auto resp = Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "data", "NUMERIC"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "data", "NUMERIC", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i3", "ON", "HASH", "SCHEMA", "data", "TAG"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i4", "ON", "HASH", "SCHEMA", "data", "TAG", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i5", "ON", "HASH", "SCHEMA", "data", "TEXT"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i6", "ON", "HASH", "SCHEMA", "data", "TEXT", "SORTABLE"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.CREATE", "i7", "ON", "HASH", "SCHEMA", "data", "VECTOR", "FLAT", "6", "TYPE",
"FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "i1", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
resp = Run({"FT.SEARCH", "i2", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
resp = Run({"FT.SEARCH", "i3", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i4", "*"});
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
resp = Run({"FT.SEARCH", "i5", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
resp = Run({"FT.SEARCH", "i6", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
resp = Run({"FT.SEARCH", "i7", "*"});
EXPECT_THAT(resp, AreDocIds("j4", "j7"));
}
TEST_F(SearchFamilyTest, WrongVectorFieldType) {
Run({"JSON.SET", "j1", ".",
R"({"vector_field": [0.1, 0.2, 0.3], "name": "doc_with_correct_dim"})"});
Run({"JSON.SET", "j2", ".", R"({"vector_field": [0.1, 0.2], "name": "doc_with_small_dim"})"});
Run({"JSON.SET", "j3", ".",
R"({"vector_field": [0.1, 0.2, 0.3, 0.4], "name": "doc_with_large_dim"})"});
Run({"JSON.SET", "j4", ".", R"({"vector_field": [1, 2, 3], "name": "doc_with_int_values"})"});
Run({"JSON.SET", "j5", ".",
R"({"vector_field":"not_vector", "name":"doc_with_incorrect_field_type"})"});
Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"});
Run({"JSON.SET", "j7", ".",
R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"});
auto resp =
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field",
"VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
EXPECT_EQ(resp, "OK");
resp = Run({"FT.SEARCH", "index", "*"});
EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4"));
}
} // namespace dfly