isa_erasure.h 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. #pragma once
  2. #include "public.h"
  3. #include "helpers.h"
  4. #include <library/cpp/yt/assert/assert.h>
  5. #include <util/generic/array_ref.h>
  6. #include <util/generic/ptr.h>
  7. #include <util/generic/singleton.h>
  8. #include <vector>
  9. extern "C" {
  10. #include <contrib/libs/isa-l/include/erasure_code.h>
  11. }
  12. namespace NErasure {
  13. template <class TBlobType>
  14. static inline unsigned char* ConstCast(typename TBlobType::const_iterator blobIter) {
  15. return const_cast<unsigned char*>(reinterpret_cast<const unsigned char*>(blobIter));
  16. }
  17. template <int DataPartCount, int ParityPartCount, class TCodecTraits, class TBlobType = typename TCodecTraits::TBlobType, class TMutableBlobType = typename TCodecTraits::TMutableBlobType>
  18. std::vector<TBlobType> ISAErasureEncode(
  19. const std::vector<unsigned char>& encodeGFTables,
  20. const std::vector<TBlobType>& dataBlocks)
  21. {
  22. YT_VERIFY(dataBlocks.size() == DataPartCount);
  23. size_t blockLength = dataBlocks.front().Size();
  24. for (size_t i = 1; i < dataBlocks.size(); ++i) {
  25. YT_VERIFY(dataBlocks[i].Size() == blockLength);
  26. }
  27. std::vector<unsigned char*> dataPointers;
  28. for (const auto& block : dataBlocks) {
  29. dataPointers.emplace_back(ConstCast<TBlobType>(block.Begin()));
  30. }
  31. std::vector<TMutableBlobType> parities(ParityPartCount);
  32. std::vector<unsigned char*> parityPointers(ParityPartCount);
  33. for (size_t i = 0; i < ParityPartCount; ++i) {
  34. parities[i] = TCodecTraits::AllocateBlob(blockLength);
  35. parityPointers[i] = ConstCast<TBlobType>(parities[i].Begin());
  36. memset(parityPointers[i], 0, blockLength);
  37. }
  38. ec_encode_data(
  39. blockLength,
  40. DataPartCount,
  41. ParityPartCount,
  42. const_cast<unsigned char*>(encodeGFTables.data()),
  43. dataPointers.data(),
  44. parityPointers.data());
  45. return std::vector<TBlobType>(parities.begin(), parities.end());
  46. }
  47. template <int DataPartCount, int ParityPartCount, class TCodecTraits, class TBlobType = typename TCodecTraits::TBlobType, class TMutableBlobType = typename TCodecTraits::TMutableBlobType>
  48. std::vector<TBlobType> ISAErasureDecode(
  49. const std::vector<TBlobType>& dataBlocks,
  50. const TPartIndexList& erasedIndices,
  51. TConstArrayRef<TPartIndexList> groups,
  52. const std::vector<unsigned char>& fullGeneratorMatrix)
  53. {
  54. YT_VERIFY(dataBlocks.size() >= DataPartCount);
  55. YT_VERIFY(erasedIndices.size() <= ParityPartCount);
  56. size_t blockLength = dataBlocks.front().Size();
  57. for (size_t i = 1; i < dataBlocks.size(); ++i) {
  58. YT_VERIFY(dataBlocks[i].Size() == blockLength);
  59. }
  60. std::vector<unsigned char> partialGeneratorMatrix(DataPartCount * DataPartCount, 0);
  61. std::vector<unsigned char*> recoveryBlocks;
  62. for (size_t i = 0; i < DataPartCount; ++i) {
  63. recoveryBlocks.emplace_back(ConstCast<TBlobType>(dataBlocks[i].Begin()));
  64. }
  65. // Groups check is specific for LRC.
  66. std::vector<int> isGroupHealthy(2, 1);
  67. for (size_t i = 0; i < 2; ++i) {
  68. for (const auto& index : erasedIndices) {
  69. if (!groups.empty() && Contains(groups[0], index)) {
  70. isGroupHealthy[0] = 0;
  71. } else if (!groups.empty() && Contains(groups[1], index)) {
  72. isGroupHealthy[1] = 0;
  73. }
  74. }
  75. }
  76. // When a group is healthy we cannot use its local parity, thus skip it using gap.
  77. size_t gap = 0;
  78. size_t decodeMatrixIndex = 0;
  79. size_t erasedBlockIndex = 0;
  80. while (decodeMatrixIndex < DataPartCount) {
  81. size_t globalIndex = decodeMatrixIndex + erasedBlockIndex + gap;
  82. if (erasedBlockIndex < erasedIndices.size() &&
  83. globalIndex == static_cast<size_t>(erasedIndices[erasedBlockIndex]))
  84. {
  85. ++erasedBlockIndex;
  86. continue;
  87. }
  88. if (!groups.empty() && globalIndex >= DataPartCount && globalIndex < DataPartCount + 2) {
  89. if (Contains(groups[0], globalIndex) && isGroupHealthy[0]) {
  90. ++gap;
  91. continue;
  92. }
  93. if (Contains(groups[1], globalIndex) && isGroupHealthy[1]) {
  94. ++gap;
  95. continue;
  96. }
  97. }
  98. memcpy(&partialGeneratorMatrix[decodeMatrixIndex * DataPartCount], &fullGeneratorMatrix[globalIndex * DataPartCount], DataPartCount);
  99. ++decodeMatrixIndex;
  100. }
  101. std::vector<unsigned char> invertedGeneratorMatrix(DataPartCount * DataPartCount, 0);
  102. int res = gf_invert_matrix(partialGeneratorMatrix.data(), invertedGeneratorMatrix.data(), DataPartCount);
  103. YT_VERIFY(res == 0);
  104. std::vector<unsigned char> decodeMatrix(DataPartCount * (DataPartCount + ParityPartCount), 0);
  105. //! Some magical code from library example.
  106. for (size_t i = 0; i < erasedIndices.size(); ++i) {
  107. if (erasedIndices[i] < DataPartCount) {
  108. memcpy(&decodeMatrix[i * DataPartCount], &invertedGeneratorMatrix[erasedIndices[i] * DataPartCount], DataPartCount);
  109. } else {
  110. for (int k = 0; k < DataPartCount; ++k) {
  111. int val = 0;
  112. for (int j = 0; j < DataPartCount; ++j) {
  113. val ^= gf_mul_erasure(invertedGeneratorMatrix[j * DataPartCount + k], fullGeneratorMatrix[DataPartCount * erasedIndices[i] + j]);
  114. }
  115. decodeMatrix[DataPartCount * i + k] = val;
  116. }
  117. }
  118. }
  119. std::vector<unsigned char> decodeGFTables(DataPartCount * erasedIndices.size() * 32);
  120. ec_init_tables(DataPartCount, erasedIndices.size(), decodeMatrix.data(), decodeGFTables.data());
  121. std::vector<TMutableBlobType> recoveredParts;
  122. std::vector<unsigned char*> recoveredPartsPointers;
  123. for (size_t i = 0; i < erasedIndices.size(); ++i) {
  124. recoveredParts.emplace_back(TCodecTraits::AllocateBlob(blockLength));
  125. recoveredPartsPointers.emplace_back(ConstCast<TBlobType>(recoveredParts.back().Begin()));
  126. memset(recoveredPartsPointers.back(), 0, blockLength);
  127. }
  128. ec_encode_data(
  129. blockLength,
  130. DataPartCount,
  131. erasedIndices.size(),
  132. decodeGFTables.data(),
  133. recoveryBlocks.data(),
  134. recoveredPartsPointers.data());
  135. return std::vector<TBlobType>(recoveredParts.begin(), recoveredParts.end());
  136. }
  137. } // namespace NErasure