levenshtein_diff.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. #pragma once
  2. #include <util/draft/matrix.h>
  3. #include <util/generic/algorithm.h>
  4. #include <util/generic/vector.h>
  5. #include <util/system/yassert.h>
  6. #include <type_traits>
  7. #include <utility>
  8. namespace NLevenshtein {
  9. enum EEditMoveType {
  10. EMT_SPECIAL,
  11. EMT_PRESERVE,
  12. EMT_REPLACE,
  13. EMT_DELETE,
  14. EMT_INSERT,
  15. EMT_TRANSPOSE
  16. };
  17. inline bool IsImportantEditMove(EEditMoveType p) {
  18. return (p != EMT_SPECIAL && p != EMT_PRESERVE);
  19. }
  20. inline void MakeMove(EEditMoveType t, int& p1, int& p2) {
  21. switch (t) {
  22. case EMT_TRANSPOSE:
  23. p1 += 2;
  24. p2 += 2;
  25. break;
  26. case EMT_PRESERVE:
  27. case EMT_REPLACE:
  28. p1++;
  29. p2++;
  30. break;
  31. case EMT_DELETE:
  32. p1++;
  33. break;
  34. case EMT_INSERT:
  35. p2++;
  36. break;
  37. default:
  38. break;
  39. }
  40. }
  41. using TEditChain = TVector<EEditMoveType>;
  42. template <typename TArgType, typename TWeightType = int>
  43. struct TWeightOneUnaryGetter {
  44. TWeightType operator()(const TArgType&) const {
  45. return 1;
  46. }
  47. };
  48. template <typename TArgType, typename TWeightType = int>
  49. struct TWeightOneBinaryGetter {
  50. TWeightType operator()(const TArgType&, const TArgType&) const {
  51. return 1;
  52. }
  53. };
  54. template <typename TArgType, typename TWeightType = int>
  55. struct TWeightInfBinaryGetter {
  56. TWeightType operator()(const TArgType&, const TArgType&) const {
  57. return std::numeric_limits<TWeightType>::max();
  58. }
  59. };
  60. template <typename TStringType>
  61. using TCharType = typename std::decay_t<decltype(std::add_const_t<TStringType>()[0])>;
  62. /// Finds sequence of "edit moves" for two strings
  63. template <class TStringType, class TWeightType = int,
  64. class TReplaceWeigher = TWeightOneBinaryGetter<TCharType<TStringType>, TWeightType>,
  65. class TDeleteWeigher = TWeightOneUnaryGetter<TCharType<TStringType>, TWeightType>,
  66. class TInsertWeigher = TWeightOneUnaryGetter<TCharType<TStringType>, TWeightType>,
  67. class TTransposeWeigher = TWeightInfBinaryGetter<TCharType<TStringType>, TWeightType>
  68. >
  69. void GetEditChain(const TStringType& str1, const TStringType& str2, TEditChain& res, TWeightType* weight = nullptr,
  70. const TReplaceWeigher& replaceWeigher = TReplaceWeigher(),
  71. const TDeleteWeigher& deleteWeigher = TDeleteWeigher(),
  72. const TInsertWeigher& insertWeigher = TInsertWeigher(),
  73. const TTransposeWeigher& transposeWeigher = TTransposeWeigher())
  74. {
  75. int l1 = (int)str1.size();
  76. int l2 = (int)str2.size();
  77. TMatrix<std::pair<TWeightType, EEditMoveType>> ma(l1 + 1, l2 + 1); /// ma[i][j].first = diff(str1[0..i-1], str2[0..j-1])
  78. ma[0][0] = std::make_pair(0, EMT_SPECIAL); // starting point
  79. for (int i = 1; i <= l1; i++) {
  80. ma[i][0] = std::make_pair(ma[i - 1][0].first + deleteWeigher(str1[i - 1]), EMT_DELETE);
  81. }
  82. for (int i = 1; i <= l2; i++) {
  83. ma[0][i] = std::make_pair(ma[0][i - 1].first + insertWeigher(str2[i - 1]), EMT_INSERT);
  84. }
  85. const TWeightType maxWeight = std::numeric_limits<TWeightType>::max();
  86. // Here goes basic Damerau-Levenshtein's algorithm
  87. for (int i = 1; i <= l1; i++) {
  88. for (int j = 1; j <= l2; j++) {
  89. if (str1[i - 1] == str2[j - 1]) {
  90. ma[i][j] = std::make_pair(ma[i - 1][j - 1].first, EMT_PRESERVE);
  91. } else {
  92. const TWeightType replaceWeight = replaceWeigher(str1[i - 1], str2[j - 1]);
  93. Y_ASSERT(replaceWeight >= 0);
  94. if (replaceWeight < maxWeight) {
  95. ma[i][j] = std::make_pair(ma[i - 1][j - 1].first + replaceWeight, EMT_REPLACE);
  96. } else {
  97. ma[i][j] = std::make_pair(replaceWeight, EMT_REPLACE);
  98. }
  99. }
  100. if (ma[i][j].first > ma[i - 1][j].first) {
  101. const TWeightType deleteWeight = deleteWeigher(str1[i - 1]);
  102. Y_ASSERT(deleteWeight >= 0);
  103. if (deleteWeight < maxWeight) {
  104. const TWeightType deletePathWeight = ma[i - 1][j].first + deleteWeight;
  105. if (deletePathWeight <= ma[i][j].first) {
  106. ma[i][j] = std::make_pair(deletePathWeight, EMT_DELETE);
  107. }
  108. }
  109. }
  110. if (ma[i][j].first > ma[i][j - 1].first) {
  111. const TWeightType insertWeight = insertWeigher(str2[j - 1]);
  112. Y_ASSERT(insertWeight >= 0);
  113. if (insertWeight < maxWeight) {
  114. const TWeightType insertPathWeight = ma[i][j - 1].first + insertWeight;
  115. if (insertPathWeight <= ma[i][j].first) {
  116. ma[i][j] = std::make_pair(insertPathWeight, EMT_INSERT);
  117. }
  118. }
  119. }
  120. if (i > 1 && j > 1 && str1[i - 1] == str2[j - 2] && str1[i - 2] == str2[j - 1]) {
  121. const TWeightType transposeWeight = transposeWeigher(str1[i - 2], str2[j - 2]);
  122. Y_ASSERT(transposeWeight >= 0);
  123. if (transposeWeight < maxWeight) {
  124. const TWeightType transposePathWeight = ma[i - 2][j - 2].first + transposeWeight;
  125. if (transposePathWeight <= ma[i][j].first) {
  126. ma[i][j] = std::make_pair(transposePathWeight, EMT_TRANSPOSE);
  127. }
  128. }
  129. }
  130. }
  131. }
  132. // Tracing the path from final point
  133. res.clear();
  134. res.reserve(Max<size_t>(l1, l2));
  135. for (int i = l1, j = l2; ma[i][j].second != EMT_SPECIAL;) {
  136. res.push_back(ma[i][j].second);
  137. switch (ma[i][j].second) {
  138. case EMT_TRANSPOSE:
  139. i -= 2;
  140. j -= 2;
  141. break;
  142. case EMT_PRESERVE:
  143. case EMT_REPLACE:
  144. --i;
  145. --j;
  146. break;
  147. case EMT_DELETE:
  148. --i;
  149. break;
  150. case EMT_INSERT:
  151. --j;
  152. break;
  153. default:
  154. // TODO: throw exception
  155. break;
  156. }
  157. }
  158. std::reverse(res.begin(), res.end());
  159. if (weight != nullptr) {
  160. *weight = ma[l1][l2].first;
  161. }
  162. }
  163. template <class TStringType, bool Damerau = false, class TWeightType = int>
  164. void GetEditChainGeneric(const TStringType& str1, const TStringType& str2, TEditChain& res, TWeightType* weight = nullptr) {
  165. typedef TCharType<TStringType> TArgType;
  166. GetEditChain<
  167. TStringType, TWeightType,
  168. TWeightOneBinaryGetter<TArgType, TWeightType>,
  169. TWeightOneUnaryGetter<TArgType, TWeightType>,
  170. TWeightOneUnaryGetter<TArgType, TWeightType>,
  171. std::conditional_t<
  172. Damerau,
  173. TWeightOneBinaryGetter<TArgType, TWeightType>,
  174. TWeightInfBinaryGetter<TArgType, TWeightType>
  175. >
  176. >(str1, str2, res, weight);
  177. }
  178. template <class TStringType, bool Damerau = false>
  179. size_t DistanceImpl(const TStringType& str1, const TStringType& str2) {
  180. if (str1.size() > str2.size()) {
  181. return DistanceImpl<TStringType, Damerau>(str2, str1);
  182. }
  183. size_t size1 = str1.size();
  184. size_t size2 = str2.size();
  185. TVector<size_t> currentRow(size1 + 1);
  186. TVector<size_t> previousRow(size1 + 1);
  187. TVector<size_t> transpositionRow(size1 + 1);
  188. for (size_t i = 0; i <= size1; ++i) {
  189. previousRow[i] = i;
  190. }
  191. for (size_t i = 1; i <= size2; ++i) {
  192. currentRow[0] = i;
  193. for (size_t j = 1; j <= size1; ++j) {
  194. size_t cost = str1[j - 1] == str2[i - 1] ? 0 : 1;
  195. currentRow[j] = std::min(previousRow[j - 1] + cost, std::min(currentRow[j - 1], previousRow[j]) + 1);
  196. if (Damerau && i > 1 && j > 1 && str1[j - 2] == str2[i - 1] && str1[j - 1] == str2[i - 2]) {
  197. currentRow[j] = std::min(currentRow[j], transpositionRow[j - 2] + cost);
  198. }
  199. }
  200. if (Damerau) {
  201. std::swap(transpositionRow, previousRow);
  202. }
  203. std::swap(previousRow, currentRow);
  204. }
  205. return previousRow[size1];
  206. }
  207. template <class TStringType>
  208. size_t Distance(const TStringType& str1, const TStringType& str2) {
  209. return DistanceImpl<TStringType, false>(str1, str2);
  210. }
  211. template <class TStringType>
  212. size_t DamerauDistance(const TStringType& str1, const TStringType& str2) {
  213. return DistanceImpl<TStringType, true>(str1, str2);
  214. }
  215. /// Calculates substrings to be replaced for str1->str2 transformation
  216. struct TReplacement {
  217. int CorrectOffset, CorrectLength, MisspelledOffset, MisspelledLength;
  218. TReplacement()
  219. : CorrectOffset(0)
  220. , CorrectLength(0)
  221. , MisspelledOffset(0)
  222. , MisspelledLength(0)
  223. {
  224. }
  225. TReplacement(int correctOffset, int correctLength, int misspelledOffset, int misspelledLength)
  226. : CorrectOffset(correctOffset)
  227. , CorrectLength(correctLength)
  228. , MisspelledOffset(misspelledOffset)
  229. , MisspelledLength(misspelledLength)
  230. {
  231. }
  232. };
  233. template <class TStringType>
  234. void GetStringReplacements(const TStringType& str1, const TStringType& str2, TVector<TReplacement>& res) {
  235. TEditChain editChain;
  236. GetEditChain(str1, str2, editChain);
  237. editChain.push_back(EMT_SPECIAL);
  238. int c1 = 0, c2 = 0;
  239. res.clear();
  240. for (TEditChain::const_iterator it = editChain.begin(); it != editChain.end(); it++) {
  241. if (IsImportantEditMove(*it)) {
  242. int sc1 = c1, sc2 = c2;
  243. do {
  244. MakeMove(*it, c1, c2);
  245. ++it;
  246. } while (IsImportantEditMove(*it));
  247. res.push_back(TReplacement(sc1, c1 - sc1, sc2, c2 - sc2));
  248. }
  249. MakeMove(*it, c1, c2);
  250. }
  251. }
  252. }