feat: implement two geo commands GEOADD/GEOHASH (#1543)

Only most basic functionality is covered, the options support
and variadic functionality for GEOADD is missing.

Signed-off-by: Roman Gershman <roman@dragonflydb.io>
This commit is contained in:
Roman Gershman 2023-07-16 09:19:35 +03:00 committed by GitHub
parent c5922fec8a
commit 187bca9317
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 170 additions and 27 deletions

View file

@ -5,6 +5,8 @@
#include "server/zset_family.h"
extern "C" {
#include "redis/geohash.h"
#include "redis/geohash_helper.h"
#include "redis/listpack.h"
#include "redis/object.h"
#include "redis/util.h"
@ -34,6 +36,7 @@ static const char kNxXxErr[] = "XX and NX options at the same time are not compa
static const char kScoreNaN[] = "resulting score is not a number (NaN)";
static const char kFloatRangeErr[] = "min or max is not a float";
static const char kLexRangeErr[] = "min or max not valid string range item";
constexpr string_view kGeoAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"sv;
constexpr unsigned kMaxListPackValue = 64;
using MScoreResponse = std::vector<std::optional<double>>;
@ -130,6 +133,45 @@ OpResult<PrimeIterator> FindZEntry(const ZParams& zparams, const OpArgs& op_args
return it;
}
bool ToAsciiGeoHash(const std::optional<double>& val, array<char, 12>* buf) {
if (!val.has_value())
return false;
double score = *val;
double xy[2];
GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX};
if (!geohashDecodeToLongLatType(hash, xy)) {
return false;
}
/* Re-encode */
GeoHashRange r[2];
r[0].min = -180;
r[0].max = 180;
r[1].min = -90;
r[1].max = 90;
geohashEncode(&r[0], &r[1], xy[0], xy[1], 26, &hash);
for (int i = 0; i < 11; i++) {
int idx;
if (i == 10) {
/* We have just 52 bits, but the API used to output
* an 11 bytes geohash. For compatibility we assume
* zero. */
idx = 0;
} else {
idx = (hash.bits >> (52 - ((i + 1) * 5))) % kGeoAlphabet.size();
}
(*buf)[i] = kGeoAlphabet[idx];
}
(*buf)[11] = '\0';
return true;
}
enum class Action { RANGE = 0, REMOVE = 1, POP = 2 };
class IntervalVisitor {
@ -599,6 +641,20 @@ bool ParseBound(string_view src, ZSetFamily::Bound* bound) {
return ParseDouble(src, &bound->val);
}
bool ParseLongLat(string_view lon, string_view lat, std::pair<double, double>* res) {
if (!ParseDouble(lon, &res->first))
return false;
if (!ParseDouble(lat, &res->second))
return false;
if (res->first < GEO_LONG_MIN || res->first > GEO_LONG_MAX || res->second < GEO_LAT_MIN ||
res->second > GEO_LAT_MAX) {
return false;
}
return true;
}
bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) {
if (src.empty())
return false;
@ -1573,6 +1629,36 @@ OpResult<StringVec> OpScan(const OpArgs& op_args, std::string_view key, uint64_t
return res;
}
void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp,
ConnectionContext* cntx) {
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp);
};
OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) {
return (*cntx)->SendError(add_result.status());
}
// KEY_NOTFOUND may happen in case of XX flag.
if (add_result.status() == OpStatus::KEY_NOTFOUND) {
if (zparams.flags & ZADD_IN_INCR)
(*cntx)->SendNull();
else
(*cntx)->SendLong(0);
} else if (add_result.status() == OpStatus::SKIPPED) {
(*cntx)->SendNull();
} else if (add_result->is_nan) {
(*cntx)->SendError(kScoreNaN);
} else {
if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result->new_score);
} else {
(*cntx)->SendLong(add_result->num_updated);
}
}
}
} // namespace
void ZSetFamily::BZPopMin(CmdArgList args, ConnectionContext* cntx) {
@ -1651,32 +1737,7 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) {
DCHECK(cntx->transaction);
absl::Span memb_sp{members.data(), members.size()};
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp);
};
OpResult<AddResult> add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) {
return (*cntx)->SendError(add_result.status());
}
// KEY_NOTFOUND may happen in case of XX flag.
if (add_result.status() == OpStatus::KEY_NOTFOUND) {
if (zparams.flags & ZADD_IN_INCR)
(*cntx)->SendNull();
else
(*cntx)->SendLong(0);
} else if (add_result.status() == OpStatus::SKIPPED) {
(*cntx)->SendNull();
} else if (add_result->is_nan) {
(*cntx)->SendError(kScoreNaN);
} else {
if (zparams.flags & ZADD_IN_INCR) {
(*cntx)->SendDouble(add_result->new_score);
} else {
(*cntx)->SendLong(add_result->num_updated);
}
}
ZAddGeneric(key, zparams, memb_sp, cntx);
}
void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) {
@ -2370,6 +2431,73 @@ void ZSetFamily::ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cn
OutputScoredArrayResult(result, range_params, cntx);
}
void ZSetFamily::GeoAdd(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
// TODO: to handle options and multiple elements
ZParams zparams;
string_view longitude = ArgS(args, 1);
string_view latitude = ArgS(args, 2);
string_view member = ArgS(args, 3);
// TODO: to remove this check once the TODO above is handled.
if (args.size() != 4) {
return (*cntx)->SendError(kSyntaxErr);
}
// TODO: the code handles only a single tripple of long,lat,member.
// it has to be extended to handle multiple elements.
pair<double, double> longlat;
for (int i = 0; i < 1; i++) {
if (!ParseLongLat(longitude, latitude, &longlat)) {
string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude, ",", latitude);
return (*cntx)->SendError(err, kSyntaxErrType);
}
}
/* Turn the coordinates into the score of the element. */
GeoHashBits hash;
geohashEncodeWGS84(longlat.first, longlat.second, GEO_STEP_MAX, &hash);
GeoHashFix52Bits bits = geohashAlign52Bits(hash);
absl::InlinedVector<ScoredMemberView, 4> members;
members.emplace_back(bits, member);
ZAddGeneric(key, zparams, absl::Span{members.data(), members.size()}, cntx);
}
void ZSetFamily::GeoHash(CmdArgList args, ConnectionContext* cntx) {
string_view key = ArgS(args, 0);
absl::InlinedVector<string_view, 8> members(args.size() - 1);
for (size_t i = 1; i < args.size(); ++i) {
members[i - 1] = ArgS(args, i);
}
auto cb = [&](Transaction* t, EngineShard* shard) {
return OpMScore(t->GetOpArgs(shard), key, members);
};
OpResult<MScoreResponse> result = cntx->transaction->ScheduleSingleHopT(std::move(cb));
if (result.status() == OpStatus::WRONG_TYPE) {
return (*cntx)->SendError(kWrongTypeErr);
}
(*cntx)->StartArray(result->size()); // Array return type.
const MScoreResponse& arr = result.value();
array<char, 12> buf;
for (const auto& p : arr) {
if (ToAsciiGeoHash(p, &buf)) {
(*cntx)->SendBulkString(string_view{buf.data(), buf.size() - 1});
} else {
(*cntx)->SendNull();
}
}
}
#define HFUNC(x) SetHandler(&ZSetFamily::x)
void ZSetFamily::Register(CommandRegistry* registry) {
@ -2408,7 +2536,11 @@ void ZSetFamily::Register(CommandRegistry* registry) {
<< CI{"ZSCAN", CO::READONLY, -3, 1, 1, 1}.HFUNC(ZScan)
<< CI{"ZUNION", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 2, 2, 1}.HFUNC(
ZUnion)
<< CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZUnionStore);
<< CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZUnionStore)
// GEO functions
<< CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, 1}.HFUNC(GeoAdd)
<< CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, 1}.HFUNC(GeoHash);
}
} // namespace dfly

View file

@ -96,6 +96,9 @@ class ZSetFamily {
static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx);
static bool ParseRangeByScoreParams(CmdArgList args, RangeParams* params);
static void ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx);
static void GeoAdd(CmdArgList args, ConnectionContext* cntx);
static void GeoHash(CmdArgList args, ConnectionContext* cntx);
};
} // namespace dfly

View file

@ -637,4 +637,12 @@ TEST_F(ZSetFamilyTest, ZDiff) {
EXPECT_THAT(resp.GetVec(), ElementsAre("two", "2", "three", "3", "four", "4"));
}
TEST_F(ZSetFamilyTest, GeoAdd) {
EXPECT_EQ(1, CheckedInt({"geoadd", "Sicily", "13.361389", "38.115556", "Palermo"}));
EXPECT_EQ(1, CheckedInt({"geoadd", "Sicily", "15.087269", "37.502669", "Catania"}));
EXPECT_EQ(0, CheckedInt({"geoadd", "Sicily", "15.087269", "37.502669", "Catania"}));
auto resp = Run({"geohash", "Sicily", "Palermo", "Catania"});
EXPECT_THAT(resp, RespArray(ElementsAre("sqc8b49rny0", "sqdtr74hyu0")));
}
} // namespace dfly