diff --git a/src/redis/geo.c b/src/redis/geo.c index 6698cde4e..0670c5a0d 100644 --- a/src/redis/geo.c +++ b/src/redis/geo.c @@ -130,6 +130,35 @@ int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance) return C_OK; } +/* Compute the sorted set scores min (inclusive), max (exclusive) we should + * query in order to retrieve all the elements inside the specified area + * 'hash'. The two scores are returned by reference in *min and *max. */ +void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max) { + /* We want to compute the sorted set scores that will include all the + * elements inside the specified Geohash 'hash', which has as many + * bits as specified by hash.step * 2. + * + * So if step is, for example, 3, and the hash value in binary + * is 101010, since our score is 52 bits we want every element which + * is in binary: 101010????????????????????????????????????????????? + * Where ? can be 0 or 1. + * + * To get the min score we just use the initial hash value left + * shifted enough to get the 52 bit value. Later we increment the + * 6 bit prefix (see the hash.bits++ statement), and get the new + * prefix: 101011, which we align again to 52 bits to get the maximum + * value (which is excluded from the search). So we get everything + * between the two following scores (represented in binary): + * + * 1010100000000000000000000000000000000000000000000000 (included) + * and + * 1010110000000000000000000000000000000000000000000000 (excluded). + */ + *min = geohashAlign52Bits(hash); + hash.bits++; + *max = geohashAlign52Bits(hash); +} + #if 0 /* Query a Redis sorted set to extract all the elements between 'min' and @@ -208,35 +237,6 @@ int geoGetPointsInRange(robj *zobj, double min, double max, GeoShape *shape, geo return ga->used - origincount; } -/* Compute the sorted set scores min (inclusive), max (exclusive) we should - * query in order to retrieve all the elements inside the specified area - * 'hash'. The two scores are returned by reference in *min and *max. */ -void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max) { - /* We want to compute the sorted set scores that will include all the - * elements inside the specified Geohash 'hash', which has as many - * bits as specified by hash.step * 2. - * - * So if step is, for example, 3, and the hash value in binary - * is 101010, since our score is 52 bits we want every element which - * is in binary: 101010????????????????????????????????????????????? - * Where ? can be 0 or 1. - * - * To get the min score we just use the initial hash value left - * shifted enough to get the 52 bit value. Later we increment the - * 6 bit prefix (see the hash.bits++ statement), and get the new - * prefix: 101011, which we align again to 52 bits to get the maximum - * value (which is excluded from the search). So we get everything - * between the two following scores (represented in binary): - * - * 1010100000000000000000000000000000000000000000000000 (included) - * and - * 1010110000000000000000000000000000000000000000000000 (excluded). - */ - *min = geohashAlign52Bits(hash); - hash.bits++; - *max = geohashAlign52Bits(hash); -} - /* Obtain all members between the min/max of this geohash bounding box. * Populate a geoArray of GeoPoints by calling geoGetPointsInRange(). * Return the number of points added to the array. */ diff --git a/src/redis/geo.h b/src/redis/geo.h index c4c11ae1a..10802eb1f 100644 --- a/src/redis/geo.h +++ b/src/redis/geo.h @@ -2,6 +2,7 @@ #define __GEO_H__ #include /* for size_t */ +#include "geohash_helper.h" /* Structures used inside geo.c in order to represent points and array of * points on the earth. */ @@ -19,4 +20,7 @@ typedef struct geoArray { size_t used; } geoArray; +int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance); +void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max); + #endif diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index 2e250cafb..6e9a3e571 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -7,6 +7,7 @@ #include "server/acl/acl_commands_def.h" extern "C" { +#include "redis/geo.h" #include "redis/geohash.h" #include "redis/geohash_helper.h" #include "redis/listpack.h" @@ -38,6 +39,11 @@ namespace { using CI = CommandId; static const char kNxXxErr[] = "XX and NX options at the same time are not compatible"; +static const char kFromMemberLonglatErr[] = + "FROMMEMBER and FROMLONLAT options at the same time are not compatible"; +static const char kByRadiusBoxErr[] = + "BYRADIUS and BYBOX options at the same time are not compatible"; +static const char kAscDescErr[] = "ASC and DESC options at the same time are not compatible"; static const char kScoreNaN[] = "resulting score is not a number (NaN)"; static const char kFloatRangeErr[] = "min or max is not a float"; static const char kLexRangeErr[] = "min or max not valid string range item"; @@ -48,6 +54,30 @@ using MScoreResponse = std::vector>; using ScoredMember = std::pair; using ScoredArray = std::vector; +struct GeoPoint { + double longitude; + double latitude; + double dist; + double score; + std::string member; + GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){}; + GeoPoint(double _longitude, double _latitude, double _dist, double _score, + const std::string& _member) + : longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){}; +}; +using GeoArray = std::vector; + +enum class Sorting { kUnsorted, kAsc, kDesc }; +struct GeoSearchOpts { + double conversion = 0; + uint64_t count = 0; + Sorting sorting = Sorting::kUnsorted; + bool any = 0; + bool withdist = 0; + bool withcoord = 0; + bool withhash = 0; +}; + inline zrangespec GetZrangeSpec(bool reverse, const ZSetFamily::ScoreInterval& si) { auto interval = si; if (reverse) @@ -1348,6 +1378,23 @@ auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, st return iv.PopResult(); } +auto OpRanges(const std::vector& range_specs, const OpArgs& op_args, + string_view key) -> OpResult> { + OpResult res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET); + if (!res_it) + return res_it.status(); + + PrimeValue& pv = res_it.value()->second; + vector result_arrays; + for (auto& range_spec : range_specs) { + IntervalVisitor iv{Action::RANGE, range_spec.params, &pv}; + std::visit(iv, range_spec.interval); + result_arrays.push_back(iv.PopResult()); + } + + return result_arrays; +} + OpResult OpRemRange(const OpArgs& op_args, string_view key, const ZSetFamily::ZRangeSpec& range_spec) { auto& db_slice = op_args.shard->db_slice(); @@ -2599,6 +2646,298 @@ void ZSetFamily::GeoDist(CmdArgList args, ConnectionContext* cntx) { distance_multiplier); } +namespace { +// Search all eight neighbors + self geohash box +bool MembersOfAllNeighbors(ConnectionContext* cntx, string_view key, const GeoHashRadius& n, + const GeoShape& shape_ref, GeoArray* ga, unsigned long limit) { + array neighbors; + unsigned int last_processed = 0; + GeoShape* shape = &(const_cast(shape_ref)); + + neighbors[0] = n.hash; + neighbors[1] = n.neighbors.north; + neighbors[2] = n.neighbors.south; + neighbors[3] = n.neighbors.east; + neighbors[4] = n.neighbors.west; + neighbors[5] = n.neighbors.north_east; + neighbors[6] = n.neighbors.north_west; + neighbors[7] = n.neighbors.south_east; + neighbors[8] = n.neighbors.south_west; + + // Get range_specs for neighbors (*and* our own hashbox) + std::vector range_specs; + for (unsigned int i = 0; i < neighbors.size(); i++) { + if (HASHISZERO(neighbors[i])) { + continue; + } + + // When a huge Radius (in the 5000 km range or more) is used, + // adjacent neighbors can be the same, leading to duplicated + // elements. Skip every range which is the same as the one + // processed previously. + if (last_processed && neighbors[i].bits == neighbors[last_processed].bits && + neighbors[i].step == neighbors[last_processed].step) { + continue; + } + + GeoHashFix52Bits min, max; + scoresOfGeoHashBox(neighbors[i], &min, &max); + + ZSetFamily::ScoreInterval si; + si.first = ZSetFamily::Bound{static_cast(min), false}; + si.second = ZSetFamily::Bound{static_cast(max), true}; + + ZSetFamily::RangeParams range_params; + range_params.interval_type = ZSetFamily::RangeParams::IntervalType::SCORE; + range_params.with_scores = true; + range_specs.emplace_back(si, range_params); + + last_processed = i; + } + + // get all the matching members and add them to the potential result list + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpRanges(range_specs, t->GetOpArgs(shard), key); + }; + OpResult> result_arrays = + cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result_arrays.status() == OpStatus::WRONG_TYPE) { + (*cntx)->SendError(kWrongTypeErr); + return false; + } + + // filter potential result list + double xy[2]; + double distance; + for (auto& arr : *result_arrays) { + for (auto& p : arr) { + if (geoWithinShape(shape, p.second, xy, &distance) == 0) { + ga->emplace_back(xy[0], xy[1], distance, p.second, p.first); + if (limit > 0 && ga->size() >= limit) + break; + } + } + } + return true; +} + +void SortIfNeeded(GeoArray* ga, Sorting sorting, uint64_t count) { + if (sorting == Sorting::kUnsorted) + return; + + auto comparator = [&](const GeoPoint& a, const GeoPoint& b) { + if (sorting == Sorting::kAsc) { + return a.dist < b.dist; + } else { + DCHECK(sorting == Sorting::kDesc); + return a.dist > b.dist; + } + }; + + if (count > 0) { + std::partial_sort(ga->begin(), ga->begin() + count, ga->end(), comparator); + ga->resize(count); + } else { + std::sort(ga->begin(), ga->end(), comparator); + } +} + +void GeoSearchGeneric(ConnectionContext* cntx, const GeoShape& shape_ref, string_view key, + const GeoSearchOpts& geo_ops) { + // query + GeoShape* shape = &(const_cast(shape_ref)); + GeoHashRadius georadius = geohashCalculateAreasByShapeWGS84(shape); + GeoArray ga; + if (!MembersOfAllNeighbors(cntx, key, georadius, shape_ref, &ga, + geo_ops.any ? geo_ops.count : 0)) { + return; + } + + // if no matching results, the user gets an empty reply. + if (ga.empty()) { + (*cntx)->SendNull(); + return; + } + + // sort and trim by count + SortIfNeeded(&ga, geo_ops.sorting, geo_ops.count); + + // generate reply array withdist, withcoords, withhash + int record_size = 1; + if (geo_ops.withdist) { + record_size++; + } + if (geo_ops.withhash) { + record_size++; + } + if (geo_ops.withcoord) { + record_size++; + } + (*cntx)->StartArray(ga.size()); + for (const auto& p : ga) { + // [member, dist, x, y, hash] + (*cntx)->StartArray(record_size); + (*cntx)->SendBulkString(p.member); + if (geo_ops.withdist) { + (*cntx)->SendDouble(p.dist / geo_ops.conversion); + } + if (geo_ops.withhash) { + (*cntx)->SendDouble(p.score); + } + if (geo_ops.withcoord) { + (*cntx)->StartArray(2); + (*cntx)->SendDouble(p.longitude); + (*cntx)->SendDouble(p.latitude); + } + } +} +} // namespace + +void ZSetFamily::GeoSearch(CmdArgList args, ConnectionContext* cntx) { + // parse arguments + string_view key = ArgS(args, 0); + GeoShape shape = {}; + GeoSearchOpts geo_ops; + + // FROMMEMBER or FROMLONLAT is set + bool from_set = false; + // BYRADIUS or BYBOX is set + bool by_set = false; + + for (size_t i = 1; i < args.size(); ++i) { + ToUpper(&args[i]); + + string_view cur_arg = ArgS(args, i); + + if (cur_arg == "FROMMEMBER") { + if (from_set) { + return (*cntx)->SendError(kFromMemberLonglatErr); + } else if (i + 1 < args.size()) { + string_view member; + member = ArgS(args, i + 1); + + // member to latlong, set shape.xy + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpScore(t->GetOpArgs(shard), key, member); + }; + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (result.status() == OpStatus::WRONG_TYPE) { + return (*cntx)->SendError(kWrongTypeErr); + } else if (!result) { + return (*cntx)->SendError("Member not found"); + } + ScoreToLongLat(*result, shape.xy); + from_set = true; + i++; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + } else if (cur_arg == "FROMLONLAT") { + if (from_set) { + return (*cntx)->SendError(kFromMemberLonglatErr); + } else if (i + 2 < args.size()) { + string_view longitude_str = ArgS(args, i + 1); + string_view latitude_str = ArgS(args, i + 2); + pair longlat; + if (!ParseLongLat(longitude_str, latitude_str, &longlat)) { + string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude_str, ",", + latitude_str); + return (*cntx)->SendError(err, kSyntaxErrType); + } + shape.xy[0] = longlat.first; + shape.xy[1] = longlat.second; + from_set = true; + i += 2; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + } else if (cur_arg == "BYRADIUS") { + if (by_set) { + return (*cntx)->SendError(kByRadiusBoxErr); + } else if (i + 2 < args.size()) { + if (!ParseDouble(ArgS(args, i + 1), &shape.t.radius)) { + return (*cntx)->SendError(kInvalidFloatErr); + } + string_view unit; + unit = ArgS(args, i + 2); + shape.conversion = ExtractUnit(unit); + geo_ops.conversion = shape.conversion; + if (shape.conversion == -1) { + return (*cntx)->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + shape.type = CIRCULAR_TYPE; + by_set = true; + i += 2; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + } else if (cur_arg == "BYBOX") { + if (by_set) { + return (*cntx)->SendError(kByRadiusBoxErr); + } else if (i + 3 < args.size()) { + if (!ParseDouble(ArgS(args, i + 1), &shape.t.r.width)) { + return (*cntx)->SendError(kInvalidFloatErr); + } + if (!ParseDouble(ArgS(args, i + 2), &shape.t.r.height)) { + return (*cntx)->SendError(kInvalidFloatErr); + } + string_view unit; + unit = ArgS(args, i + 3); + shape.conversion = ExtractUnit(unit); + geo_ops.conversion = shape.conversion; + if (shape.conversion == -1) { + return (*cntx)->SendError("unsupported unit provided. please use M, KM, FT, MI"); + } + shape.type = RECTANGLE_TYPE; + by_set = true; + i += 3; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + } else if (cur_arg == "ASC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return (*cntx)->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kAsc; + } + } else if (cur_arg == "DESC") { + if (geo_ops.sorting != Sorting::kUnsorted) { + return (*cntx)->SendError(kAscDescErr); + } else { + geo_ops.sorting = Sorting::kDesc; + } + } else if (cur_arg == "COUNT") { + if (i + 1 < args.size()) { + absl::SimpleAtoi(std::string(ArgS(args, i + 1)), &geo_ops.count); + i++; + } else { + return (*cntx)->SendError(kSyntaxErr); + } + if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") { + geo_ops.any = true; + i++; + } + } else if (cur_arg == "WITHCOORD") { + geo_ops.withcoord = true; + } else if (cur_arg == "WITHDIST") { + geo_ops.withdist = true; + } else if (cur_arg == "WITHHASH") + geo_ops.withhash = true; + else { + return (*cntx)->SendError(kSyntaxErr); + } + } + + // check mandatory options + if (!from_set) { + return (*cntx)->SendError(kSyntaxErr); + } + if (!by_set) { + return (*cntx)->SendError(kSyntaxErr); + } + GeoSearchGeneric(cntx, shape, key, geo_ops); +} + #define HFUNC(x) SetHandler(&ZSetFamily::x) namespace acl { @@ -2635,6 +2974,7 @@ constexpr uint32_t kGeoAdd = WRITE | GEO | SLOW; constexpr uint32_t kGeoHash = READ | GEO | SLOW; constexpr uint32_t kGeoPos = READ | GEO | SLOW; constexpr uint32_t kGeoDist = READ | GEO | SLOW; +constexpr uint32_t kGeoSearch = READ | GEO | SLOW; } // namespace acl void ZSetFamily::Register(CommandRegistry* registry) { @@ -2695,7 +3035,8 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, 1, acl::kGeoAdd}.HFUNC(GeoAdd) << CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, 1, acl::kGeoHash}.HFUNC(GeoHash) << CI{"GEOPOS", CO::FAST | CO::READONLY, -2, 1, 1, 1, acl::kGeoPos}.HFUNC(GeoPos) - << CI{"GEODIST", CO::READONLY, -4, 1, 1, 1, acl::kGeoDist}.HFUNC(GeoDist); + << CI{"GEODIST", CO::READONLY, -4, 1, 1, 1, acl::kGeoDist}.HFUNC(GeoDist) + << CI{"GEOSEARCH", CO::READONLY, -4, 1, 1, 1, acl::kGeoSearch}.HFUNC(GeoSearch); } } // namespace dfly diff --git a/src/server/zset_family.h b/src/server/zset_family.h index dc3632439..1f12caeec 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -48,6 +48,8 @@ class ZSetFamily { struct ZRangeSpec { std::variant interval; RangeParams params; + ZRangeSpec() = default; + ZRangeSpec(const ScoreInterval& si, const RangeParams& rp) : interval(si), params(rp){}; }; private: @@ -98,6 +100,7 @@ class ZSetFamily { static void GeoHash(CmdArgList args, ConnectionContext* cntx); static void GeoPos(CmdArgList args, ConnectionContext* cntx); static void GeoDist(CmdArgList args, ConnectionContext* cntx); + static void GeoSearch(CmdArgList args, ConnectionContext* cntx); }; } // namespace dfly diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index f0f5a22e2..3a1c717b4 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -777,4 +777,69 @@ TEST_F(ZSetFamilyTest, GeoDist) { EXPECT_THAT(resp, ArgType(RespExpr::NIL)); } +TEST_F(ZSetFamilyTest, GeoSearch) { + EXPECT_EQ(10, CheckedInt({"geoadd", "Europe", "13.4050", "52.5200", "Berlin", "3.7038", + "40.4168", "Madrid", "9.1427", "38.7369", "Lisbon", "2.3522", + "48.8566", "Paris", "16.3738", "48.2082", "Vienna", "4.8952", + "52.3702", "Amsterdam", "10.7522", "59.9139", "Oslo", "23.7275", + "37.9838", "Athens", "19.0402", "47.4979", "Budapest", "6.2603", + "53.3498", "Dublin"})); + + auto resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500", + "KM", "WITHCOORD", "WITHDIST", "WITHHASH"}); + EXPECT_THAT( + resp, + RespArray(ElementsAre( + RespArray(ElementsAre("Berlin", "0.00017343178521311378", "3673983950397063", + RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))), + RespArray( + ElementsAre("Dublin", "487.5619030644293", "3678981558208417", + RespArray(ElementsAre("6.260299980640411", "53.34980087538425"))))))); + + resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYBOX", "1000", "1000", + "KM", "WITHCOORD", "WITHDIST"}); + EXPECT_THAT( + resp, + RespArray(ElementsAre( + RespArray(ElementsAre("Vienna", "523.6926930553866", + RespArray(ElementsAre("16.373799741268158", "48.20820011474228")))), + RespArray(ElementsAre("Berlin", "0.00017343178521311378", + RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))), + RespArray( + ElementsAre("Dublin", "487.5619030644293", + RespArray(ElementsAre("6.260299980640411", "53.34980087538425"))))))); + + resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500", "KM", + "COUNT", "3", "WITHCOORD", "WITHDIST"}); + EXPECT_THAT( + resp, + RespArray(ElementsAre( + RespArray(ElementsAre("Berlin", "0.00017343178521311378", + RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))), + RespArray( + ElementsAre("Dublin", "487.5619030644293", + RespArray(ElementsAre("6.260299980640411", "53.34980087538425"))))))); + + resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500", "KM", + "DESC", "WITHCOORD", "WITHDIST"}); + EXPECT_THAT( + resp, + RespArray(ElementsAre( + RespArray(ElementsAre("Dublin", "487.5619030644293", + RespArray(ElementsAre("6.260299980640411", "53.34980087538425")))), + RespArray( + ElementsAre("Berlin", "0.00017343178521311378", + RespArray(ElementsAre("13.405002057552338", "52.51999907056681"))))))); + + resp = Run({"GEOSEARCH", "Europe", "FROMMEMBER", "Madrid", "BYRADIUS", "700", "KM", "WITHCOORD", + "WITHDIST"}); + EXPECT_THAT( + resp, + RespArray(ElementsAre( + RespArray(ElementsAre( + "Madrid", "0", RespArray(ElementsAre("3.7038007378578186", "40.416799319406216")))), + RespArray( + ElementsAre("Lisbon", "502.20769462704084", + RespArray(ElementsAre("9.142698347568512", "38.736900197448534"))))))); +} } // namespace dfly