kyber512r3_fips202x4_avx2.c 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. #include <stddef.h>
  2. #include <stdint.h>
  3. #include <string.h>
  4. #include "kyber512r3_fips202.h"
  5. #include "kyber512r3_fips202x4_avx2.h"
  6. #if defined(S2N_KYBER512R3_AVX2_BMI2)
  7. #include <immintrin.h>
  8. #define KeccakF1600_StatePermute4x S2N_KYBER_512_R3_NAMESPACE(KeccakP1600times4_PermuteAll_24rounds)
  9. extern void KeccakF1600_StatePermute4x(__m256i *s);
  10. /* Implementation is used from Crystal Kyber Repository
  11. * See for more details: https://github.com/XKCP/XKCP */
  12. static void keccakx4_absorb_once(__m256i s[25],
  13. unsigned int r,
  14. const uint8_t *in0,
  15. const uint8_t *in1,
  16. const uint8_t *in2,
  17. const uint8_t *in3,
  18. size_t inlen,
  19. uint8_t p)
  20. {
  21. size_t i;
  22. uint64_t pos = 0;
  23. __m256i t, idx;
  24. for(i = 0; i < 25; ++i)
  25. s[i] = _mm256_setzero_si256();
  26. idx = _mm256_set_epi64x((long long)in3, (long long)in2, (long long)in1, (long long)in0);
  27. while(inlen >= r) {
  28. for(i = 0; i < r/8; ++i) {
  29. t = _mm256_i64gather_epi64((long long *)pos, idx, 1);
  30. s[i] = _mm256_xor_si256(s[i], t);
  31. pos += 8;
  32. }
  33. inlen -= r;
  34. KeccakF1600_StatePermute4x(s);
  35. }
  36. for(i = 0; i < inlen/8; ++i) {
  37. t = _mm256_i64gather_epi64((long long *)pos, idx, 1);
  38. s[i] = _mm256_xor_si256(s[i], t);
  39. pos += 8;
  40. }
  41. inlen -= 8*i;
  42. if(inlen) {
  43. t = _mm256_i64gather_epi64((long long *)pos, idx, 1);
  44. idx = _mm256_set1_epi64x((1ULL << (8*inlen)) - 1);
  45. t = _mm256_and_si256(t, idx);
  46. s[i] = _mm256_xor_si256(s[i], t);
  47. }
  48. t = _mm256_set1_epi64x((uint64_t)p << 8*inlen);
  49. s[i] = _mm256_xor_si256(s[i], t);
  50. t = _mm256_set1_epi64x(1ULL << 63);
  51. s[r/8 - 1] = _mm256_xor_si256(s[r/8 - 1], t);
  52. }
  53. static void keccakx4_squeezeblocks(uint8_t *out0,
  54. uint8_t *out1,
  55. uint8_t *out2,
  56. uint8_t *out3,
  57. size_t nblocks,
  58. unsigned int r,
  59. __m256i s[25])
  60. {
  61. unsigned int i;
  62. __m128d t;
  63. while(nblocks > 0) {
  64. KeccakF1600_StatePermute4x(s);
  65. for(i=0; i < r/8; ++i) {
  66. t = _mm_castsi128_pd(_mm256_castsi256_si128(s[i]));
  67. // correcting cast-align errors
  68. // old version: _mm_storel_pd((__attribute__((__may_alias__)) double *)&out0[8*i], t);
  69. _mm_storel_pd((__attribute__((__may_alias__)) void *)&out0[8*i], t);
  70. // old version: _mm_storeh_pd((__attribute__((__may_alias__)) double *)&out1[8*i], t);
  71. _mm_storeh_pd((__attribute__((__may_alias__)) void *)&out1[8*i], t);
  72. t = _mm_castsi128_pd(_mm256_extracti128_si256(s[i],1));
  73. // old version: _mm_storel_pd((__attribute__((__may_alias__)) double *)&out2[8*i], t);
  74. _mm_storel_pd((__attribute__((__may_alias__)) void *)&out2[8*i], t);
  75. // old version: _mm_storeh_pd((__attribute__((__may_alias__)) double *)&out3[8*i], t);
  76. _mm_storeh_pd((__attribute__((__may_alias__)) void *)&out3[8*i], t);
  77. }
  78. out0 += r;
  79. out1 += r;
  80. out2 += r;
  81. out3 += r;
  82. --nblocks;
  83. }
  84. }
  85. void shake128x4_absorb_once(keccakx4_state *state,
  86. const uint8_t *in0,
  87. const uint8_t *in1,
  88. const uint8_t *in2,
  89. const uint8_t *in3,
  90. size_t inlen)
  91. {
  92. keccakx4_absorb_once(state->s, S2N_KYBER_512_R3_SHAKE128_RATE, in0, in1, in2, in3, inlen, 0x1F);
  93. }
  94. void shake128x4_squeezeblocks(uint8_t *out0,
  95. uint8_t *out1,
  96. uint8_t *out2,
  97. uint8_t *out3,
  98. size_t nblocks,
  99. keccakx4_state *state)
  100. {
  101. keccakx4_squeezeblocks(out0, out1, out2, out3, nblocks, S2N_KYBER_512_R3_SHAKE128_RATE, state->s);
  102. }
  103. void shake256x4_absorb_once(keccakx4_state *state,
  104. const uint8_t *in0,
  105. const uint8_t *in1,
  106. const uint8_t *in2,
  107. const uint8_t *in3,
  108. size_t inlen)
  109. {
  110. keccakx4_absorb_once(state->s, S2N_KYBER_512_R3_SHAKE256_RATE, in0, in1, in2, in3, inlen, 0x1F);
  111. }
  112. void shake256x4_squeezeblocks(uint8_t *out0,
  113. uint8_t *out1,
  114. uint8_t *out2,
  115. uint8_t *out3,
  116. size_t nblocks,
  117. keccakx4_state *state)
  118. {
  119. keccakx4_squeezeblocks(out0, out1, out2, out3, nblocks, S2N_KYBER_512_R3_SHAKE256_RATE, state->s);
  120. }
  121. void shake128x4(uint8_t *out0,
  122. uint8_t *out1,
  123. uint8_t *out2,
  124. uint8_t *out3,
  125. size_t outlen,
  126. const uint8_t *in0,
  127. const uint8_t *in1,
  128. const uint8_t *in2,
  129. const uint8_t *in3,
  130. size_t inlen)
  131. {
  132. unsigned int i;
  133. size_t nblocks = outlen/S2N_KYBER_512_R3_SHAKE128_RATE;
  134. uint8_t t[4][S2N_KYBER_512_R3_SHAKE128_RATE];
  135. keccakx4_state state;
  136. shake128x4_absorb_once(&state, in0, in1, in2, in3, inlen);
  137. shake128x4_squeezeblocks(out0, out1, out2, out3, nblocks, &state);
  138. out0 += nblocks*S2N_KYBER_512_R3_SHAKE128_RATE;
  139. out1 += nblocks*S2N_KYBER_512_R3_SHAKE128_RATE;
  140. out2 += nblocks*S2N_KYBER_512_R3_SHAKE128_RATE;
  141. out3 += nblocks*S2N_KYBER_512_R3_SHAKE128_RATE;
  142. outlen -= nblocks*S2N_KYBER_512_R3_SHAKE128_RATE;
  143. if(outlen) {
  144. shake128x4_squeezeblocks(t[0], t[1], t[2], t[3], 1, &state);
  145. for(i = 0; i < outlen; ++i) {
  146. out0[i] = t[0][i];
  147. out1[i] = t[1][i];
  148. out2[i] = t[2][i];
  149. out3[i] = t[3][i];
  150. }
  151. }
  152. }
  153. void shake256x4(uint8_t *out0,
  154. uint8_t *out1,
  155. uint8_t *out2,
  156. uint8_t *out3,
  157. size_t outlen,
  158. const uint8_t *in0,
  159. const uint8_t *in1,
  160. const uint8_t *in2,
  161. const uint8_t *in3,
  162. size_t inlen)
  163. {
  164. unsigned int i;
  165. size_t nblocks = outlen/S2N_KYBER_512_R3_SHAKE256_RATE;
  166. uint8_t t[4][S2N_KYBER_512_R3_SHAKE256_RATE];
  167. keccakx4_state state;
  168. shake256x4_absorb_once(&state, in0, in1, in2, in3, inlen);
  169. shake256x4_squeezeblocks(out0, out1, out2, out3, nblocks, &state);
  170. out0 += nblocks*S2N_KYBER_512_R3_SHAKE256_RATE;
  171. out1 += nblocks*S2N_KYBER_512_R3_SHAKE256_RATE;
  172. out2 += nblocks*S2N_KYBER_512_R3_SHAKE256_RATE;
  173. out3 += nblocks*S2N_KYBER_512_R3_SHAKE256_RATE;
  174. outlen -= nblocks*S2N_KYBER_512_R3_SHAKE256_RATE;
  175. if(outlen) {
  176. shake256x4_squeezeblocks(t[0], t[1], t[2], t[3], 1, &state);
  177. for(i = 0; i < outlen; ++i) {
  178. out0[i] = t[0][i];
  179. out1[i] = t[1][i];
  180. out2[i] = t[2][i];
  181. out3[i] = t[3][i];
  182. }
  183. }
  184. }
  185. #endif