dot_product.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #pragma once
  2. #include <util/system/types.h>
  3. #include <util/system/compiler.h>
  4. #include <numeric>
  5. /**
  6. * Dot product (Inner product or scalar product) implementation using SSE when possible.
  7. */
  8. namespace NDotProductImpl {
  9. extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept;
  10. extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
  11. extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept;
  12. extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept;
  13. extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept;
  14. }
  15. Y_PURE_FUNCTION
  16. inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept {
  17. return NDotProductImpl::DotProductI8Impl(lhs, rhs, length);
  18. }
  19. Y_PURE_FUNCTION
  20. inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
  21. return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length);
  22. }
  23. Y_PURE_FUNCTION
  24. inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept {
  25. return NDotProductImpl::DotProductI32Impl(lhs, rhs, length);
  26. }
  27. Y_PURE_FUNCTION
  28. inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept {
  29. return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length);
  30. }
  31. Y_PURE_FUNCTION
  32. inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept {
  33. return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length);
  34. }
  35. /**
  36. * Dot product to itself
  37. */
  38. Y_PURE_FUNCTION
  39. float L2NormSquared(const float* v, size_t length) noexcept;
  40. // TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct`
  41. // where `lhs == rhs` because it will save N load instructions.
  42. template <typename T>
  43. struct TTriWayDotProduct {
  44. T LL = 1;
  45. T LR = 0;
  46. T RR = 1;
  47. };
  48. enum class ETriWayDotProductComputeMask: unsigned {
  49. // basic
  50. LL = 0b100,
  51. LR = 0b010,
  52. RR = 0b001,
  53. // useful combinations
  54. All = 0b111,
  55. Left = 0b110, // skip computation of R·R
  56. Right = 0b011, // skip computation of L·L
  57. };
  58. Y_PURE_FUNCTION
  59. TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept;
  60. /**
  61. * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R
  62. */
  63. Y_PURE_FUNCTION
  64. static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept {
  65. return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask));
  66. }
  67. namespace NDotProduct {
  68. // Simpler wrapper allowing to use this functions as template argument.
  69. template <typename T>
  70. struct TDotProduct {
  71. using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
  72. Y_PURE_FUNCTION
  73. inline TResult operator()(const T* l, const T* r, size_t length) const {
  74. return DotProduct(l, r, length);
  75. }
  76. };
  77. void DisableAvx2();
  78. }