dot_product_sse.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #include "dot_product_sse.h"
  2. #include <library/cpp/sse/sse.h>
  3. #include <util/system/platform.h>
  4. #include <util/system/compiler.h>
  5. #ifdef ARCADIA_SSE
  6. i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept {
  7. const __m128i zero = _mm_setzero_si128();
  8. __m128i resVec = zero;
  9. while (length >= 16) {
  10. __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
  11. __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
  12. #ifdef _sse4_1_
  13. __m128i lLo = _mm_cvtepi8_epi16(lVec);
  14. __m128i rLo = _mm_cvtepi8_epi16(rVec);
  15. __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8));
  16. __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8));
  17. #else
  18. __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8);
  19. __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8);
  20. __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8);
  21. __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8);
  22. #endif
  23. resVec = _mm_add_epi32(resVec,
  24. _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
  25. lhs += 16;
  26. rhs += 16;
  27. length -= 16;
  28. }
  29. alignas(16) i32 res[4];
  30. _mm_store_si128((__m128i*)res, resVec);
  31. i32 sum = res[0] + res[1] + res[2] + res[3];
  32. for (size_t i = 0; i < length; ++i) {
  33. sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
  34. }
  35. return sum;
  36. }
  37. ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
  38. const __m128i zero = _mm_setzero_si128();
  39. __m128i resVec = zero;
  40. while (length >= 16) {
  41. __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
  42. __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
  43. __m128i lLo = _mm_unpacklo_epi8(lVec, zero);
  44. __m128i rLo = _mm_unpacklo_epi8(rVec, zero);
  45. __m128i lHi = _mm_unpackhi_epi8(lVec, zero);
  46. __m128i rHi = _mm_unpackhi_epi8(rVec, zero);
  47. resVec = _mm_add_epi32(resVec,
  48. _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
  49. lhs += 16;
  50. rhs += 16;
  51. length -= 16;
  52. }
  53. alignas(16) i32 res[4];
  54. _mm_store_si128((__m128i*)res, resVec);
  55. i32 sum = res[0] + res[1] + res[2] + res[3];
  56. for (size_t i = 0; i < length; ++i) {
  57. sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
  58. }
  59. return static_cast<ui32>(sum);
  60. }
  61. #ifdef _sse4_1_
  62. i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
  63. __m128i zero = _mm_setzero_si128();
  64. __m128i res = zero;
  65. while (length >= 4) {
  66. __m128i a = _mm_loadu_si128((const __m128i*)lhs);
  67. __m128i b = _mm_loadu_si128((const __m128i*)rhs);
  68. res = _mm_add_epi64(_mm_mul_epi32(a, b), res); // This is lower parts multiplication
  69. a = _mm_alignr_epi8(a, a, 4);
  70. b = _mm_alignr_epi8(b, b, 4);
  71. res = _mm_add_epi64(_mm_mul_epi32(a, b), res);
  72. rhs += 4;
  73. lhs += 4;
  74. length -= 4;
  75. }
  76. alignas(16) i64 r[2];
  77. _mm_store_si128((__m128i*)r, res);
  78. i64 sum = r[0] + r[1];
  79. for (size_t i = 0; i < length; ++i) {
  80. sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]);
  81. }
  82. return sum;
  83. }
  84. #else
  85. #include "dot_product_simple.h"
  86. i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
  87. return DotProductSimple(lhs, rhs, length);
  88. }
  89. #endif
  90. float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept {
  91. __m128 sum1 = _mm_setzero_ps();
  92. __m128 sum2 = _mm_setzero_ps();
  93. __m128 a1, b1, a2, b2, m1, m2;
  94. while (length >= 8) {
  95. a1 = _mm_loadu_ps(lhs);
  96. b1 = _mm_loadu_ps(rhs);
  97. m1 = _mm_mul_ps(a1, b1);
  98. a2 = _mm_loadu_ps(lhs + 4);
  99. sum1 = _mm_add_ps(sum1, m1);
  100. b2 = _mm_loadu_ps(rhs + 4);
  101. m2 = _mm_mul_ps(a2, b2);
  102. sum2 = _mm_add_ps(sum2, m2);
  103. length -= 8;
  104. lhs += 8;
  105. rhs += 8;
  106. }
  107. if (length >= 4) {
  108. a1 = _mm_loadu_ps(lhs);
  109. b1 = _mm_loadu_ps(rhs);
  110. sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
  111. length -= 4;
  112. lhs += 4;
  113. rhs += 4;
  114. }
  115. sum1 = _mm_add_ps(sum1, sum2);
  116. if (length) {
  117. switch (length) {
  118. case 3:
  119. a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
  120. b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
  121. break;
  122. case 2:
  123. a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
  124. b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
  125. break;
  126. case 1:
  127. a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
  128. b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
  129. break;
  130. default:
  131. Y_UNREACHABLE();
  132. }
  133. sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
  134. }
  135. alignas(16) float res[4];
  136. _mm_store_ps(res, sum1);
  137. return res[0] + res[1] + res[2] + res[3];
  138. }
  139. double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept {
  140. __m128d sum1 = _mm_setzero_pd();
  141. __m128d sum2 = _mm_setzero_pd();
  142. __m128d a1, b1, a2, b2;
  143. while (length >= 4) {
  144. a1 = _mm_loadu_pd(lhs);
  145. b1 = _mm_loadu_pd(rhs);
  146. sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
  147. a2 = _mm_loadu_pd(lhs + 2);
  148. b2 = _mm_loadu_pd(rhs + 2);
  149. sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2));
  150. length -= 4;
  151. lhs += 4;
  152. rhs += 4;
  153. }
  154. if (length >= 2) {
  155. a1 = _mm_loadu_pd(lhs);
  156. b1 = _mm_loadu_pd(rhs);
  157. sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
  158. length -= 2;
  159. lhs += 2;
  160. rhs += 2;
  161. }
  162. sum1 = _mm_add_pd(sum1, sum2);
  163. if (length > 0) {
  164. a1 = _mm_set_pd(lhs[0], 0.0);
  165. b1 = _mm_set_pd(rhs[0], 0.0);
  166. sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
  167. }
  168. alignas(16) double res[2];
  169. _mm_store_pd(res, sum1);
  170. return res[0] + res[1];
  171. }
  172. #endif // ARCADIA_SSE