fix: choose an appropriate SIMD implementation for aarch64 (#579)

This commit is contained in:
Roman Gershman 2022-12-19 12:18:41 +02:00 committed by GitHub
parent 69d9ef2c03
commit 5f572f00f1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -14,6 +14,8 @@
#endif
#include <absl/base/internal/endian.h>
using namespace std;
namespace dfly {
namespace detail {
@ -31,6 +33,75 @@ static inline uint64_t Compress8x7bit(uint64_t x) {
return x;
}
static inline pair<const char*, uint8_t*> simd_variant1_pack(const char* ascii, const char* end,
uint8_t* bin) {
__m128i val, rpart, lpart;
// Skips 8th byte (indexc 7) in the lower 8-byte part.
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
while (ascii <= end) {
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
/*
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
*/
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F007F007F007F00));
val = _mm_or_si128(_mm_srli_epi64(lpart, 1), rpart);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x3FFF00003FFF0000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 2), rpart);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
val = _mm_shuffle_epi8(val, control);
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
bin += 14;
ascii += 16;
}
return make_pair(ascii, bin);
}
static inline pair<const char*, uint8_t*> simd_variant2_pack(const char* ascii, const char* end,
uint8_t* bin) {
// Skips 8th byte (indexc 7) in the lower 8-byte part.
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
__m128i val, rpart, lpart;
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
while (ascii <= end) {
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
/*
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
*/
val = _mm_maddubs_epi16(_mm_set1_epi16(0x8001), val);
val = _mm_madd_epi16(_mm_set1_epi32(0x40000001), val);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
val = _mm_shuffle_epi8(val, control);
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
bin += 14;
ascii += 16;
}
return make_pair(ascii, bin);
}
// Daniel Lemire's function validate_ascii_fast() - under Apache/MIT license.
// See https://github.com/lemire/fastvalidate-utf-8/
// The function returns true (1) if all chars passed in src are
@ -103,38 +174,7 @@ void ascii_pack_simd(const char* ascii, size_t len, uint8_t* bin) {
// overwrite we finish loop one iteration earlier.
const char* end = ascii + len - 32;
// Skips 8th byte (indexc 7) in the lower 8-byte part.
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
__m128i val, rpart, lpart;
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
while (ascii <= end) {
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
/*
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
*/
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x007F007F007F007F));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x7F007F007F007F00));
val = _mm_or_si128(_mm_srli_epi64(lpart, 1), rpart);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x00003FFF00003FFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x3FFF00003FFF0000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 2), rpart);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
val = _mm_shuffle_epi8(val, control);
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
bin += 14;
ascii += 16;
}
tie(ascii, bin) = simd_variant1_pack(ascii, end, bin);
end += 32; // Bring back end.
DCHECK(ascii < end);
@ -147,32 +187,12 @@ void ascii_pack_simd2(const char* ascii, size_t len, uint8_t* bin) {
// overwrite we finish loop one iteration earlier.
const char* end = ascii + len - 32;
// Skips 8th byte (indexc 7) in the lower 8-byte part.
const __m128i control = _mm_set_epi8(-1, -1, 14, 13, 12, 11, 10, 9, 8, 6, 5, 4, 3, 2, 1, 0);
__m128i val, rpart, lpart;
// Based on the question I asked here: https://stackoverflow.com/q/74831843/2280111
while (ascii <= end) {
val = _mm_loadu_si128(reinterpret_cast<const __m128i*>(ascii));
/*
x = ((x & 0x7F007F007F007F00) >> 1) | (x & 0x007F007F007F007F);
x = ((x & 0x3FFF00003FFF0000) >> 2) | (x & 0x00003FFF00003FFF);
x = ((x & 0x0FFFFFFF00000000) >> 4) | (x & 0x000000000FFFFFFF);
*/
val = _mm_maddubs_epi16(_mm_set1_epi16(0x8001), val);
val = _mm_madd_epi16(_mm_set1_epi32(0x40000001), val);
rpart = _mm_and_si128(val, _mm_set1_epi64x(0x000000000FFFFFFF));
lpart = _mm_and_si128(val, _mm_set1_epi64x(0x0FFFFFFF00000000));
val = _mm_or_si128(_mm_srli_epi64(lpart, 4), rpart);
val = _mm_shuffle_epi8(val, control);
_mm_storeu_si128(reinterpret_cast<__m128i*>(bin), val);
bin += 14;
ascii += 16;
}
// on arm var
#if defined(__aarch64__)
tie(ascii, bin) = simd_variant1_pack(ascii, end, bin);
#else
tie(ascii, bin) = simd_variant2_pack(ascii, end, bin);
#endif
end += 32; // Bring back end.
DCHECK(ascii < end);