dot_product.cpp 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. #include "dot_product.h"
  2. #include "dot_product_sse.h"
  3. #include "dot_product_avx2.h"
  4. #include "dot_product_simple.h"
  5. #include <library/cpp/sse/sse.h>
  6. #include <library/cpp/testing/common/env.h>
  7. #include <util/system/compiler.h>
  8. #include <util/generic/utility.h>
  9. #include <util/system/cpu_id.h>
  10. #include <util/system/env.h>
  11. namespace NDotProductImpl {
  12. i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple;
  13. ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple;
  14. i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple;
  15. float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple;
  16. double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple;
  17. namespace {
  18. [[maybe_unused]] const int _ = [] {
  19. if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) {
  20. DotProductI8Impl = &DotProductAvx2;
  21. DotProductUi8Impl = &DotProductAvx2;
  22. DotProductI32Impl = &DotProductAvx2;
  23. DotProductFloatImpl = &DotProductAvx2;
  24. DotProductDoubleImpl = &DotProductAvx2;
  25. } else {
  26. #ifdef ARCADIA_SSE
  27. DotProductI8Impl = &DotProductSse;
  28. DotProductUi8Impl = &DotProductSse;
  29. DotProductI32Impl = &DotProductSse;
  30. DotProductFloatImpl = &DotProductSse;
  31. DotProductDoubleImpl = &DotProductSse;
  32. #endif
  33. }
  34. return 0;
  35. }();
  36. }
  37. }
  38. #ifdef ARCADIA_SSE
  39. float L2NormSquared(const float* v, size_t length) noexcept {
  40. __m128 sum1 = _mm_setzero_ps();
  41. __m128 sum2 = _mm_setzero_ps();
  42. __m128 a1, a2, m1, m2;
  43. while (length >= 8) {
  44. a1 = _mm_loadu_ps(v);
  45. m1 = _mm_mul_ps(a1, a1);
  46. a2 = _mm_loadu_ps(v + 4);
  47. sum1 = _mm_add_ps(sum1, m1);
  48. m2 = _mm_mul_ps(a2, a2);
  49. sum2 = _mm_add_ps(sum2, m2);
  50. length -= 8;
  51. v += 8;
  52. }
  53. if (length >= 4) {
  54. a1 = _mm_loadu_ps(v);
  55. sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
  56. length -= 4;
  57. v += 4;
  58. }
  59. sum1 = _mm_add_ps(sum1, sum2);
  60. if (length) {
  61. switch (length) {
  62. case 3:
  63. a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]);
  64. break;
  65. case 2:
  66. a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]);
  67. break;
  68. case 1:
  69. a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]);
  70. break;
  71. default:
  72. Y_UNREACHABLE();
  73. }
  74. sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
  75. }
  76. alignas(16) float res[4];
  77. _mm_store_ps(res, sum1);
  78. return res[0] + res[1] + res[2] + res[3];
  79. }
  80. template <bool computeLL, bool computeLR, bool computeRR>
  81. Y_FORCE_INLINE
  82. static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) {
  83. if constexpr (computeLL) {
  84. sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a));
  85. }
  86. if constexpr (computeLR) {
  87. sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b));
  88. }
  89. if constexpr (computeRR) {
  90. sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b));
  91. }
  92. }
  93. template <bool computeLL, bool computeLR, bool computeRR>
  94. static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept {
  95. __m128 sumLL1 = _mm_setzero_ps();
  96. __m128 sumLR1 = _mm_setzero_ps();
  97. __m128 sumRR1 = _mm_setzero_ps();
  98. __m128 sumLL2 = _mm_setzero_ps();
  99. __m128 sumLR2 = _mm_setzero_ps();
  100. __m128 sumRR2 = _mm_setzero_ps();
  101. while (length >= 8) {
  102. TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
  103. TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4));
  104. length -= 8;
  105. lhs += 8;
  106. rhs += 8;
  107. }
  108. if (length >= 4) {
  109. TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
  110. length -= 4;
  111. lhs += 4;
  112. rhs += 4;
  113. }
  114. if constexpr (computeLL) {
  115. sumLL1 = _mm_add_ps(sumLL1, sumLL2);
  116. }
  117. if constexpr (computeLR) {
  118. sumLR1 = _mm_add_ps(sumLR1, sumLR2);
  119. }
  120. if constexpr (computeRR) {
  121. sumRR1 = _mm_add_ps(sumRR1, sumRR2);
  122. }
  123. if (length) {
  124. __m128 a, b;
  125. switch (length) {
  126. case 3:
  127. a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
  128. b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
  129. break;
  130. case 2:
  131. a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
  132. b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
  133. break;
  134. case 1:
  135. a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
  136. b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
  137. break;
  138. default:
  139. Y_UNREACHABLE();
  140. }
  141. TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b);
  142. }
  143. __m128 t0 = sumLL1;
  144. __m128 t1 = sumLR1;
  145. __m128 t2 = sumRR1;
  146. __m128 t3 = _mm_setzero_ps();
  147. _MM_TRANSPOSE4_PS(t0, t1, t2, t3);
  148. t0 = _mm_add_ps(t0, t1);
  149. t0 = _mm_add_ps(t0, t2);
  150. t0 = _mm_add_ps(t0, t3);
  151. alignas(16) float res[4];
  152. _mm_store_ps(res, t0);
  153. TTriWayDotProduct<float> result{res[0], res[1], res[2]};
  154. static constexpr const TTriWayDotProduct<float> def;
  155. // fill skipped fields with default values
  156. if constexpr (!computeLL) {
  157. result.LL = def.LL;
  158. }
  159. if constexpr (!computeLR) {
  160. result.LR = def.LR;
  161. }
  162. if constexpr (!computeRR) {
  163. result.RR = def.RR;
  164. }
  165. return result;
  166. }
  167. TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
  168. mask &= 0b111;
  169. if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors
  170. return TriWayDotProductImpl<true, true, true>(lhs, rhs, length);
  171. } else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector
  172. const bool computeLL = (mask == 0b110);
  173. if (!computeLL) {
  174. DoSwap(lhs, rhs);
  175. }
  176. auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length);
  177. if (!computeLL) {
  178. DoSwap(result.LL, result.RR);
  179. }
  180. return result;
  181. } else {
  182. // dispatch unlikely & sparse cases
  183. TTriWayDotProduct<float> result{};
  184. switch(mask) {
  185. case 0b000:
  186. break;
  187. case 0b100:
  188. result.LL = L2NormSquared(lhs, length);
  189. break;
  190. case 0b010:
  191. result.LR = DotProduct(lhs, rhs, length);
  192. break;
  193. case 0b001:
  194. result.RR = L2NormSquared(rhs, length);
  195. break;
  196. case 0b101:
  197. result.LL = L2NormSquared(lhs, length);
  198. result.RR = L2NormSquared(rhs, length);
  199. break;
  200. default:
  201. Y_UNREACHABLE();
  202. }
  203. return result;
  204. }
  205. }
  206. #else
  207. float L2NormSquared(const float* v, size_t length) noexcept {
  208. return DotProduct(v, v, length);
  209. }
  210. TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
  211. TTriWayDotProduct<float> result;
  212. if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) {
  213. result.LL = L2NormSquared(lhs, length);
  214. }
  215. if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) {
  216. result.LR = DotProduct(lhs, rhs, length);
  217. }
  218. if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) {
  219. result.RR = L2NormSquared(rhs, length);
  220. }
  221. return result;
  222. }
  223. #endif // ARCADIA_SSE
  224. namespace NDotProduct {
  225. void DisableAvx2() {
  226. #ifdef ARCADIA_SSE
  227. NDotProductImpl::DotProductI8Impl = &DotProductSse;
  228. NDotProductImpl::DotProductUi8Impl = &DotProductSse;
  229. NDotProductImpl::DotProductI32Impl = &DotProductSse;
  230. NDotProductImpl::DotProductFloatImpl = &DotProductSse;
  231. NDotProductImpl::DotProductDoubleImpl = &DotProductSse;
  232. #else
  233. NDotProductImpl::DotProductI8Impl = &DotProductSimple;
  234. NDotProductImpl::DotProductUi8Impl = &DotProductSimple;
  235. NDotProductImpl::DotProductI32Impl = &DotProductSimple;
  236. NDotProductImpl::DotProductFloatImpl = &DotProductSimple;
  237. NDotProductImpl::DotProductDoubleImpl = &DotProductSimple;
  238. #endif
  239. }
  240. }