123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376 |
- #include "l2_distance.h"
- #include <library/cpp/sse/sse.h>
- #include <contrib/libs/cblas/include/cblas.h>
- #include <util/system/platform.h>
- template <typename Result, typename Number>
- inline Result SqrDelta(Number a, Number b) {
- Result diff = a < b ? b - a : a - b;
- return diff * diff;
- }
- template <typename Result, typename Number>
- inline Result L2SqrDistanceImpl(const Number* a, const Number* b, int length) {
- Result res = 0;
- for (int i = 0; i < length; i++) {
- res += SqrDelta<Result, Number>(a[i], b[i]);
- }
- return res;
- }
- template <typename Result, typename Number>
- inline Result L2SqrDistanceImpl2(const Number* a, const Number* b, int length) {
- Result s0 = 0;
- Result s1 = 0;
- while (length >= 2) {
- s0 += SqrDelta<Result, Number>(a[0], b[0]);
- s1 += SqrDelta<Result, Number>(a[1], b[1]);
- a += 2;
- b += 2;
- length -= 2;
- }
- while (length--)
- s0 += SqrDelta<Result, Number>(*a++, *b++);
- return s0 + s1;
- }
- template <typename Result, typename Number>
- inline Result L2SqrDistanceImpl4(const Number* a, const Number* b, int length) {
- Result s0 = 0;
- Result s1 = 0;
- Result s2 = 0;
- Result s3 = 0;
- while (length >= 4) {
- s0 += SqrDelta<Result, Number>(a[0], b[0]);
- s1 += SqrDelta<Result, Number>(a[1], b[1]);
- s2 += SqrDelta<Result, Number>(a[2], b[2]);
- s3 += SqrDelta<Result, Number>(a[3], b[3]);
- a += 4;
- b += 4;
- length -= 4;
- }
- while (length--)
- s0 += SqrDelta<Result, Number>(*a++, *b++);
- return s0 + s1 + s2 + s3;
- }
- inline ui32 L2SqrDistanceImplUI4(const ui8* a, const ui8* b, int length) {
- ui32 res = 0;
- for (int i = 0; i < length; i++) {
- res += SqrDelta<ui32, ui8>(a[i] & 0x0f, b[i] & 0x0f);
- res += SqrDelta<ui32, ui8>(a[i] & 0xf0, b[i] & 0xf0) >> 8;
- }
- return res;
- }
- #ifdef ARCADIA_SSE
- namespace NL2Distance {
- static const __m128i MASK_UI4_1 = _mm_set_epi8(0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f,
- 0x0f, 0x0f, 0x0f, 0x0f, 0x0f);
- static const __m128i MASK_UI4_2 = _mm_set_epi8(0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0,
- 0xf0, 0xf0, 0xf0, 0xf0, 0xf0);
- }
- ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) {
- const __m128i zero = _mm_setzero_si128();
- __m128i resVec = zero;
- while (length >= 16) {
- __m128i vec = _mm_subs_epi8(_mm_loadu_si128((const __m128i*)lhs), _mm_loadu_si128((const __m128i*)rhs));
- #ifdef _sse4_1_
- __m128i lo = _mm_cvtepi8_epi16(vec);
- __m128i hi = _mm_cvtepi8_epi16(_mm_alignr_epi8(vec, vec, 8));
- #else
- __m128i lo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, vec), 8);
- __m128i hi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, vec), 8);
- #endif
- resVec = _mm_add_epi32(resVec,
- _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi)));
- lhs += 16;
- rhs += 16;
- length -= 16;
- }
- alignas(16) ui32 res[4];
- _mm_store_si128((__m128i*)res, resVec);
- ui32 sum = res[0] + res[1] + res[2] + res[3];
- for (int i = 0; i < length; ++i) {
- sum += Sqr(static_cast<i32>(lhs[i]) - static_cast<i32>(rhs[i]));
- }
- return sum;
- }
- ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) {
- const __m128i zero = _mm_setzero_si128();
- __m128i resVec = zero;
- while (length >= 16) {
- __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
- __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
- // We will think about this vectors as about i16.
- __m128i lo = _mm_sub_epi16(_mm_unpacklo_epi8(lVec, zero), _mm_unpacklo_epi8(rVec, zero));
- __m128i hi = _mm_sub_epi16(_mm_unpackhi_epi8(lVec, zero), _mm_unpackhi_epi8(rVec, zero));
- resVec = _mm_add_epi32(resVec,
- _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi)));
- lhs += 16;
- rhs += 16;
- length -= 16;
- }
- alignas(16) ui32 res[4];
- _mm_store_si128((__m128i*)res, resVec);
- ui32 sum = res[0] + res[1] + res[2] + res[3];
- for (int i = 0; i < length; ++i) {
- sum += Sqr(static_cast<i32>(lhs[i]) - static_cast<i32>(rhs[i]));
- }
- return sum;
- }
- float L2SqrDistance(const float* lhs, const float* rhs, int length) {
- __m128 sum = _mm_setzero_ps();
- while (length >= 4) {
- __m128 a = _mm_loadu_ps(lhs);
- __m128 b = _mm_loadu_ps(rhs);
- __m128 delta = _mm_sub_ps(a, b);
- sum = _mm_add_ps(sum, _mm_mul_ps(delta, delta));
- length -= 4;
- rhs += 4;
- lhs += 4;
- }
- alignas(16) float res[4];
- _mm_store_ps(res, sum);
- while (length--)
- res[0] += Sqr(*rhs++ - *lhs++);
- return res[0] + res[1] + res[2] + res[3];
- }
- double L2SqrDistance(const double* lhs, const double* rhs, int length) {
- __m128d sum = _mm_setzero_pd();
- while (length >= 2) {
- __m128d a = _mm_loadu_pd(lhs);
- __m128d b = _mm_loadu_pd(rhs);
- __m128d delta = _mm_sub_pd(a, b);
- sum = _mm_add_pd(sum, _mm_mul_pd(delta, delta));
- length -= 2;
- rhs += 2;
- lhs += 2;
- }
- alignas(16) double res[2];
- _mm_store_pd(res, sum);
- while (length--)
- res[0] += Sqr(*rhs++ - *lhs++);
- return res[0] + res[1];
- }
- ui64 L2SqrDistance(const i32* lhs, const i32* rhs, int length) {
- __m128i zero = _mm_setzero_si128();
- __m128i res = zero;
- while (length >= 4) {
- __m128i a = _mm_loadu_si128((const __m128i*)lhs);
- __m128i b = _mm_loadu_si128((const __m128i*)rhs);
- #ifdef _sse4_1_
- // In SSE4.1 si32*si32->si64 is available, so we may do just (a-b)*(a-b) not caring about (a-b) sign
- a = _mm_sub_epi32(a, b);
- res = _mm_add_epi64(_mm_mul_epi32(a, a), res);
- a = _mm_alignr_epi8(a, a, 4);
- res = _mm_add_epi64(_mm_mul_epi32(a, a), res);
- #else
- __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0;
- __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b)
- b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a)
- a = _mm_or_si128(a2, b); // a = abs(a - b)
- a2 = _mm_unpackhi_epi32(a, zero);
- res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res);
- a2 = _mm_unpacklo_epi32(a, zero);
- res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res);
- #endif
- rhs += 4;
- lhs += 4;
- length -= 4;
- }
- alignas(16) ui64 r[2];
- _mm_store_si128((__m128i*)r, res);
- ui64 sum = r[0] + r[1];
- while (length) {
- sum += SqrDelta<ui64, i32>(lhs[0], rhs[0]);
- ++lhs;
- ++rhs;
- --length;
- }
- return sum;
- }
- ui64 L2SqrDistance(const ui32* lhs, const ui32* rhs, int length) {
- __m128i zero = _mm_setzero_si128();
- __m128i shift = _mm_set1_epi32(0x80000000);
- __m128i res = zero;
- while (length >= 4) {
- __m128i a = _mm_add_epi32(_mm_loadu_si128((const __m128i*)lhs), shift);
- __m128i b = _mm_add_epi32(_mm_loadu_si128((const __m128i*)rhs), shift);
- __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0;
- __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b)
- b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a)
- a = _mm_or_si128(a2, b); // a = abs(a - b)
- #ifdef _sse4_1_
- res = _mm_add_epi64(_mm_mul_epu32(a, a), res);
- a = _mm_alignr_epi8(a, a, 4);
- res = _mm_add_epi64(_mm_mul_epu32(a, a), res);
- #else
- a2 = _mm_unpackhi_epi32(a, zero);
- res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res);
- a2 = _mm_unpacklo_epi32(a, zero);
- res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res);
- #endif
- rhs += 4;
- lhs += 4;
- length -= 4;
- }
- alignas(16) ui64 r[2];
- _mm_store_si128((__m128i*)r, res);
- ui64 sum = r[0] + r[1];
- while (length) {
- sum += SqrDelta<ui64, ui32>(lhs[0], rhs[0]);
- ++lhs;
- ++rhs;
- --length;
- }
- return sum;
- }
- ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) {
- const __m128i zero = _mm_setzero_si128();
- __m128i resVec1 = zero;
- __m128i resVec2 = zero;
- while (length >= 16) {
- __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
- __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
- __m128i lVec1 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_1);
- __m128i lVec2 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_2);
- __m128i rVec1 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_1);
- __m128i rVec2 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_2);
- // We will think about this vectors as about i16.
- __m128i lo1 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec1, zero), _mm_unpacklo_epi8(rVec1, zero));
- __m128i hi1 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec1, zero), _mm_unpackhi_epi8(rVec1, zero));
- __m128i lo2 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec2, zero), _mm_unpacklo_epi8(rVec2, zero));
- __m128i hi2 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec2, zero), _mm_unpackhi_epi8(rVec2, zero));
- resVec1 = _mm_add_epi32(resVec1, _mm_add_epi32(_mm_madd_epi16(lo1, lo1), _mm_madd_epi16(hi1, hi1)));
- resVec2 = _mm_add_epi32(resVec2, _mm_add_epi32(_mm_madd_epi16(lo2, lo2), _mm_madd_epi16(hi2, hi2)));
- lhs += 16;
- rhs += 16;
- length -= 16;
- }
- alignas(16) ui32 res[4];
- _mm_store_si128((__m128i*)res, resVec1);
- ui32 sum = res[0] + res[1] + res[2] + res[3];
- _mm_store_si128((__m128i*)res, resVec2);
- sum += (res[0] + res[1] + res[2] + res[3]) >> 8;
- for (int i = 0; i < length; ++i) {
- sum += Sqr(static_cast<i32>(lhs[i] & 0x0f) - static_cast<i32>(rhs[i] & 0x0f));
- sum += Sqr(static_cast<i32>(lhs[i] & 0xf0) - static_cast<i32>(rhs[i] & 0xf0)) >> 8;
- }
- return sum;
- }
- #else /* !ARCADIA_SSE */
- ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) {
- return L2SqrDistanceImpl<ui32, i8>(lhs, rhs, length);
- }
- ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) {
- return L2SqrDistanceImpl<ui32, ui8>(lhs, rhs, length);
- }
- ui64 L2SqrDistance(const i32* a, const i32* b, int length) {
- return L2SqrDistanceImpl2<ui64, i32>(a, b, length);
- }
- ui64 L2SqrDistance(const ui32* a, const ui32* b, int length) {
- return L2SqrDistanceImpl2<ui64, ui32>(a, b, length);
- }
- float L2SqrDistance(const float* a, const float* b, int length) {
- return L2SqrDistanceImpl4<float, float>(a, b, length);
- }
- double L2SqrDistance(const double* a, const double* b, int length) {
- return L2SqrDistanceImpl2<double, double>(a, b, length);
- }
- ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) {
- return L2SqrDistanceImplUI4(lhs, rhs, length);
- }
- #endif /* ARCADIA_SSE */
- ui32 L2SqrDistanceSlow(const i8* lhs, const i8* rhs, int length) {
- return L2SqrDistanceImpl<ui32, i8>(lhs, rhs, length);
- }
- ui32 L2SqrDistanceSlow(const ui8* lhs, const ui8* rhs, int length) {
- return L2SqrDistanceImpl<ui32, ui8>(lhs, rhs, length);
- }
- ui64 L2SqrDistanceSlow(const i32* a, const i32* b, int length) {
- return L2SqrDistanceImpl2<ui64, i32>(a, b, length);
- }
- ui64 L2SqrDistanceSlow(const ui32* a, const ui32* b, int length) {
- return L2SqrDistanceImpl2<ui64, ui32>(a, b, length);
- }
- float L2SqrDistanceSlow(const float* a, const float* b, int length) {
- return L2SqrDistanceImpl4<float, float>(a, b, length);
- }
- double L2SqrDistanceSlow(const double* a, const double* b, int length) {
- return L2SqrDistanceImpl2<double, double>(a, b, length);
- }
- ui32 L2SqrDistanceUI4Slow(const ui8* lhs, const ui8* rhs, int length) {
- return L2SqrDistanceImplUI4(lhs, rhs, length);
- }
|