diff --git a/src/core/search/CMakeLists.txt b/src/core/search/CMakeLists.txt index a68c73bbe..bbbe89e53 100644 --- a/src/core/search/CMakeLists.txt +++ b/src/core/search/CMakeLists.txt @@ -5,7 +5,7 @@ cur_gen_dir(gen_dir) find_package(ICU REQUIRED COMPONENTS uc i18n) -add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector.cc compressed_sorted_set.cc +add_library(query_parser ast_expr.cc query_driver.cc search.cc indices.cc vector_utils.cc compressed_sorted_set.cc ${gen_dir}/parser.cc ${gen_dir}/lexer.cc) target_link_libraries(query_parser ICU::uc ICU::i18n) diff --git a/src/core/search/ast_expr.cc b/src/core/search/ast_expr.cc index 57208a5ca..86c763d97 100644 --- a/src/core/search/ast_expr.cc +++ b/src/core/search/ast_expr.cc @@ -56,9 +56,11 @@ AstTagsNode::AstTagsNode(AstExpr&& l, std::string tag) { tags.push_back(move(tag)); } -AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string field, FtVector vec) - : filter{make_unique(move(filter))}, limit{limit}, field{field.substr(1)}, vector{move( - vec)} { +AstKnnNode::AstKnnNode(AstNode&& filter, size_t limit, std::string_view field, OwnedFtVector vec) + : filter{make_unique(std::move(filter))}, + limit{limit}, + field{field.substr(1)}, + vec{std::move(vec)} { } } // namespace dfly::search diff --git a/src/core/search/ast_expr.h b/src/core/search/ast_expr.h index 7ff81b1f7..1fb1e8401 100644 --- a/src/core/search/ast_expr.h +++ b/src/core/search/ast_expr.h @@ -74,12 +74,12 @@ struct AstTagsNode { // Applies nearest neighbor search to the final result set struct AstKnnNode { - AstKnnNode(AstNode&& sub, size_t limit, std::string field, FtVector vec); + AstKnnNode(AstNode&& sub, size_t limit, std::string_view field, OwnedFtVector vec); std::unique_ptr filter; size_t limit; std::string field; - FtVector vector; + OwnedFtVector vec; }; using NodeVariants = diff --git a/src/core/search/base.h b/src/core/search/base.h index 41f8f554e..79d2e6b18 100644 --- a/src/core/search/base.h +++ b/src/core/search/base.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -14,7 +15,9 @@ namespace dfly::search { using DocId = uint32_t; -using FtVector = std::vector; +enum class VectorSimilarity { L2, COSINE }; + +using OwnedFtVector = std::pair, size_t /* dimension (size) */>; // Query params represent named parameters for queries supplied via PARAMS. struct QueryParams { @@ -38,9 +41,11 @@ struct QueryParams { // Interface for accessing document values with different data structures underneath. struct DocumentAccessor { + using VectorInfo = search::OwnedFtVector; + virtual ~DocumentAccessor() = default; virtual std::string_view GetString(std::string_view active_field) const = 0; - virtual FtVector GetVector(std::string_view active_field) const = 0; + virtual VectorInfo GetVector(std::string_view active_field) const = 0; }; // Base class for type-specific indices. diff --git a/src/core/search/compressed_sorted_set.h b/src/core/search/compressed_sorted_set.h index 0ae429ab9..927fa014c 100644 --- a/src/core/search/compressed_sorted_set.h +++ b/src/core/search/compressed_sorted_set.h @@ -59,6 +59,11 @@ class CompressedSortedSet { size_t Size() const; size_t ByteSize() const; + // To use transparently in templates together with stl containers + size_t size() const { + return Size(); + } + private: struct EntryLocation { IntType value; // Value or 0 diff --git a/src/core/search/indices.cc b/src/core/search/indices.cc index d45a939d9..032e03fe4 100644 --- a/src/core/search/indices.cc +++ b/src/core/search/indices.cc @@ -151,17 +151,31 @@ absl::flat_hash_set TagIndex::Tokenize(std::string_view value) cons return NormalizeTags(value); } +VectorIndex::VectorIndex(size_t dim, VectorSimilarity sim) : dim_{dim}, sim_{sim}, entries_{} { +} + void VectorIndex::Add(DocId id, DocumentAccessor* doc, string_view field) { - entries_[id] = doc->GetVector(field); + DCHECK_LE(id * dim_, entries_.size()); + if (id * dim_ == entries_.size()) + entries_.resize((id + 1) * dim_); + + // TODO: Let get vector write to buf itself + auto [ptr, size] = doc->GetVector(field); + + if (size == dim_) + memcpy(&entries_[id * dim_], ptr.get(), dim_ * sizeof(float)); } void VectorIndex::Remove(DocId id, DocumentAccessor* doc, string_view field) { - entries_.erase(id); + // noop } -FtVector VectorIndex::Get(DocId doc) const { - auto it = entries_.find(doc); - return it != entries_.end() ? it->second : FtVector{}; +const float* VectorIndex::Get(DocId doc) const { + return &entries_[doc * dim_]; +} + +std::pair VectorIndex::Info() const { + return {dim_, sim_}; } } // namespace dfly::search diff --git a/src/core/search/indices.h b/src/core/search/indices.h index 1b1bea42c..ca76fc3cb 100644 --- a/src/core/search/indices.h +++ b/src/core/search/indices.h @@ -57,13 +57,18 @@ struct TagIndex : public BaseStringIndex { // Index for vector fields. // Only supports lookup by id. struct VectorIndex : public BaseIndex { + VectorIndex(size_t dim, VectorSimilarity sim); + void Add(DocId id, DocumentAccessor* doc, std::string_view field) override; void Remove(DocId id, DocumentAccessor* doc, std::string_view field) override; - FtVector Get(DocId doc) const; + const float* Get(DocId doc) const; + std::pair Info() const; private: - absl::flat_hash_map entries_; + size_t dim_; + VectorSimilarity sim_; + std::vector entries_; }; } // namespace dfly::search diff --git a/src/core/search/parser.y b/src/core/search/parser.y index 36e270a39..86d24fa2a 100644 --- a/src/core/search/parser.y +++ b/src/core/search/parser.y @@ -25,7 +25,7 @@ // Added to cc file %code { #include "core/search/query_driver.h" -#include "core/search/vector.h" +#include "core/search/vector_utils.h" // Have to disable because GCC doesn't understand `symbol_type`'s union // implementation diff --git a/src/core/search/search.cc b/src/core/search/search.cc index 01c61fe73..06ea23904 100644 --- a/src/core/search/search.cc +++ b/src/core/search/search.cc @@ -10,6 +10,7 @@ #include #include +#include #include #include "base/logging.h" @@ -18,7 +19,7 @@ #include "core/search/compressed_sorted_set.h" #include "core/search/indices.h" #include "core/search/query_driver.h" -#include "core/search/vector.h" +#include "core/search/vector_utils.h" using namespace std; @@ -35,11 +36,18 @@ AstExpr ParseQuery(std::string_view query, const QueryParams* params) { return driver.Take(); } +// GCC 12 yields a wrong warning in a deeply inlined call in UnifyResults, only ignoring the whole +// scope solves it +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" + // Represents an either owned or non-owned result set that can be accessed transparently. struct IndexResult { using DocVec = vector; + using BorrowedView = variant; - IndexResult() : value_{DocVec{}} {}; + IndexResult() : value_{DocVec{}} { + } IndexResult(const CompressedSortedSet* css) : value_{css} { if (css == nullptr) @@ -49,10 +57,11 @@ struct IndexResult { IndexResult(DocVec&& dv) : value_{move(dv)} { } + IndexResult(const DocVec* dv) : value_{dv} { + } + size_t Size() const { - if (holds_alternative(value_)) - return get(value_).size(); - return get(value_)->Size(); + return visit([](auto* set) { return set->size(); }, Borrowed()); } bool IsOwned() const { @@ -64,28 +73,31 @@ struct IndexResult { swap(get(value_), entries); // swap to keep backing array entries.clear(); } else { - value_ = move(entries); + value_ = std::move(entries); } return *this; } - variant Borrowed() { - if (holds_alternative(value_)) - return &get(value_); - return get(value_); + BorrowedView Borrowed() const { + auto cb = [](const auto& v) -> BorrowedView { + if constexpr (is_pointer_v>) + return v; + else + return &v; + }; + return visit(cb, value_); } // Move out of owned or copy borrowed DocVec Take() { - if (holds_alternative(value_)) + if (IsOwned()) return move(get(value_)); - const CompressedSortedSet* css = get(value_); - return DocVec(css->begin(), css->end()); + return visit([](auto* set) { return DocVec(set->begin(), set->end()); }, Borrowed()); } private: - variant value_; + variant value_; }; struct ProfileBuilder { @@ -194,7 +206,7 @@ struct BasicSearch { sort(sub_results.begin(), sub_results.end(), [](const auto& l, const auto& r) { return l.Size() < r.Size(); }); - IndexResult out{move(sub_results[0])}; + IndexResult out{std::move(sub_results[0])}; for (auto& matched : absl::MakeSpan(sub_results).subspan(1)) Merge(move(matched), &out, op); return out; @@ -206,7 +218,7 @@ struct BasicSearch { IndexResult Search(const AstStarNode& node, string_view active_field) { DCHECK(active_field.empty()); - return vector{indices_->GetAllDocs()}; // TODO FIX; + return {&indices_->GetAllDocs()}; } // "term": access field's text index or unify results from all text indices if no field is set @@ -268,19 +280,23 @@ struct BasicSearch { auto sub_results = SearchGeneric(*knn.filter, active_field); auto* vec_index = GetIndex(knn.field); + if (auto [dim, _] = vec_index->Info(); dim != knn.vec.second) + return IndexResult{}; distances_.reserve(sub_results.Size()); auto cb = [&](auto* set) { + auto [dim, sim] = vec_index->Info(); for (DocId matched_doc : *set) { - float dist = VectorDistance(knn.vector, vec_index->Get(matched_doc)); + float dist = VectorDistance(knn.vec.first.get(), vec_index->Get(matched_doc), dim, sim); distances_.emplace_back(dist, matched_doc); } }; visit(cb, sub_results.Borrowed()); - sort(distances_.begin(), distances_.end()); + size_t prefix_size = min(knn.limit, distances_.size()); + partial_sort(distances_.begin(), distances_.begin() + prefix_size, distances_.end()); - vector out(min(knn.limit, distances_.size())); + vector out(prefix_size); for (size_t i = 0; i < out.size(); i++) out[i] = distances_[i].second; @@ -331,6 +347,8 @@ struct BasicSearch { vector> distances_; }; +#pragma GCC diagnostic pop + } // namespace FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, indices_{} { @@ -346,7 +364,7 @@ FieldIndices::FieldIndices(Schema schema) : schema_{move(schema)}, all_ids_{}, i indices_[field_ident] = make_unique(); break; case SchemaField::VECTOR: - indices_[field_ident] = make_unique(); + indices_[field_ident] = make_unique(field_info.knn_dim, field_info.knn_sim); break; } } diff --git a/src/core/search/search.h b/src/core/search/search.h index b8118df2e..8b1ad15e6 100644 --- a/src/core/search/search.h +++ b/src/core/search/search.h @@ -25,6 +25,9 @@ struct SchemaField { FieldType type; std::string short_name; // equal to ident if none provided + + size_t knn_dim = 0u; // dimension of knn vectors + VectorSimilarity knn_sim = VectorSimilarity::L2; // similarity type }; // Describes the fields of an index diff --git a/src/core/search/search_test.cc b/src/core/search/search_test.cc index ba60bfb20..b995d3bca 100644 --- a/src/core/search/search_test.cc +++ b/src/core/search/search_test.cc @@ -19,6 +19,7 @@ #include "base/logging.h" #include "core/search/base.h" #include "core/search/query_driver.h" +#include "core/search/vector_utils.h" namespace dfly { namespace search { @@ -40,15 +41,8 @@ struct MockedDocument : public DocumentAccessor { return it != fields_.end() ? string_view{it->second} : ""; } - FtVector GetVector(string_view field) const override { - string_view str_value = fields_.at(field); - FtVector out; - for (string_view coord : absl::StrSplit(str_value, ',')) { - float v; - CHECK(absl::SimpleAtof(coord, &v)); - out.push_back(v); - } - return out; + VectorInfo GetVector(string_view field) const override { + return BytesToFtVector(GetString(field)); } string DebugFormat() { @@ -331,17 +325,18 @@ TEST_F(SearchParserTest, IntegerTerms) { EXPECT_TRUE(Check()) << GetError(); } -std::string FtVectorToBytes(FtVector vec) { +std::string ToBytes(absl::Span vec) { return string{reinterpret_cast(vec.data()), sizeof(float) * vec.size()}; } TEST_F(SearchParserTest, SimpleKnn) { auto schema = MakeSimpleSchema({{"even", SchemaField::TAG}, {"pos", SchemaField::VECTOR}}); + schema.fields["pos"].knn_dim = 1; FieldIndices indices{schema}; // Place points on a straight line for (size_t i = 0; i < 100; i++) { - Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", to_string(float(i))}}}; + Map values{{{"even", i % 2 == 0 ? "YES" : "NO"}, {"pos", ToBytes({float(i)})}}}; MockedDocument doc{values}; indices.Add(i, &doc); } @@ -351,35 +346,35 @@ TEST_F(SearchParserTest, SimpleKnn) { // Five closest to 50 { - params["vec"] = FtVectorToBytes(FtVector{50.0}); + params["vec"] = ToBytes({50.0}); algo.Init("*=>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(48, 49, 50, 51, 52)); } // Five closest to 0 { - params["vec"] = FtVectorToBytes(FtVector{0.0}); + params["vec"] = ToBytes({0.0}); algo.Init("*=>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4)); } // Five closest to 20, all even { - params["vec"] = FtVectorToBytes(FtVector{20.0}); + params["vec"] = ToBytes({20.0}); algo.Init("@even:{yes} =>[KNN 5 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(16, 18, 20, 22, 24)); } // Three closest to 31, all odd { - params["vec"] = FtVectorToBytes(FtVector{31.0}); + params["vec"] = ToBytes({31.0}); algo.Init("@even:{no} =>[KNN 3 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(29, 31, 33)); } // Two closest to 70.5 { - params["vec"] = FtVectorToBytes(FtVector{70.5}); + params["vec"] = ToBytes({70.5}); algo.Init("* =>[KNN 2 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(70, 71)); } @@ -393,11 +388,11 @@ TEST_F(SearchParserTest, Simple2dKnn) { const pair kTestCoords[] = {{0, 0}, {1, 0}, {1, 1}, {0, 1}, {0.5, 0.5}}; auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); + schema.fields["pos"].knn_dim = 2; FieldIndices indices{schema}; for (size_t i = 0; i < ABSL_ARRAYSIZE(kTestCoords); i++) { - auto [x, y] = kTestCoords[i]; - string coords = absl::StrCat(x, ",", y); + string coords = ToBytes({kTestCoords[i].first, kTestCoords[i].second}); MockedDocument doc{Map{{"pos", coords}}}; indices.Add(i, &doc); } @@ -407,47 +402,83 @@ TEST_F(SearchParserTest, Simple2dKnn) { // Single center { - params["vec"] = FtVectorToBytes(FtVector{0.5, 0.5}); + params["vec"] = ToBytes({0.5, 0.5}); algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(4)); } // Lower left { - params["vec"] = FtVectorToBytes(FtVector{0, 0}); + params["vec"] = ToBytes({0, 0}); algo.Init("* =>[KNN 4 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 3, 4)); } // Upper right { - params["vec"] = FtVectorToBytes(FtVector{1, 1}); + params["vec"] = ToBytes({1, 1}); algo.Init("* =>[KNN 4 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(1, 2, 3, 4)); } // Request more than there is { - params["vec"] = FtVectorToBytes(FtVector{0, 0}); + params["vec"] = ToBytes({0, 0}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::UnorderedElementsAre(0, 1, 2, 3, 4)); } // Test correct order: (0.7, 0.15) { - params["vec"] = FtVectorToBytes(FtVector{0.7, 0.15}); + params["vec"] = ToBytes({0.7, 0.15}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(1, 4, 0, 2, 3)); } // Test correct order: (0.8, 0.9) { - params["vec"] = FtVectorToBytes(FtVector{0.8, 0.9}); + params["vec"] = ToBytes({0.8, 0.9}); algo.Init("* => [KNN 10 @pos $vec]", ¶ms); EXPECT_THAT(algo.Search(&indices).ids, testing::ElementsAre(2, 4, 3, 1, 0)); } } +static void BM_VectorSearch(benchmark::State& state) { + unsigned ndims = state.range(0); + unsigned nvecs = state.range(1); + + auto schema = MakeSimpleSchema({{"pos", SchemaField::VECTOR}}); + schema.fields["pos"].knn_dim = ndims; + FieldIndices indices{schema}; + + auto random_vec = [ndims]() { + vector coords; + for (size_t j = 0; j < ndims; j++) + coords.push_back(static_cast(rand()) / static_cast(RAND_MAX)); + return coords; + }; + + for (size_t i = 0; i < nvecs; i++) { + auto rv = random_vec(); + MockedDocument doc{Map{{"pos", ToBytes(rv)}}}; + indices.Add(i, &doc); + } + + SearchAlgorithm algo{}; + QueryParams params; + + auto rv = random_vec(); + params["vec"] = ToBytes(rv); + algo.Init("* =>[KNN 1 @pos $vec]", ¶ms); + + while (state.KeepRunningBatch(10)) { + for (size_t i = 0; i < 10; i++) + benchmark::DoNotOptimize(algo.Search(&indices)); + } +} + +BENCHMARK(BM_VectorSearch)->Args({120, 10'000}); + } // namespace search } // namespace dfly diff --git a/src/core/search/vector.cc b/src/core/search/vector.cc deleted file mode 100644 index a768c215e..000000000 --- a/src/core/search/vector.cc +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2023, DragonflyDB authors. All rights reserved. -// See LICENSE for licensing terms. -// - -#include "core/search/vector.h" - -#include -#include - -#include "base/logging.h" - -namespace dfly::search { - -using namespace std; - -FtVector BytesToFtVector(string_view value) { - DCHECK_EQ(value.size() % sizeof(float), 0u); - FtVector out(value.size() / sizeof(float)); - - // Create copy for aligned access - unique_ptr float_ptr = make_unique(out.size()); - memcpy(float_ptr.get(), value.data(), value.size()); - - for (size_t i = 0; i < out.size(); i++) - out[i] = float_ptr[i]; - return out; -} - -// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 ) -__attribute__((optimize("fast-math"))) float VectorDistance(const FtVector& u, const FtVector& v) { - DCHECK_EQ(u.size(), v.size()); - float sum = 0; - for (size_t i = 0; i < u.size(); i++) - sum += (u[i] - v[i]) * (u[i] - v[i]); - return sqrt(sum); -} - -} // namespace dfly::search diff --git a/src/core/search/vector_utils.cc b/src/core/search/vector_utils.cc new file mode 100644 index 000000000..c0a4ac181 --- /dev/null +++ b/src/core/search/vector_utils.cc @@ -0,0 +1,65 @@ +// Copyright 2023, DragonflyDB authors. All rights reserved. +// See LICENSE for licensing terms. +// + +#include "core/search/vector_utils.h" + +#include +#include + +#include "base/logging.h" + +namespace dfly::search { + +using namespace std; + +namespace { + +// Euclidean vector distance: sqrt( sum: (u[i] - v[i])^2 ) +__attribute__((optimize("fast-math"))) float L2Distance(const float* u, const float* v, + size_t dims) { + float sum = 0; + for (size_t i = 0; i < dims; i++) + sum += (u[i] - v[i]) * (u[i] - v[i]); + return sqrt(sum); +} + +__attribute__((optimize("fast-math"))) float CosineDistance(const float* u, const float* v, + size_t dims) { + float sum_uv = 0, sum_uu = 0, sum_vv = 0; + for (size_t i = 0; i < dims; i++) { + sum_uv += u[i] * v[i]; + sum_uu += u[i] * u[i]; + sum_vv += v[i] * v[i]; + } + + if (float denom = sum_uu * sum_vv; denom != 0.0f) + return sum_uv / sqrt(denom); + return 0.0f; +} + +} // namespace + +OwnedFtVector BytesToFtVector(string_view value) { + DCHECK_EQ(value.size() % sizeof(float), 0u) << value.size(); + + // Value cannot be casted directly as it might be not aligned as a float (4 bytes). + // Misaligned memory access is UB. + size_t size = value.size() / sizeof(float); + auto out = make_unique(size); + memcpy(out.get(), value.data(), size * sizeof(float)); + + return {std::move(out), size}; +} + +float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim) { + switch (sim) { + case VectorSimilarity::L2: + return L2Distance(u, v, dims); + case VectorSimilarity::COSINE: + return CosineDistance(u, v, dims); + }; + return 0.0f; +} + +} // namespace dfly::search diff --git a/src/core/search/vector.h b/src/core/search/vector_utils.h similarity index 58% rename from src/core/search/vector.h rename to src/core/search/vector_utils.h index f2e7e558b..ea19db478 100644 --- a/src/core/search/vector.h +++ b/src/core/search/vector_utils.h @@ -8,8 +8,8 @@ namespace dfly::search { -FtVector BytesToFtVector(std::string_view value); +OwnedFtVector BytesToFtVector(std::string_view value); -float VectorDistance(const FtVector& v1, const FtVector& v2); +float VectorDistance(const float* u, const float* v, size_t dims, VectorSimilarity sim); } // namespace dfly::search diff --git a/src/facade/cmd_arg_parser.h b/src/facade/cmd_arg_parser.h index 174718831..986a87b0c 100644 --- a/src/facade/cmd_arg_parser.h +++ b/src/facade/cmd_arg_parser.h @@ -157,6 +157,10 @@ struct CmdArgParser { return cur_i_ < args_.size() && !error_; } + bool HasError() { + return error_.has_value(); + } + // Get optional error if occured std::optional Error() { return std::exchange(error_, {}); diff --git a/src/server/search/doc_accessors.cc b/src/server/search/doc_accessors.cc index e8d15296d..cd251ba47 100644 --- a/src/server/search/doc_accessors.cc +++ b/src/server/search/doc_accessors.cc @@ -12,7 +12,7 @@ #include "core/json_object.h" #include "core/search/search.h" -#include "core/search/vector.h" +#include "core/search/vector_utils.h" #include "core/string_map.h" #include "server/container_utils.h" @@ -32,10 +32,11 @@ string_view SdsToSafeSv(sds str) { } string PrintField(search::SchemaField::FieldType type, string_view value) { - if (type == search::SchemaField::VECTOR) - return absl::StrCat("[", absl::StrJoin(search::BytesToFtVector(value), ","), "]"); - else - return string{value}; + if (type == search::SchemaField::VECTOR) { + auto [ptr, size] = search::BytesToFtVector(value); + return absl::StrCat("[", absl::StrJoin(absl::Span{ptr.get(), size}, ","), "]"); + } + return string{value}; } string ExtractValue(const search::Schema& schema, string_view key, string_view value) { @@ -63,7 +64,7 @@ string_view ListPackAccessor::GetString(string_view active_field) const { return container_utils::LpFind(lp_, active_field, intbuf_[0].data()).value_or(""sv); } -search::FtVector ListPackAccessor::GetVector(string_view active_field) const { +BaseAccessor::VectorInfo ListPackAccessor::GetVector(string_view active_field) const { return search::BytesToFtVector(GetString(active_field)); } @@ -89,7 +90,7 @@ string_view StringMapAccessor::GetString(string_view active_field) const { return SdsToSafeSv(hset_->Find(active_field)); } -search::FtVector StringMapAccessor::GetVector(string_view active_field) const { +BaseAccessor::VectorInfo StringMapAccessor::GetVector(string_view active_field) const { return search::BytesToFtVector(GetString(active_field)); } @@ -113,16 +114,20 @@ string_view JsonAccessor::GetString(string_view active_field) const { return buf_; } -search::FtVector JsonAccessor::GetVector(string_view active_field) const { +BaseAccessor::VectorInfo JsonAccessor::GetVector(string_view active_field) const { auto res = GetPath(active_field)->evaluate(json_); DCHECK(res.is_array()); if (res.empty()) - return {}; + return {nullptr, 0}; - search::FtVector out; + size_t size = res[0].size(); + auto ptr = make_unique(size); + + size_t i = 0; for (auto v : res[0].array_range()) - out.push_back(v.as()); - return out; + ptr[i++] = v.as(); + + return {std::move(ptr), size}; } JsonAccessor::JsonPathContainer* JsonAccessor::GetPath(std::string_view field) const { diff --git a/src/server/search/doc_accessors.h b/src/server/search/doc_accessors.h index 451b97bca..aed91d489 100644 --- a/src/server/search/doc_accessors.h +++ b/src/server/search/doc_accessors.h @@ -40,7 +40,7 @@ struct ListPackAccessor : public BaseAccessor { } std::string_view GetString(std::string_view field) const override; - search::FtVector GetVector(std::string_view field) const override; + VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; private: @@ -54,7 +54,7 @@ struct StringMapAccessor : public BaseAccessor { } std::string_view GetString(std::string_view field) const override; - search::FtVector GetVector(std::string_view field) const override; + VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; private: @@ -69,7 +69,7 @@ struct JsonAccessor : public BaseAccessor { } std::string_view GetString(std::string_view field) const override; - search::FtVector GetVector(std::string_view field) const override; + VectorInfo GetVector(std::string_view field) const override; SearchDocData Serialize(const search::Schema& schema) const override; // The JsonAccessor works with structured types and not plain strings, so an overload is needed diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index 546be6eb8..ecfb9c841 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -18,7 +18,7 @@ #include "base/logging.h" #include "core/json_object.h" #include "core/search/search.h" -#include "core/search/vector.h" +#include "core/search/vector_utils.h" #include "facade/cmd_arg_parser.h" #include "facade/error.h" #include "facade/reply_builder.h" @@ -46,6 +46,30 @@ bool IsValidJsonPath(string_view path) { return !ec; } +pair ParseVectorFieldInfo(CmdArgParser* parser, + ConnectionContext* cntx) { + size_t dim = 0; + search::VectorSimilarity sim = search::VectorSimilarity::L2; + + size_t num_args = parser->Next().Int(); + for (size_t i = 0; i * 2 < num_args; i++) { + parser->ToUpper(); + if (parser->Check("DIM").ExpectTail(1)) { + dim = parser->Next().Int(); + continue; + } + if (parser->Check("DISTANCE_METRIC").ExpectTail(1)) { + sim = parser->Next() + .Case("L2", search::VectorSimilarity::L2) + .Case("COSINE", search::VectorSimilarity::COSINE); + continue; + } + parser->Skip(2); + } + + return {dim, sim}; +} + optional ParseSchemaOrReply(DocIndex::DataType type, CmdArgParser parser, ConnectionContext* cntx) { search::Schema schema; @@ -74,15 +98,24 @@ optional ParseSchemaOrReply(DocIndex::DataType type, CmdArgParse return nullopt; } - // Skip {algorithm} {dim} flags - if (*type == search::SchemaField::VECTOR) - parser.Skip(2); + // Vector fields include: {algorithm} num_args args... + size_t knn_dim = 0; + search::VectorSimilarity knn_sim = search::VectorSimilarity::L2; + if (*type == search::SchemaField::VECTOR) { + parser.Skip(1); // algorithm + std::tie(knn_dim, knn_sim) = ParseVectorFieldInfo(&parser, cntx); + + if (!parser.HasError() && knn_dim == 0) { + (*cntx)->SendError("Vector dimension cannot be zero"); + return nullopt; + } + } // Skip all trailing ignored parameters while (kIgnoredOptions.count(parser.Peek()) > 0) parser.Skip(2); - schema.fields[field] = {*type, string{field_alias}}; + schema.fields[field] = {*type, string{field_alias}, knn_dim, knn_sim}; } // Build field name mapping table @@ -208,8 +241,9 @@ void ReplyKnn(size_t knn_limit, const SearchParams& params, absl::Spanknn_distance < r->knn_distance; }); docs.resize(min(docs.size(), knn_limit));