lrc.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. #pragma once
  2. #include "helpers.h"
  3. #include <library/cpp/sse/sse.h>
  4. #include <library/cpp/yt/assert/assert.h>
  5. #include <util/generic/array_ref.h>
  6. #include <algorithm>
  7. #include <optional>
  8. namespace NErasure {
  9. template <class TCodecTraits, class TBlobType = typename TCodecTraits::TBlobType>
  10. static inline TBlobType Xor(const std::vector<TBlobType>& refs) {
  11. using TBufferType = typename TCodecTraits::TBufferType;
  12. size_t size = refs.front().Size();
  13. TBufferType result = TCodecTraits::AllocateBuffer(size); // this also fills the buffer with zeros
  14. for (const TBlobType& ref : refs) {
  15. const char* data = reinterpret_cast<const char*>(ref.Begin());
  16. size_t pos = 0;
  17. #ifdef ARCADIA_SSE
  18. for (; pos + sizeof(__m128i) <= size; pos += sizeof(__m128i)) {
  19. __m128i* dst = reinterpret_cast<__m128i*>(result.Begin() + pos);
  20. const __m128i* src = reinterpret_cast<const __m128i*>(data + pos);
  21. _mm_storeu_si128(dst, _mm_xor_si128(_mm_loadu_si128(src), _mm_loadu_si128(dst)));
  22. }
  23. #endif
  24. for (; pos < size; ++pos) {
  25. *(result.Begin() + pos) ^= data[pos];
  26. }
  27. }
  28. return TCodecTraits::FromBufferToBlob(std::move(result));
  29. }
  30. //! Locally Reconstructable Codes
  31. /*!
  32. * See https://www.usenix.org/conference/usenixfederatedconferencesweek/erasure-coding-windows-azure-storage
  33. * for more details.
  34. */
  35. template <int DataPartCount, int ParityPartCount, int WordSize, class TCodecTraits>
  36. class TLrcCodecBase
  37. : public ICodec<typename TCodecTraits::TBlobType>
  38. {
  39. static_assert(DataPartCount % 2 == 0, "Data part count must be even.");
  40. static_assert(ParityPartCount == 4, "Now we only support n-2-2 scheme for LRC codec");
  41. static_assert(1 + DataPartCount / 2 < (1 << (WordSize / 2)), "Data part count should be enough small to construct proper matrix.");
  42. public:
  43. //! Main blob for storing data.
  44. using TBlobType = typename TCodecTraits::TBlobType;
  45. //! Main mutable blob for decoding data.
  46. using TMutableBlobType = typename TCodecTraits::TMutableBlobType;
  47. static constexpr ui64 RequiredDataAlignment = alignof(ui64);
  48. TLrcCodecBase() {
  49. Groups_[0] = MakeSegment(0, DataPartCount / 2);
  50. // Xor.
  51. Groups_[0].push_back(DataPartCount);
  52. Groups_[1] = MakeSegment(DataPartCount / 2, DataPartCount);
  53. // Xor.
  54. Groups_[1].push_back(DataPartCount + 1);
  55. constexpr int totalPartCount = DataPartCount + ParityPartCount;
  56. if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
  57. CanRepair_.resize(1 << totalPartCount);
  58. for (int mask = 0; mask < (1 << totalPartCount); ++mask) {
  59. TPartIndexList erasedIndices;
  60. for (size_t i = 0; i < totalPartCount; ++i) {
  61. if ((mask & (1 << i)) == 0) {
  62. erasedIndices.push_back(i);
  63. }
  64. }
  65. CanRepair_[mask] = CalculateCanRepair(erasedIndices);
  66. }
  67. }
  68. }
  69. /*! Note that if you want to restore any internal data, blocks offsets must by WordSize * sizeof(long) aligned.
  70. * Though it is possible to restore unaligned data if no more than one index in each Group is failed. See unittests for this case.
  71. */
  72. std::vector<TBlobType> Decode(
  73. const std::vector<TBlobType>& blocks,
  74. const TPartIndexList& erasedIndices) const override
  75. {
  76. if (erasedIndices.empty()) {
  77. return std::vector<TBlobType>();
  78. }
  79. size_t blockLength = blocks.front().Size();
  80. for (size_t i = 1; i < blocks.size(); ++i) {
  81. YT_VERIFY(blocks[i].Size() == blockLength);
  82. }
  83. TPartIndexList indices = UniqueSortedIndices(erasedIndices);
  84. // We can restore one block by xor.
  85. if (indices.size() == 1) {
  86. int index = erasedIndices.front();
  87. for (size_t i = 0; i < 2; ++i) {
  88. if (Contains(Groups_[i], index)) {
  89. return std::vector<TBlobType>(1, Xor<TCodecTraits>(blocks));
  90. }
  91. }
  92. }
  93. TPartIndexList recoveryIndices = GetRepairIndices(indices).value();
  94. // We can restore two blocks from different groups using xor.
  95. if (indices.size() == 2 &&
  96. indices.back() < DataPartCount + 2 &&
  97. recoveryIndices.back() < DataPartCount + 2)
  98. {
  99. std::vector<TBlobType> result;
  100. for (int index : indices) {
  101. for (size_t groupIndex = 0; groupIndex < 2; ++groupIndex) {
  102. if (!Contains(Groups_[groupIndex], index)) {
  103. continue;
  104. }
  105. std::vector<TBlobType> correspondingBlocks;
  106. for (int pos : Groups_[groupIndex]) {
  107. for (size_t i = 0; i < blocks.size(); ++i) {
  108. if (recoveryIndices[i] != pos) {
  109. continue;
  110. }
  111. correspondingBlocks.push_back(blocks[i]);
  112. }
  113. }
  114. result.push_back(Xor<TCodecTraits>(correspondingBlocks));
  115. }
  116. }
  117. return result;
  118. }
  119. return FallbackToCodecDecode(blocks, std::move(indices));
  120. }
  121. bool CanRepair(const TPartIndexList& erasedIndices) const final {
  122. constexpr int totalPartCount = DataPartCount + ParityPartCount;
  123. if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
  124. int mask = (1 << (totalPartCount)) - 1;
  125. for (int index : erasedIndices) {
  126. mask -= (1 << index);
  127. }
  128. return CanRepair_[mask];
  129. } else {
  130. return CalculateCanRepair(erasedIndices);
  131. }
  132. }
  133. bool CanRepair(const TPartIndexSet& erasedIndicesMask) const final {
  134. constexpr int totalPartCount = DataPartCount + ParityPartCount;
  135. if constexpr (totalPartCount <= BitmaskOptimizationThreshold) {
  136. TPartIndexSet mask = erasedIndicesMask;
  137. return CanRepair_[mask.flip().to_ulong()];
  138. } else {
  139. TPartIndexList erasedIndices;
  140. for (size_t i = 0; i < erasedIndicesMask.size(); ++i) {
  141. if (erasedIndicesMask[i]) {
  142. erasedIndices.push_back(i);
  143. }
  144. }
  145. return CalculateCanRepair(erasedIndices);
  146. }
  147. }
  148. std::optional<TPartIndexList> GetRepairIndices(const TPartIndexList& erasedIndices) const final {
  149. if (erasedIndices.empty()) {
  150. return TPartIndexList();
  151. }
  152. TPartIndexList indices = UniqueSortedIndices(erasedIndices);
  153. if (indices.size() > ParityPartCount) {
  154. return std::nullopt;
  155. }
  156. // One erasure from data or xor blocks.
  157. if (indices.size() == 1) {
  158. int index = indices.front();
  159. for (size_t i = 0; i < 2; ++i) {
  160. if (Contains(Groups_[i], index)) {
  161. return Difference(Groups_[i], index);
  162. }
  163. }
  164. }
  165. // Null if we have 4 erasures in one group.
  166. if (indices.size() == ParityPartCount) {
  167. bool intersectsAny = true;
  168. for (size_t i = 0; i < 2; ++i) {
  169. if (Intersection(indices, Groups_[i]).empty()) {
  170. intersectsAny = false;
  171. }
  172. }
  173. if (!intersectsAny) {
  174. return std::nullopt;
  175. }
  176. }
  177. // Calculate coverage of each group.
  178. int groupCoverage[2] = {};
  179. for (int index : indices) {
  180. for (size_t i = 0; i < 2; ++i) {
  181. if (Contains(Groups_[i], index)) {
  182. ++groupCoverage[i];
  183. }
  184. }
  185. }
  186. // Two erasures, one in each group.
  187. if (indices.size() == 2 && groupCoverage[0] == 1 && groupCoverage[1] == 1) {
  188. return Difference(Union(Groups_[0], Groups_[1]), indices);
  189. }
  190. // Erasures in only parity blocks.
  191. if (indices.front() >= DataPartCount) {
  192. return MakeSegment(0, DataPartCount);
  193. }
  194. // Remove unnecessary xor parities.
  195. TPartIndexList result = Difference(0, DataPartCount + ParityPartCount, indices);
  196. for (size_t i = 0; i < 2; ++i) {
  197. if (groupCoverage[i] == 0 && indices.size() <= 3) {
  198. result = Difference(result, DataPartCount + i);
  199. }
  200. }
  201. return result;
  202. }
  203. int GetDataPartCount() const override {
  204. return DataPartCount;
  205. }
  206. int GetParityPartCount() const override {
  207. return ParityPartCount;
  208. }
  209. int GetGuaranteedRepairablePartCount() const override {
  210. return ParityPartCount - 1;
  211. }
  212. int GetWordSize() const override {
  213. return WordSize * sizeof(long);
  214. }
  215. virtual ~TLrcCodecBase() = default;
  216. protected:
  217. // Indices of data blocks and corresponding xor (we have two xor parities).
  218. TConstArrayRef<TPartIndexList> GetXorGroups() const {
  219. return Groups_;
  220. }
  221. virtual std::vector<TBlobType> FallbackToCodecDecode(
  222. const std::vector<TBlobType>& /* blocks */,
  223. TPartIndexList /* erasedIndices */) const = 0;
  224. template <typename T>
  225. void InitializeGeneratorMatrix(T* generatorMatrix, const std::function<T(T)>& GFSquare) {
  226. for (int row = 0; row < ParityPartCount; ++row) {
  227. for (int column = 0; column < DataPartCount; ++column) {
  228. int index = row * DataPartCount + column;
  229. bool isFirstHalf = column < DataPartCount / 2;
  230. if (row == 0) generatorMatrix[index] = isFirstHalf ? 1 : 0;
  231. if (row == 1) generatorMatrix[index] = isFirstHalf ? 0 : 1;
  232. // Let alpha_i be coefficient of first half and beta_i of the second half.
  233. // Then matrix is non-singular iff:
  234. // a) alpha_i, beta_j != 0
  235. // b) alpha_i != beta_j
  236. // c) alpha_i + alpha_k != beta_j + beta_l
  237. // for any i, j, k, l.
  238. if (row == 2) {
  239. int shift = isFirstHalf ? 1 : (1 << (WordSize / 2));
  240. int relativeColumn = isFirstHalf ? column : (column - (DataPartCount / 2));
  241. generatorMatrix[index] = shift * (1 + relativeColumn);
  242. }
  243. // The last row is the square of the row before last.
  244. if (row == 3) {
  245. auto prev = generatorMatrix[index - DataPartCount];
  246. generatorMatrix[index] = GFSquare(prev);
  247. }
  248. }
  249. }
  250. }
  251. private:
  252. bool CalculateCanRepair(const TPartIndexList& erasedIndices) const {
  253. TPartIndexList indices = UniqueSortedIndices(erasedIndices);
  254. if (indices.size() > ParityPartCount) {
  255. return false;
  256. }
  257. if (indices.size() == 1) {
  258. int index = indices.front();
  259. for (size_t i = 0; i < 2; ++i) {
  260. if (Contains(Groups_[i], index)) {
  261. return true;
  262. }
  263. }
  264. }
  265. // If 4 indices miss in one block we cannot recover.
  266. if (indices.size() == ParityPartCount) {
  267. for (size_t i = 0; i < 2; ++i) {
  268. if (Intersection(indices, Groups_[i]).empty()) {
  269. return false;
  270. }
  271. }
  272. }
  273. return true;
  274. }
  275. TPartIndexList Groups_[2];
  276. std::vector<bool> CanRepair_;
  277. };
  278. } // namespace NErasure