feat: Add GEOSEARCH support (#2070)

Signed-off-by: azuredream <zhaozixuan67@gmail.com>
This commit is contained in:
zixuan zhao 2023-10-31 06:22:04 -04:00 committed by GitHub
parent 7b71b728c7
commit 05919efcbd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 443 additions and 30 deletions

View file

@ -130,6 +130,35 @@ int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance)
return C_OK;
}
/* Compute the sorted set scores min (inclusive), max (exclusive) we should
* query in order to retrieve all the elements inside the specified area
* 'hash'. The two scores are returned by reference in *min and *max. */
void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max) {
/* We want to compute the sorted set scores that will include all the
* elements inside the specified Geohash 'hash', which has as many
* bits as specified by hash.step * 2.
*
* So if step is, for example, 3, and the hash value in binary
* is 101010, since our score is 52 bits we want every element which
* is in binary: 101010?????????????????????????????????????????????
* Where ? can be 0 or 1.
*
* To get the min score we just use the initial hash value left
* shifted enough to get the 52 bit value. Later we increment the
* 6 bit prefix (see the hash.bits++ statement), and get the new
* prefix: 101011, which we align again to 52 bits to get the maximum
* value (which is excluded from the search). So we get everything
* between the two following scores (represented in binary):
*
* 1010100000000000000000000000000000000000000000000000 (included)
* and
* 1010110000000000000000000000000000000000000000000000 (excluded).
*/
*min = geohashAlign52Bits(hash);
hash.bits++;
*max = geohashAlign52Bits(hash);
}
#if 0
/* Query a Redis sorted set to extract all the elements between 'min' and
@ -208,35 +237,6 @@ int geoGetPointsInRange(robj *zobj, double min, double max, GeoShape *shape, geo
return ga->used - origincount;
}
/* Compute the sorted set scores min (inclusive), max (exclusive) we should
* query in order to retrieve all the elements inside the specified area
* 'hash'. The two scores are returned by reference in *min and *max. */
void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max) {
/* We want to compute the sorted set scores that will include all the
* elements inside the specified Geohash 'hash', which has as many
* bits as specified by hash.step * 2.
*
* So if step is, for example, 3, and the hash value in binary
* is 101010, since our score is 52 bits we want every element which
* is in binary: 101010?????????????????????????????????????????????
* Where ? can be 0 or 1.
*
* To get the min score we just use the initial hash value left
* shifted enough to get the 52 bit value. Later we increment the
* 6 bit prefix (see the hash.bits++ statement), and get the new
* prefix: 101011, which we align again to 52 bits to get the maximum
* value (which is excluded from the search). So we get everything
* between the two following scores (represented in binary):
*
* 1010100000000000000000000000000000000000000000000000 (included)
* and
* 1010110000000000000000000000000000000000000000000000 (excluded).
*/
*min = geohashAlign52Bits(hash);
hash.bits++;
*max = geohashAlign52Bits(hash);
}
/* Obtain all members between the min/max of this geohash bounding box.
* Populate a geoArray of GeoPoints by calling geoGetPointsInRange().
* Return the number of points added to the array. */

View file

@ -2,6 +2,7 @@
#define __GEO_H__
#include <stddef.h> /* for size_t */
#include "geohash_helper.h"
/* Structures used inside geo.c in order to represent points and array of
* points on the earth. */
@ -19,4 +20,7 @@ typedef struct geoArray {
size_t used;
} geoArray;
int geoWithinShape(GeoShape *shape, double score, double *xy, double *distance);
void scoresOfGeoHashBox(GeoHashBits hash, GeoHashFix52Bits *min, GeoHashFix52Bits *max);
#endif

View file

@ -7,6 +7,7 @@
#include "server/acl/acl_commands_def.h"
extern "C" {
#include "redis/geo.h"
#include "redis/geohash.h"
#include "redis/geohash_helper.h"
#include "redis/listpack.h"
@ -38,6 +39,11 @@ namespace {
using CI = CommandId;
static const char kNxXxErr[] = "XX and NX options at the same time are not compatible";
static const char kFromMemberLonglatErr[] =
"FROMMEMBER and FROMLONLAT options at the same time are not compatible";
static const char kByRadiusBoxErr[] =
"BYRADIUS and BYBOX options at the same time are not compatible";
static const char kAscDescErr[] = "ASC and DESC options at the same time are not compatible";
static const char kScoreNaN[] = "resulting score is not a number (NaN)";
static const char kFloatRangeErr[] = "min or max is not a float";
static const char kLexRangeErr[] = "min or max not valid string range item";
@ -48,6 +54,30 @@ using MScoreResponse = std::vector<std::optional<double>>;
using ScoredMember = std::pair<std::string, double>;
using ScoredArray = std::vector<ScoredMember>;
struct GeoPoint {
double longitude;
double latitude;
double dist;
double score;
std::string member;
GeoPoint() : longitude(0.0), latitude(0.0), dist(0.0), score(0.0){};
GeoPoint(double _longitude, double _latitude, double _dist, double _score,
const std::string& _member)
: longitude(_longitude), latitude(_latitude), dist(_dist), score(_score), member(_member){};
};
using GeoArray = std::vector<GeoPoint>;
enum class Sorting { kUnsorted, kAsc, kDesc };
struct GeoSearchOpts {
double conversion = 0;
uint64_t count = 0;
Sorting sorting = Sorting::kUnsorted;
bool any = 0;
bool withdist = 0;
bool withcoord = 0;
bool withhash = 0;
};
inline zrangespec GetZrangeSpec(bool reverse, const ZSetFamily::ScoreInterval& si) {
auto interval = si;
if (reverse)
@ -1348,6 +1378,23 @@ auto OpRange(const ZSetFamily::ZRangeSpec& range_spec, const OpArgs& op_args, st
return iv.PopResult();
}
auto OpRanges(const std::vector<ZSetFamily::ZRangeSpec>& range_specs, const OpArgs& op_args,
string_view key) -> OpResult<vector<ScoredArray>> {
OpResult<PrimeIterator> res_it = op_args.shard->db_slice().Find(op_args.db_cntx, key, OBJ_ZSET);
if (!res_it)
return res_it.status();
PrimeValue& pv = res_it.value()->second;
vector<ScoredArray> result_arrays;
for (auto& range_spec : range_specs) {
IntervalVisitor iv{Action::RANGE, range_spec.params, &pv};
std::visit(iv, range_spec.interval);
result_arrays.push_back(iv.PopResult());
}
return result_arrays;
}
OpResult<unsigned> OpRemRange(const OpArgs& op_args, string_view key,
const ZSetFamily::ZRangeSpec& range_spec) {
auto& db_slice = op_args.shard->db_slice();
@ -2599,6 +2646,298 @@ void ZSetFamily::GeoDist(CmdArgList args, ConnectionContext* cntx) {
distance_multiplier);
}
namespace {
// Search all eight neighbors + self geohash box
bool MembersOfAllNeighbors(ConnectionContext* cntx, string_view key, const GeoHashRadius& n,
const GeoShape& shape_ref, GeoArray* ga, unsigned long limit) {
array<GeoHashBits, 9> neighbors;
unsigned int last_processed = 0;
GeoShape* shape = &(const_cast<GeoShape&>(shape_ref));
neighbors[0] = n.hash;
neighbors[1] = n.neighbors.north;
neighbors[2] = n.neighbors.south;
neighbors[3] = n.neighbors.east;
neighbors[4] = n.neighbors.west;
neighbors[5] = n.neighbors.north_east;
neighbors[6] = n.neighbors.north_west;
neighbors[7] = n.neighbors.south_east;
neighbors[8] = n.neighbors.south_west;
// Get range_specs for neighbors (*and* our own hashbox)
std::vector<ZSetFamily::ZRangeSpec> range_specs;
for (unsigned int i = 0; i < neighbors.size(); i++) {
if (HASHISZERO(neighbors[i])) {
continue;
}
// When a huge Radius (in the 5000 km range or more) is used,
// adjacent neighbors can be the same, leading to duplicated
// elements. Skip every range which is the same as the one
// processed previously.
if (last_processed && neighbors[i].bits == neighbors[last_processed].bits &&
neighbors[i].step == neighbors[last_processed].step) {
continue;
}
GeoHashFix52Bits min, max;
scoresOfGeoHashBox(neighbors[i], &min, &max);
ZSetFamily::ScoreInterval si;
si.first = ZSetFamily::Bound{static_cast<double>(min), false};
si.second = ZSetFamily::Bound{static_cast<double>(max), true};
ZSetFamily::RangeParams range_params;
range_params.interval_type = ZSetFamily::RangeParams::IntervalType::SCORE;
range_params.with_scores = true;
range_specs.emplace_back(si, range_params);
last_processed = i;
}
// get all the matching members and add them to the potential result list
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpRanges(range_specs, t->GetOpArgs(shard), key);
};
OpResult<vector<ScoredArray>> result_arrays =
cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result_arrays.status() == OpStatus::WRONG_TYPE) {
(*cntx)->SendError(kWrongTypeErr);
return false;
}
// filter potential result list
double xy[2];
double distance;
for (auto& arr : *result_arrays) {
for (auto& p : arr) {
if (geoWithinShape(shape, p.second, xy, &distance) == 0) {
ga->emplace_back(xy[0], xy[1], distance, p.second, p.first);
if (limit > 0 && ga->size() >= limit)
break;
}
}
}
return true;
}
void SortIfNeeded(GeoArray* ga, Sorting sorting, uint64_t count) {
if (sorting == Sorting::kUnsorted)
return;
auto comparator = [&](const GeoPoint& a, const GeoPoint& b) {
if (sorting == Sorting::kAsc) {
return a.dist < b.dist;
} else {
DCHECK(sorting == Sorting::kDesc);
return a.dist > b.dist;
}
};
if (count > 0) {
std::partial_sort(ga->begin(), ga->begin() + count, ga->end(), comparator);
ga->resize(count);
} else {
std::sort(ga->begin(), ga->end(), comparator);
}
}
void GeoSearchGeneric(ConnectionContext* cntx, const GeoShape& shape_ref, string_view key,
const GeoSearchOpts& geo_ops) {
// query
GeoShape* shape = &(const_cast<GeoShape&>(shape_ref));
GeoHashRadius georadius = geohashCalculateAreasByShapeWGS84(shape);
GeoArray ga;
if (!MembersOfAllNeighbors(cntx, key, georadius, shape_ref, &ga,
geo_ops.any ? geo_ops.count : 0)) {
return;
}
// if no matching results, the user gets an empty reply.
if (ga.empty()) {
(*cntx)->SendNull();
return;
}
// sort and trim by count
SortIfNeeded(&ga, geo_ops.sorting, geo_ops.count);
// generate reply array withdist, withcoords, withhash
int record_size = 1;
if (geo_ops.withdist) {
record_size++;
}
if (geo_ops.withhash) {
record_size++;
}
if (geo_ops.withcoord) {
record_size++;
}
(*cntx)->StartArray(ga.size());
for (const auto& p : ga) {
// [member, dist, x, y, hash]
(*cntx)->StartArray(record_size);
(*cntx)->SendBulkString(p.member);
if (geo_ops.withdist) {
(*cntx)->SendDouble(p.dist / geo_ops.conversion);
}
if (geo_ops.withhash) {
(*cntx)->SendDouble(p.score);
}
if (geo_ops.withcoord) {
(*cntx)->StartArray(2);
(*cntx)->SendDouble(p.longitude);
(*cntx)->SendDouble(p.latitude);
}
}
}
} // namespace
void ZSetFamily::GeoSearch(CmdArgList args, ConnectionContext* cntx) {
// parse arguments
string_view key = ArgS(args, 0);
GeoShape shape = {};
GeoSearchOpts geo_ops;
// FROMMEMBER or FROMLONLAT is set
bool from_set = false;
// BYRADIUS or BYBOX is set
bool by_set = false;
for (size_t i = 1; i < args.size(); ++i) {
ToUpper(&args[i]);
string_view cur_arg = ArgS(args, i);
if (cur_arg == "FROMMEMBER") {
if (from_set) {
return (*cntx)->SendError(kFromMemberLonglatErr);
} else if (i + 1 < args.size()) {
string_view member;
member = ArgS(args, i + 1);
// member to latlong, set shape.xy
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpScore(t->GetOpArgs(shard), key, member);
};
OpResult<double> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr);
} else if (!result) {
return (*cntx)->SendError("Member not found");
}
ScoreToLongLat(*result, shape.xy);
from_set = true;
i++;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
} else if (cur_arg == "FROMLONLAT") {
if (from_set) {
return (*cntx)->SendError(kFromMemberLonglatErr);
} else if (i + 2 < args.size()) {
string_view longitude_str = ArgS(args, i + 1);
string_view latitude_str = ArgS(args, i + 2);
pair<double, double> longlat;
if (!ParseLongLat(longitude_str, latitude_str, &longlat)) {
string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude_str, ",",
latitude_str);
return (*cntx)->SendError(err, kSyntaxErrType);
}
shape.xy[0] = longlat.first;
shape.xy[1] = longlat.second;
from_set = true;
i += 2;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
} else if (cur_arg == "BYRADIUS") {
if (by_set) {
return (*cntx)->SendError(kByRadiusBoxErr);
} else if (i + 2 < args.size()) {
if (!ParseDouble(ArgS(args, i + 1), &shape.t.radius)) {
return (*cntx)->SendError(kInvalidFloatErr);
}
string_view unit;
unit = ArgS(args, i + 2);
shape.conversion = ExtractUnit(unit);
geo_ops.conversion = shape.conversion;
if (shape.conversion == -1) {
return (*cntx)->SendError("unsupported unit provided. please use M, KM, FT, MI");
}
shape.type = CIRCULAR_TYPE;
by_set = true;
i += 2;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
} else if (cur_arg == "BYBOX") {
if (by_set) {
return (*cntx)->SendError(kByRadiusBoxErr);
} else if (i + 3 < args.size()) {
if (!ParseDouble(ArgS(args, i + 1), &shape.t.r.width)) {
return (*cntx)->SendError(kInvalidFloatErr);
}
if (!ParseDouble(ArgS(args, i + 2), &shape.t.r.height)) {
return (*cntx)->SendError(kInvalidFloatErr);
}
string_view unit;
unit = ArgS(args, i + 3);
shape.conversion = ExtractUnit(unit);
geo_ops.conversion = shape.conversion;
if (shape.conversion == -1) {
return (*cntx)->SendError("unsupported unit provided. please use M, KM, FT, MI");
}
shape.type = RECTANGLE_TYPE;
by_set = true;
i += 3;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
} else if (cur_arg == "ASC") {
if (geo_ops.sorting != Sorting::kUnsorted) {
return (*cntx)->SendError(kAscDescErr);
} else {
geo_ops.sorting = Sorting::kAsc;
}
} else if (cur_arg == "DESC") {
if (geo_ops.sorting != Sorting::kUnsorted) {
return (*cntx)->SendError(kAscDescErr);
} else {
geo_ops.sorting = Sorting::kDesc;
}
} else if (cur_arg == "COUNT") {
if (i + 1 < args.size()) {
absl::SimpleAtoi(std::string(ArgS(args, i + 1)), &geo_ops.count);
i++;
} else {
return (*cntx)->SendError(kSyntaxErr);
}
if (i + 1 < args.size() && ArgS(args, i + 1) == "ANY") {
geo_ops.any = true;
i++;
}
} else if (cur_arg == "WITHCOORD") {
geo_ops.withcoord = true;
} else if (cur_arg == "WITHDIST") {
geo_ops.withdist = true;
} else if (cur_arg == "WITHHASH")
geo_ops.withhash = true;
else {
return (*cntx)->SendError(kSyntaxErr);
}
}
// check mandatory options
if (!from_set) {
return (*cntx)->SendError(kSyntaxErr);
}
if (!by_set) {
return (*cntx)->SendError(kSyntaxErr);
}
GeoSearchGeneric(cntx, shape, key, geo_ops);
}
#define HFUNC(x) SetHandler(&ZSetFamily::x)
namespace acl {
@ -2635,6 +2974,7 @@ constexpr uint32_t kGeoAdd = WRITE | GEO | SLOW;
constexpr uint32_t kGeoHash = READ | GEO | SLOW;
constexpr uint32_t kGeoPos = READ | GEO | SLOW;
constexpr uint32_t kGeoDist = READ | GEO | SLOW;
constexpr uint32_t kGeoSearch = READ | GEO | SLOW;
} // namespace acl
void ZSetFamily::Register(CommandRegistry* registry) {
@ -2695,7 +3035,8 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, 1, acl::kGeoAdd}.HFUNC(GeoAdd)
<< CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, 1, acl::kGeoHash}.HFUNC(GeoHash)
<< CI{"GEOPOS", CO::FAST | CO::READONLY, -2, 1, 1, 1, acl::kGeoPos}.HFUNC(GeoPos)
<< CI{"GEODIST", CO::READONLY, -4, 1, 1, 1, acl::kGeoDist}.HFUNC(GeoDist);
<< CI{"GEODIST", CO::READONLY, -4, 1, 1, 1, acl::kGeoDist}.HFUNC(GeoDist)
<< CI{"GEOSEARCH", CO::READONLY, -4, 1, 1, 1, acl::kGeoSearch}.HFUNC(GeoSearch);
}
} // namespace dfly

View file

@ -48,6 +48,8 @@ class ZSetFamily {
struct ZRangeSpec {
std::variant<IndexInterval, ScoreInterval, LexInterval, TopNScored> interval;
RangeParams params;
ZRangeSpec() = default;
ZRangeSpec(const ScoreInterval& si, const RangeParams& rp) : interval(si), params(rp){};
};
private:
@ -98,6 +100,7 @@ class ZSetFamily {
static void GeoHash(CmdArgList args, ConnectionContext* cntx);
static void GeoPos(CmdArgList args, ConnectionContext* cntx);
static void GeoDist(CmdArgList args, ConnectionContext* cntx);
static void GeoSearch(CmdArgList args, ConnectionContext* cntx);
};
} // namespace dfly

View file

@ -777,4 +777,69 @@ TEST_F(ZSetFamilyTest, GeoDist) {
EXPECT_THAT(resp, ArgType(RespExpr::NIL));
}
TEST_F(ZSetFamilyTest, GeoSearch) {
EXPECT_EQ(10, CheckedInt({"geoadd", "Europe", "13.4050", "52.5200", "Berlin", "3.7038",
"40.4168", "Madrid", "9.1427", "38.7369", "Lisbon", "2.3522",
"48.8566", "Paris", "16.3738", "48.2082", "Vienna", "4.8952",
"52.3702", "Amsterdam", "10.7522", "59.9139", "Oslo", "23.7275",
"37.9838", "Athens", "19.0402", "47.4979", "Budapest", "6.2603",
"53.3498", "Dublin"}));
auto resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500",
"KM", "WITHCOORD", "WITHDIST", "WITHHASH"});
EXPECT_THAT(
resp,
RespArray(ElementsAre(
RespArray(ElementsAre("Berlin", "0.00017343178521311378", "3673983950397063",
RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))),
RespArray(
ElementsAre("Dublin", "487.5619030644293", "3678981558208417",
RespArray(ElementsAre("6.260299980640411", "53.34980087538425")))))));
resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYBOX", "1000", "1000",
"KM", "WITHCOORD", "WITHDIST"});
EXPECT_THAT(
resp,
RespArray(ElementsAre(
RespArray(ElementsAre("Vienna", "523.6926930553866",
RespArray(ElementsAre("16.373799741268158", "48.20820011474228")))),
RespArray(ElementsAre("Berlin", "0.00017343178521311378",
RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))),
RespArray(
ElementsAre("Dublin", "487.5619030644293",
RespArray(ElementsAre("6.260299980640411", "53.34980087538425")))))));
resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500", "KM",
"COUNT", "3", "WITHCOORD", "WITHDIST"});
EXPECT_THAT(
resp,
RespArray(ElementsAre(
RespArray(ElementsAre("Berlin", "0.00017343178521311378",
RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))),
RespArray(
ElementsAre("Dublin", "487.5619030644293",
RespArray(ElementsAre("6.260299980640411", "53.34980087538425")))))));
resp = Run({"GEOSEARCH", "Europe", "FROMLONLAT", "13.4050", "52.5200", "BYRADIUS", "500", "KM",
"DESC", "WITHCOORD", "WITHDIST"});
EXPECT_THAT(
resp,
RespArray(ElementsAre(
RespArray(ElementsAre("Dublin", "487.5619030644293",
RespArray(ElementsAre("6.260299980640411", "53.34980087538425")))),
RespArray(
ElementsAre("Berlin", "0.00017343178521311378",
RespArray(ElementsAre("13.405002057552338", "52.51999907056681")))))));
resp = Run({"GEOSEARCH", "Europe", "FROMMEMBER", "Madrid", "BYRADIUS", "700", "KM", "WITHCOORD",
"WITHDIST"});
EXPECT_THAT(
resp,
RespArray(ElementsAre(
RespArray(ElementsAre(
"Madrid", "0", RespArray(ElementsAre("3.7038007378578186", "40.416799319406216")))),
RespArray(
ElementsAre("Lisbon", "502.20769462704084",
RespArray(ElementsAre("9.142698347568512", "38.736900197448534")))))));
}
} // namespace dfly