kyber512r3_poly_avx2.c 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. #include <stdint.h>
  2. #include <string.h>
  3. #include "kyber512r3_align_avx2.h"
  4. #include "kyber512r3_consts_avx2.h"
  5. #include "kyber512r3_poly_avx2.h"
  6. #include "kyber512r3_ntt_avx2.h"
  7. #include "kyber512r3_reduce_avx2.h"
  8. #include "kyber512r3_cbd_avx2.h"
  9. #include "kyber512r3_fips202.h"
  10. #include "kyber512r3_fips202x4_avx2.h"
  11. #include "kyber512r3_symmetric.h"
  12. #if defined(S2N_KYBER512R3_AVX2_BMI2)
  13. #include <immintrin.h>
  14. /*************************************************
  15. * Name: poly_compress_avx2
  16. *
  17. * Description: Compression and subsequent serialization of a polynomial.
  18. * The coefficients of the input polynomial are assumed to
  19. * lie in the invertal [0,q], i.e. the polynomial must be reduced
  20. * by poly_reduce_avx2().
  21. *
  22. * Arguments: - uint8_t *r: pointer to output byte array
  23. * (of length S2N_KYBER_512_R3_POLYCOMPRESSEDBYTES)
  24. * - const poly *a: pointer to input polynomial
  25. **************************************************/
  26. void poly_compress_avx2(uint8_t r[128], const poly * restrict a)
  27. {
  28. unsigned int i;
  29. __m256i f0, f1, f2, f3;
  30. const __m256i v = _mm256_load_si256(&qdata.vec[_16XV/16]);
  31. const __m256i shift1 = _mm256_set1_epi16(1 << 9);
  32. const __m256i mask = _mm256_set1_epi16(15);
  33. const __m256i shift2 = _mm256_set1_epi16((16 << 8) + 1);
  34. const __m256i permdidx = _mm256_set_epi32(7,3,6,2,5,1,4,0);
  35. for(i=0;i<S2N_KYBER_512_R3_N/64;i++) {
  36. f0 = _mm256_load_si256(&a->vec[4*i+0]);
  37. f1 = _mm256_load_si256(&a->vec[4*i+1]);
  38. f2 = _mm256_load_si256(&a->vec[4*i+2]);
  39. f3 = _mm256_load_si256(&a->vec[4*i+3]);
  40. f0 = _mm256_mulhi_epi16(f0,v);
  41. f1 = _mm256_mulhi_epi16(f1,v);
  42. f2 = _mm256_mulhi_epi16(f2,v);
  43. f3 = _mm256_mulhi_epi16(f3,v);
  44. f0 = _mm256_mulhrs_epi16(f0,shift1);
  45. f1 = _mm256_mulhrs_epi16(f1,shift1);
  46. f2 = _mm256_mulhrs_epi16(f2,shift1);
  47. f3 = _mm256_mulhrs_epi16(f3,shift1);
  48. f0 = _mm256_and_si256(f0,mask);
  49. f1 = _mm256_and_si256(f1,mask);
  50. f2 = _mm256_and_si256(f2,mask);
  51. f3 = _mm256_and_si256(f3,mask);
  52. f0 = _mm256_packus_epi16(f0,f1);
  53. f2 = _mm256_packus_epi16(f2,f3);
  54. f0 = _mm256_maddubs_epi16(f0,shift2);
  55. f2 = _mm256_maddubs_epi16(f2,shift2);
  56. f0 = _mm256_packus_epi16(f0,f2);
  57. f0 = _mm256_permutevar8x32_epi32(f0,permdidx);
  58. // correcting cast-align error
  59. // old version: _mm256_storeu_si256((__m256i *)&r[32*i],f0);
  60. _mm256_storeu_si256((void *)&r[32*i],f0);
  61. }
  62. }
  63. void poly_decompress_avx2(poly * restrict r, const uint8_t a[128])
  64. {
  65. unsigned int i;
  66. __m128i t;
  67. __m256i f;
  68. const __m256i q = _mm256_load_si256(&qdata.vec[_16XQ/16]);
  69. const __m256i shufbidx = _mm256_set_epi8(7,7,7,7,6,6,6,6,5,5,5,5,4,4,4,4,
  70. 3,3,3,3,2,2,2,2,1,1,1,1,0,0,0,0);
  71. const __m256i mask = _mm256_set1_epi32(0x00F0000F);
  72. const __m256i shift = _mm256_set1_epi32((128 << 16) + 2048);
  73. for(i=0;i<S2N_KYBER_512_R3_N/16;i++) {
  74. // correcting cast-align and cast-qual errors
  75. // old version: t = _mm_loadl_epi64((__m128i *)&a[8*i]);
  76. t = _mm_loadl_epi64((const void *)&a[8*i]);
  77. f = _mm256_broadcastsi128_si256(t);
  78. f = _mm256_shuffle_epi8(f,shufbidx);
  79. f = _mm256_and_si256(f,mask);
  80. f = _mm256_mullo_epi16(f,shift);
  81. f = _mm256_mulhrs_epi16(f,q);
  82. _mm256_store_si256(&r->vec[i],f);
  83. }
  84. }
  85. /*************************************************
  86. * Name: poly_tobytes_avx2
  87. *
  88. * Description: Serialization of a polynomial in NTT representation.
  89. * The coefficients of the input polynomial are assumed to
  90. * lie in the invertal [0,q], i.e. the polynomial must be reduced
  91. * by poly_reduce_avx2(). The coefficients are orderd as output by
  92. * poly_ntt_avx2(); the serialized output coefficients are in bitreversed
  93. * order.
  94. *
  95. * Arguments: - uint8_t *r: pointer to output byte array
  96. * (needs space for S2N_KYBER_512_R3_POLYBYTES bytes)
  97. * - poly *a: pointer to input polynomial
  98. **************************************************/
  99. void poly_tobytes_avx2(uint8_t r[S2N_KYBER_512_R3_POLYBYTES], const poly *a)
  100. {
  101. ntttobytes_avx2_asm(r, a->vec, qdata.vec);
  102. }
  103. /*************************************************
  104. * Name: poly_frombytes_avx2
  105. *
  106. * Description: De-serialization of a polynomial;
  107. * inverse of poly_tobytes_avx2
  108. *
  109. * Arguments: - poly *r: pointer to output polynomial
  110. * - const uint8_t *a: pointer to input byte array
  111. * (of S2N_KYBER_512_R3_POLYBYTES bytes)
  112. **************************************************/
  113. void poly_frombytes_avx2(poly *r, const uint8_t a[S2N_KYBER_512_R3_POLYBYTES])
  114. {
  115. nttfrombytes_avx2_asm(r->vec, a, qdata.vec);
  116. }
  117. /*************************************************
  118. * Name: poly_frommsg_avx2
  119. *
  120. * Description: Convert 32-byte message to polynomial
  121. *
  122. * Arguments: - poly *r: pointer to output polynomial
  123. * - const uint8_t *msg: pointer to input message
  124. **************************************************/
  125. void poly_frommsg_avx2(poly * restrict r, const uint8_t msg[S2N_KYBER_512_R3_INDCPA_MSGBYTES])
  126. {
  127. __m256i f, g0, g1, g2, g3, h0, h1, h2, h3;
  128. const __m256i shift = _mm256_broadcastsi128_si256(_mm_set_epi32(0,1,2,3));
  129. const __m256i idx = _mm256_broadcastsi128_si256(_mm_set_epi8(15,14,11,10,7,6,3,2,13,12,9,8,5,4,1,0));
  130. const __m256i hqs = _mm256_set1_epi16((S2N_KYBER_512_R3_Q+1)/2);
  131. #define FROMMSG64(i) \
  132. g3 = _mm256_shuffle_epi32(f,0x55*i); \
  133. g3 = _mm256_sllv_epi32(g3,shift); \
  134. g3 = _mm256_shuffle_epi8(g3,idx); \
  135. g0 = _mm256_slli_epi16(g3,12); \
  136. g1 = _mm256_slli_epi16(g3,8); \
  137. g2 = _mm256_slli_epi16(g3,4); \
  138. g0 = _mm256_srai_epi16(g0,15); \
  139. g1 = _mm256_srai_epi16(g1,15); \
  140. g2 = _mm256_srai_epi16(g2,15); \
  141. g3 = _mm256_srai_epi16(g3,15); \
  142. g0 = _mm256_and_si256(g0,hqs); /* 19 18 17 16 3 2 1 0 */ \
  143. g1 = _mm256_and_si256(g1,hqs); /* 23 22 21 20 7 6 5 4 */ \
  144. g2 = _mm256_and_si256(g2,hqs); /* 27 26 25 24 11 10 9 8 */ \
  145. g3 = _mm256_and_si256(g3,hqs); /* 31 30 29 28 15 14 13 12 */ \
  146. h0 = _mm256_unpacklo_epi64(g0,g1); \
  147. h2 = _mm256_unpackhi_epi64(g0,g1); \
  148. h1 = _mm256_unpacklo_epi64(g2,g3); \
  149. h3 = _mm256_unpackhi_epi64(g2,g3); \
  150. g0 = _mm256_permute2x128_si256(h0,h1,0x20); \
  151. g2 = _mm256_permute2x128_si256(h0,h1,0x31); \
  152. g1 = _mm256_permute2x128_si256(h2,h3,0x20); \
  153. g3 = _mm256_permute2x128_si256(h2,h3,0x31); \
  154. _mm256_store_si256(&r->vec[0+2*i+0],g0); \
  155. _mm256_store_si256(&r->vec[0+2*i+1],g1); \
  156. _mm256_store_si256(&r->vec[8+2*i+0],g2); \
  157. _mm256_store_si256(&r->vec[8+2*i+1],g3)
  158. // correcting cast-align and cast-qual errors
  159. // old version: f = _mm256_loadu_si256((__m256i *)msg);
  160. f = _mm256_loadu_si256((const void *)msg);
  161. FROMMSG64(0);
  162. FROMMSG64(1);
  163. FROMMSG64(2);
  164. FROMMSG64(3);
  165. }
  166. /*************************************************
  167. * Name: poly_tomsg_avx2
  168. *
  169. * Description: Convert polynomial to 32-byte message.
  170. * The coefficients of the input polynomial are assumed to
  171. * lie in the invertal [0,q], i.e. the polynomial must be reduced
  172. * by poly_reduce_avx2().
  173. *
  174. * Arguments: - uint8_t *msg: pointer to output message
  175. * - poly *a: pointer to input polynomial
  176. **************************************************/
  177. void poly_tomsg_avx2(uint8_t msg[S2N_KYBER_512_R3_INDCPA_MSGBYTES], const poly * restrict a)
  178. {
  179. unsigned int i;
  180. uint32_t small;
  181. __m256i f0, f1, g0, g1;
  182. const __m256i hq = _mm256_set1_epi16((S2N_KYBER_512_R3_Q - 1)/2);
  183. const __m256i hhq = _mm256_set1_epi16((S2N_KYBER_512_R3_Q - 1)/4);
  184. for(i=0;i<S2N_KYBER_512_R3_N/32;i++) {
  185. f0 = _mm256_load_si256(&a->vec[2*i+0]);
  186. f1 = _mm256_load_si256(&a->vec[2*i+1]);
  187. f0 = _mm256_sub_epi16(hq, f0);
  188. f1 = _mm256_sub_epi16(hq, f1);
  189. g0 = _mm256_srai_epi16(f0, 15);
  190. g1 = _mm256_srai_epi16(f1, 15);
  191. f0 = _mm256_xor_si256(f0, g0);
  192. f1 = _mm256_xor_si256(f1, g1);
  193. f0 = _mm256_sub_epi16(f0, hhq);
  194. f1 = _mm256_sub_epi16(f1, hhq);
  195. f0 = _mm256_packs_epi16(f0, f1);
  196. f0 = _mm256_permute4x64_epi64(f0, 0xD8);
  197. small = _mm256_movemask_epi8(f0);
  198. memcpy(&msg[4*i], &small, 4);
  199. }
  200. }
  201. /*************************************************
  202. * Name: poly_getnoise_eta1_avx2
  203. *
  204. * Description: Sample a polynomial deterministically from a seed and a nonce,
  205. * with output polynomial close to centered binomial distribution
  206. * with parameter S2N_KYBER_512_R3_ETA1
  207. *
  208. * Arguments: - poly *r: pointer to output polynomial
  209. * - const uint8_t *seed: pointer to input seed
  210. * (of length S2N_KYBER_512_R3_SYMBYTES bytes)
  211. * - uint8_t nonce: one-byte input nonce
  212. **************************************************/
  213. void poly_getnoise_eta1_avx2(poly *r, const uint8_t seed[S2N_KYBER_512_R3_SYMBYTES], uint8_t nonce)
  214. {
  215. ALIGNED_UINT8(S2N_KYBER_512_R3_ETA1*S2N_KYBER_512_R3_N/4+32) buf; // +32 bytes as required by poly_cbd_eta1_avx2
  216. shake256_prf(buf.coeffs, S2N_KYBER_512_R3_ETA1*S2N_KYBER_512_R3_N/4, seed, nonce);
  217. poly_cbd_eta1_avx2(r, buf.vec);
  218. }
  219. /*************************************************
  220. * Name: poly_getnoise_eta2_avx2
  221. *
  222. * Description: Sample a polynomial deterministically from a seed and a nonce,
  223. * with output polynomial close to centered binomial distribution
  224. * with parameter S2N_KYBER_512_R3_ETA2
  225. *
  226. * Arguments: - poly *r: pointer to output polynomial
  227. * - const uint8_t *seed: pointer to input seed
  228. * (of length S2N_KYBER_512_R3_SYMBYTES bytes)
  229. * - uint8_t nonce: one-byte input nonce
  230. **************************************************/
  231. void poly_getnoise_eta2_avx2(poly *r, const uint8_t seed[S2N_KYBER_512_R3_SYMBYTES], uint8_t nonce)
  232. {
  233. ALIGNED_UINT8(S2N_KYBER_512_R3_ETA2*S2N_KYBER_512_R3_N/4) buf;
  234. shake256_prf(buf.coeffs, S2N_KYBER_512_R3_ETA2*S2N_KYBER_512_R3_N/4, seed, nonce);
  235. poly_cbd_eta2_avx2(r, buf.vec);
  236. }
  237. #define NOISE_NBLOCKS ((S2N_KYBER_512_R3_ETA1*S2N_KYBER_512_R3_N/4+S2N_KYBER_512_R3_SHAKE256_RATE-1)/S2N_KYBER_512_R3_SHAKE256_RATE)
  238. void poly_getnoise_eta1_4x(poly *r0,
  239. poly *r1,
  240. poly *r2,
  241. poly *r3,
  242. const uint8_t seed[32],
  243. uint8_t nonce0,
  244. uint8_t nonce1,
  245. uint8_t nonce2,
  246. uint8_t nonce3)
  247. {
  248. ALIGNED_UINT8(NOISE_NBLOCKS*S2N_KYBER_512_R3_SHAKE256_RATE) buf[4];
  249. __m256i f;
  250. keccakx4_state state;
  251. // correcting cast-align and cast-qual errors
  252. // old version: f = _mm256_loadu_si256((__m256i *)seed);
  253. f = _mm256_loadu_si256((const void *)seed);
  254. _mm256_store_si256(buf[0].vec, f);
  255. _mm256_store_si256(buf[1].vec, f);
  256. _mm256_store_si256(buf[2].vec, f);
  257. _mm256_store_si256(buf[3].vec, f);
  258. buf[0].coeffs[32] = nonce0;
  259. buf[1].coeffs[32] = nonce1;
  260. buf[2].coeffs[32] = nonce2;
  261. buf[3].coeffs[32] = nonce3;
  262. shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33);
  263. shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state);
  264. poly_cbd_eta1_avx2(r0, buf[0].vec);
  265. poly_cbd_eta1_avx2(r1, buf[1].vec);
  266. poly_cbd_eta1_avx2(r2, buf[2].vec);
  267. poly_cbd_eta1_avx2(r3, buf[3].vec);
  268. }
  269. void poly_getnoise_eta1122_4x(poly *r0,
  270. poly *r1,
  271. poly *r2,
  272. poly *r3,
  273. const uint8_t seed[32],
  274. uint8_t nonce0,
  275. uint8_t nonce1,
  276. uint8_t nonce2,
  277. uint8_t nonce3)
  278. {
  279. ALIGNED_UINT8(NOISE_NBLOCKS*S2N_KYBER_512_R3_SHAKE256_RATE) buf[4];
  280. __m256i f;
  281. keccakx4_state state;
  282. // correcting cast-align and cast-qual errors
  283. // old version: f = _mm256_loadu_si256((__m256i *)seed);
  284. f = _mm256_loadu_si256((const void *)seed);
  285. _mm256_store_si256(buf[0].vec, f);
  286. _mm256_store_si256(buf[1].vec, f);
  287. _mm256_store_si256(buf[2].vec, f);
  288. _mm256_store_si256(buf[3].vec, f);
  289. buf[0].coeffs[32] = nonce0;
  290. buf[1].coeffs[32] = nonce1;
  291. buf[2].coeffs[32] = nonce2;
  292. buf[3].coeffs[32] = nonce3;
  293. shake256x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 33);
  294. shake256x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, NOISE_NBLOCKS, &state);
  295. poly_cbd_eta1_avx2(r0, buf[0].vec);
  296. poly_cbd_eta1_avx2(r1, buf[1].vec);
  297. poly_cbd_eta2_avx2(r2, buf[2].vec);
  298. poly_cbd_eta2_avx2(r3, buf[3].vec);
  299. }
  300. /*************************************************
  301. * Name: poly_ntt_avx2
  302. *
  303. * Description: Computes negacyclic number-theoretic transform (NTT) of
  304. * a polynomial in place.
  305. * Input coefficients assumed to be in normal order,
  306. * output coefficients are in special order that is natural
  307. * for the vectorization. Input coefficients are assumed to be
  308. * bounded by q in absolute value, output coefficients are bounded
  309. * by 16118 in absolute value.
  310. *
  311. * Arguments: - poly *r: pointer to in/output polynomial
  312. **************************************************/
  313. void poly_ntt_avx2(poly *r)
  314. {
  315. ntt_avx2_asm(r->vec, qdata.vec);
  316. }
  317. /*************************************************
  318. * Name: poly_invntt_tomont_avx2
  319. *
  320. * Description: Computes inverse of negacyclic number-theoretic transform (NTT)
  321. * of a polynomial in place;
  322. * Input coefficients assumed to be in special order from vectorized
  323. * forward ntt, output in normal order. Input coefficients can be
  324. * arbitrary 16-bit integers, output coefficients are bounded by 14870
  325. * in absolute value.
  326. *
  327. * Arguments: - poly *a: pointer to in/output polynomial
  328. **************************************************/
  329. void poly_invntt_tomont_avx2(poly *r)
  330. {
  331. invntt_avx2_asm(r->vec, qdata.vec);
  332. }
  333. void poly_nttunpack_avx2(poly *r)
  334. {
  335. nttunpack_avx2_asm(r->vec, qdata.vec);
  336. }
  337. /*************************************************
  338. * Name: poly_basemul_montgomery_avx2
  339. *
  340. * Description: Multiplication of two polynomials in NTT domain.
  341. * One of the input polynomials needs to have coefficients
  342. * bounded by q, the other polynomial can have arbitrary
  343. * coefficients. Output coefficients are bounded by 6656.
  344. *
  345. * Arguments: - poly *r: pointer to output polynomial
  346. * - const poly *a: pointer to first input polynomial
  347. * - const poly *b: pointer to second input polynomial
  348. **************************************************/
  349. void poly_basemul_montgomery_avx2(poly *r, const poly *a, const poly *b)
  350. {
  351. basemul_avx2_asm(r->vec, a->vec, b->vec, qdata.vec);
  352. }
  353. /*************************************************
  354. * Name: poly_tomont_avx2
  355. *
  356. * Description: Inplace conversion of all coefficients of a polynomial
  357. * from normal domain to Montgomery domain
  358. *
  359. * Arguments: - poly *r: pointer to input/output polynomial
  360. **************************************************/
  361. void poly_tomont_avx2(poly *r)
  362. {
  363. tomont_avx2_asm(r->vec, qdata.vec);
  364. }
  365. /*************************************************
  366. * Name: poly_reduce_avx2
  367. *
  368. * Description: Applies Barrett reduction to all coefficients of a polynomial
  369. * for details of the Barrett reduction see comments in reduce.c
  370. *
  371. * Arguments: - poly *r: pointer to input/output polynomial
  372. **************************************************/
  373. void poly_reduce_avx2(poly *r)
  374. {
  375. reduce_avx2_asm(r->vec, qdata.vec);
  376. }
  377. /*************************************************
  378. * Name: poly_add_avx2
  379. *
  380. * Description: Add two polynomials. No modular reduction
  381. * is performed.
  382. *
  383. * Arguments: - poly *r: pointer to output polynomial
  384. * - const poly *a: pointer to first input polynomial
  385. * - const poly *b: pointer to second input polynomial
  386. **************************************************/
  387. void poly_add_avx2(poly *r, const poly *a, const poly *b)
  388. {
  389. unsigned int i;
  390. __m256i f0, f1;
  391. for(i=0;i<S2N_KYBER_512_R3_N/16;i++) {
  392. f0 = _mm256_load_si256(&a->vec[i]);
  393. f1 = _mm256_load_si256(&b->vec[i]);
  394. f0 = _mm256_add_epi16(f0, f1);
  395. _mm256_store_si256(&r->vec[i], f0);
  396. }
  397. }
  398. /*************************************************
  399. * Name: poly_sub_avx2
  400. *
  401. * Description: Subtract two polynomials. No modular reduction
  402. * is performed.
  403. *
  404. * Arguments: - poly *r: pointer to output polynomial
  405. * - const poly *a: pointer to first input polynomial
  406. * - const poly *b: pointer to second input polynomial
  407. **************************************************/
  408. void poly_sub_avx2(poly *r, const poly *a, const poly *b)
  409. {
  410. unsigned int i;
  411. __m256i f0, f1;
  412. for(i=0;i<S2N_KYBER_512_R3_N/16;i++) {
  413. f0 = _mm256_load_si256(&a->vec[i]);
  414. f1 = _mm256_load_si256(&b->vec[i]);
  415. f0 = _mm256_sub_epi16(f0, f1);
  416. _mm256_store_si256(&r->vec[i], f0);
  417. }
  418. }
  419. #endif