Browse Source

Publish library/cpp/dot_product

Publish pod_product to  https://github.com/ydb-platform/ydb

It has already been published to github: https://github.com/catboost/catboost/tree/master/library/cpp/dot_product
44150e7508881f4239c960f90320799b1b090072
azevaykin 11 months ago
parent
commit
ce034cd07e

+ 15 - 0
library/cpp/dot_product/README.md

@@ -0,0 +1,15 @@
+Библиотека для вычисления скалярного произведения векторов.
+=====================================================
+
+Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов.
+В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения
+можно посмотреть результаты бенчмарка.
+
+Типичное использование - замена кусков кода вроде:
+```
+for (int i = 0; i < len; i++)
+    dot_product += a[i] * b[i]);
+```
+на существенно более эффективный вызов ```DotProduct(a, b, len)```.
+
+Работает для типов i8, i32, float, double.

+ 274 - 0
library/cpp/dot_product/dot_product.cpp

@@ -0,0 +1,274 @@
+#include "dot_product.h"
+#include "dot_product_sse.h"
+#include "dot_product_avx2.h"
+#include "dot_product_simple.h"
+
+#include <library/cpp/sse/sse.h>
+#include <library/cpp/testing/common/env.h>
+#include <util/system/compiler.h>
+#include <util/generic/utility.h>
+#include <util/system/cpu_id.h>
+#include <util/system/env.h>
+
+namespace NDotProductImpl {
+    i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple;
+    ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple;
+    i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple;
+    float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple;
+    double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple;
+
+    namespace {
+        [[maybe_unused]] const int _ = [] {
+            if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) {
+                DotProductI8Impl = &DotProductAvx2;
+                DotProductUi8Impl = &DotProductAvx2;
+                DotProductI32Impl = &DotProductAvx2;
+                DotProductFloatImpl = &DotProductAvx2;
+                DotProductDoubleImpl = &DotProductAvx2;
+            } else {
+#ifdef ARCADIA_SSE
+                DotProductI8Impl = &DotProductSse;
+                DotProductUi8Impl = &DotProductSse;
+                DotProductI32Impl = &DotProductSse;
+                DotProductFloatImpl = &DotProductSse;
+                DotProductDoubleImpl = &DotProductSse;
+#endif
+            }
+            return 0;
+        }();
+    }
+}
+
+#ifdef ARCADIA_SSE
+float L2NormSquared(const float* v, size_t length) noexcept {
+    __m128 sum1 = _mm_setzero_ps();
+    __m128 sum2 = _mm_setzero_ps();
+    __m128 a1, a2, m1, m2;
+
+    while (length >= 8) {
+        a1 = _mm_loadu_ps(v);
+        m1 = _mm_mul_ps(a1, a1);
+
+        a2 = _mm_loadu_ps(v + 4);
+        sum1 = _mm_add_ps(sum1, m1);
+
+        m2 = _mm_mul_ps(a2, a2);
+        sum2 = _mm_add_ps(sum2, m2);
+
+        length -= 8;
+        v += 8;
+    }
+
+    if (length >= 4) {
+        a1 = _mm_loadu_ps(v);
+        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
+
+        length -= 4;
+        v += 4;
+    }
+
+    sum1 = _mm_add_ps(sum1, sum2);
+
+    if (length) {
+        switch (length) {
+            case 3:
+                a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]);
+                break;
+
+            case 2:
+                a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]);
+                break;
+
+            case 1:
+                a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]);
+                break;
+
+            default:
+                Y_UNREACHABLE();
+        }
+
+        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1));
+    }
+
+    alignas(16) float res[4];
+    _mm_store_ps(res, sum1);
+
+    return res[0] + res[1] + res[2] + res[3];
+}
+
+template <bool computeLL, bool computeLR, bool computeRR>
+Y_FORCE_INLINE
+static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) {
+    if constexpr (computeLL) {
+        sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a));
+    }
+    if constexpr (computeLR) {
+        sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b));
+    }
+    if constexpr (computeRR) {
+        sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b));
+    }
+}
+
+
+template <bool computeLL, bool computeLR, bool computeRR>
+static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept {
+    __m128 sumLL1 = _mm_setzero_ps();
+    __m128 sumLR1 = _mm_setzero_ps();
+    __m128 sumRR1 = _mm_setzero_ps();
+    __m128 sumLL2 = _mm_setzero_ps();
+    __m128 sumLR2 = _mm_setzero_ps();
+    __m128 sumRR2 = _mm_setzero_ps();
+
+    while (length >= 8) {
+        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
+        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4));
+        length -= 8;
+        lhs += 8;
+        rhs += 8;
+    }
+
+    if (length >= 4) {
+        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0));
+        length -= 4;
+        lhs += 4;
+        rhs += 4;
+    }
+
+    if constexpr (computeLL) {
+        sumLL1 = _mm_add_ps(sumLL1, sumLL2);
+    }
+    if constexpr (computeLR) {
+        sumLR1 = _mm_add_ps(sumLR1, sumLR2);
+    }
+    if constexpr (computeRR) {
+        sumRR1 = _mm_add_ps(sumRR1, sumRR2);
+    }
+
+    if (length) {
+        __m128 a, b;
+        switch (length) {
+            case 3:
+                a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
+                b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
+                break;
+            case 2:
+                a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
+                b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
+                break;
+            case 1:
+                a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
+                b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
+                break;
+            default:
+                Y_UNREACHABLE();
+        }
+        TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b);
+    }
+
+    __m128 t0 = sumLL1;
+    __m128 t1 = sumLR1;
+    __m128 t2 = sumRR1;
+    __m128 t3 = _mm_setzero_ps();
+    _MM_TRANSPOSE4_PS(t0, t1, t2, t3);
+    t0 = _mm_add_ps(t0, t1);
+    t0 = _mm_add_ps(t0, t2);
+    t0 = _mm_add_ps(t0, t3);
+
+    alignas(16) float res[4];
+    _mm_store_ps(res, t0);
+    TTriWayDotProduct<float> result{res[0], res[1], res[2]};
+    static constexpr const TTriWayDotProduct<float> def;
+    // fill skipped fields with default values
+    if constexpr (!computeLL) {
+        result.LL = def.LL;
+    }
+    if constexpr (!computeLR) {
+        result.LR = def.LR;
+    }
+    if constexpr (!computeRR) {
+        result.RR = def.RR;
+    }
+    return result;
+}
+
+
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
+    mask &= 0b111;
+    if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors
+        return TriWayDotProductImpl<true, true, true>(lhs, rhs, length);
+    } else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector
+        const bool computeLL = (mask == 0b110);
+        if (!computeLL) {
+            DoSwap(lhs, rhs);
+        }
+        auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length);
+        if (!computeLL) {
+            DoSwap(result.LL, result.RR);
+        }
+        return result;
+    } else {
+        // dispatch unlikely & sparse cases
+        TTriWayDotProduct<float> result{};
+        switch(mask) {
+            case 0b000:
+                break;
+            case 0b100:
+                result.LL = L2NormSquared(lhs, length);
+                break;
+            case 0b010:
+                result.LR = DotProduct(lhs, rhs, length);
+                break;
+            case 0b001:
+                result.RR = L2NormSquared(rhs, length);
+                break;
+            case 0b101:
+                result.LL = L2NormSquared(lhs, length);
+                result.RR = L2NormSquared(rhs, length);
+                break;
+            default:
+                Y_UNREACHABLE();
+        }
+        return result;
+    }
+}
+
+#else
+
+float L2NormSquared(const float* v, size_t length) noexcept {
+    return DotProduct(v, v, length);
+}
+
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept {
+    TTriWayDotProduct<float> result;
+    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) {
+        result.LL = L2NormSquared(lhs, length);
+    }
+    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) {
+        result.LR = DotProduct(lhs, rhs, length);
+    }
+    if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) {
+        result.RR = L2NormSquared(rhs, length);
+    }
+    return result;
+}
+
+#endif // ARCADIA_SSE
+
+namespace NDotProduct {
+    void DisableAvx2() {
+#ifdef ARCADIA_SSE
+        NDotProductImpl::DotProductI8Impl = &DotProductSse;
+        NDotProductImpl::DotProductUi8Impl = &DotProductSse;
+        NDotProductImpl::DotProductI32Impl = &DotProductSse;
+        NDotProductImpl::DotProductFloatImpl = &DotProductSse;
+        NDotProductImpl::DotProductDoubleImpl = &DotProductSse;
+#else
+        NDotProductImpl::DotProductI8Impl = &DotProductSimple;
+        NDotProductImpl::DotProductUi8Impl = &DotProductSimple;
+        NDotProductImpl::DotProductI32Impl = &DotProductSimple;
+        NDotProductImpl::DotProductFloatImpl = &DotProductSimple;
+        NDotProductImpl::DotProductDoubleImpl = &DotProductSimple;
+#endif
+    }
+}

+ 96 - 0
library/cpp/dot_product/dot_product.h

@@ -0,0 +1,96 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+#include <numeric>
+
+/**
+ * Dot product (Inner product or scalar product) implementation using SSE when possible.
+ */
+namespace NDotProductImpl {
+    extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept;
+    extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+    extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept;
+    extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept;
+    extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept;
+}
+
+Y_PURE_FUNCTION
+inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    return NDotProductImpl::DotProductI8Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    return NDotProductImpl::DotProductI32Impl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept {
+    return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length);
+}
+
+Y_PURE_FUNCTION
+inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept {
+    return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length);
+}
+
+/**
+ * Dot product to itself
+ */
+Y_PURE_FUNCTION
+float L2NormSquared(const float* v, size_t length) noexcept;
+
+// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct`
+// where `lhs == rhs` because it will save N load instructions.
+
+template <typename T>
+struct TTriWayDotProduct {
+    T LL = 1;
+    T LR = 0;
+    T RR = 1;
+};
+
+enum class ETriWayDotProductComputeMask: unsigned {
+    // basic
+    LL = 0b100,
+    LR = 0b010,
+    RR = 0b001,
+
+    // useful combinations
+    All = 0b111,
+    Left = 0b110, // skip computation of R·R
+    Right = 0b011, // skip computation of L·L
+};
+
+Y_PURE_FUNCTION
+TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept;
+
+/**
+ * For two vectors L and R computes 3 dot-products: L·L, L·R, R·R
+ */
+Y_PURE_FUNCTION
+static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept {
+    return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask));
+}
+
+namespace NDotProduct {
+    // Simpler wrapper allowing to use this functions as template argument.
+    template <typename T>
+    struct TDotProduct {
+        using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0));
+        Y_PURE_FUNCTION
+        inline TResult operator()(const T* l, const T* r, size_t length) const {
+            return DotProduct(l, r, length);
+        }
+    };
+
+    void DisableAvx2();
+}
+

+ 344 - 0
library/cpp/dot_product/dot_product_avx2.cpp

@@ -0,0 +1,344 @@
+#include "dot_product_avx2.h"
+#include "dot_product_simple.h"
+#include "dot_product_sse.h"
+
+#if defined(_avx2_) && defined(_fma_)
+
+#include <util/system/platform.h>
+#include <util/system/compiler.h>
+#include <util/generic/utility.h>
+
+#include <immintrin.h>
+
+namespace {
+    constexpr i64 Bits(int n) {
+        return i64(-1) ^ ((i64(1) << (64 - n)) - 1);
+    }
+
+    constexpr __m256 BlendMask64[8] = {
+        __m256i{Bits(64), Bits(64), Bits(64), Bits(64)},
+        __m256i{0, Bits(64), Bits(64), Bits(64)},
+        __m256i{0, 0, Bits(64), Bits(64)},
+        __m256i{0, 0, 0, Bits(64)},
+    };
+
+    constexpr __m256 BlendMask32[8] = {
+        __m256i{Bits(64), Bits(64), Bits(64), Bits(64)},
+        __m256i{Bits(32), Bits(64), Bits(64), Bits(64)},
+        __m256i{0, Bits(64), Bits(64), Bits(64)},
+        __m256i{0, Bits(32), Bits(64), Bits(64)},
+        __m256i{0, 0, Bits(64), Bits(64)},
+        __m256i{0, 0, Bits(32), Bits(64)},
+        __m256i{0, 0, 0, Bits(64)},
+        __m256i{0, 0, 0, Bits(32)},
+    };
+
+    constexpr __m128 BlendMask8[16] = {
+        __m128i{Bits(64), Bits(64)},
+        __m128i{Bits(56), Bits(64)},
+        __m128i{Bits(48), Bits(64)},
+        __m128i{Bits(40), Bits(64)},
+        __m128i{Bits(32), Bits(64)},
+        __m128i{Bits(24), Bits(64)},
+        __m128i{Bits(16), Bits(64)},
+        __m128i{Bits(8), Bits(64)},
+        __m128i{0, Bits(64)},
+        __m128i{0, Bits(56)},
+        __m128i{0, Bits(48)},
+        __m128i{0, Bits(40)},
+        __m128i{0, Bits(32)},
+        __m128i{0, Bits(24)},
+        __m128i{0, Bits(16)},
+        __m128i{0, Bits(8)},
+    };
+
+    // See https://stackoverflow.com/a/60109639
+    // Horizontal sum of eight i32 values in an avx register
+    i32 HsumI32(__m256i v) {
+        __m128i x = _mm_add_epi32(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
+        __m128i hi64  = _mm_unpackhi_epi64(x, x);
+        __m128i sum64 = _mm_add_epi32(hi64, x);
+        __m128i hi32  = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+        __m128i sum32 = _mm_add_epi32(sum64, hi32);
+        return _mm_cvtsi128_si32(sum32);
+    }
+
+    // Horizontal sum of four i64 values in an avx register
+    i64 HsumI64(__m256i v) {
+        __m128i x = _mm_add_epi64(_mm256_castsi256_si128(v), _mm256_extracti128_si256(v, 1));
+        return _mm_cvtsi128_si64(x) + _mm_extract_epi64(x, 1);
+    }
+
+    // Horizontal sum of eight float values in an avx register
+    float HsumFloat(__m256 v) {
+        __m256 y = _mm256_permute2f128_ps(v, v, 1);
+        v = _mm256_add_ps(v, y);
+        v = _mm256_hadd_ps(v, v);
+        return _mm256_cvtss_f32(_mm256_hadd_ps(v, v));
+    }
+
+    // Horizontal sum of four double values in an avx register
+    double HsumDouble(__m256 v) {
+        __m128d x = _mm_add_pd(_mm256_castpd256_pd128(v), _mm256_extractf128_pd(v, 1));
+        x = _mm_add_pd(x, _mm_shuffle_pd(x, x, 1));
+        return _mm_cvtsd_f64(x);
+    }
+
+    __m128i Load128i(const void* ptr) {
+        return _mm_loadu_si128((const __m128i*)ptr);
+    }
+
+    __m256i Load256i(const void* ptr) {
+        return _mm256_loadu_si256((const __m256i*)ptr);
+    }
+
+    // Unrolled dot product for relatively small sizes
+    // The loop with known upper bound is unrolled by the compiler, no need to do anything special about it
+    template <size_t size, class TInput, class TExtend>
+    i32 DotProductInt8Avx2_Unroll(const TInput* lhs, const TInput* rhs, TExtend extend) noexcept {
+        static_assert(size % 16 == 0);
+        auto sum = _mm256_setzero_ps();
+        for (size_t i = 0; i != size; i += 16) {
+            sum = _mm256_add_epi32(sum, _mm256_madd_epi16(extend(Load128i(lhs + i)), extend(Load128i(rhs + i))));
+        }
+
+        return HsumI32(sum);
+    }
+
+    template <class TInput, class TExtend>
+    i32 DotProductInt8Avx2(const TInput* lhs, const TInput* rhs, size_t length, TExtend extend) noexcept {
+        // Fully unrolled versions for small multiples for 16
+        switch (length) {
+            case 16: return DotProductInt8Avx2_Unroll<16>(lhs, rhs, extend);
+            case 32: return DotProductInt8Avx2_Unroll<32>(lhs, rhs, extend);
+            case 48: return DotProductInt8Avx2_Unroll<48>(lhs, rhs, extend);
+            case 64: return DotProductInt8Avx2_Unroll<64>(lhs, rhs, extend);
+        }
+
+        __m256i sum = _mm256_setzero_ps();
+
+        if (const auto leftover = length % 16; leftover != 0) {
+            auto a = _mm_blendv_epi8(
+                    Load128i(lhs), _mm_setzero_ps(), BlendMask8[leftover]);
+            auto b = _mm_blendv_epi8(
+                    Load128i(rhs), _mm_setzero_ps(), BlendMask8[leftover]);
+
+            sum = _mm256_madd_epi16(extend(a), extend(b));
+
+            lhs += leftover;
+            rhs += leftover;
+            length -= leftover;
+        }
+
+        while (length >= 32) {
+            const auto l0 = extend(Load128i(lhs));
+            const auto r0 = extend(Load128i(rhs));
+            const auto l1 = extend(Load128i(lhs + 16));
+            const auto r1 = extend(Load128i(rhs + 16));
+
+            const auto s0 = _mm256_madd_epi16(l0, r0);
+            const auto s1 = _mm256_madd_epi16(l1, r1);
+
+            sum = _mm256_add_epi32(sum, _mm256_add_epi32(s0, s1));
+
+            lhs += 32;
+            rhs += 32;
+            length -= 32;
+        }
+
+        if (length > 0) {
+            auto l = extend(Load128i(lhs));
+            auto r = extend(Load128i(rhs));
+
+            sum = _mm256_add_epi32(sum, _mm256_madd_epi16(l, r));
+        }
+
+        return HsumI32(sum);
+    }
+}
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    if (length < 16) {
+        return DotProductSse(lhs, rhs, length);
+    }
+    return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) {
+        return _mm256_cvtepi8_epi16(x);
+    });
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    if (length < 16) {
+        return DotProductSse(lhs, rhs, length);
+    }
+    return DotProductInt8Avx2(lhs, rhs, length, [](const __m128i x) {
+        return _mm256_cvtepu8_epi16(x);
+    });
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    if (length < 16) {
+        return DotProductSse(lhs, rhs, length);
+    }
+    __m256i res = _mm256_setzero_ps();
+
+    if (const auto leftover = length % 8; leftover != 0) {
+        // Use floating-point blendv. Who cares as long as the size is right.
+        __m256i a = _mm256_blendv_ps(
+                Load256i(lhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+        __m256i b = _mm256_blendv_ps(
+                Load256i(rhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+
+        res = _mm256_mul_epi32(a, b);
+        a = _mm256_alignr_epi8(a, a, 4);
+        b = _mm256_alignr_epi8(b, b, 4);
+        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);
+
+        lhs += leftover;
+        rhs += leftover;
+        length -= leftover;
+    }
+
+    while (length >= 8) {
+        __m256i a = Load256i(lhs);
+        __m256i b = Load256i(rhs);
+        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);    // This is lower parts multiplication
+        a = _mm256_alignr_epi8(a, a, 4);
+        b = _mm256_alignr_epi8(b, b, 4);
+        res = _mm256_add_epi64(_mm256_mul_epi32(a, b), res);
+        rhs += 8;
+        lhs += 8;
+        length -= 8;
+    }
+
+    return HsumI64(res);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+    if (length < 16) {
+        return DotProductSse(lhs, rhs, length);
+    }
+    __m256 sum1 = _mm256_setzero_ps();
+    __m256 sum2 = _mm256_setzero_ps();
+    __m256 a1, b1, a2, b2;
+
+    if (const auto leftover = length % 8; leftover != 0) {
+        a1 = _mm256_blendv_ps(
+                _mm256_loadu_ps(lhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+        b1 = _mm256_blendv_ps(
+                _mm256_loadu_ps(rhs), _mm256_setzero_ps(), BlendMask32[leftover]);
+        sum1 = _mm256_mul_ps(a1, b1);
+        lhs += leftover;
+        rhs += leftover;
+        length -= leftover;
+    }
+
+    while (length >= 16) {
+        a1 = _mm256_loadu_ps(lhs);
+        b1 = _mm256_loadu_ps(rhs);
+        a2 = _mm256_loadu_ps(lhs + 8);
+        b2 = _mm256_loadu_ps(rhs + 8);
+
+        sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+        sum2 = _mm256_fmadd_ps(a2, b2, sum2);
+
+        length -= 16;
+        lhs += 16;
+        rhs += 16;
+    }
+
+    if (length > 0) {
+        a1 = _mm256_loadu_ps(lhs);
+        b1 = _mm256_loadu_ps(rhs);
+        sum1 = _mm256_fmadd_ps(a1, b1, sum1);
+    }
+
+    return HsumFloat(_mm256_add_ps(sum1, sum2));
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+    if (length < 16) {
+        return DotProductSse(lhs, rhs, length);
+    }
+    __m256d sum1 = _mm256_setzero_pd();
+    __m256d sum2 = _mm256_setzero_pd();
+    __m256d a1, b1, a2, b2;
+
+    if (const auto leftover = length % 4; leftover != 0) {
+        a1 = _mm256_blendv_pd(
+                _mm256_loadu_pd(lhs), _mm256_setzero_ps(), BlendMask64[leftover]);
+        b1 = _mm256_blendv_pd(
+                _mm256_loadu_pd(rhs), _mm256_setzero_ps(), BlendMask64[leftover]);
+        sum1 = _mm256_mul_pd(a1, b1);
+        lhs += leftover;
+        rhs += leftover;
+        length -= leftover;
+    }
+
+    while (length >= 8) {
+        a1 = _mm256_loadu_pd(lhs);
+        b1 = _mm256_loadu_pd(rhs);
+        a2 = _mm256_loadu_pd(lhs + 4);
+        b2 = _mm256_loadu_pd(rhs + 4);
+
+        sum1 = _mm256_fmadd_pd(a1, b1, sum1);
+        sum2 = _mm256_fmadd_pd(a2, b2, sum2);
+
+        length -= 8;
+        lhs += 8;
+        rhs += 8;
+    }
+
+    if (length > 0) {
+        a1 = _mm256_loadu_pd(lhs);
+        b1 = _mm256_loadu_pd(rhs);
+        sum1 = _mm256_fmadd_pd(a1, b1, sum1);
+    }
+
+    return HsumDouble(_mm256_add_pd(sum1, sum2));
+}
+
+#elif defined(ARCADIA_SSE)
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    return DotProductSse(lhs, rhs, length);
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    return DotProductSse(lhs, rhs, length);
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    return DotProductSse(lhs, rhs, length);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+    return DotProductSse(lhs, rhs, length);
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+    return DotProductSse(lhs, rhs, length);
+}
+
+#else
+
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+#endif

+ 19 - 0
library/cpp/dot_product/dot_product_avx2.h

@@ -0,0 +1,19 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+Y_PURE_FUNCTION
+i32 DotProductAvx2(const i8* lhs, const i8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductAvx2(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+i64 DotProductAvx2(const i32* lhs, const i32* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+float DotProductAvx2(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductAvx2(const double* lhs, const double* rhs, size_t length) noexcept;

+ 44 - 0
library/cpp/dot_product/dot_product_simple.cpp

@@ -0,0 +1,44 @@
+#include "dot_product_simple.h"
+
+namespace {
+    template <typename Res, typename Number>
+    static Res DotProductSimpleImpl(const Number* lhs, const Number* rhs, size_t length) noexcept {
+        Res s0 = 0;
+        Res s1 = 0;
+        Res s2 = 0;
+        Res s3 = 0;
+
+        while (length >= 4) {
+            s0 += static_cast<Res>(lhs[0]) * static_cast<Res>(rhs[0]);
+            s1 += static_cast<Res>(lhs[1]) * static_cast<Res>(rhs[1]);
+            s2 += static_cast<Res>(lhs[2]) * static_cast<Res>(rhs[2]);
+            s3 += static_cast<Res>(lhs[3]) * static_cast<Res>(rhs[3]);
+            lhs += 4;
+            rhs += 4;
+            length -= 4;
+        }
+
+        while (length--) {
+            s0 += static_cast<Res>(*lhs++) * static_cast<Res>(*rhs++);
+        }
+
+        return s0 + s1 + s2 + s3;
+    }
+}
+
+float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept {
+    return DotProductSimpleImpl<float, float>(lhs, rhs, length);
+}
+
+double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept {
+    return DotProductSimpleImpl<double, double>(lhs, rhs, length);
+}
+
+ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept {
+    ui32 res = 0;
+    for (size_t i = 0; i < lengtInBytes; ++i) {
+        res += static_cast<ui32>(lhs[i] & 0x0f) * static_cast<ui32>(rhs[i] & 0x0f);
+        res += static_cast<ui32>(lhs[i] & 0xf0) * static_cast<ui32>(rhs[i] & 0xf0) >> 8;
+    }
+    return res;
+}

+ 40 - 0
library/cpp/dot_product/dot_product_simple.h

@@ -0,0 +1,40 @@
+#pragma once
+
+#include <util/system/compiler.h>
+#include <util/system/types.h>
+
+#include <numeric>
+
+/**
+ * Dot product implementation without SSE optimizations.
+ */
+Y_PURE_FUNCTION
+inline ui32 DotProductSimple(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    return std::inner_product(lhs, lhs + length, rhs, static_cast<ui32>(0u),
+                              [](ui32 x1, ui16 x2) {return x1 + x2;},
+                              [](ui16 x1, ui8 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+inline i32 DotProductSimple(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    return std::inner_product(lhs, lhs + length, rhs, static_cast<i32>(0),
+                              [](i32 x1, i16 x2) {return x1 + x2;},
+                              [](i16 x1, i8 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+inline i64 DotProductSimple(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    return std::inner_product(lhs, lhs + length, rhs, static_cast<i64>(0),
+                              [](i64 x1, i64 x2) {return x1 + x2;},
+                              [](i64 x1, i32 x2) {return x1 * x2;});
+}
+
+Y_PURE_FUNCTION
+float DotProductSimple(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductSimple(const double* lhs, const double* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductUI4Simple(const ui8* lhs, const ui8* rhs, size_t lengtInBytes) noexcept;
+

+ 219 - 0
library/cpp/dot_product/dot_product_sse.cpp

@@ -0,0 +1,219 @@
+#include "dot_product_sse.h"
+
+#include <library/cpp/sse/sse.h>
+#include <util/system/platform.h>
+#include <util/system/compiler.h>
+
+#ifdef ARCADIA_SSE
+i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept {
+    const __m128i zero = _mm_setzero_si128();
+    __m128i resVec = zero;
+    while (length >= 16) {
+        __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
+        __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
+
+#ifdef _sse4_1_
+        __m128i lLo = _mm_cvtepi8_epi16(lVec);
+        __m128i rLo = _mm_cvtepi8_epi16(rVec);
+        __m128i lHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(lVec, lVec, 8));
+        __m128i rHi = _mm_cvtepi8_epi16(_mm_alignr_epi8(rVec, rVec, 8));
+#else
+        __m128i lLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, lVec), 8);
+        __m128i rLo = _mm_srai_epi16(_mm_unpacklo_epi8(zero, rVec), 8);
+        __m128i lHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, lVec), 8);
+        __m128i rHi = _mm_srai_epi16(_mm_unpackhi_epi8(zero, rVec), 8);
+#endif
+        resVec = _mm_add_epi32(resVec,
+                               _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
+
+        lhs += 16;
+        rhs += 16;
+        length -= 16;
+    }
+
+    alignas(16) i32 res[4];
+    _mm_store_si128((__m128i*)res, resVec);
+    i32 sum = res[0] + res[1] + res[2] + res[3];
+    for (size_t i = 0; i < length; ++i) {
+        sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
+    }
+
+    return sum;
+}
+
+ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept {
+    const __m128i zero = _mm_setzero_si128();
+    __m128i resVec = zero;
+    while (length >= 16) {
+        __m128i lVec = _mm_loadu_si128((const __m128i*)lhs);
+        __m128i rVec = _mm_loadu_si128((const __m128i*)rhs);
+
+        __m128i lLo = _mm_unpacklo_epi8(lVec, zero);
+        __m128i rLo = _mm_unpacklo_epi8(rVec, zero);
+        __m128i lHi = _mm_unpackhi_epi8(lVec, zero);
+        __m128i rHi = _mm_unpackhi_epi8(rVec, zero);
+
+        resVec = _mm_add_epi32(resVec,
+                               _mm_add_epi32(_mm_madd_epi16(lLo, rLo), _mm_madd_epi16(lHi, rHi)));
+
+        lhs += 16;
+        rhs += 16;
+        length -= 16;
+    }
+
+    alignas(16) i32 res[4];
+    _mm_store_si128((__m128i*)res, resVec);
+    i32 sum = res[0] + res[1] + res[2] + res[3];
+    for (size_t i = 0; i < length; ++i) {
+        sum += static_cast<i32>(lhs[i]) * static_cast<i32>(rhs[i]);
+    }
+
+    return static_cast<ui32>(sum);
+}
+#ifdef _sse4_1_
+
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    __m128i zero = _mm_setzero_si128();
+    __m128i res = zero;
+
+    while (length >= 4) {
+        __m128i a = _mm_loadu_si128((const __m128i*)lhs);
+        __m128i b = _mm_loadu_si128((const __m128i*)rhs);
+        res = _mm_add_epi64(_mm_mul_epi32(a, b), res);    // This is lower parts multiplication
+        a = _mm_alignr_epi8(a, a, 4);
+        b = _mm_alignr_epi8(b, b, 4);
+        res = _mm_add_epi64(_mm_mul_epi32(a, b), res);
+        rhs += 4;
+        lhs += 4;
+        length -= 4;
+    }
+
+    alignas(16) i64 r[2];
+    _mm_store_si128((__m128i*)r, res);
+    i64 sum = r[0] + r[1];
+
+    for (size_t i = 0; i < length; ++i) {
+        sum += static_cast<i64>(lhs[i]) * static_cast<i64>(rhs[i]);
+    }
+
+    return sum;
+}
+
+#else
+#include "dot_product_simple.h"
+
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept {
+    return DotProductSimple(lhs, rhs, length);
+}
+
+#endif
+
+float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept {
+    __m128 sum1 = _mm_setzero_ps();
+    __m128 sum2 = _mm_setzero_ps();
+    __m128 a1, b1, a2, b2, m1, m2;
+
+    while (length >= 8) {
+        a1 = _mm_loadu_ps(lhs);
+        b1 = _mm_loadu_ps(rhs);
+        m1 = _mm_mul_ps(a1, b1);
+
+        a2 = _mm_loadu_ps(lhs + 4);
+        sum1 = _mm_add_ps(sum1, m1);
+
+        b2 = _mm_loadu_ps(rhs + 4);
+        m2 = _mm_mul_ps(a2, b2);
+
+        sum2 = _mm_add_ps(sum2, m2);
+
+        length -= 8;
+        lhs += 8;
+        rhs += 8;
+    }
+
+    if (length >= 4) {
+        a1 = _mm_loadu_ps(lhs);
+        b1 = _mm_loadu_ps(rhs);
+        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
+
+        length -= 4;
+        lhs += 4;
+        rhs += 4;
+    }
+
+    sum1 = _mm_add_ps(sum1, sum2);
+
+    if (length) {
+        switch (length) {
+            case 3:
+                a1 = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]);
+                b1 = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]);
+                break;
+
+            case 2:
+                a1 = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]);
+                b1 = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]);
+                break;
+
+            case 1:
+                a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]);
+                b1 = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]);
+                break;
+
+            default:
+                Y_UNREACHABLE();
+        }
+
+        sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, b1));
+    }
+
+    alignas(16) float res[4];
+    _mm_store_ps(res, sum1);
+
+    return res[0] + res[1] + res[2] + res[3];
+}
+
+double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept {
+    __m128d sum1 = _mm_setzero_pd();
+    __m128d sum2 = _mm_setzero_pd();
+    __m128d a1, b1, a2, b2;
+
+    while (length >= 4) {
+        a1 = _mm_loadu_pd(lhs);
+        b1 = _mm_loadu_pd(rhs);
+        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+
+        a2 = _mm_loadu_pd(lhs + 2);
+        b2 = _mm_loadu_pd(rhs + 2);
+        sum2 = _mm_add_pd(sum2, _mm_mul_pd(a2, b2));
+
+        length -= 4;
+        lhs += 4;
+        rhs += 4;
+    }
+
+    if (length >= 2) {
+        a1 = _mm_loadu_pd(lhs);
+        b1 = _mm_loadu_pd(rhs);
+        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+
+        length -= 2;
+        lhs += 2;
+        rhs += 2;
+    }
+
+    sum1 = _mm_add_pd(sum1, sum2);
+
+    if (length > 0) {
+        a1 = _mm_set_pd(lhs[0], 0.0);
+        b1 = _mm_set_pd(rhs[0], 0.0);
+        sum1 = _mm_add_pd(sum1, _mm_mul_pd(a1, b1));
+    }
+
+    alignas(16) double res[2];
+    _mm_store_pd(res, sum1);
+
+    return res[0] + res[1];
+}
+
+#endif // ARCADIA_SSE

+ 19 - 0
library/cpp/dot_product/dot_product_sse.h

@@ -0,0 +1,19 @@
+#pragma once
+
+#include <util/system/types.h>
+#include <util/system/compiler.h>
+
+Y_PURE_FUNCTION
+i32 DotProductSse(const i8* lhs, const i8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+ui32 DotProductSse(const ui8* lhs, const ui8* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+i64 DotProductSse(const i32* lhs, const i32* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+float DotProductSse(const float* lhs, const float* rhs, size_t length) noexcept;
+
+Y_PURE_FUNCTION
+double DotProductSse(const double* lhs, const double* rhs, size_t length) noexcept;

+ 20 - 0
library/cpp/dot_product/ya.make

@@ -0,0 +1,20 @@
+LIBRARY()
+
+SRCS(
+    dot_product.cpp
+    dot_product_sse.cpp
+    dot_product_simple.cpp
+)
+
+IF (USE_SSE4 == "yes" AND OS_LINUX == "yes")
+    SRC_C_AVX2(dot_product_avx2.cpp -mfma)
+ELSE()
+    SRC(dot_product_avx2.cpp)
+ENDIF()
+
+PEERDIR(
+    library/cpp/sse
+    library/cpp/testing/common
+)
+
+END()