mirror of
https://github.com/dragonflydb/dragonfly.git
synced 2025-05-11 02:15:45 +02:00
feat(search): sized vectors (#1788)
* feat(search): Sized vectors --------- Signed-off-by: Vladislav Oleshko <vlad@dragonflydb.io>
This commit is contained in:
parent
36be222091
commit
aa4cadfa12
18 changed files with 275 additions and 122 deletions
|
@ -5,7 +5,7 @@ cur_gen_dir(gen_dir)
|
||||||
|
|
||||||
find_package(ICU REQUIRED COMPONENTS uc i18n)
|
find_package(ICU REQUIRED COMPONENTS uc i18n)
|
||||||
|
|
||||||
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector.cc compressed_sorted_set.cc
|
add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc
|
||||||
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)
|
${gen_dir}/parser.cc ${gen_dir}/lexer.cc)
|
||||||
|
|
||||||
target_link_libraries(query_parser ICU::uc ICU::i18n)
|
target_link_libraries(query_parser ICU::uc ICU::i18n)
|
||||||
|
|
|
@ -56,9 +56,11 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) {
|
||||||
tags.push_back(move(tag));
|
tags.push_back(move(tag));
|
||||||
}
|
}
|
||||||
|
|
||||||
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string field, FtVector vec)
|
AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec)
|
||||||
: filter{make_unique<AstNode>(move(filter))}, limit{limit}, field{field.substr(1)}, vector{move(
|
: filter{make_unique<AstNode>(std::move(filter))},
|
||||||
vec)} {
|
limit{limit},
|
||||||
|
field{field.substr(1)},
|
||||||
|
vec{std::move(vec)} {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dfly::search
|
} // namespace dfly::search
|
||||||
|
|
|
@ -74,12 +74,12 @@ struct AstTagsNode {
|
||||||
|
|
||||||
// Applies nearest neighbor search to the final result set
|
// Applies nearest neighbor search to the final result set
|
||||||
struct AstKnnNode {
|
struct AstKnnNode {
|
||||||
AstKnnNode(AstNode&& sub, size_t limit, std::string field, FtVector vec);
|
AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec);
|
||||||
|
|
||||||
std::unique_ptr<AstNode> filter;
|
std::unique_ptr<AstNode> filter;
|
||||||
size_t limit;
|
size_t limit;
|
||||||
std::string field;
|
std::string field;
|
||||||
FtVector vector;
|
OwnedFtVector vec;
|
||||||
};
|
};
|
||||||
|
|
||||||
using NodeVariants =
|
using NodeVariants =
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
#include <absl/container/flat_hash_map.h>
|
#include <absl/container/flat_hash_map.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
@ -14,7 +15,9 @@ namespace dfly::search {
|
||||||
|
|
||||||
using DocId = uint32_t;
|
using DocId = uint32_t;
|
||||||
|
|
||||||
using FtVector = std::vector<float>;
|
enum class VectorSimilarity { L2, COSINE };
|
||||||
|
|
||||||
|
using OwnedFtVector = std::pair<std::unique_ptr<float[]>, size_t /* dimension (size) */>;
|
||||||
|
|
||||||
// Query params represent named parameters for queries supplied via PARAMS.
|
// Query params represent named parameters for queries supplied via PARAMS.
|
||||||
struct QueryParams {
|
struct QueryParams {
|
||||||
|
@ -38,9 +41,11 @@ struct QueryParams {
|
||||||
|
|
||||||
// Interface for accessing document values with different data structures underneath.
|
// Interface for accessing document values with different data structures underneath.
|
||||||
struct DocumentAccessor {
|
struct DocumentAccessor {
|
||||||
|
using VectorInfo = search::OwnedFtVector;
|
||||||
|
|
||||||
virtual ~DocumentAccessor() = default;
|
virtual ~DocumentAccessor() = default;
|
||||||
virtual std::string_view GetString(std::string_view active_field) const = 0;
|
virtual std::string_view GetString(std::string_view active_field) const = 0;
|
||||||
virtual FtVector GetVector(std::string_view active_field) const = 0;
|
virtual VectorInfo GetVector(std::string_view active_field) const = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Base class for type-specific indices.
|
// Base class for type-specific indices.
|
||||||
|
|
|
@ -59,6 +59,11 @@ class CompressedSortedSet {
|
||||||
size_t Size() const;
|
size_t Size() const;
|
||||||
size_t ByteSize() const;
|
size_t ByteSize() const;
|
||||||
|
|
||||||
|
// To use transparently in templates together with stl containers
|
||||||
|
size_t size() const {
|
||||||
|
return Size();
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct EntryLocation {
|
struct EntryLocation {
|
||||||
IntType value; // Value or 0
|
IntType value; // Value or 0
|
||||||
|
|
|
@ -151,17 +151,31 @@ absl::flat_hash_set<std::string> TagIndex::Tokenize(std::string_view value) cons
|
||||||
return NormalizeTags(value);
|
return NormalizeTags(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} {
|
||||||
|
}
|
||||||
|
|
||||||
void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) {
|
||||||
entries_[id] = doc->GetVector(field);
|
DCHECK_LE(id * dim_, entries_.size());
|
||||||
|
if (id * dim_ == entries_.size())
|
||||||
|
entries_.resize((id + 1) * dim_);
|
||||||
|
|
||||||
|
// TODO: Let get vector write to buf itself
|
||||||
|
auto [ptr, size] = doc->GetVector(field);
|
||||||
|
|
||||||
|
if (size == dim_)
|
||||||
|
memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) {
|
||||||
entries_.erase(id);
|
// noop
|
||||||
}
|
}
|
||||||
|
|
||||||
FtVector VectorIndex::Get(DocId doc) const {
|
const float* VectorIndex::Get(DocId doc) const {
|
||||||
auto it = entries_.find(doc);
|
return &entries_[doc * dim_];
|
||||||
return it != entries_.end() ? it->second : FtVector{};
|
}
|
||||||
|
|
||||||
|
std::pair<size_t /*dim*/, VectorSimilarity> VectorIndex::Info() const {
|
||||||
|
return {dim_, sim_};
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dfly::search
|
} // namespace dfly::search
|
||||||
|
|
|
@ -57,13 +57,18 @@ struct TagIndex : public BaseStringIndex {
|
||||||
// Index for vector fields.
|
// Index for vector fields.
|
||||||
// Only supports lookup by id.
|
// Only supports lookup by id.
|
||||||
struct VectorIndex : public BaseIndex {
|
struct VectorIndex : public BaseIndex {
|
||||||
|
VectorIndex(size_t dim, VectorSimilarity sim);
|
||||||
|
|
||||||
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
void Add(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||||
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override;
|
||||||
|
|
||||||
FtVector Get(DocId doc) const;
|
const float* Get(DocId doc) const;
|
||||||
|
std::pair<size_t /*dim*/, VectorSimilarity> Info() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
absl::flat_hash_map<DocId, FtVector> entries_;
|
size_t dim_;
|
||||||
|
VectorSimilarity sim_;
|
||||||
|
std::vector<float> entries_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dfly::search
|
} // namespace dfly::search
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
// Added to cc file
|
// Added to cc file
|
||||||
%code {
|
%code {
|
||||||
#include "core/search/query_driver.h"
|
#include "core/search/query_driver.h"
|
||||||
#include "core/search/vector.h"
|
#include "core/search/vector_utils.h"
|
||||||
|
|
||||||
// Have to disable because GCC doesn't understand `symbol_type`'s union
|
// Have to disable because GCC doesn't understand `symbol_type`'s union
|
||||||
// implementation
|
// implementation
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#include <absl/strings/str_join.h>
|
#include <absl/strings/str_join.h>
|
||||||
|
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
#include <type_traits>
|
||||||
#include <variant>
|
#include <variant>
|
||||||
|
|
||||||
#include "base/logging.h"
|
#include "base/logging.h"
|
||||||
|
@ -18,7 +19,7 @@
|
||||||
#include "core/search/compressed_sorted_set.h"
|
#include "core/search/compressed_sorted_set.h"
|
||||||
#include "core/search/indices.h"
|
#include "core/search/indices.h"
|
||||||
#include "core/search/query_driver.h"
|
#include "core/search/query_driver.h"
|
||||||
#include "core/search/vector.h"
|
#include "core/search/vector_utils.h"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
@ -35,11 +36,18 @@ AstExpr ParseQuery(std::string_view query, const QueryParams* params) {
|
||||||
return driver.Take();
|
return driver.Take();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GCC 12 yields a wrong warning in a deeply inlined call in UnifyResults, only ignoring the whole
|
||||||
|
// scope solves it
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
|
||||||
|
|
||||||
// Represents an either owned or non-owned result set that can be accessed transparently.
|
// Represents an either owned or non-owned result set that can be accessed transparently.
|
||||||
struct IndexResult {
|
struct IndexResult {
|
||||||
using DocVec = vector<DocId>;
|
using DocVec = vector<DocId>;
|
||||||
|
using BorrowedView = variant<const DocVec*, const CompressedSortedSet*>;
|
||||||
|
|
||||||
IndexResult() : value_{DocVec{}} {};
|
IndexResult() : value_{DocVec{}} {
|
||||||
|
}
|
||||||
|
|
||||||
IndexResult(const CompressedSortedSet* css) : value_{css} {
|
IndexResult(const CompressedSortedSet* css) : value_{css} {
|
||||||
if (css == nullptr)
|
if (css == nullptr)
|
||||||
|
@ -49,10 +57,11 @@ struct IndexResult {
|
||||||
IndexResult(DocVec&& dv) : value_{move(dv)} {
|
IndexResult(DocVec&& dv) : value_{move(dv)} {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
IndexResult(const DocVec* dv) : value_{dv} {
|
||||||
|
}
|
||||||
|
|
||||||
size_t Size() const {
|
size_t Size() const {
|
||||||
if (holds_alternative<DocVec>(value_))
|
return visit([](auto* set) { return set->size(); }, Borrowed());
|
||||||
return get<DocVec>(value_).size();
|
|
||||||
return get<const CompressedSortedSet*>(value_)->Size();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsOwned() const {
|
bool IsOwned() const {
|
||||||
|
@ -64,28 +73,31 @@ struct IndexResult {
|
||||||
swap(get<DocVec>(value_), entries); // swap to keep backing array
|
swap(get<DocVec>(value_), entries); // swap to keep backing array
|
||||||
entries.clear();
|
entries.clear();
|
||||||
} else {
|
} else {
|
||||||
value_ = move(entries);
|
value_ = std::move(entries);
|
||||||
}
|
}
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
variant<const DocVec*, const CompressedSortedSet*> Borrowed() {
|
BorrowedView Borrowed() const {
|
||||||
if (holds_alternative<DocVec>(value_))
|
auto cb = [](const auto& v) -> BorrowedView {
|
||||||
return &get<DocVec>(value_);
|
if constexpr (is_pointer_v<remove_reference_t<decltype(v)>>)
|
||||||
return get<const CompressedSortedSet*>(value_);
|
return v;
|
||||||
|
else
|
||||||
|
return &v;
|
||||||
|
};
|
||||||
|
return visit(cb, value_);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Move out of owned or copy borrowed
|
// Move out of owned or copy borrowed
|
||||||
DocVec Take() {
|
DocVec Take() {
|
||||||
if (holds_alternative<DocVec>(value_))
|
if (IsOwned())
|
||||||
return move(get<DocVec>(value_));
|
return move(get<DocVec>(value_));
|
||||||
|
|
||||||
const CompressedSortedSet* css = get<const CompressedSortedSet*>(value_);
|
return visit([](auto* set) { return DocVec(set->begin(), set->end()); }, Borrowed());
|
||||||
return DocVec(css->begin(), css->end());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
variant<DocVec /*owned*/, const CompressedSortedSet* /* borrowed */> value_;
|
variant<DocVec /*owned*/, const CompressedSortedSet*, const DocVec*> value_;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ProfileBuilder {
|
struct ProfileBuilder {
|
||||||
|
@ -194,7 +206,7 @@ struct BasicSearch {
|
||||||
sort(sub_results.begin(), sub_results.end(),
|
sort(sub_results.begin(), sub_results.end(),
|
||||||
[](const auto& l, const auto& r) { return l.Size() < r.Size(); });
|
[](const auto& l, const auto& r) { return l.Size() < r.Size(); });
|
||||||
|
|
||||||
IndexResult out{move(sub_results[0])};
|
IndexResult out{std::move(sub_results[0])};
|
||||||
for (auto& matched : absl::MakeSpan(sub_results).subspan(1))
|
for (auto& matched : absl::MakeSpan(sub_results).subspan(1))
|
||||||
Merge(move(matched), &out, op);
|
Merge(move(matched), &out, op);
|
||||||
return out;
|
return out;
|
||||||
|
@ -206,7 +218,7 @@ struct BasicSearch {
|
||||||
|
|
||||||
IndexResult Search(const AstStarNode& node, string_view active_field) {
|
IndexResult Search(const AstStarNode& node, string_view active_field) {
|
||||||
DCHECK(active_field.empty());
|
DCHECK(active_field.empty());
|
||||||
return vector<DocId>{indices_->GetAllDocs()}; // TODO FIX;
|
return {&indices_->GetAllDocs()};
|
||||||
}
|
}
|
||||||
|
|
||||||
// "term": access field's text index or unify results from all text indices if no field is set
|
// "term": access field's text index or unify results from all text indices if no field is set
|
||||||
|
@ -268,19 +280,23 @@ struct BasicSearch {
|
||||||
auto sub_results = SearchGeneric(*knn.filter, active_field);
|
auto sub_results = SearchGeneric(*knn.filter, active_field);
|
||||||
|
|
||||||
auto* vec_index = GetIndex<VectorIndex>(knn.field);
|
auto* vec_index = GetIndex<VectorIndex>(knn.field);
|
||||||
|
if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second)
|
||||||
|
return IndexResult{};
|
||||||
|
|
||||||
distances_.reserve(sub_results.Size());
|
distances_.reserve(sub_results.Size());
|
||||||
auto cb = [&](auto* set) {
|
auto cb = [&](auto* set) {
|
||||||
|
auto [dim, sim] = vec_index->Info();
|
||||||
for (DocId matched_doc : *set) {
|
for (DocId matched_doc : *set) {
|
||||||
float dist = VectorDistance(knn.vector, vec_index->Get(matched_doc));
|
float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim);
|
||||||
distances_.emplace_back(dist, matched_doc);
|
distances_.emplace_back(dist, matched_doc);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
visit(cb, sub_results.Borrowed());
|
visit(cb, sub_results.Borrowed());
|
||||||
|
|
||||||
sort(distances_.begin(), distances_.end());
|
size_t prefix_size = min(knn.limit, distances_.size());
|
||||||
|
partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end());
|
||||||
|
|
||||||
vector<DocId> out(min(knn.limit, distances_.size()));
|
vector<DocId> out(prefix_size);
|
||||||
for (size_t i = 0; i < out.size(); i++)
|
for (size_t i = 0; i < out.size(); i++)
|
||||||
out[i] = distances_[i].second;
|
out[i] = distances_[i].second;
|
||||||
|
|
||||||
|
@ -331,6 +347,8 @@ struct BasicSearch {
|
||||||
vector<pair<float, DocId>> distances_;
|
vector<pair<float, DocId>> distances_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
|
FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} {
|
||||||
|
@ -346,7 +364,7 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i
|
||||||
indices_[field_ident] = make_unique<NumericIndex>();
|
indices_[field_ident] = make_unique<NumericIndex>();
|
||||||
break;
|
break;
|
||||||
case SchemaField::VECTOR:
|
case SchemaField::VECTOR:
|
||||||
indices_[field_ident] = make_unique<VectorIndex>();
|
indices_[field_ident] = make_unique<VectorIndex>(field_info.knn_dim, field_info.knn_sim);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,9 @@ struct SchemaField {
|
||||||
|
|
||||||
FieldType type;
|
FieldType type;
|
||||||
std::string short_name; // equal to ident if none provided
|
std::string short_name; // equal to ident if none provided
|
||||||
|
|
||||||
|
size_t knn_dim = 0u; // dimension of knn vectors
|
||||||
|
VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type
|
||||||
};
|
};
|
||||||
|
|
||||||
// Describes the fields of an index
|
// Describes the fields of an index
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include "base/logging.h"
|
#include "base/logging.h"
|
||||||
#include "core/search/base.h"
|
#include "core/search/base.h"
|
||||||
#include "core/search/query_driver.h"
|
#include "core/search/query_driver.h"
|
||||||
|
#include "core/search/vector_utils.h"
|
||||||
|
|
||||||
namespace dfly {
|
namespace dfly {
|
||||||
namespace search {
|
namespace search {
|
||||||
|
@ -40,15 +41,8 @@ struct MockedDocument : public DocumentAccessor {
|
||||||
return it != fields_.end() ? string_view{it->second} : "";
|
return it != fields_.end() ? string_view{it->second} : "";
|
||||||
}
|
}
|
||||||
|
|
||||||
FtVector GetVector(string_view field) const override {
|
VectorInfo GetVector(string_view field) const override {
|
||||||
string_view str_value = fields_.at(field);
|
return BytesToFtVector(GetString(field));
|
||||||
FtVector out;
|
|
||||||
for (string_view coord : absl::StrSplit(str_value, ',')) {
|
|
||||||
float v;
|
|
||||||
CHECK(absl::SimpleAtof(coord, &v));
|
|
||||||
out.push_back(v);
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
string DebugFormat() {
|
string DebugFormat() {
|
||||||
|
@ -331,17 +325,18 @@ TEST_F(SearchParserTest, IntegerTerms) {
|
||||||
EXPECT_TRUE(Check()) << GetError();
|
EXPECT_TRUE(Check()) << GetError();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string FtVectorToBytes(FtVector vec) {
|
std::string ToBytes(absl::Span<const float> vec) {
|
||||||
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
|
return string{reinterpret_cast<const char*>(vec.data()), sizeof(float) * vec.size()};
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SearchParserTest, SimpleKnn) {
|
TEST_F(SearchParserTest, SimpleKnn) {
|
||||||
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
|
auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}});
|
||||||
|
schema.fields["pos"].knn_dim = 1;
|
||||||
FieldIndices indices{schema};
|
FieldIndices indices{schema};
|
||||||
|
|
||||||
// Place points on a straight line
|
// Place points on a straight line
|
||||||
for (size_t i = 0; i < 100; i++) {
|
for (size_t i = 0; i < 100; i++) {
|
||||||
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", to_string(float(i))}}};
|
Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}};
|
||||||
MockedDocument doc{values};
|
MockedDocument doc{values};
|
||||||
indices.Add(i, &doc);
|
indices.Add(i, &doc);
|
||||||
}
|
}
|
||||||
|
@ -351,35 +346,35 @@ TEST_F(SearchParserTest, SimpleKnn) {
|
||||||
|
|
||||||
// Five closest to 50
|
// Five closest to 50
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{50.0});
|
params["vec"] = ToBytes({50.0});
|
||||||
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Five closest to 0
|
// Five closest to 0
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0.0});
|
params["vec"] = ToBytes({0.0});
|
||||||
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
algo.Init("*=>[KNN 5 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Five closest to 20, all even
|
// Five closest to 20, all even
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{20.0});
|
params["vec"] = ToBytes({20.0});
|
||||||
algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", ¶ms);
|
algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Three closest to 31, all odd
|
// Three closest to 31, all odd
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{31.0});
|
params["vec"] = ToBytes({31.0});
|
||||||
algo.Init("@even:{no} =>[KNN 3 @pos $vec]", ¶ms);
|
algo.Init("@even:{no} =>[KNN 3 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Two closest to 70.5
|
// Two closest to 70.5
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{70.5});
|
params["vec"] = ToBytes({70.5});
|
||||||
algo.Init("* =>[KNN 2 @pos $vec]", ¶ms);
|
algo.Init("* =>[KNN 2 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71));
|
||||||
}
|
}
|
||||||
|
@ -393,11 +388,11 @@ TEST_F(SearchParserTest, Simple2dKnn) {
|
||||||
const pair<float, float> kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}};
|
const pair<float, float> kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}};
|
||||||
|
|
||||||
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
|
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
|
||||||
|
schema.fields["pos"].knn_dim = 2;
|
||||||
FieldIndices indices{schema};
|
FieldIndices indices{schema};
|
||||||
|
|
||||||
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
|
for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) {
|
||||||
auto [x, y] = kTestCoords[i];
|
string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second});
|
||||||
string coords = absl::StrCat(x, ",", y);
|
|
||||||
MockedDocument doc{Map{{"pos", coords}}};
|
MockedDocument doc{Map{{"pos", coords}}};
|
||||||
indices.Add(i, &doc);
|
indices.Add(i, &doc);
|
||||||
}
|
}
|
||||||
|
@ -407,47 +402,83 @@ TEST_F(SearchParserTest, Simple2dKnn) {
|
||||||
|
|
||||||
// Single center
|
// Single center
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0.5, 0.5});
|
params["vec"] = ToBytes({0.5, 0.5});
|
||||||
algo.Init("* =>[KNN 1 @pos $vec]", ¶ms);
|
algo.Init("* =>[KNN 1 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lower left
|
// Lower left
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0, 0});
|
params["vec"] = ToBytes({0, 0});
|
||||||
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upper right
|
// Upper right
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{1, 1});
|
params["vec"] = ToBytes({1, 1});
|
||||||
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
algo.Init("* =>[KNN 4 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Request more than there is
|
// Request more than there is
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0, 0});
|
params["vec"] = ToBytes({0, 0});
|
||||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test correct order: (0.7, 0.15)
|
// Test correct order: (0.7, 0.15)
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0.7, 0.15});
|
params["vec"] = ToBytes({0.7, 0.15});
|
||||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test correct order: (0.8, 0.9)
|
// Test correct order: (0.8, 0.9)
|
||||||
{
|
{
|
||||||
params["vec"] = FtVectorToBytes(FtVector{0.8, 0.9});
|
params["vec"] = ToBytes({0.8, 0.9});
|
||||||
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
algo.Init("* => [KNN 10 @pos $vec]", ¶ms);
|
||||||
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0));
|
EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void BM_VectorSearch(benchmark::State& state) {
|
||||||
|
unsigned ndims = state.range(0);
|
||||||
|
unsigned nvecs = state.range(1);
|
||||||
|
|
||||||
|
auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}});
|
||||||
|
schema.fields["pos"].knn_dim = ndims;
|
||||||
|
FieldIndices indices{schema};
|
||||||
|
|
||||||
|
auto random_vec = [ndims]() {
|
||||||
|
vector<float> coords;
|
||||||
|
for (size_t j = 0; j < ndims; j++)
|
||||||
|
coords.push_back(static_cast<float>(rand()) / static_cast<float>(RAND_MAX));
|
||||||
|
return coords;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < nvecs; i++) {
|
||||||
|
auto rv = random_vec();
|
||||||
|
MockedDocument doc{Map{{"pos", ToBytes(rv)}}};
|
||||||
|
indices.Add(i, &doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
SearchAlgorithm algo{};
|
||||||
|
QueryParams params;
|
||||||
|
|
||||||
|
auto rv = random_vec();
|
||||||
|
params["vec"] = ToBytes(rv);
|
||||||
|
algo.Init("* =>[KNN 1 @pos $vec]", ¶ms);
|
||||||
|
|
||||||
|
while (state.KeepRunningBatch(10)) {
|
||||||
|
for (size_t i = 0; i < 10; i++)
|
||||||
|
benchmark::DoNotOptimize(algo.Search(&indices));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
BENCHMARK(BM_VectorSearch)->Args({120, 10'000});
|
||||||
|
|
||||||
} // namespace search
|
} // namespace search
|
||||||
|
|
||||||
} // namespace dfly
|
} // namespace dfly
|
||||||
|
|
|
@ -1,38 +0,0 @@
|
||||||
// Copyright 2023, DragonflyDB authors. All rights reserved.
|
|
||||||
// See LICENSE for licensing terms.
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "core/search/vector.h"
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "base/logging.h"
|
|
||||||
|
|
||||||
namespace dfly::search {
|
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
FtVector BytesToFtVector(string_view value) {
|
|
||||||
DCHECK_EQ(value.size() % sizeof(float), 0u);
|
|
||||||
FtVector out(value.size() / sizeof(float));
|
|
||||||
|
|
||||||
// Create copy for aligned access
|
|
||||||
unique_ptr<float[]> float_ptr = make_unique<float[]>(out.size());
|
|
||||||
memcpy(float_ptr.get(), value.data(), value.size());
|
|
||||||
|
|
||||||
for (size_t i = 0; i < out.size(); i++)
|
|
||||||
out[i] = float_ptr[i];
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
|
|
||||||
__attribute__((optimize("fast-math"))) float VectorDistance(const FtVector& u, const FtVector& v) {
|
|
||||||
DCHECK_EQ(u.size(), v.size());
|
|
||||||
float sum = 0;
|
|
||||||
for (size_t i = 0; i < u.size(); i++)
|
|
||||||
sum += (u[i] - v[i]) * (u[i] - v[i]);
|
|
||||||
return sqrt(sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace dfly::search
|
|
65
src/core/search/vector_utils.cc
Normal file
65
src/core/search/vector_utils.cc
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
// Copyright 2023, DragonflyDB authors. All rights reserved.
|
||||||
|
// See LICENSE for licensing terms.
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "core/search/vector_utils.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "base/logging.h"
|
||||||
|
|
||||||
|
namespace dfly::search {
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 )
|
||||||
|
__attribute__((optimize("fast-math"))) float L2Distance(const float* u, const float* v,
|
||||||
|
size_t dims) {
|
||||||
|
float sum = 0;
|
||||||
|
for (size_t i = 0; i < dims; i++)
|
||||||
|
sum += (u[i] - v[i]) * (u[i] - v[i]);
|
||||||
|
return sqrt(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
__attribute__((optimize("fast-math"))) float CosineDistance(const float* u, const float* v,
|
||||||
|
size_t dims) {
|
||||||
|
float sum_uv = 0, sum_uu = 0, sum_vv = 0;
|
||||||
|
for (size_t i = 0; i < dims; i++) {
|
||||||
|
sum_uv += u[i] * v[i];
|
||||||
|
sum_uu += u[i] * u[i];
|
||||||
|
sum_vv += v[i] * v[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (float denom = sum_uu * sum_vv; denom != 0.0f)
|
||||||
|
return sum_uv / sqrt(denom);
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
OwnedFtVector BytesToFtVector(string_view value) {
|
||||||
|
DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size();
|
||||||
|
|
||||||
|
// Value cannot be casted directly as it might be not aligned as a float (4 bytes).
|
||||||
|
// Misaligned memory access is UB.
|
||||||
|
size_t size = value.size() / sizeof(float);
|
||||||
|
auto out = make_unique<float[]>(size);
|
||||||
|
memcpy(out.get(), value.data(), size * sizeof(float));
|
||||||
|
|
||||||
|
return {std::move(out), size};
|
||||||
|
}
|
||||||
|
|
||||||
|
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) {
|
||||||
|
switch (sim) {
|
||||||
|
case VectorSimilarity::L2:
|
||||||
|
return L2Distance(u, v, dims);
|
||||||
|
case VectorSimilarity::COSINE:
|
||||||
|
return CosineDistance(u, v, dims);
|
||||||
|
};
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dfly::search
|
|
@ -8,8 +8,8 @@
|
||||||
|
|
||||||
namespace dfly::search {
|
namespace dfly::search {
|
||||||
|
|
||||||
FtVector BytesToFtVector(std::string_view value);
|
OwnedFtVector BytesToFtVector(std::string_view value);
|
||||||
|
|
||||||
float VectorDistance(const FtVector& v1, const FtVector& v2);
|
float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim);
|
||||||
|
|
||||||
} // namespace dfly::search
|
} // namespace dfly::search
|
|
@ -157,6 +157,10 @@ struct CmdArgParser {
|
||||||
return cur_i_ < args_.size() && !error_;
|
return cur_i_ < args_.size() && !error_;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HasError() {
|
||||||
|
return error_.has_value();
|
||||||
|
}
|
||||||
|
|
||||||
// Get optional error if occured
|
// Get optional error if occured
|
||||||
std::optional<ErrorInfo> Error() {
|
std::optional<ErrorInfo> Error() {
|
||||||
return std::exchange(error_, {});
|
return std::exchange(error_, {});
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
|
|
||||||
#include "core/json_object.h"
|
#include "core/json_object.h"
|
||||||
#include "core/search/search.h"
|
#include "core/search/search.h"
|
||||||
#include "core/search/vector.h"
|
#include "core/search/vector_utils.h"
|
||||||
#include "core/string_map.h"
|
#include "core/string_map.h"
|
||||||
#include "server/container_utils.h"
|
#include "server/container_utils.h"
|
||||||
|
|
||||||
|
@ -32,10 +32,11 @@ string_view SdsToSafeSv(sds str) {
|
||||||
}
|
}
|
||||||
|
|
||||||
string PrintField(search::SchemaField::FieldType type, string_view value) {
|
string PrintField(search::SchemaField::FieldType type, string_view value) {
|
||||||
if (type == search::SchemaField::VECTOR)
|
if (type == search::SchemaField::VECTOR) {
|
||||||
return absl::StrCat("[", absl::StrJoin(search::BytesToFtVector(value), ","), "]");
|
auto [ptr, size] = search::BytesToFtVector(value);
|
||||||
else
|
return absl::StrCat("[", absl::StrJoin(absl::Span<const float>{ptr.get(), size}, ","), "]");
|
||||||
return string{value};
|
}
|
||||||
|
return string{value};
|
||||||
}
|
}
|
||||||
|
|
||||||
string ExtractValue(const search::Schema& schema, string_view key, string_view value) {
|
string ExtractValue(const search::Schema& schema, string_view key, string_view value) {
|
||||||
|
@ -63,7 +64,7 @@ string_view ListPackAccessor::GetString(string_view active_field) const {
|
||||||
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
|
return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv);
|
||||||
}
|
}
|
||||||
|
|
||||||
search::FtVector ListPackAccessor::GetVector(string_view active_field) const {
|
BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const {
|
||||||
return search::BytesToFtVector(GetString(active_field));
|
return search::BytesToFtVector(GetString(active_field));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,7 +90,7 @@ string_view StringMapAccessor::GetString(string_view active_field) const {
|
||||||
return SdsToSafeSv(hset_->Find(active_field));
|
return SdsToSafeSv(hset_->Find(active_field));
|
||||||
}
|
}
|
||||||
|
|
||||||
search::FtVector StringMapAccessor::GetVector(string_view active_field) const {
|
BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const {
|
||||||
return search::BytesToFtVector(GetString(active_field));
|
return search::BytesToFtVector(GetString(active_field));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,16 +114,20 @@ string_view JsonAccessor::GetString(string_view active_field) const {
|
||||||
return buf_;
|
return buf_;
|
||||||
}
|
}
|
||||||
|
|
||||||
search::FtVector JsonAccessor::GetVector(string_view active_field) const {
|
BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const {
|
||||||
auto res = GetPath(active_field)->evaluate(json_);
|
auto res = GetPath(active_field)->evaluate(json_);
|
||||||
DCHECK(res.is_array());
|
DCHECK(res.is_array());
|
||||||
if (res.empty())
|
if (res.empty())
|
||||||
return {};
|
return {nullptr, 0};
|
||||||
|
|
||||||
search::FtVector out;
|
size_t size = res[0].size();
|
||||||
|
auto ptr = make_unique<float[]>(size);
|
||||||
|
|
||||||
|
size_t i = 0;
|
||||||
for (auto v : res[0].array_range())
|
for (auto v : res[0].array_range())
|
||||||
out.push_back(v.as<float>());
|
ptr[i++] = v.as<float>();
|
||||||
return out;
|
|
||||||
|
return {std::move(ptr), size};
|
||||||
}
|
}
|
||||||
|
|
||||||
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
|
JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const {
|
||||||
|
|
|
@ -40,7 +40,7 @@ struct ListPackAccessor : public BaseAccessor {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view GetString(std::string_view field) const override;
|
std::string_view GetString(std::string_view field) const override;
|
||||||
search::FtVector GetVector(std::string_view field) const override;
|
VectorInfo GetVector(std::string_view field) const override;
|
||||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -54,7 +54,7 @@ struct StringMapAccessor : public BaseAccessor {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view GetString(std::string_view field) const override;
|
std::string_view GetString(std::string_view field) const override;
|
||||||
search::FtVector GetVector(std::string_view field) const override;
|
VectorInfo GetVector(std::string_view field) const override;
|
||||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -69,7 +69,7 @@ struct JsonAccessor : public BaseAccessor {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string_view GetString(std::string_view field) const override;
|
std::string_view GetString(std::string_view field) const override;
|
||||||
search::FtVector GetVector(std::string_view field) const override;
|
VectorInfo GetVector(std::string_view field) const override;
|
||||||
SearchDocData Serialize(const search::Schema& schema) const override;
|
SearchDocData Serialize(const search::Schema& schema) const override;
|
||||||
|
|
||||||
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
|
// The JsonAccessor works with structured types and not plain strings, so an overload is needed
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "base/logging.h"
|
#include "base/logging.h"
|
||||||
#include "core/json_object.h"
|
#include "core/json_object.h"
|
||||||
#include "core/search/search.h"
|
#include "core/search/search.h"
|
||||||
#include "core/search/vector.h"
|
#include "core/search/vector_utils.h"
|
||||||
#include "facade/cmd_arg_parser.h"
|
#include "facade/cmd_arg_parser.h"
|
||||||
#include "facade/error.h"
|
#include "facade/error.h"
|
||||||
#include "facade/reply_builder.h"
|
#include "facade/reply_builder.h"
|
||||||
|
@ -46,6 +46,30 @@ bool IsValidJsonPath(string_view path) {
|
||||||
return !ec;
|
return !ec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pair<size_t, search::VectorSimilarity> ParseVectorFieldInfo(CmdArgParser* parser,
|
||||||
|
ConnectionContext* cntx) {
|
||||||
|
size_t dim = 0;
|
||||||
|
search::VectorSimilarity sim = search::VectorSimilarity::L2;
|
||||||
|
|
||||||
|
size_t num_args = parser->Next().Int<size_t>();
|
||||||
|
for (size_t i = 0; i * 2 < num_args; i++) {
|
||||||
|
parser->ToUpper();
|
||||||
|
if (parser->Check("DIM").ExpectTail(1)) {
|
||||||
|
dim = parser->Next().Int<size_t>();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) {
|
||||||
|
sim = parser->Next()
|
||||||
|
.Case("L2", search::VectorSimilarity::L2)
|
||||||
|
.Case("COSINE", search::VectorSimilarity::COSINE);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
parser->Skip(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
return {dim, sim};
|
||||||
|
}
|
||||||
|
|
||||||
optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParser parser,
|
optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParser parser,
|
||||||
ConnectionContext* cntx) {
|
ConnectionContext* cntx) {
|
||||||
search::Schema schema;
|
search::Schema schema;
|
||||||
|
@ -74,15 +98,24 @@ optional<search::Schema> ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse
|
||||||
return nullopt;
|
return nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip {algorithm} {dim} flags
|
// Vector fields include: {algorithm} num_args args...
|
||||||
if (*type == search::SchemaField::VECTOR)
|
size_t knn_dim = 0;
|
||||||
parser.Skip(2);
|
search::VectorSimilarity knn_sim = search::VectorSimilarity::L2;
|
||||||
|
if (*type == search::SchemaField::VECTOR) {
|
||||||
|
parser.Skip(1); // algorithm
|
||||||
|
std::tie(knn_dim, knn_sim) = ParseVectorFieldInfo(&parser, cntx);
|
||||||
|
|
||||||
|
if (!parser.HasError() && knn_dim == 0) {
|
||||||
|
(*cntx)->SendError("Vector dimension cannot be zero");
|
||||||
|
return nullopt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Skip all trailing ignored parameters
|
// Skip all trailing ignored parameters
|
||||||
while (kIgnoredOptions.count(parser.Peek()) > 0)
|
while (kIgnoredOptions.count(parser.Peek()) > 0)
|
||||||
parser.Skip(2);
|
parser.Skip(2);
|
||||||
|
|
||||||
schema.fields[field] = {*type, string{field_alias}};
|
schema.fields[field] = {*type, string{field_alias}, knn_dim, knn_sim};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build field name mapping table
|
// Build field name mapping table
|
||||||
|
@ -208,8 +241,9 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Span<SearchRes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
partial_sort(docs.begin(),
|
size_t prefix = min(params.limit_offset + params.limit_total, knn_limit);
|
||||||
docs.begin() + min(params.limit_offset + params.limit_total, knn_limit), docs.end(),
|
|
||||||
|
partial_sort(docs.begin(), docs.begin() + min(docs.size(), prefix), docs.end(),
|
||||||
[](const auto* l, const auto* r) { return l->knn_distance < r->knn_distance; });
|
[](const auto* l, const auto* r) { return l->knn_distance < r->knn_distance; });
|
||||||
docs.resize(min(docs.size(), knn_limit));
|
docs.resize(min(docs.size(), knn_limit));
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue