mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-10 18:05:44 +02:00
fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands (#4070)
* fix(search_family): Process wrong field types in indexes for the FT.SEARCH and FT.AGGREGATE commands fixes #3986 --------- Signed-off-by: Stepan Bagritsevich <stefan@dragonflydb.io>
This commit is contained in:
parent
f745f3133d
commit
503bb4ed33
16 changed files with 686 additions and 219 deletions
|
@ -4,6 +4,8 @@
|
|||
|
||||
#include "core/search/base.h"
|
||||
|
||||
#include <absl/strings/numbers.h>
|
||||
|
||||
namespace dfly::search {
|
||||
|
||||
std::string_view QueryParams::operator[](std::string_view name) const {
|
||||
|
@ -37,4 +39,11 @@ WrappedStrPtr::operator std::string_view() const {
|
|||
return std::string_view{ptr.get(), std::strlen(ptr.get())};
|
||||
}
|
||||
|
||||
std::optional<double> ParseNumericField(std::string_view value) {
|
||||
double value_as_double;
|
||||
if (absl::SimpleAtod(value, &value_as_double))
|
||||
return value_as_double;
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -68,11 +68,18 @@ using SortableValue = std::variant<std::monostate, double, std::string>;
|
|||
struct DocumentAccessor {
|
||||
using VectorInfo = search::OwnedFtVector;
|
||||
using StringList = absl::InlinedVector<std::string_view, 1>;
|
||||
using NumsList = absl::InlinedVector<double, 1>;
|
||||
|
||||
virtual ~DocumentAccessor() = default;
|
||||
|
||||
virtual StringList GetStrings(std::string_view active_field) const = 0;
|
||||
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
|
||||
/* Returns nullopt if the specified field is not a list of strings */
|
||||
virtual std::optional<StringList> GetStrings(std::string_view active_field) const = 0;
|
||||
|
||||
/* Returns nullopt if the specified field is not a vector */
|
||||
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const = 0;
|
||||
|
||||
/* Return nullopt if the specified field is not a list of doubles */
|
||||
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const = 0;
|
||||
};
|
||||
|
||||
// Base class for type-specific indices.
|
||||
|
@ -81,8 +88,10 @@ struct DocumentAccessor {
|
|||
// query functions. All results for all index types should be sorted.
|
||||
struct BaseIndex {
|
||||
virtual ~BaseIndex() = default;
|
||||
virtual void Add(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
|
||||
virtual void Remove(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
|
||||
|
||||
// Returns true if the document was added / indexed
|
||||
virtual bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
|
||||
virtual void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) = 0;
|
||||
};
|
||||
|
||||
// Base class for type-specific sorting indices.
|
||||
|
@ -91,4 +100,20 @@ struct BaseSortIndex : BaseIndex {
|
|||
virtual std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const = 0;
|
||||
};
|
||||
|
||||
/* Used for converting field values to double. Returns std::nullopt if the conversion fails */
|
||||
std::optional<double> ParseNumericField(std::string_view value);
|
||||
|
||||
/* Temporary method to create an empty std::optional<InlinedVector> in DocumentAccessor::GetString
|
||||
and DocumentAccessor::GetNumbers methods. The problem is that due to internal implementation
|
||||
details of absl::InlineVector, we are getting a -Wmaybe-uninitialized compiler warning. To
|
||||
suppress this false warning, we temporarily disable it around this block of code using GCC
|
||||
diagnostic directives. */
|
||||
template <typename InlinedVector> std::optional<InlinedVector> EmptyAccessResult() {
|
||||
// GCC 13.1 throws spurious warnings around this code.
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
||||
return InlinedVector{};
|
||||
#pragma GCC diagnostic pop
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -71,19 +71,22 @@ absl::flat_hash_set<string> NormalizeTags(string_view taglist, bool case_sensiti
|
|||
NumericIndex::NumericIndex(PMR_NS::memory_resource* mr) : entries_{mr} {
|
||||
}
|
||||
|
||||
void NumericIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
for (auto str : doc->GetStrings(field)) {
|
||||
double num;
|
||||
if (absl::SimpleAtod(str, &num))
|
||||
entries_.emplace(num, id);
|
||||
bool NumericIndex::Add(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
auto numbers = doc.GetNumbers(field);
|
||||
if (!numbers) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto num : numbers.value()) {
|
||||
entries_.emplace(num, id);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void NumericIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
for (auto str : doc->GetStrings(field)) {
|
||||
double num;
|
||||
if (absl::SimpleAtod(str, &num))
|
||||
entries_.erase({num, id});
|
||||
void NumericIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
auto numbers = doc.GetNumbers(field).value();
|
||||
for (auto num : numbers) {
|
||||
entries_.erase({num, id});
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -139,19 +142,27 @@ typename BaseStringIndex<C>::Container* BaseStringIndex<C>::GetOrCreate(string_v
|
|||
}
|
||||
|
||||
template <typename C>
|
||||
void BaseStringIndex<C>::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
bool BaseStringIndex<C>::Add(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
auto strings_list = doc.GetStrings(field);
|
||||
if (!strings_list) {
|
||||
return false;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<std::string> tokens;
|
||||
for (string_view str : doc->GetStrings(field))
|
||||
for (string_view str : strings_list.value())
|
||||
tokens.merge(Tokenize(str));
|
||||
|
||||
for (string_view token : tokens)
|
||||
GetOrCreate(token)->Insert(id);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename C>
|
||||
void BaseStringIndex<C>::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
void BaseStringIndex<C>::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
auto strings_list = doc.GetStrings(field).value();
|
||||
|
||||
absl::flat_hash_set<std::string> tokens;
|
||||
for (string_view str : doc->GetStrings(field))
|
||||
for (string_view str : strings_list)
|
||||
tokens.merge(Tokenize(str));
|
||||
|
||||
for (const auto& token : tokens) {
|
||||
|
@ -192,6 +203,20 @@ std::pair<size_t /*dim*/, VectorSimilarity> BaseVectorIndex::Info() const {
|
|||
return {dim_, sim_};
|
||||
}
|
||||
|
||||
bool BaseVectorIndex::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
|
||||
auto vector = doc.GetVector(field);
|
||||
if (!vector)
|
||||
return false;
|
||||
|
||||
auto& [ptr, size] = vector.value();
|
||||
if (ptr && size != dim_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
AddVector(id, ptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
|
||||
PMR_NS::memory_resource* mr)
|
||||
: BaseVectorIndex{params.dim, params.sim}, entries_{mr} {
|
||||
|
@ -199,19 +224,18 @@ FlatVectorIndex::FlatVectorIndex(const SchemaField::VectorParams& params,
|
|||
entries_.reserve(params.capacity * params.dim);
|
||||
}
|
||||
|
||||
void FlatVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
void FlatVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
|
||||
DCHECK_LE(id * dim_, entries_.size());
|
||||
if (id * dim_ == entries_.size())
|
||||
entries_.resize((id + 1) * dim_);
|
||||
|
||||
// TODO: Let get vector write to buf itself
|
||||
auto [ptr, size] = doc->GetVector(field);
|
||||
|
||||
if (size == dim_)
|
||||
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
|
||||
if (vector) {
|
||||
memcpy(&entries_[id * dim_], vector.get(), dim_ * sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
void FlatVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
void FlatVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
// noop
|
||||
}
|
||||
|
||||
|
@ -229,7 +253,7 @@ struct HnswlibAdapter {
|
|||
100 /* seed*/} {
|
||||
}
|
||||
|
||||
void Add(float* data, DocId id) {
|
||||
void Add(const float* data, DocId id) {
|
||||
if (world_.cur_element_count + 1 >= world_.max_elements_)
|
||||
world_.resizeIndex(world_.cur_element_count * 2);
|
||||
world_.addPoint(data, id);
|
||||
|
@ -298,10 +322,10 @@ HnswVectorIndex::HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS
|
|||
HnswVectorIndex::~HnswVectorIndex() {
|
||||
}
|
||||
|
||||
void HnswVectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
auto [ptr, size] = doc->GetVector(field);
|
||||
if (size == dim_)
|
||||
adapter_->Add(ptr.get(), id);
|
||||
void HnswVectorIndex::AddVector(DocId id, const VectorPtr& vector) {
|
||||
if (vector) {
|
||||
adapter_->Add(vector.get(), id);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t k,
|
||||
|
@ -314,7 +338,7 @@ std::vector<std::pair<float, DocId>> HnswVectorIndex::Knn(float* target, size_t
|
|||
return adapter_->Knn(target, k, ef, allowed);
|
||||
}
|
||||
|
||||
void HnswVectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||
void HnswVectorIndex::Remove(DocId id, const DocumentAccessor& doc, string_view field) {
|
||||
adapter_->Remove(id);
|
||||
}
|
||||
|
||||
|
|
|
@ -28,8 +28,8 @@ namespace dfly::search {
|
|||
struct NumericIndex : public BaseIndex {
|
||||
explicit NumericIndex(PMR_NS::memory_resource* mr);
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
|
||||
std::vector<DocId> Range(double l, double r) const;
|
||||
|
||||
|
@ -44,8 +44,8 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
|
|||
|
||||
BaseStringIndex(PMR_NS::memory_resource* mr, bool case_sensitive);
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
|
||||
// Used by Add & Remove to tokenize text value
|
||||
virtual absl::flat_hash_set<std::string> Tokenize(std::string_view value) const = 0;
|
||||
|
@ -53,7 +53,7 @@ template <typename C> struct BaseStringIndex : public BaseIndex {
|
|||
// Pointer is valid as long as index is not mutated. Nullptr if not found
|
||||
const Container* Matching(std::string_view str) const;
|
||||
|
||||
// Iterate over all Machting on prefix.
|
||||
// Iterate over all Matching on prefix.
|
||||
void MatchingPrefix(std::string_view prefix, absl::FunctionRef<void(const Container*)> cb) const;
|
||||
|
||||
// Returns all the terms that appear as keys in the reverse index.
|
||||
|
@ -97,9 +97,14 @@ struct TagIndex : public BaseStringIndex<SortedVector> {
|
|||
struct BaseVectorIndex : public BaseIndex {
|
||||
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
|
||||
|
||||
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override final;
|
||||
|
||||
protected:
|
||||
BaseVectorIndex(size_t dim, VectorSimilarity sim);
|
||||
|
||||
using VectorPtr = decltype(std::declval<OwnedFtVector>().first);
|
||||
virtual void AddVector(DocId id, const VectorPtr& vector) = 0;
|
||||
|
||||
size_t dim_;
|
||||
VectorSimilarity sim_;
|
||||
};
|
||||
|
@ -109,11 +114,13 @@ struct BaseVectorIndex : public BaseIndex {
|
|||
struct FlatVectorIndex : public BaseVectorIndex {
|
||||
FlatVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
|
||||
const float* Get(DocId doc) const;
|
||||
|
||||
protected:
|
||||
void AddVector(DocId id, const VectorPtr& vector) override;
|
||||
|
||||
private:
|
||||
PMR_NS::vector<float> entries_;
|
||||
};
|
||||
|
@ -124,13 +131,15 @@ struct HnswVectorIndex : public BaseVectorIndex {
|
|||
HnswVectorIndex(const SchemaField::VectorParams& params, PMR_NS::memory_resource* mr);
|
||||
~HnswVectorIndex();
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
|
||||
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef) const;
|
||||
std::vector<std::pair<float, DocId>> Knn(float* target, size_t k, std::optional<size_t> ef,
|
||||
const std::vector<DocId>& allowed) const;
|
||||
|
||||
protected:
|
||||
void AddVector(DocId id, const VectorPtr& vector) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<HnswlibAdapter> adapter_;
|
||||
};
|
||||
|
|
|
@ -571,23 +571,48 @@ void FieldIndices::CreateSortIndices(PMR_NS::memory_resource* mr) {
|
|||
}
|
||||
}
|
||||
|
||||
void FieldIndices::Add(DocId doc, DocumentAccessor* access) {
|
||||
for (auto& [field, index] : indices_)
|
||||
index->Add(doc, access, field);
|
||||
for (auto& [field, sort_index] : sort_indices_)
|
||||
sort_index->Add(doc, access, field);
|
||||
bool FieldIndices::Add(DocId doc, const DocumentAccessor& access) {
|
||||
bool was_added = true;
|
||||
|
||||
std::vector<std::pair<std::string_view, BaseIndex*>> successfully_added_indices;
|
||||
successfully_added_indices.reserve(indices_.size() + sort_indices_.size());
|
||||
|
||||
auto try_add = [&](const auto& indices_container) {
|
||||
for (auto& [field, index] : indices_container) {
|
||||
if (index->Add(doc, access, field)) {
|
||||
successfully_added_indices.emplace_back(field, index.get());
|
||||
} else {
|
||||
was_added = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
try_add(indices_);
|
||||
|
||||
if (was_added) {
|
||||
try_add(sort_indices_);
|
||||
}
|
||||
|
||||
if (!was_added) {
|
||||
for (auto& [field, index] : successfully_added_indices) {
|
||||
index->Remove(doc, access, field);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
all_ids_.insert(upper_bound(all_ids_.begin(), all_ids_.end(), doc), doc);
|
||||
return true;
|
||||
}
|
||||
|
||||
void FieldIndices::Remove(DocId doc, DocumentAccessor* access) {
|
||||
void FieldIndices::Remove(DocId doc, const DocumentAccessor& access) {
|
||||
for (auto& [field, index] : indices_)
|
||||
index->Remove(doc, access, field);
|
||||
for (auto& [field, sort_index] : sort_indices_)
|
||||
sort_index->Remove(doc, access, field);
|
||||
|
||||
auto it = lower_bound(all_ids_.begin(), all_ids_.end(), doc);
|
||||
CHECK(it != all_ids_.end() && *it == doc);
|
||||
DCHECK(it != all_ids_.end() && *it == doc);
|
||||
all_ids_.erase(it);
|
||||
}
|
||||
|
||||
|
|
|
@ -77,8 +77,9 @@ class FieldIndices {
|
|||
// Create indices based on schema and options. Both must outlive the indices
|
||||
FieldIndices(const Schema& schema, const IndicesOptions& options, PMR_NS::memory_resource* mr);
|
||||
|
||||
void Add(DocId doc, DocumentAccessor* access);
|
||||
void Remove(DocId doc, DocumentAccessor* access);
|
||||
// Returns true if document was added
|
||||
bool Add(DocId doc, const DocumentAccessor& access);
|
||||
void Remove(DocId doc, const DocumentAccessor& access);
|
||||
|
||||
BaseIndex* GetIndex(std::string_view field) const;
|
||||
BaseSortIndex* GetSortIndex(std::string_view field) const;
|
||||
|
|
|
@ -44,13 +44,36 @@ struct MockedDocument : public DocumentAccessor {
|
|||
MockedDocument(std::string test_field) : fields_{{"field", test_field}} {
|
||||
}
|
||||
|
||||
StringList GetStrings(string_view field) const override {
|
||||
std::optional<StringList> GetStrings(string_view field) const override {
|
||||
auto it = fields_.find(field);
|
||||
return {it != fields_.end() ? string_view{it->second} : ""};
|
||||
if (it == fields_.end()) {
|
||||
return EmptyAccessResult<StringList>();
|
||||
}
|
||||
return StringList{string_view{it->second}};
|
||||
}
|
||||
|
||||
VectorInfo GetVector(string_view field) const override {
|
||||
return BytesToFtVector(GetStrings(field).front());
|
||||
std::optional<VectorInfo> GetVector(string_view field) const override {
|
||||
auto strings_list = GetStrings(field);
|
||||
if (!strings_list)
|
||||
return std::nullopt;
|
||||
return !strings_list->empty() ? BytesToFtVectorSafe(strings_list->front()) : VectorInfo{};
|
||||
}
|
||||
|
||||
std::optional<NumsList> GetNumbers(std::string_view field) const override {
|
||||
auto strings_list = GetStrings(field);
|
||||
if (!strings_list)
|
||||
return std::nullopt;
|
||||
|
||||
NumsList nums_list;
|
||||
nums_list.reserve(strings_list->size());
|
||||
for (auto str : strings_list.value()) {
|
||||
auto num = ParseNumericField(str);
|
||||
if (!num) {
|
||||
return std::nullopt;
|
||||
}
|
||||
nums_list.push_back(num.value());
|
||||
}
|
||||
return nums_list;
|
||||
}
|
||||
|
||||
string DebugFormat() {
|
||||
|
@ -121,7 +144,7 @@ class SearchTest : public ::testing::Test {
|
|||
|
||||
shuffle(entries_.begin(), entries_.end(), default_random_engine{});
|
||||
for (DocId i = 0; i < entries_.size(); i++)
|
||||
index.Add(i, &entries_[i].first);
|
||||
index.Add(i, entries_[i].first);
|
||||
|
||||
SearchAlgorithm search_algo{};
|
||||
if (!search_algo.Init(query_, ¶ms_)) {
|
||||
|
@ -430,7 +453,7 @@ TEST_F(SearchTest, StopWords) {
|
|||
"explicitly found!"};
|
||||
for (size_t i = 0; i < documents.size(); i++) {
|
||||
MockedDocument doc{{{"title", documents[i]}}};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
// words is a stopword
|
||||
|
@ -484,7 +507,7 @@ TEST_P(KnnTest, Simple1D) {
|
|||
for (size_t i = 0; i < 100; i++) {
|
||||
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
|
||||
MockedDocument doc{values};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
|
@ -540,7 +563,7 @@ TEST_P(KnnTest, Simple2D) {
|
|||
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
|
||||
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
|
||||
MockedDocument doc{Map{{"pos", coords}}};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
|
@ -602,7 +625,7 @@ TEST_P(KnnTest, Cosine) {
|
|||
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
|
||||
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
|
||||
MockedDocument doc{Map{{"pos", coords}}};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
|
@ -646,7 +669,7 @@ TEST_P(KnnTest, AddRemove) {
|
|||
vector<MockedDocument> documents(10);
|
||||
for (size_t i = 0; i < 10; i++) {
|
||||
documents[i] = Map{{"pos", ToBytes({float(i)})}};
|
||||
indices.Add(i, &documents[i]);
|
||||
indices.Add(i, documents[i]);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
|
@ -661,7 +684,7 @@ TEST_P(KnnTest, AddRemove) {
|
|||
|
||||
// delete leftmost 5
|
||||
for (size_t i = 0; i < 5; i++)
|
||||
indices.Remove(i, &documents[i]);
|
||||
indices.Remove(i, documents[i]);
|
||||
|
||||
// search leftmost 5 again
|
||||
{
|
||||
|
@ -672,7 +695,7 @@ TEST_P(KnnTest, AddRemove) {
|
|||
|
||||
// add removed elements
|
||||
for (size_t i = 0; i < 5; i++)
|
||||
indices.Add(i, &documents[i]);
|
||||
indices.Add(i, documents[i]);
|
||||
|
||||
// repeat first search
|
||||
{
|
||||
|
@ -693,7 +716,7 @@ TEST_P(KnnTest, AutoResize) {
|
|||
|
||||
for (size_t i = 0; i < 100; i++) {
|
||||
MockedDocument doc{Map{{"pos", ToBytes({float(i)})}}};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
EXPECT_EQ(indices.GetAllDocs().size(), 100);
|
||||
|
@ -720,7 +743,7 @@ static void BM_VectorSearch(benchmark::State& state) {
|
|||
for (size_t i = 0; i < nvecs; i++) {
|
||||
auto rv = random_vec();
|
||||
MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
|
||||
indices.Add(i, &doc);
|
||||
indices.Add(i, doc);
|
||||
}
|
||||
|
||||
SearchAlgorithm algo{};
|
||||
|
|
|
@ -46,15 +46,23 @@ std::vector<ResultScore> SimpleValueSortIndex<T>::Sort(std::vector<DocId>* ids,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void SimpleValueSortIndex<T>::Add(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
bool SimpleValueSortIndex<T>::Add(DocId id, const DocumentAccessor& doc, std::string_view field) {
|
||||
auto field_value = Get(doc, field);
|
||||
if (!field_value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
DCHECK_LE(id, values_.size()); // Doc ids grow at most by one
|
||||
if (id >= values_.size())
|
||||
values_.resize(id + 1);
|
||||
values_[id] = Get(id, doc, field);
|
||||
|
||||
values_[id] = field_value.value();
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SimpleValueSortIndex<T>::Remove(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
void SimpleValueSortIndex<T>::Remove(DocId id, const DocumentAccessor& doc,
|
||||
std::string_view field) {
|
||||
DCHECK_LT(id, values_.size());
|
||||
values_[id] = T{};
|
||||
}
|
||||
|
@ -66,23 +74,22 @@ template <typename T> PMR_NS::memory_resource* SimpleValueSortIndex<T>::GetMemRe
|
|||
template struct SimpleValueSortIndex<double>;
|
||||
template struct SimpleValueSortIndex<PMR_NS::string>;
|
||||
|
||||
double NumericSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
auto str = doc->GetStrings(field);
|
||||
if (str.empty())
|
||||
return 0;
|
||||
|
||||
double v;
|
||||
if (!absl::SimpleAtod(str.front(), &v))
|
||||
return 0;
|
||||
return v;
|
||||
std::optional<double> NumericSortIndex::Get(const DocumentAccessor& doc, std::string_view field) {
|
||||
auto numbers_list = doc.GetNumbers(field);
|
||||
if (!numbers_list) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return !numbers_list->empty() ? numbers_list->front() : 0.0;
|
||||
}
|
||||
|
||||
PMR_NS::string StringSortIndex::Get(DocId id, DocumentAccessor* doc, std::string_view field) {
|
||||
auto str = doc->GetStrings(field);
|
||||
if (str.empty())
|
||||
return "";
|
||||
|
||||
return PMR_NS::string{str.front(), GetMemRes()};
|
||||
std::optional<PMR_NS::string> StringSortIndex::Get(const DocumentAccessor& doc,
|
||||
std::string_view field) {
|
||||
auto strings_list = doc.GetStrings(field);
|
||||
if (!strings_list) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return !strings_list->empty() ? PMR_NS::string{strings_list->front(), GetMemRes()}
|
||||
: PMR_NS::string{GetMemRes()};
|
||||
}
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -24,11 +24,11 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
|
|||
SortableValue Lookup(DocId doc) const override;
|
||||
std::vector<ResultScore> Sort(std::vector<DocId>* ids, size_t limit, bool desc) const override;
|
||||
|
||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
bool Add(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
void Remove(DocId id, const DocumentAccessor& doc, std::string_view field) override;
|
||||
|
||||
protected:
|
||||
virtual T Get(DocId id, DocumentAccessor* doc, std::string_view field) = 0;
|
||||
virtual std::optional<T> Get(const DocumentAccessor& doc, std::string_view field_value) = 0;
|
||||
|
||||
PMR_NS::memory_resource* GetMemRes() const;
|
||||
|
||||
|
@ -39,14 +39,14 @@ template <typename T> struct SimpleValueSortIndex : BaseSortIndex {
|
|||
struct NumericSortIndex : public SimpleValueSortIndex<double> {
|
||||
NumericSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
|
||||
|
||||
double Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
std::optional<double> Get(const DocumentAccessor& doc, std::string_view field) override;
|
||||
};
|
||||
|
||||
// TODO: Map tags to integers for fast sort
|
||||
struct StringSortIndex : public SimpleValueSortIndex<PMR_NS::string> {
|
||||
StringSortIndex(PMR_NS::memory_resource* mr) : SimpleValueSortIndex{mr} {};
|
||||
|
||||
PMR_NS::string Get(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||
std::optional<PMR_NS::string> Get(const DocumentAccessor& doc, std::string_view field) override;
|
||||
};
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -39,18 +39,28 @@ __attribute__((optimize("fast-math"))) float CosineDistance(const float* u, cons
|
|||
return 0.0f;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OwnedFtVector BytesToFtVector(string_view value) {
|
||||
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
|
||||
|
||||
OwnedFtVector ConvertToFtVector(string_view value) {
|
||||
// Value cannot be casted directly as it might be not aligned as a float (4 bytes).
|
||||
// Misaligned memory access is UB.
|
||||
size_t size = value.size() / sizeof(float);
|
||||
auto out = make_unique<float[]>(size);
|
||||
memcpy(out.get(), value.data(), size * sizeof(float));
|
||||
|
||||
return {std::move(out), size};
|
||||
return OwnedFtVector{std::move(out), size};
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OwnedFtVector BytesToFtVector(string_view value) {
|
||||
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
|
||||
return ConvertToFtVector(value);
|
||||
}
|
||||
|
||||
std::optional<OwnedFtVector> BytesToFtVectorSafe(string_view value) {
|
||||
if (value.size() % sizeof(float)) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return ConvertToFtVector(value);
|
||||
}
|
||||
|
||||
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {
|
||||
|
|
|
@ -10,6 +10,10 @@ namespace dfly::search {
|
|||
|
||||
OwnedFtVector BytesToFtVector(std::string_view value);
|
||||
|
||||
// Returns std::nullopt if value can not be converted to the vector
|
||||
// TODO: Remove unsafe version
|
||||
std::optional<OwnedFtVector> BytesToFtVectorSafe(std::string_view value);
|
||||
|
||||
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
|
||||
|
||||
} // namespace dfly::search
|
||||
|
|
|
@ -38,43 +38,44 @@ string_view SdsToSafeSv(sds str) {
|
|||
return str != nullptr ? string_view{str, sdslen(str)} : ""sv;
|
||||
}
|
||||
|
||||
search::SortableValue FieldToSortableValue(search::SchemaField::FieldType type, string_view value) {
|
||||
using FieldValue = std::optional<search::SortableValue>;
|
||||
|
||||
FieldValue ToSortableValue(search::SchemaField::FieldType type, string_view value) {
|
||||
if (value.empty()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (type == search::SchemaField::NUMERIC) {
|
||||
double value_as_double = 0;
|
||||
if (!absl::SimpleAtod(value, &value_as_double)) { // temporary convert to double
|
||||
auto value_as_double = search::ParseNumericField(value);
|
||||
if (!value_as_double) { // temporary convert to double
|
||||
LOG(DFATAL) << "Failed to convert " << value << " to double";
|
||||
return std::nullopt;
|
||||
}
|
||||
return value_as_double;
|
||||
return value_as_double.value();
|
||||
}
|
||||
if (type == search::SchemaField::VECTOR) {
|
||||
auto [ptr, size] = search::BytesToFtVector(value);
|
||||
auto opt_vector = search::BytesToFtVectorSafe(value);
|
||||
if (!opt_vector) {
|
||||
LOG(DFATAL) << "Failed to convert " << value << " to vector";
|
||||
return std::nullopt;
|
||||
}
|
||||
auto& [ptr, size] = opt_vector.value();
|
||||
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
|
||||
}
|
||||
return string{value};
|
||||
}
|
||||
|
||||
search::SortableValue JsonToSortableValue(const search::SchemaField::FieldType type,
|
||||
const JsonType& json) {
|
||||
if (type == search::SchemaField::NUMERIC) {
|
||||
return json.as_double();
|
||||
}
|
||||
return json.to_string();
|
||||
}
|
||||
|
||||
search::SortableValue ExtractSortableValue(const search::Schema& schema, string_view key,
|
||||
string_view value) {
|
||||
FieldValue ExtractSortableValue(const search::Schema& schema, string_view key, string_view value) {
|
||||
auto it = schema.fields.find(key);
|
||||
if (it == schema.fields.end())
|
||||
return FieldToSortableValue(search::SchemaField::TEXT, value);
|
||||
return FieldToSortableValue(it->second.type, value);
|
||||
return ToSortableValue(search::SchemaField::TEXT, value);
|
||||
return ToSortableValue(it->second.type, value);
|
||||
}
|
||||
|
||||
search::SortableValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
|
||||
const JsonType& json) {
|
||||
auto it = schema.fields.find(key);
|
||||
if (it == schema.fields.end())
|
||||
return JsonToSortableValue(search::SchemaField::TEXT, json);
|
||||
return JsonToSortableValue(it->second.type, json);
|
||||
FieldValue ExtractSortableValueFromJson(const search::Schema& schema, string_view key,
|
||||
const JsonType& json) {
|
||||
auto json_as_string = json.to_string();
|
||||
return ExtractSortableValue(schema, key, json_as_string);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -83,7 +84,11 @@ SearchDocData BaseAccessor::Serialize(
|
|||
const search::Schema& schema, absl::Span<const SearchField<std::string_view>> fields) const {
|
||||
SearchDocData out{};
|
||||
for (const auto& [fident, fname] : fields) {
|
||||
out[fname] = ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident), ","));
|
||||
auto field_value =
|
||||
ExtractSortableValue(schema, fident, absl::StrJoin(GetStrings(fident).value(), ","));
|
||||
if (field_value) {
|
||||
out[fname] = std::move(field_value).value();
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@ -92,14 +97,39 @@ SearchDocData BaseAccessor::SerializeDocument(const search::Schema& schema) cons
|
|||
return Serialize(schema);
|
||||
}
|
||||
|
||||
BaseAccessor::StringList ListPackAccessor::GetStrings(string_view active_field) const {
|
||||
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
|
||||
return strsv.has_value() ? StringList{*strsv} : StringList{};
|
||||
std::optional<BaseAccessor::VectorInfo> BaseAccessor::GetVector(
|
||||
std::string_view active_field) const {
|
||||
auto strings_list = GetStrings(active_field);
|
||||
if (strings_list) {
|
||||
return !strings_list->empty() ? search::BytesToFtVectorSafe(strings_list->front())
|
||||
: VectorInfo{};
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
|
||||
auto strlist = GetStrings(active_field);
|
||||
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
|
||||
std::optional<BaseAccessor::NumsList> BaseAccessor::GetNumbers(
|
||||
std::string_view active_field) const {
|
||||
auto strings_list = GetStrings(active_field);
|
||||
if (!strings_list) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
NumsList nums_list;
|
||||
nums_list.reserve(strings_list->size());
|
||||
for (auto str : strings_list.value()) {
|
||||
auto num = search::ParseNumericField(str);
|
||||
if (!num) {
|
||||
return std::nullopt;
|
||||
}
|
||||
nums_list.push_back(num.value());
|
||||
}
|
||||
return nums_list;
|
||||
}
|
||||
|
||||
std::optional<BaseAccessor::StringList> ListPackAccessor::GetStrings(
|
||||
string_view active_field) const {
|
||||
auto strsv = container_utils::LpFind(lp_, active_field, intbuf_[0].data());
|
||||
return strsv.has_value() ? StringList{*strsv} : StringList{};
|
||||
}
|
||||
|
||||
SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
|
||||
|
@ -114,27 +144,29 @@ SearchDocData ListPackAccessor::Serialize(const search::Schema& schema) const {
|
|||
string_view v = container_utils::LpGetView(fptr, intbuf_[1].data());
|
||||
fptr = lpNext(lp_, fptr);
|
||||
|
||||
out[k] = ExtractSortableValue(schema, k, v);
|
||||
auto field_value = ExtractSortableValue(schema, k, v);
|
||||
if (field_value) {
|
||||
out[k] = std::move(field_value).value();
|
||||
}
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
BaseAccessor::StringList StringMapAccessor::GetStrings(string_view active_field) const {
|
||||
std::optional<BaseAccessor::StringList> StringMapAccessor::GetStrings(
|
||||
string_view active_field) const {
|
||||
auto it = hset_->Find(active_field);
|
||||
return it != hset_->end() ? StringList{SdsToSafeSv(it->second)} : StringList{};
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
|
||||
auto strlist = GetStrings(active_field);
|
||||
return strlist.empty() ? VectorInfo{} : search::BytesToFtVector(strlist.front());
|
||||
}
|
||||
|
||||
SearchDocData StringMapAccessor::Serialize(const search::Schema& schema) const {
|
||||
SearchDocData out{};
|
||||
for (const auto& [kptr, vptr] : *hset_)
|
||||
out[SdsToSafeSv(kptr)] = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
|
||||
|
||||
for (const auto& [kptr, vptr] : *hset_) {
|
||||
auto field_value = ExtractSortableValue(schema, SdsToSafeSv(kptr), SdsToSafeSv(vptr));
|
||||
if (field_value) {
|
||||
out[SdsToSafeSv(kptr)] = std::move(field_value).value();
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -159,27 +191,54 @@ struct JsonAccessor::JsonPathContainer {
|
|||
variant<json::Path, jsoncons::jsonpath::jsonpath_expression<JsonType>> val;
|
||||
};
|
||||
|
||||
BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) const {
|
||||
std::optional<BaseAccessor::StringList> JsonAccessor::GetStrings(string_view active_field) const {
|
||||
auto* path = GetPath(active_field);
|
||||
if (!path)
|
||||
return {};
|
||||
return search::EmptyAccessResult<StringList>();
|
||||
|
||||
auto path_res = path->Evaluate(json_);
|
||||
if (path_res.empty())
|
||||
return {};
|
||||
return search::EmptyAccessResult<StringList>();
|
||||
|
||||
if (path_res.size() == 1 && !path_res[0].is_array()) {
|
||||
if (!path_res[0].is_string())
|
||||
return std::nullopt;
|
||||
|
||||
if (path_res.size() == 1) {
|
||||
buf_ = path_res[0].as_string();
|
||||
return {buf_};
|
||||
return StringList{buf_};
|
||||
}
|
||||
|
||||
buf_.clear();
|
||||
|
||||
// First, grow buffer and compute string sizes
|
||||
vector<size_t> sizes;
|
||||
for (const auto& element : path_res) {
|
||||
|
||||
auto add_json_to_buf = [&](const JsonType& json) {
|
||||
size_t start = buf_.size();
|
||||
buf_ += element.as_string();
|
||||
buf_ += json.as_string();
|
||||
sizes.push_back(buf_.size() - start);
|
||||
};
|
||||
|
||||
if (!path_res[0].is_array()) {
|
||||
sizes.reserve(path_res.size());
|
||||
for (const auto& element : path_res) {
|
||||
if (!element.is_string())
|
||||
return std::nullopt;
|
||||
|
||||
add_json_to_buf(element);
|
||||
}
|
||||
} else {
|
||||
if (path_res.size() > 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
sizes.reserve(path_res[0].size());
|
||||
for (const auto& element : path_res[0].array_range()) {
|
||||
if (!element.is_string())
|
||||
return std::nullopt;
|
||||
|
||||
add_json_to_buf(element);
|
||||
}
|
||||
}
|
||||
|
||||
// Reposition start pointers to the most recent allocation of buf
|
||||
|
@ -194,23 +253,62 @@ BaseAccessor::StringList JsonAccessor::GetStrings(string_view active_field) cons
|
|||
return out;
|
||||
}
|
||||
|
||||
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
|
||||
std::optional<BaseAccessor::VectorInfo> JsonAccessor::GetVector(string_view active_field) const {
|
||||
auto* path = GetPath(active_field);
|
||||
if (!path)
|
||||
return {};
|
||||
return VectorInfo{};
|
||||
|
||||
auto res = path->Evaluate(json_);
|
||||
if (res.empty())
|
||||
return {nullptr, 0};
|
||||
return VectorInfo{};
|
||||
|
||||
if (!res[0].is_array())
|
||||
return std::nullopt;
|
||||
|
||||
size_t size = res[0].size();
|
||||
auto ptr = make_unique<float[]>(size);
|
||||
|
||||
size_t i = 0;
|
||||
for (const auto& v : res[0].array_range())
|
||||
for (const auto& v : res[0].array_range()) {
|
||||
if (!v.is_number()) {
|
||||
return std::nullopt;
|
||||
}
|
||||
ptr[i++] = v.as<float>();
|
||||
}
|
||||
|
||||
return {std::move(ptr), size};
|
||||
return BaseAccessor::VectorInfo{std::move(ptr), size};
|
||||
}
|
||||
|
||||
std::optional<BaseAccessor::NumsList> JsonAccessor::GetNumbers(string_view active_field) const {
|
||||
auto* path = GetPath(active_field);
|
||||
if (!path)
|
||||
return search::EmptyAccessResult<NumsList>();
|
||||
|
||||
auto path_res = path->Evaluate(json_);
|
||||
if (path_res.empty())
|
||||
return search::EmptyAccessResult<NumsList>();
|
||||
|
||||
NumsList nums_list;
|
||||
if (!path_res[0].is_array()) {
|
||||
nums_list.reserve(path_res.size());
|
||||
for (const auto& element : path_res) {
|
||||
if (!element.is_number())
|
||||
return std::nullopt;
|
||||
nums_list.push_back(element.as<double>());
|
||||
}
|
||||
} else {
|
||||
if (path_res.size() > 1) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
nums_list.reserve(path_res[0].size());
|
||||
for (const auto& element : path_res[0].array_range()) {
|
||||
if (!element.is_number())
|
||||
return std::nullopt;
|
||||
nums_list.push_back(element.as<double>());
|
||||
}
|
||||
}
|
||||
return nums_list;
|
||||
}
|
||||
|
||||
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
|
||||
|
@ -259,8 +357,12 @@ SearchDocData JsonAccessor::Serialize(
|
|||
SearchDocData out{};
|
||||
for (const auto& [ident, name] : fields) {
|
||||
if (auto* path = GetPath(ident); path) {
|
||||
if (auto res = path->Evaluate(json_); !res.empty())
|
||||
out[name] = ExtractSortableValueFromJson(schema, ident, res[0]);
|
||||
if (auto res = path->Evaluate(json_); !res.empty()) {
|
||||
auto field_value = ExtractSortableValueFromJson(schema, ident, res[0]);
|
||||
if (field_value) {
|
||||
out[name] = std::move(field_value).value();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "core/json/json_object.h"
|
||||
#include "core/search/search.h"
|
||||
#include "core/search/vector_utils.h"
|
||||
#include "server/common.h"
|
||||
#include "server/search/doc_index.h"
|
||||
#include "server/table.h"
|
||||
|
@ -37,6 +38,10 @@ struct BaseAccessor : public search::DocumentAccessor {
|
|||
indexed field
|
||||
*/
|
||||
virtual SearchDocData SerializeDocument(const search::Schema& schema) const;
|
||||
|
||||
// Default implementation uses GetStrings
|
||||
virtual std::optional<VectorInfo> GetVector(std::string_view active_field) const;
|
||||
virtual std::optional<NumsList> GetNumbers(std::string_view active_field) const;
|
||||
};
|
||||
|
||||
// Accessor for hashes stored with listpack
|
||||
|
@ -46,8 +51,7 @@ struct ListPackAccessor : public BaseAccessor {
|
|||
explicit ListPackAccessor(LpPtr ptr) : lp_{ptr} {
|
||||
}
|
||||
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
std::optional<StringList> GetStrings(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
private:
|
||||
|
@ -60,8 +64,7 @@ struct StringMapAccessor : public BaseAccessor {
|
|||
explicit StringMapAccessor(StringMap* hset) : hset_{hset} {
|
||||
}
|
||||
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
std::optional<StringList> GetStrings(std::string_view field) const override;
|
||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||
|
||||
private:
|
||||
|
@ -75,8 +78,9 @@ struct JsonAccessor : public BaseAccessor {
|
|||
explicit JsonAccessor(const JsonType* json) : json_{*json} {
|
||||
}
|
||||
|
||||
StringList GetStrings(std::string_view field) const override;
|
||||
VectorInfo GetVector(std::string_view field) const override;
|
||||
std::optional<StringList> GetStrings(std::string_view field) const override;
|
||||
std::optional<VectorInfo> GetVector(std::string_view field) const override;
|
||||
std::optional<NumsList> GetNumbers(std::string_view active_field) const override;
|
||||
|
||||
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
|
||||
SearchDocData Serialize(const search::Schema& schema,
|
||||
|
|
|
@ -41,7 +41,7 @@ void TraverseAllMatching(const DocIndex& index, const OpArgs& op_args, F&& f) {
|
|||
return;
|
||||
|
||||
auto accessor = GetAccessor(op_args.db_cntx, pv);
|
||||
f(key, accessor.get());
|
||||
f(key, *accessor);
|
||||
};
|
||||
|
||||
PrimeTable::Cursor cursor;
|
||||
|
@ -146,12 +146,14 @@ ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Add(string_view key) {
|
|||
return id;
|
||||
}
|
||||
|
||||
ShardDocIndex::DocId ShardDocIndex::DocKeyIndex::Remove(string_view key) {
|
||||
DCHECK_GT(ids_.count(key), 0u);
|
||||
std::optional<ShardDocIndex::DocId> ShardDocIndex::DocKeyIndex::Remove(string_view key) {
|
||||
auto it = ids_.extract(key);
|
||||
if (!it) {
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
DocId id = ids_.find(key)->second;
|
||||
const DocId id = it.mapped();
|
||||
keys_[id] = "";
|
||||
ids_.erase(key);
|
||||
free_ids_.push_back(id);
|
||||
|
||||
return id;
|
||||
|
@ -184,7 +186,13 @@ void ShardDocIndex::Rebuild(const OpArgs& op_args, PMR_NS::memory_resource* mr)
|
|||
key_index_ = DocKeyIndex{};
|
||||
indices_.emplace(base_->schema, base_->options, mr);
|
||||
|
||||
auto cb = [this](string_view key, BaseAccessor* doc) { indices_->Add(key_index_.Add(key), doc); };
|
||||
auto cb = [this](string_view key, const BaseAccessor& doc) {
|
||||
DocId id = key_index_.Add(key);
|
||||
if (!indices_->Add(id, doc)) {
|
||||
key_index_.Remove(key);
|
||||
}
|
||||
};
|
||||
|
||||
TraverseAllMatching(*base_, op_args, cb);
|
||||
|
||||
VLOG(1) << "Indexed " << key_index_.Size() << " docs on " << base_->prefix;
|
||||
|
@ -195,7 +203,10 @@ void ShardDocIndex::AddDoc(string_view key, const DbContext& db_cntx, const Prim
|
|||
return;
|
||||
|
||||
auto accessor = GetAccessor(db_cntx, pv);
|
||||
indices_->Add(key_index_.Add(key), accessor.get());
|
||||
DocId id = key_index_.Add(key);
|
||||
if (!indices_->Add(id, *accessor)) {
|
||||
key_index_.Remove(key);
|
||||
}
|
||||
}
|
||||
|
||||
void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const PrimeValue& pv) {
|
||||
|
@ -203,8 +214,10 @@ void ShardDocIndex::RemoveDoc(string_view key, const DbContext& db_cntx, const P
|
|||
return;
|
||||
|
||||
auto accessor = GetAccessor(db_cntx, pv);
|
||||
DocId id = key_index_.Remove(key);
|
||||
indices_->Remove(id, accessor.get());
|
||||
auto id = key_index_.Remove(key);
|
||||
if (id) {
|
||||
indices_->Remove(id.value(), *accessor);
|
||||
}
|
||||
}
|
||||
|
||||
bool ShardDocIndex::Matches(string_view key, unsigned obj_code) const {
|
||||
|
|
|
@ -133,7 +133,7 @@ class ShardDocIndex {
|
|||
// DocKeyIndex manages mapping document keys to ids and vice versa through a simple interface.
|
||||
struct DocKeyIndex {
|
||||
DocId Add(std::string_view key);
|
||||
DocId Remove(std::string_view key);
|
||||
std::optional<DocId> Remove(std::string_view key);
|
||||
|
||||
std::string_view Get(DocId id) const;
|
||||
size_t Size() const;
|
||||
|
|
|
@ -90,6 +90,13 @@ template <typename... Args> auto IsArray(Args... args) {
|
|||
template <typename... Args> auto IsUnordArray(Args... args) {
|
||||
return RespArray(UnorderedElementsAre(std::forward<Args>(args)...));
|
||||
}
|
||||
template <typename Expected, size_t... Is>
|
||||
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
|
||||
const Expected& expected, std::index_sequence<Is...>) {
|
||||
std::initializer_list<int>{
|
||||
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
|
||||
0)...};
|
||||
}
|
||||
|
||||
MATCHER_P(IsMapMatcher, expected, "") {
|
||||
if (arg.type != RespExpr::ARRAY) {
|
||||
|
@ -97,73 +104,29 @@ MATCHER_P(IsMapMatcher, expected, "") {
|
|||
return false;
|
||||
}
|
||||
|
||||
constexpr size_t expected_size = std::tuple_size<decltype(expected)>::value;
|
||||
constexpr size_t exprected_pairs_number = expected_size / 2;
|
||||
|
||||
auto result = arg.GetVec();
|
||||
if (result.size() != expected.size()) {
|
||||
if (result.size() != expected_size) {
|
||||
*result_listener << "Wrong resp array size: " << result.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
using KeyValueArray = std::vector<std::pair<std::string, std::string>>;
|
||||
|
||||
KeyValueArray received_pairs;
|
||||
std::vector<std::pair<std::string, RespExpr>> received_pairs;
|
||||
for (size_t i = 0; i < result.size(); i += 2) {
|
||||
received_pairs.emplace_back(result[i].GetString(), result[i + 1].GetString());
|
||||
received_pairs.emplace_back(result[i].GetString(), result[i + 1]);
|
||||
}
|
||||
|
||||
KeyValueArray expected_pairs;
|
||||
for (size_t i = 0; i < expected.size(); i += 2) {
|
||||
expected_pairs.emplace_back(expected[i], expected[i + 1]);
|
||||
}
|
||||
std::vector<Matcher<std::pair<std::string, RespExpr>>> kv_matchers;
|
||||
BuildKvMatchers(kv_matchers, expected, std::make_index_sequence<exprected_pairs_number>{});
|
||||
|
||||
// Custom unordered comparison
|
||||
std::sort(received_pairs.begin(), received_pairs.end());
|
||||
std::sort(expected_pairs.begin(), expected_pairs.end());
|
||||
|
||||
return received_pairs == expected_pairs;
|
||||
}
|
||||
|
||||
template <typename... Matchers> auto IsMap(Matchers... matchers) {
|
||||
return IsMapMatcher(std::vector<std::string>{std::forward<Matchers>(matchers)...});
|
||||
}
|
||||
|
||||
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
|
||||
if (arg.type != RespExpr::ARRAY) {
|
||||
*result_listener << "Wrong response type: " << arg.type;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = arg.GetVec();
|
||||
size_t expected_size = std::tuple_size<decltype(expected)>::value;
|
||||
if (result.size() != expected_size + 1) {
|
||||
*result_listener << "Wrong resp array size: " << result.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result[0].GetInt() != expected_size) {
|
||||
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
|
||||
|
||||
// Create a vector of matchers from the tuple
|
||||
std::vector<Matcher<RespExpr>> matchers;
|
||||
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
|
||||
|
||||
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
|
||||
return ExplainMatchResult(UnorderedElementsAreArray(kv_matchers), received_pairs,
|
||||
result_listener);
|
||||
}
|
||||
|
||||
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) {
|
||||
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...));
|
||||
}
|
||||
|
||||
template <typename Expected, size_t... Is>
|
||||
void BuildKvMatchers(std::vector<Matcher<std::pair<std::string, RespExpr>>>& kv_matchers,
|
||||
const Expected& expected, std::index_sequence<Is...>) {
|
||||
std::initializer_list<int>{
|
||||
(kv_matchers.emplace_back(Pair(std::get<Is * 2>(expected), std::get<Is * 2 + 1>(expected))),
|
||||
0)...};
|
||||
template <typename... Args> auto IsMap(Args... args) {
|
||||
return IsMapMatcher(std::make_tuple(args...));
|
||||
}
|
||||
|
||||
MATCHER_P(IsMapWithSizeMatcher, expected, "") {
|
||||
|
@ -201,6 +164,38 @@ template <typename... Args> auto IsMapWithSize(Args... args) {
|
|||
return IsMapWithSizeMatcher(std::make_tuple(args...));
|
||||
}
|
||||
|
||||
MATCHER_P(IsUnordArrayWithSizeMatcher, expected, "") {
|
||||
if (arg.type != RespExpr::ARRAY) {
|
||||
*result_listener << "Wrong response type: " << arg.type;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto result = arg.GetVec();
|
||||
size_t expected_size = std::tuple_size<decltype(expected)>::value;
|
||||
if (result.size() != expected_size + 1) {
|
||||
*result_listener << "Wrong resp array size: " << result.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
if (result[0].GetInt() != expected_size) {
|
||||
*result_listener << "Wrong elements count: " << result[0].GetInt().value_or(-1);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<RespExpr> received_elements(result.begin() + 1, result.end());
|
||||
|
||||
// Create a vector of matchers from the tuple
|
||||
std::vector<Matcher<RespExpr>> matchers;
|
||||
std::apply([&matchers](auto&&... args) { ((matchers.push_back(args)), ...); }, expected);
|
||||
|
||||
return ExplainMatchResult(UnorderedElementsAreArray(matchers), received_elements,
|
||||
result_listener);
|
||||
}
|
||||
|
||||
template <typename... Matchers> auto IsUnordArrayWithSize(Matchers... matchers) {
|
||||
return IsUnordArrayWithSizeMatcher(std::make_tuple(matchers...));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, CreateDropListIndex) {
|
||||
EXPECT_EQ(Run({"ft.create", "idx-1", "ON", "HASH", "PREFIX", "1", "prefix-1"}), "OK");
|
||||
EXPECT_EQ(Run({"ft.create", "idx-2", "ON", "JSON", "PREFIX", "1", "prefix-2"}), "OK");
|
||||
|
@ -649,7 +644,7 @@ TEST_F(SearchFamilyTest, TestReturn) {
|
|||
|
||||
// Check non-existing field
|
||||
resp = Run({"ft.search", "i1", "@justA:0", "return", "1", "nothere"});
|
||||
EXPECT_THAT(resp, MatchEntry("k0", "nothere", ""));
|
||||
EXPECT_THAT(resp, MatchEntry("k0"));
|
||||
|
||||
// Checl implcit __vector_score is provided
|
||||
float score = 20;
|
||||
|
@ -1194,8 +1189,8 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
|
|||
IsMap("foo_total", "10", "word", "item1")));
|
||||
|
||||
// Test JSON
|
||||
Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":"10","text":"first key"})"});
|
||||
Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":"20","text":"second key"})"});
|
||||
Run({"JSON.SET", "j1", ".", R"({"word":"item1","foo":10,"text":"first key"})"});
|
||||
Run({"JSON.SET", "j2", ".", R"({"word":"item2","foo":20,"text":"second key"})"});
|
||||
|
||||
resp = Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.word", "AS", "word", "TAG", "$.foo",
|
||||
"AS", "foo", "NUMERIC", "$.text", "AS", "text", "TEXT"});
|
||||
|
@ -1214,4 +1209,220 @@ TEST_F(SearchFamilyTest, AggregateWithLoadOptionHard) {
|
|||
}
|
||||
#endif
|
||||
|
||||
TEST_F(SearchFamilyTest, WrongFieldTypeJson) {
|
||||
// Test simple
|
||||
Run({"JSON.SET", "j1", ".", R"({"value":"one"})"});
|
||||
Run({"JSON.SET", "j2", ".", R"({"value":1})"});
|
||||
|
||||
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC",
|
||||
"SORTABLE"}),
|
||||
"OK");
|
||||
|
||||
auto resp = Run({"FT.SEARCH", "i1", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2"));
|
||||
|
||||
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "$.value"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("$.value", "1")));
|
||||
|
||||
// Test with two fields. One is loading
|
||||
Run({"JSON.SET", "j3", ".", R"({"value":"two","another_value":1})"});
|
||||
Run({"JSON.SET", "j4", ".", R"({"value":2,"another_value":2})"});
|
||||
|
||||
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.value", "AS", "value", "NUMERIC"}),
|
||||
"OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "$.another_value"});
|
||||
EXPECT_THAT(
|
||||
resp, IsMapWithSize("j2", IsMap("$", R"({"value":1})"), "j4",
|
||||
IsMap("$", R"({"another_value":2,"value":2})", "$.another_value", "2")));
|
||||
|
||||
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "$.value", "$.another_value", "GROUPBY", "2",
|
||||
"$.value", "$.another_value", "REDUCE", "COUNT", "0", "AS", "count"});
|
||||
EXPECT_THAT(resp,
|
||||
IsUnordArrayWithSize(
|
||||
IsMap("$.value", "1", "$.another_value", ArgType(RespExpr::NIL), "count", "1"),
|
||||
IsMap("$.value", "2", "$.another_value", "2", "count", "1")));
|
||||
|
||||
// Test multiple field values
|
||||
Run({"JSON.SET", "j5", ".", R"({"arr":[{"id":1},{"id":"two"}]})"});
|
||||
Run({"JSON.SET", "j6", ".", R"({"arr":[{"id":1},{"id":2}]})"});
|
||||
Run({"JSON.SET", "j7", ".", R"({"arr":[]})"});
|
||||
|
||||
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i3", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
|
||||
|
||||
resp = Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.arr[*].id", "AS", "id", "NUMERIC",
|
||||
"SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i4", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j1", "j2", "j3", "j4", "j6", "j7")); // Only j5 fails
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, WrongFieldTypeHash) {
|
||||
// Test simple
|
||||
Run({"HSET", "h1", "value", "one"});
|
||||
Run({"HSET", "h2", "value", "1"});
|
||||
|
||||
EXPECT_EQ(Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "value", "NUMERIC", "SORTABLE"}), "OK");
|
||||
|
||||
auto resp = Run({"FT.SEARCH", "i1", "*"});
|
||||
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1")));
|
||||
|
||||
resp = Run({"FT.AGGREGATE", "i1", "*", "LOAD", "1", "@value"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("value", "1")));
|
||||
|
||||
// Test with two fields. One is loading
|
||||
Run({"HSET", "h3", "value", "two", "another_value", "1"});
|
||||
Run({"HSET", "h4", "value", "2", "another_value", "2"});
|
||||
|
||||
EXPECT_EQ(Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "value", "NUMERIC"}), "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i2", "*", "LOAD", "1", "@another_value"});
|
||||
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("value", "1"), "h4",
|
||||
IsMap("value", "2", "another_value", "2")));
|
||||
|
||||
resp = Run({"FT.AGGREGATE", "i2", "*", "LOAD", "2", "@value", "@another_value", "GROUPBY", "2",
|
||||
"@value", "@another_value", "REDUCE", "COUNT", "0", "AS", "count"});
|
||||
EXPECT_THAT(resp, IsUnordArrayWithSize(
|
||||
IsMap("value", "1", "another_value", ArgType(RespExpr::NIL), "count", "1"),
|
||||
IsMap("value", "2", "another_value", "2", "count", "1")));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, WrongFieldTypeHardJson) {
|
||||
Run({"JSON.SET", "j1", ".", R"({"data":1,"name":"doc_with_int"})"});
|
||||
Run({"JSON.SET", "j2", ".", R"({"data":"1","name":"doc_with_int_as_string"})"});
|
||||
Run({"JSON.SET", "j3", ".", R"({"data":"string","name":"doc_with_string"})"});
|
||||
Run({"JSON.SET", "j4", ".", R"({"name":"no_data"})"});
|
||||
Run({"JSON.SET", "j5", ".", R"({"data":[5,4,3],"name":"doc_with_vector"})"});
|
||||
Run({"JSON.SET", "j6", ".", R"({"data":"[5,4,3]","name":"doc_with_vector_as_string"})"});
|
||||
|
||||
auto resp = Run({"FT.CREATE", "i1", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run(
|
||||
{"FT.CREATE", "i2", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "NUMERIC", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i3", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp =
|
||||
Run({"FT.CREATE", "i4", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TAG", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i5", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp =
|
||||
Run({"FT.CREATE", "i6", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "TEXT", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i7", "ON", "JSON", "SCHEMA", "$.data", "AS", "data", "VECTOR", "FLAT",
|
||||
"6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i1", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i2", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j1", "j4", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i3", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i4", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j3", "j6", "j4"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i5", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i6", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j3", "j6"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i7", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j5"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, WrongFieldTypeHardHash) {
|
||||
Run({"HSET", "j1", "data", "1", "name", "doc_with_int"});
|
||||
Run({"HSET", "j2", "data", "1", "name", "doc_with_int_as_string"});
|
||||
Run({"HSET", "j3", "data", "string", "name", "doc_with_string"});
|
||||
Run({"HSET", "j4", "name", "no_data"});
|
||||
Run({"HSET", "j5", "data", "5,4,3", "name", "doc_with_fake_vector"});
|
||||
Run({"HSET", "j6", "data", "[5,4,3]", "name", "doc_with_fake_vector_as_string"});
|
||||
|
||||
// Vector [1, 2, 3]
|
||||
std::string vector = std::string("\x3f\x80\x00\x00\x40\x00\x00\x00\x40\x40\x00\x00", 12);
|
||||
Run({"HSET", "j7", "data", vector, "name", "doc_with_vector [1, 2, 3]"});
|
||||
|
||||
auto resp = Run({"FT.CREATE", "i1", "ON", "HASH", "SCHEMA", "data", "NUMERIC"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i2", "ON", "HASH", "SCHEMA", "data", "NUMERIC", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i3", "ON", "HASH", "SCHEMA", "data", "TAG"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i4", "ON", "HASH", "SCHEMA", "data", "TAG", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i5", "ON", "HASH", "SCHEMA", "data", "TEXT"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i6", "ON", "HASH", "SCHEMA", "data", "TEXT", "SORTABLE"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.CREATE", "i7", "ON", "HASH", "SCHEMA", "data", "VECTOR", "FLAT", "6", "TYPE",
|
||||
"FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "i1", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i2", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j1", "j4"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i3", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i4", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j2", "j7", "j3", "j6", "j1", "j4", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i5", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i6", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j2", "j7", "j3", "j6", "j1", "j5"));
|
||||
|
||||
resp = Run({"FT.SEARCH", "i7", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j4", "j7"));
|
||||
}
|
||||
|
||||
TEST_F(SearchFamilyTest, WrongVectorFieldType) {
|
||||
Run({"JSON.SET", "j1", ".",
|
||||
R"({"vector_field": [0.1, 0.2, 0.3], "name": "doc_with_correct_dim"})"});
|
||||
Run({"JSON.SET", "j2", ".", R"({"vector_field": [0.1, 0.2], "name": "doc_with_small_dim"})"});
|
||||
Run({"JSON.SET", "j3", ".",
|
||||
R"({"vector_field": [0.1, 0.2, 0.3, 0.4], "name": "doc_with_large_dim"})"});
|
||||
Run({"JSON.SET", "j4", ".", R"({"vector_field": [1, 2, 3], "name": "doc_with_int_values"})"});
|
||||
Run({"JSON.SET", "j5", ".",
|
||||
R"({"vector_field":"not_vector", "name":"doc_with_incorrect_field_type"})"});
|
||||
Run({"JSON.SET", "j6", ".", R"({"name":"doc_with_no_field"})"});
|
||||
Run({"JSON.SET", "j7", ".",
|
||||
R"({"vector_field": [999999999999999999999999999999999999999, -999999999999999999999999999999999999999, 500000000000000000000000000000000000000], "name": "doc_with_out_of_range_values"})"});
|
||||
|
||||
auto resp =
|
||||
Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.vector_field", "AS", "vector_field",
|
||||
"VECTOR", "FLAT", "6", "TYPE", "FLOAT32", "DIM", "3", "DISTANCE_METRIC", "L2"});
|
||||
EXPECT_EQ(resp, "OK");
|
||||
|
||||
resp = Run({"FT.SEARCH", "index", "*"});
|
||||
EXPECT_THAT(resp, AreDocIds("j6", "j7", "j1", "j4"));
|
||||
}
|
||||
|
||||
} // namespace dfly
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue