l2_distance.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. #pragma once
  2. #include <util/system/types.h>
  3. #include <util/generic/ymath.h>
  4. #include <cmath>
  5. namespace NPrivate {
  6. namespace NL2Distance {
  7. template <typename Number>
  8. inline Number L2DistanceSqrt(Number a) {
  9. return std::sqrt(a);
  10. }
  11. template <>
  12. inline ui64 L2DistanceSqrt(ui64 a) {
  13. // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_.28base_2.29
  14. ui64 res = 0;
  15. ui64 bit = static_cast<ui64>(1) << (sizeof(ui64) * 8 - 2);
  16. while (bit > a)
  17. bit >>= 2;
  18. while (bit != 0) {
  19. if (a >= res + bit) {
  20. a -= (res + bit);
  21. res = (res >> 1) + bit;
  22. } else {
  23. res >>= 1;
  24. }
  25. bit >>= 2;
  26. }
  27. return res;
  28. }
  29. template <>
  30. inline ui32 L2DistanceSqrt(ui32 a) {
  31. return L2DistanceSqrt<ui64>(a);
  32. }
  33. // Special class to match argument type and result type.
  34. template <typename Arg>
  35. class TMatchArgumentResult {
  36. public:
  37. using TResult = Arg;
  38. };
  39. template <>
  40. class TMatchArgumentResult<i8> {
  41. public:
  42. using TResult = ui32;
  43. };
  44. template <>
  45. class TMatchArgumentResult<ui8> {
  46. public:
  47. using TResult = ui32;
  48. };
  49. template <>
  50. class TMatchArgumentResult<i32> {
  51. public:
  52. using TResult = ui64;
  53. };
  54. template <>
  55. class TMatchArgumentResult<ui32> {
  56. public:
  57. using TResult = ui64;
  58. };
  59. }
  60. }
  61. /**
  62. * sqr(l2_distance) = sum((a[i]-b[i])^2)
  63. * If target system does not support SSE2 Slow functions are used automatically.
  64. */
  65. ui32 L2SqrDistance(const i8* a, const i8* b, int cnt);
  66. ui32 L2SqrDistance(const ui8* a, const ui8* b, int cnt);
  67. ui64 L2SqrDistance(const i32* a, const i32* b, int length);
  68. ui64 L2SqrDistance(const ui32* a, const ui32* b, int length);
  69. float L2SqrDistance(const float* a, const float* b, int length);
  70. double L2SqrDistance(const double* a, const double* b, int length);
  71. ui32 L2SqrDistanceUI4(const ui8* a, const ui8* b, int cnt);
  72. ui32 L2SqrDistanceSlow(const i8* a, const i8* b, int cnt);
  73. ui32 L2SqrDistanceSlow(const ui8* a, const ui8* b, int cnt);
  74. ui64 L2SqrDistanceSlow(const i32* a, const i32* b, int length);
  75. ui64 L2SqrDistanceSlow(const ui32* a, const ui32* b, int length);
  76. float L2SqrDistanceSlow(const float* a, const float* b, int length);
  77. double L2SqrDistanceSlow(const double* a, const double* b, int length);
  78. ui32 L2SqrDistanceUI4Slow(const ui8* a, const ui8* b, int cnt);
  79. /**
  80. * L2 distance = sqrt(sum((a[i]-b[i])^2))
  81. */
  82. template <typename Number, typename Result = typename NPrivate::NL2Distance::TMatchArgumentResult<Number>::TResult>
  83. inline Result L2Distance(const Number* a, const Number* b, int cnt) {
  84. return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistance(a, b, cnt));
  85. }
  86. template <typename Number, typename Result = typename NPrivate::NL2Distance::TMatchArgumentResult<Number>::TResult>
  87. inline Result L2DistanceSlow(const Number* a, const Number* b, int cnt) {
  88. return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistanceSlow(a, b, cnt));
  89. }
  90. namespace NL2Distance {
  91. // You can use this structures as template function arguments.
  92. template <typename T>
  93. struct TL2Distance {
  94. using TResult = decltype(L2Distance(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
  95. inline TResult operator()(const T* a, const T* b, int length) const {
  96. return L2Distance(a, b, length);
  97. }
  98. };
  99. struct TL2DistanceUI4 {
  100. using TResult = ui32;
  101. inline TResult operator()(const ui8* a, const ui8* b, int lengtInBytes) const {
  102. return NPrivate::NL2Distance::L2DistanceSqrt(L2SqrDistanceUI4(a, b, lengtInBytes));
  103. }
  104. };
  105. template <typename T>
  106. struct TL2SqrDistance {
  107. using TResult = decltype(L2SqrDistance(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
  108. inline TResult operator()(const T* a, const T* b, int length) const {
  109. return L2SqrDistance(a, b, length);
  110. }
  111. };
  112. struct TL2SqrDistanceUI4 {
  113. using TResult = ui32;
  114. inline TResult operator()(const ui8* a, const ui8* b, int lengtInBytes) const {
  115. return L2SqrDistanceUI4(a, b, lengtInBytes);
  116. }
  117. };
  118. }