gd_builder.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #include "gd_builder.h"
  2. #include <library/cpp/string_utils/relaxed_escaper/relaxed_escaper.h>
  3. #include <util/generic/algorithm.h>
  4. #include <util/random/shuffle.h>
  5. #include <util/stream/output.h>
  6. #include <util/string/printf.h>
  7. #include <util/system/rusage.h>
  8. namespace NGreedyDict {
  9. void TDictBuilder::RebuildCounts(ui32 maxcand, bool final) {
  10. if (!Current) {
  11. Current = MakeHolder<TEntrySet>();
  12. Current->InitWithAlpha();
  13. }
  14. TEntrySet& set = *Current;
  15. for (auto& it : set)
  16. it.Count = 0;
  17. CompoundCounts = nullptr;
  18. CompoundCountsPool.Clear();
  19. if (!final) {
  20. CompoundCounts = MakeHolder<TCompoundCounts>(&CompoundCountsPool);
  21. CompoundCounts->reserve(maxcand);
  22. }
  23. Shuffle(Input.begin(), Input.end(), Rng);
  24. for (auto str : Input) {
  25. if (!final && CompoundCounts->size() > maxcand)
  26. break;
  27. i32 prev = -1;
  28. while (!!str) {
  29. TEntry* e = set.FindPrefix(str);
  30. ui32 num = e->Number;
  31. e->Count += 1;
  32. if (!final && prev >= 0) {
  33. (*CompoundCounts)[Compose(prev, num)] += 1;
  34. }
  35. prev = num;
  36. ++set.TotalCount;
  37. }
  38. }
  39. Current->SetModelP();
  40. }
  41. ui32 TDictBuilder::BuildNextGeneration(ui32 maxent, ui32 maxlen) {
  42. TAutoPtr<TEntrySet> newset = new TEntrySet;
  43. newset->InitWithAlpha();
  44. maxent -= newset->size();
  45. ui32 additions = 0;
  46. ui32 deletions = 0;
  47. {
  48. const TEntrySet& set = *Current;
  49. Candidates.clear();
  50. const ui32 total = set.TotalCount;
  51. const float minpval = Settings.MinPValue;
  52. const EEntryStatTest test = Settings.StatTest;
  53. const EEntryScore score = Settings.Score;
  54. const ui32 mincnt = Settings.MinAbsCount;
  55. for (const auto& it : set) {
  56. const TEntry& e = it;
  57. float modelp = e.ModelP;
  58. ui32 cnt = e.Count;
  59. if (e.HasPrefix() && e.Count > mincnt && StatTest(test, modelp, cnt, total) > minpval)
  60. Candidates.push_back(TCandidate(-Score(score, e.Len(), modelp, cnt, total), it.Number));
  61. }
  62. if (!!CompoundCounts) {
  63. for (TCompoundCounts::const_iterator it = CompoundCounts->begin(); it != CompoundCounts->end(); ++it) {
  64. const TEntry& prev = set.Get(Prev(it->first));
  65. const TEntry& next = set.Get(Next(it->first));
  66. float modelp = ModelP(prev.Count, next.Count, total);
  67. ui32 cnt = it->second;
  68. if (cnt > mincnt && StatTest(test, modelp, cnt, total) > minpval && prev.Len() + next.Len() <= maxlen)
  69. Candidates.push_back(TCandidate(-Score(score, prev.Len() + next.Len(), modelp, cnt, total), it->first));
  70. }
  71. }
  72. Sort(Candidates.begin(), Candidates.end());
  73. if (Candidates.size() > maxent)
  74. Candidates.resize(maxent);
  75. for (const auto& candidate : Candidates) {
  76. if (IsCompound(candidate.second)) {
  77. additions++;
  78. newset->Add(set.Get(Prev(candidate.second)).Str, set.Get(Next(candidate.second)).Str);
  79. } else {
  80. newset->Add(set.Get(candidate.second).Str);
  81. }
  82. }
  83. deletions = set.size() - (newset->size() - additions);
  84. }
  85. Current = newset;
  86. Current->BuildHierarchy();
  87. return deletions + additions;
  88. }
  89. ui32 TDictBuilder::Build(ui32 maxentries, ui32 maxiters, ui32 maxlen, ui32 mindiff) {
  90. /* size_t totalsz = 0;
  91. for (auto it : Input)
  92. totalsz += it.size();*/
  93. while (maxiters) {
  94. maxiters--;
  95. RebuildCounts(maxentries * Settings.GrowLimit, false);
  96. if (Settings.Verbose) {
  97. TString mess = Sprintf("iter:%" PRIu32 " sz:%" PRIu32 " pend:%" PRIu32, maxiters, (ui32)Current->size(), (ui32)CompoundCounts->size());
  98. Clog << Sprintf("%-110s RSS=%" PRIu32 "M", mess.data(), (ui32)(TRusage::Get().MaxRss >> 20)) << Endl;
  99. }
  100. ui32 diff = BuildNextGeneration(maxentries, maxlen);
  101. if (Current->size() == maxentries && diff < mindiff)
  102. break;
  103. }
  104. RebuildCounts(0, true);
  105. Current->SetScores(Settings.Score);
  106. return maxiters;
  107. }
  108. }