From 187bca931737729253c1bdfb0fee71c9559b31bb Mon Sep 17 00:00:00 2001 From: Roman Gershman Date: Sun, 16 Jul 2023 09:19:35 +0300 Subject: [PATCH] feat: implement two geo commands GEOADD/GEOHASH (#1543) Only most basic functionality is covered, the options support and variadic functionality for GEOADD is missing. Signed-off-by: Roman Gershman --- src/server/zset_family.cc | 186 ++++++++++++++++++++++++++++----- src/server/zset_family.h | 3 + src/server/zset_family_test.cc | 8 ++ 3 files changed, 170 insertions(+), 27 deletions(-) diff --git a/src/server/zset_family.cc b/src/server/zset_family.cc index ae166573a..632016a50 100644 --- a/src/server/zset_family.cc +++ b/src/server/zset_family.cc @@ -5,6 +5,8 @@ #include "server/zset_family.h" extern "C" { +#include "redis/geohash.h" +#include "redis/geohash_helper.h" #include "redis/listpack.h" #include "redis/object.h" #include "redis/util.h" @@ -34,6 +36,7 @@ static const char kNxXxErr[] = "XX and NX options at the same time are not compa static const char kScoreNaN[] = "resulting score is not a number (NaN)"; static const char kFloatRangeErr[] = "min or max is not a float"; static const char kLexRangeErr[] = "min or max not valid string range item"; +constexpr string_view kGeoAlphabet = "0123456789bcdefghjkmnpqrstuvwxyz"sv; constexpr unsigned kMaxListPackValue = 64; using MScoreResponse = std::vector>; @@ -130,6 +133,45 @@ OpResult FindZEntry(const ZParams& zparams, const OpArgs& op_args return it; } +bool ToAsciiGeoHash(const std::optional& val, array* buf) { + if (!val.has_value()) + return false; + + double score = *val; + + double xy[2]; + GeoHashBits hash = {.bits = (uint64_t)score, .step = GEO_STEP_MAX}; + + if (!geohashDecodeToLongLatType(hash, xy)) { + return false; + } + + /* Re-encode */ + GeoHashRange r[2]; + r[0].min = -180; + r[0].max = 180; + r[1].min = -90; + r[1].max = 90; + + geohashEncode(&r[0], &r[1], xy[0], xy[1], 26, &hash); + + for (int i = 0; i < 11; i++) { + int idx; + if (i == 10) { + /* We have just 52 bits, but the API used to output + * an 11 bytes geohash. For compatibility we assume + * zero. */ + idx = 0; + } else { + idx = (hash.bits >> (52 - ((i + 1) * 5))) % kGeoAlphabet.size(); + } + (*buf)[i] = kGeoAlphabet[idx]; + } + (*buf)[11] = '\0'; + + return true; +} + enum class Action { RANGE = 0, REMOVE = 1, POP = 2 }; class IntervalVisitor { @@ -599,6 +641,20 @@ bool ParseBound(string_view src, ZSetFamily::Bound* bound) { return ParseDouble(src, &bound->val); } +bool ParseLongLat(string_view lon, string_view lat, std::pair* res) { + if (!ParseDouble(lon, &res->first)) + return false; + + if (!ParseDouble(lat, &res->second)) + return false; + + if (res->first < GEO_LONG_MIN || res->first > GEO_LONG_MAX || res->second < GEO_LAT_MIN || + res->second > GEO_LAT_MAX) { + return false; + } + return true; +} + bool ParseLexBound(string_view src, ZSetFamily::LexBound* bound) { if (src.empty()) return false; @@ -1573,6 +1629,36 @@ OpResult OpScan(const OpArgs& op_args, std::string_view key, uint64_t return res; } +void ZAddGeneric(string_view key, const ZParams& zparams, ScoredMemberSpan memb_sp, + ConnectionContext* cntx) { + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp); + }; + + OpResult add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) { + return (*cntx)->SendError(add_result.status()); + } + + // KEY_NOTFOUND may happen in case of XX flag. + if (add_result.status() == OpStatus::KEY_NOTFOUND) { + if (zparams.flags & ZADD_IN_INCR) + (*cntx)->SendNull(); + else + (*cntx)->SendLong(0); + } else if (add_result.status() == OpStatus::SKIPPED) { + (*cntx)->SendNull(); + } else if (add_result->is_nan) { + (*cntx)->SendError(kScoreNaN); + } else { + if (zparams.flags & ZADD_IN_INCR) { + (*cntx)->SendDouble(add_result->new_score); + } else { + (*cntx)->SendLong(add_result->num_updated); + } + } +} + } // namespace void ZSetFamily::BZPopMin(CmdArgList args, ConnectionContext* cntx) { @@ -1651,32 +1737,7 @@ void ZSetFamily::ZAdd(CmdArgList args, ConnectionContext* cntx) { DCHECK(cntx->transaction); absl::Span memb_sp{members.data(), members.size()}; - auto cb = [&](Transaction* t, EngineShard* shard) { - return OpAdd(t->GetOpArgs(shard), zparams, key, memb_sp); - }; - - OpResult add_result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); - if (base::_in(add_result.status(), {OpStatus::WRONG_TYPE, OpStatus::OUT_OF_MEMORY})) { - return (*cntx)->SendError(add_result.status()); - } - - // KEY_NOTFOUND may happen in case of XX flag. - if (add_result.status() == OpStatus::KEY_NOTFOUND) { - if (zparams.flags & ZADD_IN_INCR) - (*cntx)->SendNull(); - else - (*cntx)->SendLong(0); - } else if (add_result.status() == OpStatus::SKIPPED) { - (*cntx)->SendNull(); - } else if (add_result->is_nan) { - (*cntx)->SendError(kScoreNaN); - } else { - if (zparams.flags & ZADD_IN_INCR) { - (*cntx)->SendDouble(add_result->new_score); - } else { - (*cntx)->SendLong(add_result->num_updated); - } - } + ZAddGeneric(key, zparams, memb_sp, cntx); } void ZSetFamily::ZCard(CmdArgList args, ConnectionContext* cntx) { @@ -2370,6 +2431,73 @@ void ZSetFamily::ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cn OutputScoredArrayResult(result, range_params, cntx); } +void ZSetFamily::GeoAdd(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 0); + + // TODO: to handle options and multiple elements + ZParams zparams; + + string_view longitude = ArgS(args, 1); + string_view latitude = ArgS(args, 2); + string_view member = ArgS(args, 3); + + // TODO: to remove this check once the TODO above is handled. + if (args.size() != 4) { + return (*cntx)->SendError(kSyntaxErr); + } + + // TODO: the code handles only a single tripple of long,lat,member. + // it has to be extended to handle multiple elements. + pair longlat; + for (int i = 0; i < 1; i++) { + if (!ParseLongLat(longitude, latitude, &longlat)) { + string err = absl::StrCat("-ERR invalid longitude,latitude pair ", longitude, ",", latitude); + + return (*cntx)->SendError(err, kSyntaxErrType); + } + } + + /* Turn the coordinates into the score of the element. */ + GeoHashBits hash; + geohashEncodeWGS84(longlat.first, longlat.second, GEO_STEP_MAX, &hash); + GeoHashFix52Bits bits = geohashAlign52Bits(hash); + + absl::InlinedVector members; + members.emplace_back(bits, member); + ZAddGeneric(key, zparams, absl::Span{members.data(), members.size()}, cntx); +} + +void ZSetFamily::GeoHash(CmdArgList args, ConnectionContext* cntx) { + string_view key = ArgS(args, 0); + + absl::InlinedVector members(args.size() - 1); + for (size_t i = 1; i < args.size(); ++i) { + members[i - 1] = ArgS(args, i); + } + + auto cb = [&](Transaction* t, EngineShard* shard) { + return OpMScore(t->GetOpArgs(shard), key, members); + }; + + OpResult result = cntx->transaction->ScheduleSingleHopT(std::move(cb)); + + if (result.status() == OpStatus::WRONG_TYPE) { + return (*cntx)->SendError(kWrongTypeErr); + } + + (*cntx)->StartArray(result->size()); // Array return type. + const MScoreResponse& arr = result.value(); + + array buf; + for (const auto& p : arr) { + if (ToAsciiGeoHash(p, &buf)) { + (*cntx)->SendBulkString(string_view{buf.data(), buf.size() - 1}); + } else { + (*cntx)->SendNull(); + } + } +} + #define HFUNC(x) SetHandler(&ZSetFamily::x) void ZSetFamily::Register(CommandRegistry* registry) { @@ -2408,7 +2536,11 @@ void ZSetFamily::Register(CommandRegistry* registry) { << CI{"ZSCAN", CO::READONLY, -3, 1, 1, 1}.HFUNC(ZScan) << CI{"ZUNION", CO::READONLY | CO::REVERSE_MAPPING | CO::VARIADIC_KEYS, -3, 2, 2, 1}.HFUNC( ZUnion) - << CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZUnionStore); + << CI{"ZUNIONSTORE", kStoreMask, -4, 3, 3, 1}.HFUNC(ZUnionStore) + + // GEO functions + << CI{"GEOADD", CO::FAST | CO::WRITE | CO::DENYOOM, -5, 1, 1, 1}.HFUNC(GeoAdd) + << CI{"GEOHASH", CO::FAST | CO::READONLY, -2, 1, 1, 1}.HFUNC(GeoHash); } } // namespace dfly diff --git a/src/server/zset_family.h b/src/server/zset_family.h index cfeb596b5..7b73ac4a9 100644 --- a/src/server/zset_family.h +++ b/src/server/zset_family.h @@ -96,6 +96,9 @@ class ZSetFamily { static void ZRankGeneric(CmdArgList args, bool reverse, ConnectionContext* cntx); static bool ParseRangeByScoreParams(CmdArgList args, RangeParams* params); static void ZPopMinMax(CmdArgList args, bool reverse, ConnectionContext* cntx); + + static void GeoAdd(CmdArgList args, ConnectionContext* cntx); + static void GeoHash(CmdArgList args, ConnectionContext* cntx); }; } // namespace dfly diff --git a/src/server/zset_family_test.cc b/src/server/zset_family_test.cc index 947de4230..85e952266 100644 --- a/src/server/zset_family_test.cc +++ b/src/server/zset_family_test.cc @@ -637,4 +637,12 @@ TEST_F(ZSetFamilyTest, ZDiff) { EXPECT_THAT(resp.GetVec(), ElementsAre("two", "2", "three", "3", "four", "4")); } +TEST_F(ZSetFamilyTest, GeoAdd) { + EXPECT_EQ(1, CheckedInt({"geoadd", "Sicily", "13.361389", "38.115556", "Palermo"})); + EXPECT_EQ(1, CheckedInt({"geoadd", "Sicily", "15.087269", "37.502669", "Catania"})); + EXPECT_EQ(0, CheckedInt({"geoadd", "Sicily", "15.087269", "37.502669", "Catania"})); + auto resp = Run({"geohash", "Sicily", "Palermo", "Catania"}); + EXPECT_THAT(resp, RespArray(ElementsAre("sqc8b49rny0", "sqdtr74hyu0"))); +} + } // namespace dfly