levenshtein_diff.h 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. #pragma once
  2. #include <util/draft/matrix.h>
  3. #include <util/generic/algorithm.h>
  4. #include <util/generic/vector.h>
  5. #include <util/system/yassert.h>
  6. #include <type_traits>
  7. #include <utility>
  8. namespace NLevenshtein {
  9. enum EEditMoveType {
  10. EMT_SPECIAL,
  11. EMT_PRESERVE,
  12. EMT_REPLACE,
  13. EMT_DELETE,
  14. EMT_INSERT
  15. };
  16. inline bool IsImportantEditMove(EEditMoveType p) {
  17. return (p != EMT_SPECIAL && p != EMT_PRESERVE);
  18. }
  19. inline void MakeMove(EEditMoveType t, int& p1, int& p2) {
  20. switch (t) {
  21. case EMT_PRESERVE:
  22. case EMT_REPLACE:
  23. p1++;
  24. p2++;
  25. break;
  26. case EMT_DELETE:
  27. p1++;
  28. break;
  29. case EMT_INSERT:
  30. p2++;
  31. break;
  32. default:
  33. break;
  34. }
  35. }
  36. using TEditChain = TVector<EEditMoveType>;
  37. template <typename TArgType>
  38. struct TWeightOneUnaryGetter {
  39. int operator()(const TArgType&) const {
  40. return 1;
  41. }
  42. };
  43. template <typename TArgType>
  44. struct TWeightOneBinaryGetter {
  45. int operator()(const TArgType&, const TArgType&) const {
  46. return 1;
  47. }
  48. };
  49. template <typename TStringType>
  50. using TCharType = typename std::decay_t<decltype(std::add_const_t<TStringType>()[0])>;
  51. /// Finds sequence of "edit moves" for two strings
  52. template <class TStringType, class TWeightType = int,
  53. class TReplaceWeigher = TWeightOneBinaryGetter<TCharType<TStringType>>,
  54. class TDeleteWeigher = TWeightOneUnaryGetter<TCharType<TStringType>>,
  55. class TInsertWeigher = TWeightOneUnaryGetter<TCharType<TStringType>>
  56. >
  57. void GetEditChain(const TStringType& str1, const TStringType& str2, TEditChain& res, TWeightType* weight = nullptr,
  58. const TReplaceWeigher& replaceWeigher = TReplaceWeigher(),
  59. const TDeleteWeigher& deleteWeigher = TDeleteWeigher(),
  60. const TInsertWeigher& insertWeigher = TInsertWeigher())
  61. {
  62. int l1 = (int)str1.size();
  63. int l2 = (int)str2.size();
  64. TMatrix<std::pair<TWeightType, EEditMoveType>> ma(l1 + 1, l2 + 1); /// ma[i][j].first = diff(str1[0..i-1], str2[0..j-1])
  65. ma[0][0] = std::make_pair(0, EMT_SPECIAL); // starting point
  66. for (int i = 1; i <= l1; i++) {
  67. ma[i][0] = std::make_pair(ma[i - 1][0].first + deleteWeigher(str1[i - 1]), EMT_DELETE);
  68. }
  69. for (int i = 1; i <= l2; i++) {
  70. ma[0][i] = std::make_pair(ma[0][i - 1].first + insertWeigher(str2[i - 1]), EMT_INSERT);
  71. }
  72. // Here goes basic Levestein's algorithm
  73. for (int i = 1; i <= l1; i++) {
  74. for (int j = 1; j <= l2; j++) {
  75. if (str1[i - 1] == str2[j - 1]) {
  76. ma[i][j] = std::make_pair(ma[i - 1][j - 1].first, EMT_PRESERVE);
  77. } else {
  78. const TWeightType replaceWeight = replaceWeigher(str1[i - 1], str2[j - 1]);
  79. Y_ASSERT(replaceWeight >= 0);
  80. ma[i][j] = std::make_pair(ma[i - 1][j - 1].first + replaceWeight, EMT_REPLACE);
  81. }
  82. if (ma[i][j].first > ma[i - 1][j].first) {
  83. const TWeightType deleteWeight = deleteWeigher(str1[i - 1]);
  84. Y_ASSERT(deleteWeight >= 0);
  85. const TWeightType deletePathWeight = ma[i - 1][j].first + deleteWeight;
  86. if (deletePathWeight <= ma[i][j].first) {
  87. ma[i][j] = std::make_pair(deletePathWeight, EMT_DELETE);
  88. }
  89. }
  90. if (ma[i][j].first > ma[i][j - 1].first) {
  91. const TWeightType insertWeight = insertWeigher(str2[j - 1]);
  92. Y_ASSERT(insertWeight >= 0);
  93. const TWeightType insertPathWeight = ma[i][j - 1].first + insertWeight;
  94. if (insertPathWeight <= ma[i][j].first) {
  95. ma[i][j] = std::make_pair(insertPathWeight, EMT_INSERT);
  96. }
  97. }
  98. }
  99. }
  100. // Tracing the path from final point
  101. res.clear();
  102. res.reserve(Max<size_t>(l1, l2));
  103. for (int i = l1, j = l2; ma[i][j].second != EMT_SPECIAL;) {
  104. res.push_back(ma[i][j].second);
  105. switch (ma[i][j].second) {
  106. case EMT_PRESERVE:
  107. case EMT_REPLACE:
  108. --i;
  109. --j;
  110. break;
  111. case EMT_DELETE:
  112. --i;
  113. break;
  114. case EMT_INSERT:
  115. --j;
  116. break;
  117. default:
  118. // TODO: throw exception
  119. break;
  120. }
  121. }
  122. std::reverse(res.begin(), res.end());
  123. if (weight != nullptr) {
  124. *weight = ma[l1][l2].first;
  125. }
  126. }
  127. template <class TStringType>
  128. size_t Distance(const TStringType& str1, const TStringType& str2) {
  129. TEditChain editChain;
  130. GetEditChain(str1, str2, editChain);
  131. size_t result = 0;
  132. for (auto edit : editChain) {
  133. if (IsImportantEditMove(edit))
  134. result++;
  135. }
  136. return result;
  137. }
  138. /// Calculates substrings to be replaced for str1->str2 transformation
  139. struct TReplacement {
  140. int CorrectOffset, CorrectLength, MisspelledOffset, MisspelledLength;
  141. TReplacement()
  142. : CorrectOffset(0)
  143. , CorrectLength(0)
  144. , MisspelledOffset(0)
  145. , MisspelledLength(0)
  146. {
  147. }
  148. TReplacement(int correctOffset, int correctLength, int misspelledOffset, int misspelledLength)
  149. : CorrectOffset(correctOffset)
  150. , CorrectLength(correctLength)
  151. , MisspelledOffset(misspelledOffset)
  152. , MisspelledLength(misspelledLength)
  153. {
  154. }
  155. };
  156. template <class TStringType>
  157. void GetStringReplacements(const TStringType& str1, const TStringType& str2, TVector<TReplacement>& res) {
  158. TEditChain editChain;
  159. GetEditChain(str1, str2, editChain);
  160. editChain.push_back(EMT_SPECIAL);
  161. int c1 = 0, c2 = 0;
  162. res.clear();
  163. for (TEditChain::const_iterator it = editChain.begin(); it != editChain.end(); it++) {
  164. if (IsImportantEditMove(*it)) {
  165. int sc1 = c1, sc2 = c2;
  166. do {
  167. MakeMove(*it, c1, c2);
  168. ++it;
  169. } while (IsImportantEditMove(*it));
  170. res.push_back(TReplacement(sc1, c1 - sc1, sc2, c2 - sc2));
  171. }
  172. MakeMove(*it, c1, c2);
  173. }
  174. }
  175. }