#include "l2_distance.h" #include #include #include template inline Result SqrDelta(Number a, Number b) { Result diff = a < b ? b - a : a - b; return diff * diff; } template inline Result L2SqrDistanceImpl(const Number* a, const Number* b, int length) { Result res = 0; for (int i = 0; i < length; i++) { res += SqrDelta(a[i], b[i]); } return res; } template inline Result L2SqrDistanceImpl2(const Number* a, const Number* b, int length) { Result s0 = 0; Result s1 = 0; while (length >= 2) { s0 += SqrDelta(a[0], b[0]); s1 += SqrDelta(a[1], b[1]); a += 2; b += 2; length -= 2; } while (length--) s0 += SqrDelta(*a++, *b++); return s0 + s1; } template inline Result L2SqrDistanceImpl4(const Number* a, const Number* b, int length) { Result s0 = 0; Result s1 = 0; Result s2 = 0; Result s3 = 0; while (length >= 4) { s0 += SqrDelta(a[0], b[0]); s1 += SqrDelta(a[1], b[1]); s2 += SqrDelta(a[2], b[2]); s3 += SqrDelta(a[3], b[3]); a += 4; b += 4; length -= 4; } while (length--) s0 += SqrDelta(*a++, *b++); return s0 + s1 + s2 + s3; } inline ui32 L2SqrDistanceImplUI4(const ui8* a, const ui8* b, int length) { ui32 res = 0; for (int i = 0; i < length; i++) { res += SqrDelta(a[i] & 0x0f, b[i] & 0x0f); res += SqrDelta(a[i] & 0xf0, b[i] & 0xf0) >> 8; } return res; } #ifdef ARCADIA_SSE namespace NL2Distance { static const __m128i MASK_UI4_1 = _mm_set_epi8(0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f, 0x0f); static const __m128i MASK_UI4_2 = _mm_set_epi8(0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0, 0xf0); } ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) { const __m128i zero = _mm_setzero_si128(); __m128i resVec = zero; while (length >= 16) { __m128i vec = _mm_subs_epi8(_mm_loadu_si128((const __m128i*)lhs), _mm_loadu_si128((const __m128i*)rhs)); #ifdef _sse4_1_ __m128i lo = _mm_cvtepi8_epi16(vec); __m128i hi = _mm_cvtepi8_epi16(_mm_alignr_epi8(vec, vec, 8)); #else __m128i lo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, vec), 8); __m128i hi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, vec), 8); #endif resVec = _mm_add_epi32(resVec, _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi))); lhs += 16; rhs += 16; length -= 16; } alignas(16) ui32 res[4]; _mm_store_si128((__m128i*)res, resVec); ui32 sum = res[0] + res[1] + res[2] + res[3]; for (int i = 0; i < length; ++i) { sum += Sqr(static_cast(lhs[i]) - static_cast(rhs[i])); } return sum; } ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) { const __m128i zero = _mm_setzero_si128(); __m128i resVec = zero; while (length >= 16) { __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); // We will think about this vectors as about i16. __m128i lo = _mm_sub_epi16(_mm_unpacklo_epi8(lVec, zero), _mm_unpacklo_epi8(rVec, zero)); __m128i hi = _mm_sub_epi16(_mm_unpackhi_epi8(lVec, zero), _mm_unpackhi_epi8(rVec, zero)); resVec = _mm_add_epi32(resVec, _mm_add_epi32(_mm_madd_epi16(lo, lo), _mm_madd_epi16(hi, hi))); lhs += 16; rhs += 16; length -= 16; } alignas(16) ui32 res[4]; _mm_store_si128((__m128i*)res, resVec); ui32 sum = res[0] + res[1] + res[2] + res[3]; for (int i = 0; i < length; ++i) { sum += Sqr(static_cast(lhs[i]) - static_cast(rhs[i])); } return sum; } float L2SqrDistance(const float* lhs, const float* rhs, int length) { __m128 sum = _mm_setzero_ps(); while (length >= 4) { __m128 a = _mm_loadu_ps(lhs); __m128 b = _mm_loadu_ps(rhs); __m128 delta = _mm_sub_ps(a, b); sum = _mm_add_ps(sum, _mm_mul_ps(delta, delta)); length -= 4; rhs += 4; lhs += 4; } alignas(16) float res[4]; _mm_store_ps(res, sum); while (length--) res[0] += Sqr(*rhs++ - *lhs++); return res[0] + res[1] + res[2] + res[3]; } double L2SqrDistance(const double* lhs, const double* rhs, int length) { __m128d sum = _mm_setzero_pd(); while (length >= 2) { __m128d a = _mm_loadu_pd(lhs); __m128d b = _mm_loadu_pd(rhs); __m128d delta = _mm_sub_pd(a, b); sum = _mm_add_pd(sum, _mm_mul_pd(delta, delta)); length -= 2; rhs += 2; lhs += 2; } alignas(16) double res[2]; _mm_store_pd(res, sum); while (length--) res[0] += Sqr(*rhs++ - *lhs++); return res[0] + res[1]; } ui64 L2SqrDistance(const i32* lhs, const i32* rhs, int length) { __m128i zero = _mm_setzero_si128(); __m128i res = zero; while (length >= 4) { __m128i a = _mm_loadu_si128((const __m128i*)lhs); __m128i b = _mm_loadu_si128((const __m128i*)rhs); #ifdef _sse4_1_ // In SSE4.1 si32*si32->si64 is available, so we may do just (a-b)*(a-b) not caring about (a-b) sign a = _mm_sub_epi32(a, b); res = _mm_add_epi64(_mm_mul_epi32(a, a), res); a = _mm_alignr_epi8(a, a, 4); res = _mm_add_epi64(_mm_mul_epi32(a, a), res); #else __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0; __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b) b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a) a = _mm_or_si128(a2, b); // a = abs(a - b) a2 = _mm_unpackhi_epi32(a, zero); res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); a2 = _mm_unpacklo_epi32(a, zero); res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); #endif rhs += 4; lhs += 4; length -= 4; } alignas(16) ui64 r[2]; _mm_store_si128((__m128i*)r, res); ui64 sum = r[0] + r[1]; while (length) { sum += SqrDelta(lhs[0], rhs[0]); ++lhs; ++rhs; --length; } return sum; } ui64 L2SqrDistance(const ui32* lhs, const ui32* rhs, int length) { __m128i zero = _mm_setzero_si128(); __m128i shift = _mm_set1_epi32(0x80000000); __m128i res = zero; while (length >= 4) { __m128i a = _mm_add_epi32(_mm_loadu_si128((const __m128i*)lhs), shift); __m128i b = _mm_add_epi32(_mm_loadu_si128((const __m128i*)rhs), shift); __m128i mask = _mm_cmpgt_epi32(a, b); // mask = a > b? 0xffffffff: 0; __m128i a2 = _mm_sub_epi32(_mm_and_si128(mask, a), _mm_and_si128(mask, b)); // a2 = (a & mask) - (b & mask) (for a > b) b = _mm_sub_epi32(_mm_andnot_si128(mask, b), _mm_andnot_si128(mask, a)); // b = (b & ~mask) - (a & ~mask) (for b > a) a = _mm_or_si128(a2, b); // a = abs(a - b) #ifdef _sse4_1_ res = _mm_add_epi64(_mm_mul_epu32(a, a), res); a = _mm_alignr_epi8(a, a, 4); res = _mm_add_epi64(_mm_mul_epu32(a, a), res); #else a2 = _mm_unpackhi_epi32(a, zero); res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); a2 = _mm_unpacklo_epi32(a, zero); res = _mm_add_epi64(_mm_mul_epu32(a2, a2), res); #endif rhs += 4; lhs += 4; length -= 4; } alignas(16) ui64 r[2]; _mm_store_si128((__m128i*)r, res); ui64 sum = r[0] + r[1]; while (length) { sum += SqrDelta(lhs[0], rhs[0]); ++lhs; ++rhs; --length; } return sum; } ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) { const __m128i zero = _mm_setzero_si128(); __m128i resVec1 = zero; __m128i resVec2 = zero; while (length >= 16) { __m128i lVec = _mm_loadu_si128((const __m128i*)lhs); __m128i rVec = _mm_loadu_si128((const __m128i*)rhs); __m128i lVec1 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_1); __m128i lVec2 = _mm_and_si128(lVec, NL2Distance::MASK_UI4_2); __m128i rVec1 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_1); __m128i rVec2 = _mm_and_si128(rVec, NL2Distance::MASK_UI4_2); // We will think about this vectors as about i16. __m128i lo1 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec1, zero), _mm_unpacklo_epi8(rVec1, zero)); __m128i hi1 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec1, zero), _mm_unpackhi_epi8(rVec1, zero)); __m128i lo2 = _mm_sub_epi16(_mm_unpacklo_epi8(lVec2, zero), _mm_unpacklo_epi8(rVec2, zero)); __m128i hi2 = _mm_sub_epi16(_mm_unpackhi_epi8(lVec2, zero), _mm_unpackhi_epi8(rVec2, zero)); resVec1 = _mm_add_epi32(resVec1, _mm_add_epi32(_mm_madd_epi16(lo1, lo1), _mm_madd_epi16(hi1, hi1))); resVec2 = _mm_add_epi32(resVec2, _mm_add_epi32(_mm_madd_epi16(lo2, lo2), _mm_madd_epi16(hi2, hi2))); lhs += 16; rhs += 16; length -= 16; } alignas(16) ui32 res[4]; _mm_store_si128((__m128i*)res, resVec1); ui32 sum = res[0] + res[1] + res[2] + res[3]; _mm_store_si128((__m128i*)res, resVec2); sum += (res[0] + res[1] + res[2] + res[3]) >> 8; for (int i = 0; i < length; ++i) { sum += Sqr(static_cast(lhs[i] & 0x0f) - static_cast(rhs[i] & 0x0f)); sum += Sqr(static_cast(lhs[i] & 0xf0) - static_cast(rhs[i] & 0xf0)) >> 8; } return sum; } #else /* !ARCADIA_SSE */ ui32 L2SqrDistance(const i8* lhs, const i8* rhs, int length) { return L2SqrDistanceImpl(lhs, rhs, length); } ui32 L2SqrDistance(const ui8* lhs, const ui8* rhs, int length) { return L2SqrDistanceImpl(lhs, rhs, length); } ui64 L2SqrDistance(const i32* a, const i32* b, int length) { return L2SqrDistanceImpl2(a, b, length); } ui64 L2SqrDistance(const ui32* a, const ui32* b, int length) { return L2SqrDistanceImpl2(a, b, length); } float L2SqrDistance(const float* a, const float* b, int length) { return L2SqrDistanceImpl4(a, b, length); } double L2SqrDistance(const double* a, const double* b, int length) { return L2SqrDistanceImpl2(a, b, length); } ui32 L2SqrDistanceUI4(const ui8* lhs, const ui8* rhs, int length) { return L2SqrDistanceImplUI4(lhs, rhs, length); } #endif /* ARCADIA_SSE */ ui32 L2SqrDistanceSlow(const i8* lhs, const i8* rhs, int length) { return L2SqrDistanceImpl(lhs, rhs, length); } ui32 L2SqrDistanceSlow(const ui8* lhs, const ui8* rhs, int length) { return L2SqrDistanceImpl(lhs, rhs, length); } ui64 L2SqrDistanceSlow(const i32* a, const i32* b, int length) { return L2SqrDistanceImpl2(a, b, length); } ui64 L2SqrDistanceSlow(const ui32* a, const ui32* b, int length) { return L2SqrDistanceImpl2(a, b, length); } float L2SqrDistanceSlow(const float* a, const float* b, int length) { return L2SqrDistanceImpl4(a, b, length); } double L2SqrDistanceSlow(const double* a, const double* b, int length) { return L2SqrDistanceImpl2(a, b, length); } ui32 L2SqrDistanceUI4Slow(const ui8* lhs, const ui8* rhs, int length) { return L2SqrDistanceImplUI4(lhs, rhs, length); }