mlbe.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. // © 2022 and later: Unicode, Inc. and others.
  2. // License & terms of use: http://www.unicode.org/copyright.html
  3. #include "unicode/utypes.h"
  4. #if !UCONFIG_NO_BREAK_ITERATION
  5. #include "cmemory.h"
  6. #include "mlbe.h"
  7. #include "uassert.h"
  8. #include "ubrkimpl.h"
  9. #include "unicode/resbund.h"
  10. #include "unicode/udata.h"
  11. #include "unicode/utf16.h"
  12. #include "uresimp.h"
  13. #include "util.h"
  14. #include "uvectr32.h"
  15. U_NAMESPACE_BEGIN
  16. enum class ModelIndex { kUWStart = 0, kBWStart = 6, kTWStart = 9 };
  17. MlBreakEngine::MlBreakEngine(const UnicodeSet &digitOrOpenPunctuationOrAlphabetSet,
  18. const UnicodeSet &closePunctuationSet, UErrorCode &status)
  19. : fDigitOrOpenPunctuationOrAlphabetSet(digitOrOpenPunctuationOrAlphabetSet),
  20. fClosePunctuationSet(closePunctuationSet),
  21. fNegativeSum(0) {
  22. if (U_FAILURE(status)) {
  23. return;
  24. }
  25. loadMLModel(status);
  26. }
  27. MlBreakEngine::~MlBreakEngine() {}
  28. int32_t MlBreakEngine::divideUpRange(UText *inText, int32_t rangeStart, int32_t rangeEnd,
  29. UVector32 &foundBreaks, const UnicodeString &inString,
  30. const LocalPointer<UVector32> &inputMap,
  31. UErrorCode &status) const {
  32. if (U_FAILURE(status)) {
  33. return 0;
  34. }
  35. if (rangeStart >= rangeEnd) {
  36. status = U_ILLEGAL_ARGUMENT_ERROR;
  37. return 0;
  38. }
  39. UVector32 boundary(inString.countChar32() + 1, status);
  40. if (U_FAILURE(status)) {
  41. return 0;
  42. }
  43. int32_t numBreaks = 0;
  44. int32_t codePointLength = inString.countChar32();
  45. // The ML algorithm groups six char and evaluates whether the 4th char is a breakpoint.
  46. // In each iteration, it evaluates the 4th char and then moves forward one char like a sliding
  47. // window. Initially, the first six values in the indexList are [-1, -1, 0, 1, 2, 3]. After
  48. // moving forward, finally the last six values in the indexList are
  49. // [length-4, length-3, length-2, length-1, -1, -1]. The "+4" here means four extra "-1".
  50. int32_t indexSize = codePointLength + 4;
  51. LocalMemory<int32_t> indexList(static_cast<int32_t*>(uprv_malloc(indexSize * sizeof(int32_t))));
  52. if (indexList.isNull()) {
  53. status = U_MEMORY_ALLOCATION_ERROR;
  54. return 0;
  55. }
  56. int32_t numCodeUnits = initIndexList(inString, indexList.getAlias(), status);
  57. // Add a break for the start.
  58. boundary.addElement(0, status);
  59. numBreaks++;
  60. if (U_FAILURE(status)) return 0;
  61. for (int32_t idx = 0; idx + 1 < codePointLength && U_SUCCESS(status); idx++) {
  62. numBreaks =
  63. evaluateBreakpoint(inString, indexList.getAlias(), idx, numCodeUnits, numBreaks, boundary, status);
  64. if (idx + 4 < codePointLength) {
  65. indexList[idx + 6] = numCodeUnits;
  66. numCodeUnits += U16_LENGTH(inString.char32At(indexList[idx + 6]));
  67. }
  68. }
  69. if (U_FAILURE(status)) return 0;
  70. // Add a break for the end if there is not one there already.
  71. if (boundary.lastElementi() != inString.countChar32()) {
  72. boundary.addElement(inString.countChar32(), status);
  73. numBreaks++;
  74. }
  75. int32_t prevCPPos = -1;
  76. int32_t prevUTextPos = -1;
  77. int32_t correctedNumBreaks = 0;
  78. for (int32_t i = 0; i < numBreaks; i++) {
  79. int32_t cpPos = boundary.elementAti(i);
  80. int32_t utextPos = inputMap.isValid() ? inputMap->elementAti(cpPos) : cpPos + rangeStart;
  81. U_ASSERT(cpPos > prevCPPos);
  82. U_ASSERT(utextPos >= prevUTextPos);
  83. if (utextPos > prevUTextPos) {
  84. if (utextPos != rangeStart ||
  85. (utextPos > 0 &&
  86. fClosePunctuationSet.contains(utext_char32At(inText, utextPos - 1)))) {
  87. foundBreaks.push(utextPos, status);
  88. correctedNumBreaks++;
  89. }
  90. } else {
  91. // Normalization expanded the input text, the dictionary found a boundary
  92. // within the expansion, giving two boundaries with the same index in the
  93. // original text. Ignore the second. See ticket #12918.
  94. --numBreaks;
  95. }
  96. prevCPPos = cpPos;
  97. prevUTextPos = utextPos;
  98. }
  99. (void)prevCPPos; // suppress compiler warnings about unused variable
  100. UChar32 nextChar = utext_char32At(inText, rangeEnd);
  101. if (!foundBreaks.isEmpty() && foundBreaks.peeki() == rangeEnd) {
  102. // In phrase breaking, there has to be a breakpoint between Cj character and
  103. // the number/open punctuation.
  104. // E.g. る文字「そうだ、京都」->る▁文字▁「そうだ、▁京都」-> breakpoint between 字 and「
  105. // E.g. 乗車率90%程度だろうか -> 乗車▁率▁90%▁程度だろうか -> breakpoint between 率 and 9
  106. // E.g. しかもロゴがUnicode! -> しかも▁ロゴが▁Unicode!-> breakpoint between が and U
  107. if (!fDigitOrOpenPunctuationOrAlphabetSet.contains(nextChar)) {
  108. foundBreaks.popi();
  109. correctedNumBreaks--;
  110. }
  111. }
  112. return correctedNumBreaks;
  113. }
  114. int32_t MlBreakEngine::evaluateBreakpoint(const UnicodeString &inString, int32_t *indexList,
  115. int32_t startIdx, int32_t numCodeUnits, int32_t numBreaks,
  116. UVector32 &boundary, UErrorCode &status) const {
  117. if (U_FAILURE(status)) {
  118. return numBreaks;
  119. }
  120. int32_t start = 0, end = 0;
  121. int32_t score = fNegativeSum;
  122. for (int i = 0; i < 6; i++) {
  123. // UW1 ~ UW6
  124. start = startIdx + i;
  125. if (indexList[start] != -1) {
  126. end = (indexList[start + 1] != -1) ? indexList[start + 1] : numCodeUnits;
  127. score += fModel[static_cast<int32_t>(ModelIndex::kUWStart) + i].geti(
  128. inString.tempSubString(indexList[start], end - indexList[start]));
  129. }
  130. }
  131. for (int i = 0; i < 3; i++) {
  132. // BW1 ~ BW3
  133. start = startIdx + i + 1;
  134. if (indexList[start] != -1 && indexList[start + 1] != -1) {
  135. end = (indexList[start + 2] != -1) ? indexList[start + 2] : numCodeUnits;
  136. score += fModel[static_cast<int32_t>(ModelIndex::kBWStart) + i].geti(
  137. inString.tempSubString(indexList[start], end - indexList[start]));
  138. }
  139. }
  140. for (int i = 0; i < 4; i++) {
  141. // TW1 ~ TW4
  142. start = startIdx + i;
  143. if (indexList[start] != -1 && indexList[start + 1] != -1 && indexList[start + 2] != -1) {
  144. end = (indexList[start + 3] != -1) ? indexList[start + 3] : numCodeUnits;
  145. score += fModel[static_cast<int32_t>(ModelIndex::kTWStart) + i].geti(
  146. inString.tempSubString(indexList[start], end - indexList[start]));
  147. }
  148. }
  149. if (score > 0) {
  150. boundary.addElement(startIdx + 1, status);
  151. numBreaks++;
  152. }
  153. return numBreaks;
  154. }
  155. int32_t MlBreakEngine::initIndexList(const UnicodeString &inString, int32_t *indexList,
  156. UErrorCode &status) const {
  157. if (U_FAILURE(status)) {
  158. return 0;
  159. }
  160. int32_t index = 0;
  161. int32_t length = inString.countChar32();
  162. // Set all (lenght+4) items inside indexLength to -1 presuming -1 is 4 bytes of 0xff.
  163. uprv_memset(indexList, 0xff, (length + 4) * sizeof(int32_t));
  164. if (length > 0) {
  165. indexList[2] = 0;
  166. index = U16_LENGTH(inString.char32At(0));
  167. if (length > 1) {
  168. indexList[3] = index;
  169. index += U16_LENGTH(inString.char32At(index));
  170. if (length > 2) {
  171. indexList[4] = index;
  172. index += U16_LENGTH(inString.char32At(index));
  173. if (length > 3) {
  174. indexList[5] = index;
  175. index += U16_LENGTH(inString.char32At(index));
  176. }
  177. }
  178. }
  179. }
  180. return index;
  181. }
  182. void MlBreakEngine::loadMLModel(UErrorCode &error) {
  183. // BudouX's model consists of thirteen categories, each of which is make up of pairs of the
  184. // feature and its score. As integrating it into jaml.txt, we define thirteen kinds of key and
  185. // value to represent the feature and the corresponding score respectively.
  186. if (U_FAILURE(error)) return;
  187. UnicodeString key;
  188. StackUResourceBundle stackTempBundle;
  189. ResourceDataValue modelKey;
  190. LocalUResourceBundlePointer rbp(ures_openDirect(U_ICUDATA_BRKITR, "jaml", &error));
  191. UResourceBundle *rb = rbp.getAlias();
  192. if (U_FAILURE(error)) return;
  193. int32_t index = 0;
  194. initKeyValue(rb, "UW1Keys", "UW1Values", fModel[index++], error);
  195. initKeyValue(rb, "UW2Keys", "UW2Values", fModel[index++], error);
  196. initKeyValue(rb, "UW3Keys", "UW3Values", fModel[index++], error);
  197. initKeyValue(rb, "UW4Keys", "UW4Values", fModel[index++], error);
  198. initKeyValue(rb, "UW5Keys", "UW5Values", fModel[index++], error);
  199. initKeyValue(rb, "UW6Keys", "UW6Values", fModel[index++], error);
  200. initKeyValue(rb, "BW1Keys", "BW1Values", fModel[index++], error);
  201. initKeyValue(rb, "BW2Keys", "BW2Values", fModel[index++], error);
  202. initKeyValue(rb, "BW3Keys", "BW3Values", fModel[index++], error);
  203. initKeyValue(rb, "TW1Keys", "TW1Values", fModel[index++], error);
  204. initKeyValue(rb, "TW2Keys", "TW2Values", fModel[index++], error);
  205. initKeyValue(rb, "TW3Keys", "TW3Values", fModel[index++], error);
  206. initKeyValue(rb, "TW4Keys", "TW4Values", fModel[index++], error);
  207. fNegativeSum /= 2;
  208. }
  209. void MlBreakEngine::initKeyValue(UResourceBundle *rb, const char *keyName, const char *valueName,
  210. Hashtable &model, UErrorCode &error) {
  211. int32_t keySize = 0;
  212. int32_t valueSize = 0;
  213. int32_t stringLength = 0;
  214. UnicodeString key;
  215. StackUResourceBundle stackTempBundle;
  216. ResourceDataValue modelKey;
  217. // get modelValues
  218. LocalUResourceBundlePointer modelValue(ures_getByKey(rb, valueName, nullptr, &error));
  219. const int32_t *value = ures_getIntVector(modelValue.getAlias(), &valueSize, &error);
  220. if (U_FAILURE(error)) return;
  221. // get modelKeys
  222. ures_getValueWithFallback(rb, keyName, stackTempBundle.getAlias(), modelKey, error);
  223. ResourceArray stringArray = modelKey.getArray(error);
  224. keySize = stringArray.getSize();
  225. if (U_FAILURE(error)) return;
  226. for (int32_t idx = 0; idx < keySize; idx++) {
  227. stringArray.getValue(idx, modelKey);
  228. key = UnicodeString(modelKey.getString(stringLength, error));
  229. if (U_SUCCESS(error)) {
  230. U_ASSERT(idx < valueSize);
  231. fNegativeSum -= value[idx];
  232. model.puti(key, value[idx], error);
  233. }
  234. }
  235. }
  236. U_NAMESPACE_END
  237. #endif /* #if !UCONFIG_NO_BREAK_ITERATION */