cbo_optimizer_new.cpp 9.3 KB


  1. #include "cbo_optimizer_new.h"
  2. #include <array>
  3. #include <util/string/builder.h>
  4. #include <util/generic/hash.h>
  5. #include <util/generic/hash_set.h>
  6. #include <util/string/cast.h>
  7. #include <util/string/join.h>
  8. #include <util/string/printf.h>
  9. const TString& ToString(NYql::EJoinKind);
  10. const TString& ToString(NYql::EJoinAlgoType);
  11. namespace NYql {
  12. using namespace NYql::NDq;
  13. namespace {
  14. THashMap<TString,EJoinKind> JoinKindMap = {
  15. {"Inner",EJoinKind::InnerJoin},
  16. {"Left",EJoinKind::LeftJoin},
  17. {"Right",EJoinKind::RightJoin},
  18. {"Full",EJoinKind::OuterJoin},
  19. {"LeftOnly",EJoinKind::LeftOnly},
  20. {"RightOnly",EJoinKind::RightOnly},
  21. {"Exclusion",EJoinKind::Exclusion},
  22. {"LeftSemi",EJoinKind::LeftSemi},
  23. {"RightSemi",EJoinKind::RightSemi},
  24. {"Cross",EJoinKind::Cross}};
  25. THashMap<TString,TCardinalityHints::ECardOperation> HintOpMap = {
  26. {"+",TCardinalityHints::ECardOperation::Add},
  27. {"-",TCardinalityHints::ECardOperation::Subtract},
  28. {"*",TCardinalityHints::ECardOperation::Multiply},
  29. {"/",TCardinalityHints::ECardOperation::Divide},
  30. {"#",TCardinalityHints::ECardOperation::Replace}};
  31. }
  32. EJoinKind ConvertToJoinKind(const TString& joinString) {
  33. auto maybeKind = JoinKindMap.find(joinString);
  34. Y_ENSURE(maybeKind != JoinKindMap.end());
  35. return maybeKind->second;
  36. }
  37. TString ConvertToJoinString(const EJoinKind kind) {
  38. for (auto [k,v] : JoinKindMap) {
  39. if (v == kind) {
  40. return k;
  41. }
  42. }
  43. Y_ENSURE(false,"Unknown join kind");
  44. }
  45. TVector<TString> TRelOptimizerNode::Labels() {
  46. TVector<TString> res;
  47. res.emplace_back(Label);
  48. return res;
  49. }
  50. void TRelOptimizerNode::Print(std::stringstream& stream, int ntabs) {
  51. for (int i = 0; i < ntabs; i++){
  52. stream << " ";
  53. }
  54. stream << "Rel: " << Label << "\n";
  55. for (int i = 0; i < ntabs; i++){
  56. stream << " ";
  57. }
  58. stream << Stats << "\n";
  59. }
  60. TJoinOptimizerNode::TJoinOptimizerNode(
  61. const std::shared_ptr<IBaseOptimizerNode>& left,
  62. const std::shared_ptr<IBaseOptimizerNode>& right,
  63. TVector<TJoinColumn> leftKeys,
  64. TVector<TJoinColumn> rightKeys,
  65. const EJoinKind joinType,
  66. const EJoinAlgoType joinAlgo,
  67. bool leftAny,
  68. bool rightAny,
  69. bool nonReorderable
  70. ) : IBaseOptimizerNode(JoinNodeType)
  71. , LeftArg(left)
  72. , RightArg(right)
  73. , LeftJoinKeys(leftKeys)
  74. , RightJoinKeys(rightKeys)
  75. , JoinType(joinType)
  76. , JoinAlgo(joinAlgo)
  77. , LeftAny(leftAny)
  78. , RightAny(rightAny)
  79. , IsReorderable(!nonReorderable)
  80. {}
  81. TVector<TString> TJoinOptimizerNode::Labels() {
  82. auto res = LeftArg->Labels();
  83. auto rightLabels = RightArg->Labels();
  84. res.insert(res.begin(),rightLabels.begin(),rightLabels.end());
  85. return res;
  86. }
  87. void TJoinOptimizerNode::Print(std::stringstream& stream, int ntabs) {
  88. for (int i = 0; i < ntabs; i++){
  89. stream << " ";
  90. }
  91. stream << "Join: (" << ToString(JoinType) << "," << ToString(JoinAlgo);
  92. if (LeftAny) {
  93. stream << ",LeftAny";
  94. }
  95. if (RightAny) {
  96. stream << ",RightAny";
  97. }
  98. stream << ") ";
  99. for (size_t i=0; i<LeftJoinKeys.size(); i++){
  100. stream << LeftJoinKeys[i].RelName << "." << LeftJoinKeys[i].AttributeName
  101. << "=" << RightJoinKeys[i].RelName << "."
  102. << RightJoinKeys[i].AttributeName << ",";
  103. }
  104. stream << "\n";
  105. for (int i = 0; i < ntabs; i++){
  106. stream << " ";
  107. }
  108. stream << Stats << "\n";
  109. LeftArg->Print(stream, ntabs+1);
  110. RightArg->Print(stream, ntabs+1);
  111. }
  112. bool IsPKJoin(const TOptimizerStatistics& stats, const TVector<TJoinColumn>& joinKeys) {
  113. if (!stats.KeyColumns) {
  114. return false;
  115. }
  116. for(size_t i = 0; i < stats.KeyColumns->Data.size(); i++){
  117. if (std::find_if(joinKeys.begin(), joinKeys.end(),
  118. [&] (const TJoinColumn& c) { return c.AttributeName == stats.KeyColumns->Data[i];}) == joinKeys.end()) {
  119. return false;
  120. }
  121. }
  122. return true;
  123. }
  124. bool TBaseProviderContext::IsJoinApplicable(const std::shared_ptr<IBaseOptimizerNode>& left,
  125. const std::shared_ptr<IBaseOptimizerNode>& right,
  126. const TVector<TJoinColumn>& leftJoinKeys,
  127. const TVector<TJoinColumn>& rightJoinKeys,
  128. EJoinAlgoType joinAlgo,
  129. EJoinKind joinKind) {
  130. Y_UNUSED(left);
  131. Y_UNUSED(right);
  132. Y_UNUSED(leftJoinKeys);
  133. Y_UNUSED(rightJoinKeys);
  134. Y_UNUSED(joinKind);
  135. return joinAlgo == EJoinAlgoType::MapJoin;
  136. }
  137. double TBaseProviderContext::ComputeJoinCost(const TOptimizerStatistics& leftStats, const TOptimizerStatistics& rightStats, const double outputRows, const double outputByteSize, EJoinAlgoType joinAlgo) const {
  138. Y_UNUSED(outputByteSize);
  139. Y_UNUSED(joinAlgo);
  140. return leftStats.Nrows + 2.0 * rightStats.Nrows + outputRows;
  141. }
  142. /**
  143. * Compute the cost and output cardinality of a join
  144. *
  145. * Currently a very basic computation targeted at GraceJoin
  146. *
  147. * The build is on the right side, so we make the build side a bit more expensive than the probe
  148. */
  149. TOptimizerStatistics TBaseProviderContext::ComputeJoinStats(
  150. const TOptimizerStatistics& leftStats,
  151. const TOptimizerStatistics& rightStats,
  152. const TVector<TJoinColumn>& leftJoinKeys,
  153. const TVector<TJoinColumn>& rightJoinKeys,
  154. EJoinAlgoType joinAlgo,
  155. EJoinKind joinKind,
  156. TCardinalityHints::TCardinalityHint* maybeHint) const
  157. {
  158. double newCard{};
  159. EStatisticsType outputType;
  160. bool leftKeyColumns = false;
  161. bool rightKeyColumns = false;
  162. double selectivity = 1.0;
  163. bool isRightPKJoin = IsPKJoin(rightStats,rightJoinKeys);
  164. bool isLeftPKJoin = IsPKJoin(leftStats,leftJoinKeys);
  165. if (isRightPKJoin && isLeftPKJoin) {
  166. auto rightPKJoinCard = leftStats.Nrows * rightStats.Selectivity;
  167. auto leftPKJoinCard = rightStats.Nrows * leftStats.Selectivity;
  168. if (rightPKJoinCard > leftPKJoinCard) {
  169. isRightPKJoin = false;
  170. }
  171. }
  172. if (isRightPKJoin) {
  173. switch (joinKind) {
  174. case EJoinKind::LeftJoin:
  175. case EJoinKind::LeftOnly:
  176. newCard = leftStats.Nrows; break;
  177. default: {
  178. newCard = leftStats.Nrows * rightStats.Selectivity;
  179. }
  180. }
  181. selectivity = leftStats.Selectivity * rightStats.Selectivity;
  182. leftKeyColumns = true;
  183. if (leftStats.Type == EStatisticsType::BaseTable){
  184. outputType = EStatisticsType::FilteredFactTable;
  185. } else {
  186. outputType = leftStats.Type;
  187. }
  188. } else if (isLeftPKJoin) {
  189. switch (joinKind) {
  190. case EJoinKind::RightJoin:
  191. case EJoinKind::RightOnly:
  192. newCard = rightStats.Nrows; break;
  193. default: {
  194. newCard = leftStats.Selectivity * rightStats.Nrows;
  195. }
  196. }
  197. selectivity = leftStats.Selectivity * rightStats.Selectivity;
  198. rightKeyColumns = true;
  199. if (rightStats.Type == EStatisticsType::BaseTable){
  200. outputType = EStatisticsType::FilteredFactTable;
  201. } else {
  202. outputType = rightStats.Type;
  203. }
  204. } else {
  205. std::optional<double> lhsUniqueVals;
  206. std::optional<double> rhsUniqueVals;
  207. if (leftStats.ColumnStatistics && rightStats.ColumnStatistics && !leftJoinKeys.empty() && !rightJoinKeys.empty()) {
  208. auto lhs = leftJoinKeys[0].AttributeName;
  209. lhsUniqueVals = leftStats.ColumnStatistics->Data[lhs].NumUniqueVals;
  210. auto rhs = rightJoinKeys[0].AttributeName;
  211. rightStats.ColumnStatistics->Data[rhs];
  212. rhsUniqueVals = leftStats.ColumnStatistics->Data[lhs].NumUniqueVals;
  213. }
  214. if (lhsUniqueVals.has_value() && rhsUniqueVals.has_value()) {
  215. newCard = leftStats.Nrows * rightStats.Nrows / std::max(*lhsUniqueVals, *rhsUniqueVals);
  216. } else {
  217. newCard = 0.2 * leftStats.Nrows * rightStats.Nrows;
  218. }
  219. outputType = EStatisticsType::ManyManyJoin;
  220. }
  221. if (maybeHint) {
  222. newCard = maybeHint->ApplyHint(newCard);
  223. }
  224. int newNCols = leftStats.Ncols + rightStats.Ncols;
  225. double newByteSize = leftStats.Nrows ? (leftStats.ByteSize / leftStats.Nrows) * newCard : 0 +
  226. rightStats.Nrows ? (rightStats.ByteSize / rightStats.Nrows) * newCard : 0;
  227. double cost = ComputeJoinCost(leftStats, rightStats, newCard, newByteSize, joinAlgo)
  228. + leftStats.Cost + rightStats.Cost;
  229. auto result = TOptimizerStatistics(outputType, newCard, newNCols, newByteSize, cost,
  230. leftKeyColumns ? leftStats.KeyColumns : ( rightKeyColumns ? rightStats.KeyColumns : TIntrusivePtr<TOptimizerStatistics::TKeyColumns>()));
  231. result.Selectivity = selectivity;
  232. return result;
  233. }
  234. const TBaseProviderContext& TBaseProviderContext::Instance() {
  235. static TBaseProviderContext staticContext;
  236. return staticContext;
  237. }
  238. TVector<TString> TOptimizerHints::GetUnappliedString() {
  239. TVector<TString> res;
  240. for (const auto& hint: JoinAlgoHints->Hints) {
  241. if (!hint.Applied) {
  242. res.push_back(hint.StringRepr);
  243. }
  244. }
  245. for (const auto& hint: JoinOrderHints->Hints) {
  246. if (!hint.Applied) {
  247. res.push_back(hint.StringRepr);
  248. }
  249. }
  250. for (const auto& hint: CardinalityHints->Hints) {
  251. if (!hint.Applied) {
  252. res.push_back(hint.StringRepr);
  253. }
  254. }
  255. return res;
  256. }
  257. } // namespace NYql