encoding_avx2.c 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384
  1. /**
  2. * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
  3. * SPDX-License-Identifier: Apache-2.0.
  4. */
  5. #include <emmintrin.h>
  6. #include <immintrin.h>
  7. #include <stdio.h>
  8. #include <stdlib.h>
  9. #include <string.h>
  10. #include <aws/common/common.h>
  11. /***** Decode logic *****/
  12. /*
  13. * Decodes ranges of bytes in place
  14. * For each byte of 'in' that is between lo and hi (inclusive), adds offset and _adds_ it to the corresponding offset in
  15. * out.
  16. */
  17. static inline __m256i translate_range(__m256i in, uint8_t lo, uint8_t hi, uint8_t offset) {
  18. __m256i lovec = _mm256_set1_epi8(lo);
  19. __m256i hivec = _mm256_set1_epi8((char)(hi - lo));
  20. __m256i offsetvec = _mm256_set1_epi8(offset);
  21. __m256i tmp = _mm256_sub_epi8(in, lovec);
  22. /*
  23. * we'll use the unsigned min operator to do our comparison. Note that
  24. * there's no unsigned compare as a comparison intrinsic.
  25. */
  26. __m256i mask = _mm256_min_epu8(tmp, hivec);
  27. /* if mask = tmp, then keep that byte */
  28. mask = _mm256_cmpeq_epi8(mask, tmp);
  29. tmp = _mm256_add_epi8(tmp, offsetvec);
  30. tmp = _mm256_and_si256(tmp, mask);
  31. return tmp;
  32. }
  33. /*
  34. * For each 8-bit element in in, if the element equals match, add to the corresponding element in out the value decode.
  35. */
  36. static inline __m256i translate_exact(__m256i in, uint8_t match, uint8_t decode) {
  37. __m256i mask = _mm256_cmpeq_epi8(in, _mm256_set1_epi8(match));
  38. return _mm256_and_si256(mask, _mm256_set1_epi8(decode));
  39. }
  40. /*
  41. * Input: a pointer to a 256-bit vector of base64 characters
  42. * The pointed-to-vector is replaced by a 256-bit vector of 6-bit decoded parts;
  43. * on decode failure, returns false, else returns true on success.
  44. */
  45. static inline bool decode_vec(__m256i *in) {
  46. __m256i tmp1, tmp2, tmp3;
  47. /*
  48. * Base64 decoding table, see RFC4648
  49. *
  50. * Note that we use multiple vector registers to try to allow the CPU to
  51. * paralellize the merging ORs
  52. */
  53. tmp1 = translate_range(*in, 'A', 'Z', 0 + 1);
  54. tmp2 = translate_range(*in, 'a', 'z', 26 + 1);
  55. tmp3 = translate_range(*in, '0', '9', 52 + 1);
  56. tmp1 = _mm256_or_si256(tmp1, translate_exact(*in, '+', 62 + 1));
  57. tmp2 = _mm256_or_si256(tmp2, translate_exact(*in, '/', 63 + 1));
  58. tmp3 = _mm256_or_si256(tmp3, _mm256_or_si256(tmp1, tmp2));
  59. /*
  60. * We use 0 to mark decode failures, so everything is decoded to one higher
  61. * than normal. We'll shift this down now.
  62. */
  63. *in = _mm256_sub_epi8(tmp3, _mm256_set1_epi8(1));
  64. /* If any byte is now zero, we had a decode failure */
  65. __m256i mask = _mm256_cmpeq_epi8(tmp3, _mm256_set1_epi8(0));
  66. return _mm256_testz_si256(mask, mask);
  67. }
  68. AWS_ALIGNED_TYPEDEF(uint8_t, aligned256[32], 32);
  69. /*
  70. * Input: a 256-bit vector, interpreted as 32 * 6-bit values
  71. * Output: a 256-bit vector, the lower 24 bytes of which contain the packed version of the input
  72. */
  73. static inline __m256i pack_vec(__m256i in) {
  74. /*
  75. * Our basic strategy is to split the input vector into three vectors, for each 6-bit component
  76. * of each 24-bit group, shift the groups into place, then OR the vectors together. Conveniently,
  77. * we can do this on a (32 bit) dword-by-dword basis.
  78. *
  79. * It's important to note that we're interpreting the vector as being little-endian. That is,
  80. * on entry, we have dwords that look like this:
  81. *
  82. * MSB LSB
  83. * 00DD DDDD 00CC CCCC 00BB BBBB 00AA AAAA
  84. *
  85. * And we want to translate to:
  86. *
  87. * MSB LSB
  88. * 0000 0000 AAAA AABB BBBB CCCC CCDD DDDD
  89. *
  90. * After which point we can pack these dwords together to produce our final output.
  91. */
  92. __m256i maskA = _mm256_set1_epi32(0xFF); // low bits
  93. __m256i maskB = _mm256_set1_epi32(0xFF00);
  94. __m256i maskC = _mm256_set1_epi32(0xFF0000);
  95. __m256i maskD = _mm256_set1_epi32((int)0xFF000000);
  96. __m256i bitsA = _mm256_slli_epi32(_mm256_and_si256(in, maskA), 18);
  97. __m256i bitsB = _mm256_slli_epi32(_mm256_and_si256(in, maskB), 4);
  98. __m256i bitsC = _mm256_srli_epi32(_mm256_and_si256(in, maskC), 10);
  99. __m256i bitsD = _mm256_srli_epi32(_mm256_and_si256(in, maskD), 24);
  100. __m256i dwords = _mm256_or_si256(_mm256_or_si256(bitsA, bitsB), _mm256_or_si256(bitsC, bitsD));
  101. /*
  102. * Now we have a series of dwords with empty MSBs.
  103. * We need to pack them together (and shift down) with a shuffle operation.
  104. * Unfortunately the shuffle operation operates independently within each 128-bit lane,
  105. * so we'll need to do this in two steps: First we compact dwords within each lane, then
  106. * we do a dword shuffle to compact the two lanes together.
  107. * 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00 <- byte index (little endian)
  108. * -- 09 0a 0b -- 06 07 08 -- 03 04 05 -- 00 01 02 <- data index
  109. *
  110. * We also reverse the order of 3-byte fragments within each lane; we've constructed
  111. * those fragments in little endian but the order of fragments within the overall
  112. * vector is in memory order (big endian)
  113. */
  114. const aligned256 shufvec_buf = {
  115. /* clang-format off */
  116. /* MSB */
  117. 0xFF, 0xFF, 0xFF, 0xFF, /* Zero out the top 4 bytes of the lane */
  118. 2, 1, 0,
  119. 6, 5, 4,
  120. 10, 9, 8,
  121. 14, 13, 12,
  122. 0xFF, 0xFF, 0xFF, 0xFF, /* Zero out the top 4 bytes of the lane */
  123. 2, 1, 0,
  124. 6, 5, 4,
  125. 10, 9, 8,
  126. 14, 13, 12
  127. /* LSB */
  128. /* clang-format on */
  129. };
  130. __m256i shufvec = _mm256_load_si256((__m256i const *)&shufvec_buf);
  131. dwords = _mm256_shuffle_epi8(dwords, shufvec);
  132. /*
  133. * Now shuffle the 32-bit words:
  134. * A B C 0 D E F 0 -> 0 0 A B C D E F
  135. */
  136. __m256i shuf32 = _mm256_set_epi32(0, 0, 7, 6, 5, 3, 2, 1);
  137. dwords = _mm256_permutevar8x32_epi32(dwords, shuf32);
  138. return dwords;
  139. }
  140. static inline bool decode(const unsigned char *in, unsigned char *out) {
  141. __m256i vec = _mm256_loadu_si256((__m256i const *)in);
  142. if (!decode_vec(&vec)) {
  143. return false;
  144. }
  145. vec = pack_vec(vec);
  146. /*
  147. * We'll do overlapping writes to get both the low 128 bits and the high 64-bits written.
  148. * Input (memory order): 0 1 2 3 4 5 - - (dwords)
  149. * Input (little endian) - - 5 4 3 2 1 0
  150. * Output in memory:
  151. * [0 1 2 3] [4 5]
  152. */
  153. __m128i lo = _mm256_extracti128_si256(vec, 0);
  154. /*
  155. * Unfortunately some compilers don't support _mm256_extract_epi64,
  156. * so we'll just copy right out of the vector as a fallback
  157. */
  158. #ifdef HAVE_MM256_EXTRACT_EPI64
  159. uint64_t hi = _mm256_extract_epi64(vec, 2);
  160. const uint64_t *p_hi = &hi;
  161. #else
  162. const uint64_t *p_hi = (uint64_t *)&vec + 2;
  163. #endif
  164. _mm_storeu_si128((__m128i *)out, lo);
  165. memcpy(out + 16, p_hi, sizeof(*p_hi));
  166. return true;
  167. }
  168. size_t aws_common_private_base64_decode_sse41(const unsigned char *in, unsigned char *out, size_t len) {
  169. if (len % 4) {
  170. return (size_t)-1;
  171. }
  172. size_t outlen = 0;
  173. while (len > 32) {
  174. if (!decode(in, out)) {
  175. return (size_t)-1;
  176. }
  177. len -= 32;
  178. in += 32;
  179. out += 24;
  180. outlen += 24;
  181. }
  182. if (len > 0) {
  183. unsigned char tmp_in[32];
  184. unsigned char tmp_out[24];
  185. memset(tmp_out, 0xEE, sizeof(tmp_out));
  186. /* We need to ensure the vector contains valid b64 characters */
  187. memset(tmp_in, 'A', sizeof(tmp_in));
  188. memcpy(tmp_in, in, len);
  189. size_t final_out = (3 * len) / 4;
  190. /* Check for end-of-string padding (up to 2 characters) */
  191. for (int i = 0; i < 2; i++) {
  192. if (tmp_in[len - 1] == '=') {
  193. tmp_in[len - 1] = 'A'; /* make sure the inner loop doesn't bail out */
  194. len--;
  195. final_out--;
  196. }
  197. }
  198. if (!decode(tmp_in, tmp_out)) {
  199. return (size_t)-1;
  200. }
  201. /* Check that there are no trailing ones bits */
  202. for (size_t i = final_out; i < sizeof(tmp_out); i++) {
  203. if (tmp_out[i]) {
  204. return (size_t)-1;
  205. }
  206. }
  207. memcpy(out, tmp_out, final_out);
  208. outlen += final_out;
  209. }
  210. return outlen;
  211. }
  212. /***** Encode logic *****/
  213. static inline __m256i encode_chars(__m256i in) {
  214. __m256i tmp1, tmp2, tmp3;
  215. /*
  216. * Base64 encoding table, see RFC4648
  217. *
  218. * We again use fan-in for the ORs here.
  219. */
  220. tmp1 = translate_range(in, 0, 25, 'A');
  221. tmp2 = translate_range(in, 26, 26 + 25, 'a');
  222. tmp3 = translate_range(in, 52, 61, '0');
  223. tmp1 = _mm256_or_si256(tmp1, translate_exact(in, 62, '+'));
  224. tmp2 = _mm256_or_si256(tmp2, translate_exact(in, 63, '/'));
  225. return _mm256_or_si256(tmp3, _mm256_or_si256(tmp1, tmp2));
  226. }
  227. /*
  228. * Input: A 256-bit vector, interpreted as 24 bytes (LSB) plus 8 bytes of high-byte padding
  229. * Output: A 256-bit vector of base64 characters
  230. */
  231. static inline __m256i encode_stride(__m256i vec) {
  232. /*
  233. * First, since byte-shuffle operations operate within 128-bit subvectors, swap around the dwords
  234. * to balance the amount of actual data between 128-bit subvectors.
  235. * After this we want the LE representation to look like: -- XX XX XX -- XX XX XX
  236. */
  237. __m256i shuf32 = _mm256_set_epi32(7, 5, 4, 3, 6, 2, 1, 0);
  238. vec = _mm256_permutevar8x32_epi32(vec, shuf32);
  239. /*
  240. * Next, within each group of 3 bytes, we need to byteswap into little endian form so our bitshifts
  241. * will work properly. We also shuffle around so that each dword has one 3-byte group, plus one byte
  242. * (MSB) of zero-padding.
  243. * Because this is a byte-shuffle, indexes are within each 128-bit subvector.
  244. *
  245. * -- -- -- -- 11 10 09 08 07 06 05 04 03 02 01 00
  246. */
  247. const aligned256 shufvec_buf = {
  248. /* clang-format off */
  249. /* MSB */
  250. 2, 1, 0, 0xFF,
  251. 5, 4, 3, 0xFF,
  252. 8, 7, 6, 0xFF,
  253. 11, 10, 9, 0xFF,
  254. 2, 1, 0, 0xFF,
  255. 5, 4, 3, 0xFF,
  256. 8, 7, 6, 0xFF,
  257. 11, 10, 9, 0xFF
  258. /* LSB */
  259. /* clang-format on */
  260. };
  261. vec = _mm256_shuffle_epi8(vec, _mm256_load_si256((__m256i const *)&shufvec_buf));
  262. /*
  263. * Now shift and mask to split out 6-bit groups.
  264. * We'll also do a second byteswap to get back into big-endian
  265. */
  266. __m256i mask0 = _mm256_set1_epi32(0x3F);
  267. __m256i mask1 = _mm256_set1_epi32(0x3F << 6);
  268. __m256i mask2 = _mm256_set1_epi32(0x3F << 12);
  269. __m256i mask3 = _mm256_set1_epi32(0x3F << 18);
  270. __m256i digit0 = _mm256_and_si256(mask0, vec);
  271. __m256i digit1 = _mm256_and_si256(mask1, vec);
  272. __m256i digit2 = _mm256_and_si256(mask2, vec);
  273. __m256i digit3 = _mm256_and_si256(mask3, vec);
  274. /*
  275. * Because we want to byteswap, the low-order digit0 goes into the
  276. * high-order byte
  277. */
  278. digit0 = _mm256_slli_epi32(digit0, 24);
  279. digit1 = _mm256_slli_epi32(digit1, 10);
  280. digit2 = _mm256_srli_epi32(digit2, 4);
  281. digit3 = _mm256_srli_epi32(digit3, 18);
  282. vec = _mm256_or_si256(_mm256_or_si256(digit0, digit1), _mm256_or_si256(digit2, digit3));
  283. /* Finally translate to the base64 character set */
  284. return encode_chars(vec);
  285. }
  286. void aws_common_private_base64_encode_sse41(const uint8_t *input, uint8_t *output, size_t inlen) {
  287. __m256i instride, outstride;
  288. while (inlen >= 32) {
  289. /*
  290. * Where possible, we'll load a full vector at a time and ignore the over-read.
  291. * However, if we have < 32 bytes left, this would result in a potential read
  292. * of unreadable pages, so we use bounce buffers below.
  293. */
  294. instride = _mm256_loadu_si256((__m256i const *)input);
  295. outstride = encode_stride(instride);
  296. _mm256_storeu_si256((__m256i *)output, outstride);
  297. input += 24;
  298. output += 32;
  299. inlen -= 24;
  300. }
  301. while (inlen) {
  302. /*
  303. * We need to go through a bounce buffer for anything remaining, as we
  304. * don't want to over-read or over-write the ends of the buffers.
  305. */
  306. size_t stridelen = inlen > 24 ? 24 : inlen;
  307. size_t outlen = ((stridelen + 2) / 3) * 4;
  308. memset(&instride, 0, sizeof(instride));
  309. memcpy(&instride, input, stridelen);
  310. outstride = encode_stride(instride);
  311. memcpy(output, &outstride, outlen);
  312. if (inlen < 24) {
  313. if (inlen % 3 >= 1) {
  314. /* AA== or AAA= */
  315. output[outlen - 1] = '=';
  316. }
  317. if (inlen % 3 == 1) {
  318. /* AA== */
  319. output[outlen - 2] = '=';
  320. }
  321. return;
  322. }
  323. input += stridelen;
  324. output += outlen;
  325. inlen -= stridelen;
  326. }
  327. }