SCEVValidator.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. #include "polly/Support/SCEVValidator.h"
  2. #include "polly/ScopDetection.h"
  3. #include "llvm/Analysis/RegionInfo.h"
  4. #include "llvm/Analysis/ScalarEvolution.h"
  5. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  6. #include "llvm/Support/Debug.h"
  7. using namespace llvm;
  8. using namespace polly;
  9. #define DEBUG_TYPE "polly-scev-validator"
  10. namespace SCEVType {
  11. /// The type of a SCEV
  12. ///
  13. /// To check for the validity of a SCEV we assign to each SCEV a type. The
  14. /// possible types are INT, PARAM, IV and INVALID. The order of the types is
  15. /// important. The subexpressions of SCEV with a type X can only have a type
  16. /// that is smaller or equal than X.
  17. enum TYPE {
  18. // An integer value.
  19. INT,
  20. // An expression that is constant during the execution of the Scop,
  21. // but that may depend on parameters unknown at compile time.
  22. PARAM,
  23. // An expression that may change during the execution of the SCoP.
  24. IV,
  25. // An invalid expression.
  26. INVALID
  27. };
  28. } // namespace SCEVType
  29. /// The result the validator returns for a SCEV expression.
  30. class ValidatorResult {
  31. /// The type of the expression
  32. SCEVType::TYPE Type;
  33. /// The set of Parameters in the expression.
  34. ParameterSetTy Parameters;
  35. public:
  36. /// The copy constructor
  37. ValidatorResult(const ValidatorResult &Source) {
  38. Type = Source.Type;
  39. Parameters = Source.Parameters;
  40. }
  41. /// Construct a result with a certain type and no parameters.
  42. ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
  43. assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
  44. }
  45. /// Construct a result with a certain type and a single parameter.
  46. ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
  47. Parameters.insert(Expr);
  48. }
  49. /// Get the type of the ValidatorResult.
  50. SCEVType::TYPE getType() { return Type; }
  51. /// Is the analyzed SCEV constant during the execution of the SCoP.
  52. bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
  53. /// Is the analyzed SCEV valid.
  54. bool isValid() { return Type != SCEVType::INVALID; }
  55. /// Is the analyzed SCEV of Type IV.
  56. bool isIV() { return Type == SCEVType::IV; }
  57. /// Is the analyzed SCEV of Type INT.
  58. bool isINT() { return Type == SCEVType::INT; }
  59. /// Is the analyzed SCEV of Type PARAM.
  60. bool isPARAM() { return Type == SCEVType::PARAM; }
  61. /// Get the parameters of this validator result.
  62. const ParameterSetTy &getParameters() { return Parameters; }
  63. /// Add the parameters of Source to this result.
  64. void addParamsFrom(const ValidatorResult &Source) {
  65. Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
  66. }
  67. /// Merge a result.
  68. ///
  69. /// This means to merge the parameters and to set the Type to the most
  70. /// specific Type that matches both.
  71. void merge(const ValidatorResult &ToMerge) {
  72. Type = std::max(Type, ToMerge.Type);
  73. addParamsFrom(ToMerge);
  74. }
  75. void print(raw_ostream &OS) {
  76. switch (Type) {
  77. case SCEVType::INT:
  78. OS << "SCEVType::INT";
  79. break;
  80. case SCEVType::PARAM:
  81. OS << "SCEVType::PARAM";
  82. break;
  83. case SCEVType::IV:
  84. OS << "SCEVType::IV";
  85. break;
  86. case SCEVType::INVALID:
  87. OS << "SCEVType::INVALID";
  88. break;
  89. }
  90. }
  91. };
  92. raw_ostream &operator<<(raw_ostream &OS, class ValidatorResult &VR) {
  93. VR.print(OS);
  94. return OS;
  95. }
  96. /// Check if a SCEV is valid in a SCoP.
  97. struct SCEVValidator
  98. : public SCEVVisitor<SCEVValidator, class ValidatorResult> {
  99. private:
  100. const Region *R;
  101. Loop *Scope;
  102. ScalarEvolution &SE;
  103. InvariantLoadsSetTy *ILS;
  104. public:
  105. SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
  106. InvariantLoadsSetTy *ILS)
  107. : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
  108. class ValidatorResult visitConstant(const SCEVConstant *Constant) {
  109. return ValidatorResult(SCEVType::INT);
  110. }
  111. class ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
  112. const SCEV *Operand) {
  113. ValidatorResult Op = visit(Operand);
  114. auto Type = Op.getType();
  115. // If unsigned operations are allowed return the operand, otherwise
  116. // check if we can model the expression without unsigned assumptions.
  117. if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
  118. return Op;
  119. if (Type == SCEVType::IV)
  120. return ValidatorResult(SCEVType::INVALID);
  121. return ValidatorResult(SCEVType::PARAM, Expr);
  122. }
  123. class ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
  124. return visit(Expr->getOperand());
  125. }
  126. class ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
  127. return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
  128. }
  129. class ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
  130. return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
  131. }
  132. class ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
  133. return visit(Expr->getOperand());
  134. }
  135. class ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
  136. ValidatorResult Return(SCEVType::INT);
  137. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  138. ValidatorResult Op = visit(Expr->getOperand(i));
  139. Return.merge(Op);
  140. // Early exit.
  141. if (!Return.isValid())
  142. break;
  143. }
  144. return Return;
  145. }
  146. class ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
  147. ValidatorResult Return(SCEVType::INT);
  148. bool HasMultipleParams = false;
  149. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  150. ValidatorResult Op = visit(Expr->getOperand(i));
  151. if (Op.isINT())
  152. continue;
  153. if (Op.isPARAM() && Return.isPARAM()) {
  154. HasMultipleParams = true;
  155. continue;
  156. }
  157. if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
  158. LLVM_DEBUG(
  159. dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
  160. << "\tExpr: " << *Expr << "\n"
  161. << "\tPrevious expression type: " << Return << "\n"
  162. << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
  163. << "\n");
  164. return ValidatorResult(SCEVType::INVALID);
  165. }
  166. Return.merge(Op);
  167. }
  168. if (HasMultipleParams && Return.isValid())
  169. return ValidatorResult(SCEVType::PARAM, Expr);
  170. return Return;
  171. }
  172. class ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
  173. if (!Expr->isAffine()) {
  174. LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
  175. return ValidatorResult(SCEVType::INVALID);
  176. }
  177. ValidatorResult Start = visit(Expr->getStart());
  178. ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
  179. if (!Start.isValid())
  180. return Start;
  181. if (!Recurrence.isValid())
  182. return Recurrence;
  183. auto *L = Expr->getLoop();
  184. if (R->contains(L) && (!Scope || !L->contains(Scope))) {
  185. LLVM_DEBUG(
  186. dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
  187. "non-affine subregion or has a non-synthesizable exit "
  188. "value.");
  189. return ValidatorResult(SCEVType::INVALID);
  190. }
  191. if (R->contains(L)) {
  192. if (Recurrence.isINT()) {
  193. ValidatorResult Result(SCEVType::IV);
  194. Result.addParamsFrom(Start);
  195. return Result;
  196. }
  197. LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
  198. "recurrence part");
  199. return ValidatorResult(SCEVType::INVALID);
  200. }
  201. assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
  202. // Directly generate ValidatorResult for Expr if 'start' is zero.
  203. if (Expr->getStart()->isZero())
  204. return ValidatorResult(SCEVType::PARAM, Expr);
  205. // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
  206. // if 'start' is not zero.
  207. const SCEV *ZeroStartExpr = SE.getAddRecExpr(
  208. SE.getConstant(Expr->getStart()->getType(), 0),
  209. Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
  210. ValidatorResult ZeroStartResult =
  211. ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
  212. ZeroStartResult.addParamsFrom(Start);
  213. return ZeroStartResult;
  214. }
  215. class ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
  216. ValidatorResult Return(SCEVType::INT);
  217. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  218. ValidatorResult Op = visit(Expr->getOperand(i));
  219. if (!Op.isValid())
  220. return Op;
  221. Return.merge(Op);
  222. }
  223. return Return;
  224. }
  225. class ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
  226. ValidatorResult Return(SCEVType::INT);
  227. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  228. ValidatorResult Op = visit(Expr->getOperand(i));
  229. if (!Op.isValid())
  230. return Op;
  231. Return.merge(Op);
  232. }
  233. return Return;
  234. }
  235. class ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
  236. // We do not support unsigned max operations. If 'Expr' is constant during
  237. // Scop execution we treat this as a parameter, otherwise we bail out.
  238. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  239. ValidatorResult Op = visit(Expr->getOperand(i));
  240. if (!Op.isConstant()) {
  241. LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
  242. return ValidatorResult(SCEVType::INVALID);
  243. }
  244. }
  245. return ValidatorResult(SCEVType::PARAM, Expr);
  246. }
  247. class ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
  248. // We do not support unsigned min operations. If 'Expr' is constant during
  249. // Scop execution we treat this as a parameter, otherwise we bail out.
  250. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  251. ValidatorResult Op = visit(Expr->getOperand(i));
  252. if (!Op.isConstant()) {
  253. LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
  254. return ValidatorResult(SCEVType::INVALID);
  255. }
  256. }
  257. return ValidatorResult(SCEVType::PARAM, Expr);
  258. }
  259. class ValidatorResult
  260. visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
  261. // We do not support unsigned min operations. If 'Expr' is constant during
  262. // Scop execution we treat this as a parameter, otherwise we bail out.
  263. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  264. ValidatorResult Op = visit(Expr->getOperand(i));
  265. if (!Op.isConstant()) {
  266. LLVM_DEBUG(
  267. dbgs()
  268. << "INVALID: SCEVSequentialUMinExpr has a non-constant operand");
  269. return ValidatorResult(SCEVType::INVALID);
  270. }
  271. }
  272. return ValidatorResult(SCEVType::PARAM, Expr);
  273. }
  274. ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
  275. if (R->contains(I)) {
  276. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
  277. "within the region\n");
  278. return ValidatorResult(SCEVType::INVALID);
  279. }
  280. return ValidatorResult(SCEVType::PARAM, S);
  281. }
  282. ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
  283. if (R->contains(I) && ILS) {
  284. ILS->insert(cast<LoadInst>(I));
  285. return ValidatorResult(SCEVType::PARAM, S);
  286. }
  287. return visitGenericInst(I, S);
  288. }
  289. ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
  290. const SCEV *DivExpr,
  291. Instruction *SDiv = nullptr) {
  292. // First check if we might be able to model the division, thus if the
  293. // divisor is constant. If so, check the dividend, otherwise check if
  294. // the whole division can be seen as a parameter.
  295. if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
  296. return visit(Dividend);
  297. // For signed divisions use the SDiv instruction to check for a parameter
  298. // division, for unsigned divisions check the operands.
  299. if (SDiv)
  300. return visitGenericInst(SDiv, DivExpr);
  301. ValidatorResult LHS = visit(Dividend);
  302. ValidatorResult RHS = visit(Divisor);
  303. if (LHS.isConstant() && RHS.isConstant())
  304. return ValidatorResult(SCEVType::PARAM, DivExpr);
  305. LLVM_DEBUG(
  306. dbgs() << "INVALID: unsigned division of non-constant expressions");
  307. return ValidatorResult(SCEVType::INVALID);
  308. }
  309. ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
  310. if (!PollyAllowUnsignedOperations)
  311. return ValidatorResult(SCEVType::INVALID);
  312. auto *Dividend = Expr->getLHS();
  313. auto *Divisor = Expr->getRHS();
  314. return visitDivision(Dividend, Divisor, Expr);
  315. }
  316. ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
  317. assert(SDiv->getOpcode() == Instruction::SDiv &&
  318. "Assumed SDiv instruction!");
  319. auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
  320. auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
  321. return visitDivision(Dividend, Divisor, Expr, SDiv);
  322. }
  323. ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
  324. assert(SRem->getOpcode() == Instruction::SRem &&
  325. "Assumed SRem instruction!");
  326. auto *Divisor = SRem->getOperand(1);
  327. auto *CI = dyn_cast<ConstantInt>(Divisor);
  328. if (!CI || CI->isZeroValue())
  329. return visitGenericInst(SRem, S);
  330. auto *Dividend = SRem->getOperand(0);
  331. auto *DividendSCEV = SE.getSCEV(Dividend);
  332. return visit(DividendSCEV);
  333. }
  334. ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
  335. Value *V = Expr->getValue();
  336. if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
  337. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
  338. return ValidatorResult(SCEVType::INVALID);
  339. }
  340. if (isa<UndefValue>(V)) {
  341. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
  342. return ValidatorResult(SCEVType::INVALID);
  343. }
  344. if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
  345. switch (I->getOpcode()) {
  346. case Instruction::IntToPtr:
  347. return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
  348. case Instruction::Load:
  349. return visitLoadInstruction(I, Expr);
  350. case Instruction::SDiv:
  351. return visitSDivInstruction(I, Expr);
  352. case Instruction::SRem:
  353. return visitSRemInstruction(I, Expr);
  354. default:
  355. return visitGenericInst(I, Expr);
  356. }
  357. }
  358. if (Expr->getType()->isPointerTy()) {
  359. if (isa<ConstantPointerNull>(V))
  360. return ValidatorResult(SCEVType::INT); // "int"
  361. }
  362. return ValidatorResult(SCEVType::PARAM, Expr);
  363. }
  364. };
  365. /// Check whether a SCEV refers to an SSA name defined inside a region.
  366. class SCEVInRegionDependences {
  367. const Region *R;
  368. Loop *Scope;
  369. const InvariantLoadsSetTy &ILS;
  370. bool AllowLoops;
  371. bool HasInRegionDeps = false;
  372. public:
  373. SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
  374. const InvariantLoadsSetTy &ILS)
  375. : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
  376. bool follow(const SCEV *S) {
  377. if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
  378. Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
  379. if (Inst) {
  380. // When we invariant load hoist a load, we first make sure that there
  381. // can be no dependences created by it in the Scop region. So, we should
  382. // not consider scalar dependences to `LoadInst`s that are invariant
  383. // load hoisted.
  384. //
  385. // If this check is not present, then we create data dependences which
  386. // are strictly not necessary by tracking the invariant load as a
  387. // scalar.
  388. LoadInst *LI = dyn_cast<LoadInst>(Inst);
  389. if (LI && ILS.contains(LI))
  390. return false;
  391. }
  392. // Return true when Inst is defined inside the region R.
  393. if (!Inst || !R->contains(Inst))
  394. return true;
  395. HasInRegionDeps = true;
  396. return false;
  397. }
  398. if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
  399. if (AllowLoops)
  400. return true;
  401. auto *L = AddRec->getLoop();
  402. if (R->contains(L) && !L->contains(Scope)) {
  403. HasInRegionDeps = true;
  404. return false;
  405. }
  406. }
  407. return true;
  408. }
  409. bool isDone() { return false; }
  410. bool hasDependences() { return HasInRegionDeps; }
  411. };
  412. namespace polly {
  413. /// Find all loops referenced in SCEVAddRecExprs.
  414. class SCEVFindLoops {
  415. SetVector<const Loop *> &Loops;
  416. public:
  417. SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
  418. bool follow(const SCEV *S) {
  419. if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
  420. Loops.insert(AddRec->getLoop());
  421. return true;
  422. }
  423. bool isDone() { return false; }
  424. };
  425. void findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
  426. SCEVFindLoops FindLoops(Loops);
  427. SCEVTraversal<SCEVFindLoops> ST(FindLoops);
  428. ST.visitAll(Expr);
  429. }
  430. /// Find all values referenced in SCEVUnknowns.
  431. class SCEVFindValues {
  432. ScalarEvolution &SE;
  433. SetVector<Value *> &Values;
  434. public:
  435. SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
  436. : SE(SE), Values(Values) {}
  437. bool follow(const SCEV *S) {
  438. const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
  439. if (!Unknown)
  440. return true;
  441. Values.insert(Unknown->getValue());
  442. Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
  443. if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
  444. Inst->getOpcode() != Instruction::SDiv))
  445. return false;
  446. auto *Dividend = SE.getSCEV(Inst->getOperand(1));
  447. if (!isa<SCEVConstant>(Dividend))
  448. return false;
  449. auto *Divisor = SE.getSCEV(Inst->getOperand(0));
  450. SCEVFindValues FindValues(SE, Values);
  451. SCEVTraversal<SCEVFindValues> ST(FindValues);
  452. ST.visitAll(Dividend);
  453. ST.visitAll(Divisor);
  454. return false;
  455. }
  456. bool isDone() { return false; }
  457. };
  458. void findValues(const SCEV *Expr, ScalarEvolution &SE,
  459. SetVector<Value *> &Values) {
  460. SCEVFindValues FindValues(SE, Values);
  461. SCEVTraversal<SCEVFindValues> ST(FindValues);
  462. ST.visitAll(Expr);
  463. }
  464. bool hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
  465. llvm::Loop *Scope, bool AllowLoops,
  466. const InvariantLoadsSetTy &ILS) {
  467. SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
  468. SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
  469. ST.visitAll(Expr);
  470. return InRegionDeps.hasDependences();
  471. }
  472. bool isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
  473. ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
  474. if (isa<SCEVCouldNotCompute>(Expr))
  475. return false;
  476. SCEVValidator Validator(R, Scope, SE, ILS);
  477. LLVM_DEBUG({
  478. dbgs() << "\n";
  479. dbgs() << "Expr: " << *Expr << "\n";
  480. dbgs() << "Region: " << R->getNameStr() << "\n";
  481. dbgs() << " -> ";
  482. });
  483. ValidatorResult Result = Validator.visit(Expr);
  484. LLVM_DEBUG({
  485. if (Result.isValid())
  486. dbgs() << "VALID\n";
  487. dbgs() << "\n";
  488. });
  489. return Result.isValid();
  490. }
  491. static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
  492. ScalarEvolution &SE, ParameterSetTy &Params) {
  493. auto *E = SE.getSCEV(V);
  494. if (isa<SCEVCouldNotCompute>(E))
  495. return false;
  496. SCEVValidator Validator(R, Scope, SE, nullptr);
  497. ValidatorResult Result = Validator.visit(E);
  498. if (!Result.isValid())
  499. return false;
  500. auto ResultParams = Result.getParameters();
  501. Params.insert(ResultParams.begin(), ResultParams.end());
  502. return true;
  503. }
  504. bool isAffineConstraint(Value *V, const Region *R, llvm::Loop *Scope,
  505. ScalarEvolution &SE, ParameterSetTy &Params,
  506. bool OrExpr) {
  507. if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
  508. return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
  509. true) &&
  510. isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
  511. } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
  512. auto Opcode = BinOp->getOpcode();
  513. if (Opcode == Instruction::And || Opcode == Instruction::Or)
  514. return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
  515. false) &&
  516. isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
  517. false);
  518. /* Fall through */
  519. }
  520. if (!OrExpr)
  521. return false;
  522. return isAffineExpr(V, R, Scope, SE, Params);
  523. }
  524. ParameterSetTy getParamsInAffineExpr(const Region *R, Loop *Scope,
  525. const SCEV *Expr, ScalarEvolution &SE) {
  526. if (isa<SCEVCouldNotCompute>(Expr))
  527. return ParameterSetTy();
  528. InvariantLoadsSetTy ILS;
  529. SCEVValidator Validator(R, Scope, SE, &ILS);
  530. ValidatorResult Result = Validator.visit(Expr);
  531. assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
  532. return Result.getParameters();
  533. }
  534. std::pair<const SCEVConstant *, const SCEV *>
  535. extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
  536. auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
  537. if (auto *Constant = dyn_cast<SCEVConstant>(S))
  538. return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
  539. auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
  540. if (AddRec) {
  541. auto *StartExpr = AddRec->getStart();
  542. if (StartExpr->isZero()) {
  543. auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
  544. auto *LeftOverAddRec =
  545. SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
  546. AddRec->getNoWrapFlags());
  547. return std::make_pair(StepPair.first, LeftOverAddRec);
  548. }
  549. return std::make_pair(ConstPart, S);
  550. }
  551. if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
  552. SmallVector<const SCEV *, 4> LeftOvers;
  553. auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
  554. auto *Factor = Op0Pair.first;
  555. if (SE.isKnownNegative(Factor)) {
  556. Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
  557. LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
  558. } else {
  559. LeftOvers.push_back(Op0Pair.second);
  560. }
  561. for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
  562. auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
  563. // TODO: Use something smarter than equality here, e.g., gcd.
  564. if (Factor == OpUPair.first)
  565. LeftOvers.push_back(OpUPair.second);
  566. else if (Factor == SE.getNegativeSCEV(OpUPair.first))
  567. LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
  568. else
  569. return std::make_pair(ConstPart, S);
  570. }
  571. auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
  572. return std::make_pair(Factor, NewAdd);
  573. }
  574. auto *Mul = dyn_cast<SCEVMulExpr>(S);
  575. if (!Mul)
  576. return std::make_pair(ConstPart, S);
  577. SmallVector<const SCEV *, 4> LeftOvers;
  578. for (auto *Op : Mul->operands())
  579. if (isa<SCEVConstant>(Op))
  580. ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
  581. else
  582. LeftOvers.push_back(Op);
  583. return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
  584. }
  585. const SCEV *tryForwardThroughPHI(const SCEV *Expr, Region &R,
  586. ScalarEvolution &SE, ScopDetection *SD) {
  587. if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
  588. Value *V = Unknown->getValue();
  589. auto *PHI = dyn_cast<PHINode>(V);
  590. if (!PHI)
  591. return Expr;
  592. Value *Final = nullptr;
  593. for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
  594. BasicBlock *Incoming = PHI->getIncomingBlock(i);
  595. if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
  596. continue;
  597. if (Final)
  598. return Expr;
  599. Final = PHI->getIncomingValue(i);
  600. }
  601. if (Final)
  602. return SE.getSCEV(Final);
  603. }
  604. return Expr;
  605. }
  606. Value *getUniqueNonErrorValue(PHINode *PHI, Region *R, ScopDetection *SD) {
  607. Value *V = nullptr;
  608. for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
  609. BasicBlock *BB = PHI->getIncomingBlock(i);
  610. if (!SD->isErrorBlock(*BB, *R)) {
  611. if (V)
  612. return nullptr;
  613. V = PHI->getIncomingValue(i);
  614. }
  615. }
  616. return V;
  617. }
  618. } // namespace polly