dot_product_simple.cpp 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #include "dot_product_simple.h"
  2. namespace {
  3. template <typename Res, typename Number>
  4. static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept {
  5. Res s0 = 0;
  6. Res s1 = 0;
  7. Res s2 = 0;
  8. Res s3 = 0;
  9. while (length >= 4) {
  10. s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]);
  11. s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]);
  12. s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]);
  13. s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]);
  14. lhs += 4;
  15. rhs += 4;
  16. length -= 4;
  17. }
  18. while (length--) {
  19. s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++);
  20. }
  21. return s0 + s1 + s2 + s3;
  22. }
  23. }
  24. float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept {
  25. return DotProductSimpleImpl<float, float>(lhs, rhs, length);
  26. }
  27. double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept {
  28. return DotProductSimpleImpl<double, double>(lhs, rhs, length);
  29. }
  30. ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept {
  31. ui32 res = 0;
  32. for (size_t i = 0; i < lengtInBytes; ++i) {
  33. res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f);
  34. res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8;
  35. }
  36. return res;
  37. }