kyber512r3_indcpa_avx2.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. #include <stddef.h>
  2. #include <stdint.h>
  3. #include <string.h>
  4. #include "kyber512r3_align_avx2.h"
  5. #include "kyber512r3_params.h"
  6. #include "kyber512r3_indcpa_avx2.h"
  7. #include "kyber512r3_polyvec_avx2.h"
  8. #include "kyber512r3_poly_avx2.h"
  9. #include "kyber512r3_rejsample_avx2.h"
  10. #include "kyber512r3_fips202.h"
  11. #include "kyber512r3_fips202x4_avx2.h"
  12. #include "pq-crypto/s2n_pq_random.h"
  13. #include "utils/s2n_safety.h"
  14. #if defined(S2N_KYBER512R3_AVX2_BMI2)
  15. #include <immintrin.h>
  16. /*************************************************
  17. * Name: pack_pk
  18. *
  19. * Description: Serialize the public key as concatenation of the
  20. * serialized vector of polynomials pk and the
  21. * public seed used to generate the matrix A.
  22. * The polynomial coefficients in pk are assumed to
  23. * lie in the invertal [0,q], i.e. pk must be reduced
  24. * by polyvec_reduce_avx2().
  25. *
  26. * Arguments: uint8_t *r: pointer to the output serialized public key
  27. * polyvec *pk: pointer to the input public-key polyvec
  28. * const uint8_t *seed: pointer to the input public seed
  29. **************************************************/
  30. static void pack_pk(uint8_t r[S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES],
  31. polyvec *pk,
  32. const uint8_t seed[S2N_KYBER_512_R3_SYMBYTES])
  33. {
  34. polyvec_tobytes_avx2(r, pk);
  35. memcpy(r+S2N_KYBER_512_R3_POLYVECBYTES, seed, S2N_KYBER_512_R3_SYMBYTES);
  36. }
  37. /*************************************************
  38. * Name: unpack_pk
  39. *
  40. * Description: De-serialize public key from a byte array;
  41. * approximate inverse of pack_pk
  42. *
  43. * Arguments: - polyvec *pk: pointer to output public-key polynomial vector
  44. * - uint8_t *seed: pointer to output seed to generate matrix A
  45. * - const uint8_t *packedpk: pointer to input serialized public key
  46. **************************************************/
  47. static void unpack_pk(polyvec *pk,
  48. uint8_t seed[S2N_KYBER_512_R3_SYMBYTES],
  49. const uint8_t packedpk[S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES])
  50. {
  51. polyvec_frombytes_avx2(pk, packedpk);
  52. memcpy(seed, packedpk+S2N_KYBER_512_R3_POLYVECBYTES, S2N_KYBER_512_R3_SYMBYTES);
  53. }
  54. /*************************************************
  55. * Name: pack_sk
  56. *
  57. * Description: Serialize the secret key.
  58. * The polynomial coefficients in sk are assumed to
  59. * lie in the invertal [0,q], i.e. sk must be reduced
  60. * by polyvec_reduce_avx2().
  61. *
  62. * Arguments: - uint8_t *r: pointer to output serialized secret key
  63. * - polyvec *sk: pointer to input vector of polynomials (secret key)
  64. **************************************************/
  65. static void pack_sk(uint8_t r[S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES], polyvec *sk)
  66. {
  67. polyvec_tobytes_avx2(r, sk);
  68. }
  69. /*************************************************
  70. * Name: unpack_sk
  71. *
  72. * Description: De-serialize the secret key; inverse of pack_sk
  73. *
  74. * Arguments: - polyvec *sk: pointer to output vector of polynomials (secret key)
  75. * - const uint8_t *packedsk: pointer to input serialized secret key
  76. **************************************************/
  77. static void unpack_sk(polyvec *sk, const uint8_t packedsk[S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES])
  78. {
  79. polyvec_frombytes_avx2(sk, packedsk);
  80. }
  81. /*************************************************
  82. * Name: pack_ciphertext
  83. *
  84. * Description: Serialize the ciphertext as concatenation of the
  85. * compressed and serialized vector of polynomials b
  86. * and the compressed and serialized polynomial v.
  87. * The polynomial coefficients in b and v are assumed to
  88. * lie in the invertal [0,q], i.e. b and v must be reduced
  89. * by polyvec_reduce_avx2() and poly_reduce_avx2(), respectively.
  90. *
  91. * Arguments: uint8_t *r: pointer to the output serialized ciphertext
  92. * poly *pk: pointer to the input vector of polynomials b
  93. * poly *v: pointer to the input polynomial v
  94. **************************************************/
  95. static void pack_ciphertext(uint8_t r[S2N_KYBER_512_R3_INDCPA_BYTES], polyvec *b, poly *v)
  96. {
  97. polyvec_compress_avx2(r, b);
  98. poly_compress_avx2(r+S2N_KYBER_512_R3_POLYVECCOMPRESSEDBYTES, v);
  99. }
  100. /*************************************************
  101. * Name: unpack_ciphertext
  102. *
  103. * Description: De-serialize and decompress ciphertext from a byte array;
  104. * approximate inverse of pack_ciphertext
  105. *
  106. * Arguments: - polyvec *b: pointer to the output vector of polynomials b
  107. * - poly *v: pointer to the output polynomial v
  108. * - const uint8_t *c: pointer to the input serialized ciphertext
  109. **************************************************/
  110. static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[S2N_KYBER_512_R3_INDCPA_BYTES])
  111. {
  112. polyvec_decompress_avx2(b, c);
  113. poly_decompress_avx2(v, c+S2N_KYBER_512_R3_POLYVECCOMPRESSEDBYTES);
  114. }
  115. /*************************************************
  116. * Name: rej_uniform
  117. *
  118. * Description: Run rejection sampling on uniform random bytes to generate
  119. * uniform random integers mod q
  120. *
  121. * Arguments: - int16_t *r: pointer to output array
  122. * - unsigned int len: requested number of 16-bit integers (uniform mod q)
  123. * - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes)
  124. * - unsigned int buflen: length of input buffer in bytes
  125. *
  126. * Returns number of sampled 16-bit integers (at most len)
  127. **************************************************/
  128. static unsigned int rej_uniform(int16_t *r,
  129. unsigned int len,
  130. const uint8_t *buf,
  131. unsigned int buflen)
  132. {
  133. unsigned int ctr, pos;
  134. uint16_t val0, val1;
  135. ctr = pos = 0;
  136. while(ctr < len && pos <= buflen - 3) { // buflen is always at least 3
  137. val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF;
  138. val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF;
  139. pos += 3;
  140. if(val0 < S2N_KYBER_512_R3_Q)
  141. r[ctr++] = val0;
  142. if(ctr < len && val1 < S2N_KYBER_512_R3_Q)
  143. r[ctr++] = val1;
  144. }
  145. return ctr;
  146. }
  147. #define gen_a(A,B) gen_matrix_avx2(A,B,0)
  148. #define gen_at(A,B) gen_matrix_avx2(A,B,1)
  149. /*************************************************
  150. * Name: gen_matrix_avx2
  151. *
  152. * Description: Deterministically generate matrix A (or the transpose of A)
  153. * from a seed. Entries of the matrix are polynomials that look
  154. * uniformly random. Performs rejection sampling on output of
  155. * a XOF
  156. *
  157. * Arguments: - polyvec *a: pointer to ouptput matrix A
  158. * - const uint8_t *seed: pointer to input seed
  159. * - int transposed: boolean deciding whether A or A^T is generated
  160. **************************************************/
  161. void gen_matrix_avx2(polyvec *a, const uint8_t seed[32], int transposed)
  162. {
  163. unsigned int ctr0, ctr1, ctr2, ctr3;
  164. ALIGNED_UINT8(S2N_KYBER_512_R3_REJ_UNIFORM_AVX_NBLOCKS*S2N_KYBER_512_R3_SHAKE128_RATE) buf[4];
  165. __m256i f;
  166. keccakx4_state state;
  167. // correcting cast-align and cast-qual errors
  168. // old version: f = _mm256_loadu_si256((__m256i *)seed);
  169. f = _mm256_loadu_si256((const void *)seed);
  170. _mm256_store_si256(buf[0].vec, f);
  171. _mm256_store_si256(buf[1].vec, f);
  172. _mm256_store_si256(buf[2].vec, f);
  173. _mm256_store_si256(buf[3].vec, f);
  174. if(transposed) {
  175. buf[0].coeffs[32] = 0;
  176. buf[0].coeffs[33] = 0;
  177. buf[1].coeffs[32] = 0;
  178. buf[1].coeffs[33] = 1;
  179. buf[2].coeffs[32] = 1;
  180. buf[2].coeffs[33] = 0;
  181. buf[3].coeffs[32] = 1;
  182. buf[3].coeffs[33] = 1;
  183. }
  184. else {
  185. buf[0].coeffs[32] = 0;
  186. buf[0].coeffs[33] = 0;
  187. buf[1].coeffs[32] = 1;
  188. buf[1].coeffs[33] = 0;
  189. buf[2].coeffs[32] = 0;
  190. buf[2].coeffs[33] = 1;
  191. buf[3].coeffs[32] = 1;
  192. buf[3].coeffs[33] = 1;
  193. }
  194. shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34);
  195. shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, S2N_KYBER_512_R3_REJ_UNIFORM_AVX_NBLOCKS, &state);
  196. ctr0 = rej_uniform_avx2(a[0].vec[0].coeffs, buf[0].coeffs);
  197. ctr1 = rej_uniform_avx2(a[0].vec[1].coeffs, buf[1].coeffs);
  198. ctr2 = rej_uniform_avx2(a[1].vec[0].coeffs, buf[2].coeffs);
  199. ctr3 = rej_uniform_avx2(a[1].vec[1].coeffs, buf[3].coeffs);
  200. while(ctr0 < S2N_KYBER_512_R3_N || ctr1 < S2N_KYBER_512_R3_N || ctr2 < S2N_KYBER_512_R3_N || ctr3 < S2N_KYBER_512_R3_N) {
  201. shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state);
  202. ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, S2N_KYBER_512_R3_N - ctr0, buf[0].coeffs, S2N_KYBER_512_R3_SHAKE128_RATE);
  203. ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, S2N_KYBER_512_R3_N - ctr1, buf[1].coeffs, S2N_KYBER_512_R3_SHAKE128_RATE);
  204. ctr2 += rej_uniform(a[1].vec[0].coeffs + ctr2, S2N_KYBER_512_R3_N - ctr2, buf[2].coeffs, S2N_KYBER_512_R3_SHAKE128_RATE);
  205. ctr3 += rej_uniform(a[1].vec[1].coeffs + ctr3, S2N_KYBER_512_R3_N - ctr3, buf[3].coeffs, S2N_KYBER_512_R3_SHAKE128_RATE);
  206. }
  207. poly_nttunpack_avx2(&a[0].vec[0]);
  208. poly_nttunpack_avx2(&a[0].vec[1]);
  209. poly_nttunpack_avx2(&a[1].vec[0]);
  210. poly_nttunpack_avx2(&a[1].vec[1]);
  211. }
  212. /*************************************************
  213. * Name: indcpa_keypair_avx2
  214. *
  215. * Description: Generates public and private key for the CPA-secure
  216. * public-key encryption scheme underlying Kyber
  217. *
  218. * Arguments: - uint8_t *pk: pointer to output public key
  219. * (of length S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES bytes)
  220. * - uint8_t *sk: pointer to output private key
  221. (of length S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES bytes)
  222. **************************************************/
  223. int indcpa_keypair_avx2(uint8_t pk[S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES],
  224. uint8_t sk[S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES])
  225. {
  226. unsigned int i;
  227. uint8_t buf[2*S2N_KYBER_512_R3_SYMBYTES];
  228. const uint8_t *publicseed = buf;
  229. const uint8_t *noiseseed = buf + S2N_KYBER_512_R3_SYMBYTES;
  230. polyvec a[S2N_KYBER_512_R3_K], e, pkpv, skpv;
  231. POSIX_GUARD_RESULT(s2n_get_random_bytes(buf, S2N_KYBER_512_R3_SYMBYTES));
  232. sha3_512(buf, buf, S2N_KYBER_512_R3_SYMBYTES);
  233. gen_a(a, publicseed);
  234. poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, e.vec+0, e.vec+1, noiseseed, 0, 1, 2, 3);
  235. polyvec_ntt_avx2(&skpv);
  236. polyvec_reduce_avx2(&skpv);
  237. polyvec_ntt_avx2(&e);
  238. // matrix-vector multiplication
  239. for(i=0;i<S2N_KYBER_512_R3_K;i++) {
  240. polyvec_basemul_acc_montgomery_avx2(&pkpv.vec[i], &a[i], &skpv);
  241. poly_tomont_avx2(&pkpv.vec[i]);
  242. }
  243. polyvec_add_avx2(&pkpv, &pkpv, &e);
  244. polyvec_reduce_avx2(&pkpv);
  245. pack_sk(sk, &skpv);
  246. pack_pk(pk, &pkpv, publicseed);
  247. return 0;
  248. }
  249. /*************************************************
  250. * Name: indcpa_enc_avx2
  251. *
  252. * Description: Encryption function of the CPA-secure
  253. * public-key encryption scheme underlying Kyber.
  254. *
  255. * Arguments: - uint8_t *c: pointer to output ciphertext
  256. * (of length S2N_KYBER_512_R3_INDCPA_BYTES bytes)
  257. * - const uint8_t *m: pointer to input message
  258. * (of length S2N_KYBER_512_R3_INDCPA_MSGBYTES bytes)
  259. * - const uint8_t *pk: pointer to input public key
  260. * (of length S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES)
  261. * - const uint8_t *coins: pointer to input random coins used as seed
  262. * (of length S2N_KYBER_512_R3_SYMBYTES) to deterministically
  263. * generate all randomness
  264. **************************************************/
  265. void indcpa_enc_avx2(uint8_t c[S2N_KYBER_512_R3_INDCPA_BYTES],
  266. const uint8_t m[S2N_KYBER_512_R3_INDCPA_MSGBYTES],
  267. const uint8_t pk[S2N_KYBER_512_R3_INDCPA_PUBLICKEYBYTES],
  268. const uint8_t coins[S2N_KYBER_512_R3_SYMBYTES])
  269. {
  270. unsigned int i;
  271. uint8_t seed[S2N_KYBER_512_R3_SYMBYTES];
  272. polyvec sp, pkpv, ep, at[S2N_KYBER_512_R3_K], b;
  273. poly v, k, epp;
  274. unpack_pk(&pkpv, seed, pk);
  275. poly_frommsg_avx2(&k, m);
  276. gen_at(at, seed);
  277. poly_getnoise_eta1122_4x(sp.vec+0, sp.vec+1, ep.vec+0, ep.vec+1, coins, 0, 1, 2, 3);
  278. poly_getnoise_eta2_avx2(&epp, coins, 4);
  279. polyvec_ntt_avx2(&sp);
  280. // matrix-vector multiplication
  281. for(i=0;i<S2N_KYBER_512_R3_K;i++)
  282. polyvec_basemul_acc_montgomery_avx2(&b.vec[i], &at[i], &sp);
  283. polyvec_basemul_acc_montgomery_avx2(&v, &pkpv, &sp);
  284. polyvec_invntt_tomont_avx2(&b);
  285. poly_invntt_tomont_avx2(&v);
  286. polyvec_add_avx2(&b, &b, &ep);
  287. poly_add_avx2(&v, &v, &epp);
  288. poly_add_avx2(&v, &v, &k);
  289. polyvec_reduce_avx2(&b);
  290. poly_reduce_avx2(&v);
  291. pack_ciphertext(c, &b, &v);
  292. }
  293. /*************************************************
  294. * Name: indcpa_dec_avx2
  295. *
  296. * Description: Decryption function of the CPA-secure
  297. * public-key encryption scheme underlying Kyber.
  298. *
  299. * Arguments: - uint8_t *m: pointer to output decrypted message
  300. * (of length S2N_KYBER_512_R3_INDCPA_MSGBYTES)
  301. * - const uint8_t *c: pointer to input ciphertext
  302. * (of length S2N_KYBER_512_R3_INDCPA_BYTES)
  303. * - const uint8_t *sk: pointer to input secret key
  304. * (of length S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES)
  305. **************************************************/
  306. void indcpa_dec_avx2(uint8_t m[S2N_KYBER_512_R3_INDCPA_MSGBYTES],
  307. const uint8_t c[S2N_KYBER_512_R3_INDCPA_BYTES],
  308. const uint8_t sk[S2N_KYBER_512_R3_INDCPA_SECRETKEYBYTES])
  309. {
  310. polyvec b, skpv;
  311. poly v, mp;
  312. unpack_ciphertext(&b, &v, c);
  313. unpack_sk(&skpv, sk);
  314. polyvec_ntt_avx2(&b);
  315. polyvec_basemul_acc_montgomery_avx2(&mp, &skpv, &b);
  316. poly_invntt_tomont_avx2(&mp);
  317. poly_sub_avx2(&mp, &v, &mp);
  318. poly_reduce_avx2(&mp);
  319. poly_tomsg_avx2(m, &mp);
  320. }
  321. #endif