SCEVValidator.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769
  1. #include "polly/Support/SCEVValidator.h"
  2. #include "polly/ScopDetection.h"
  3. #include "llvm/Analysis/RegionInfo.h"
  4. #include "llvm/Analysis/ScalarEvolution.h"
  5. #include "llvm/Analysis/ScalarEvolutionExpressions.h"
  6. #include "llvm/Support/Debug.h"
  7. using namespace llvm;
  8. using namespace polly;
  9. #define DEBUG_TYPE "polly-scev-validator"
  10. namespace SCEVType {
  11. /// The type of a SCEV
  12. ///
  13. /// To check for the validity of a SCEV we assign to each SCEV a type. The
  14. /// possible types are INT, PARAM, IV and INVALID. The order of the types is
  15. /// important. The subexpressions of SCEV with a type X can only have a type
  16. /// that is smaller or equal than X.
  17. enum TYPE {
  18. // An integer value.
  19. INT,
  20. // An expression that is constant during the execution of the Scop,
  21. // but that may depend on parameters unknown at compile time.
  22. PARAM,
  23. // An expression that may change during the execution of the SCoP.
  24. IV,
  25. // An invalid expression.
  26. INVALID
  27. };
  28. } // namespace SCEVType
  29. /// The result the validator returns for a SCEV expression.
  30. class ValidatorResult final {
  31. /// The type of the expression
  32. SCEVType::TYPE Type;
  33. /// The set of Parameters in the expression.
  34. ParameterSetTy Parameters;
  35. public:
  36. /// The copy constructor
  37. ValidatorResult(const ValidatorResult &Source) {
  38. Type = Source.Type;
  39. Parameters = Source.Parameters;
  40. }
  41. /// Construct a result with a certain type and no parameters.
  42. ValidatorResult(SCEVType::TYPE Type) : Type(Type) {
  43. assert(Type != SCEVType::PARAM && "Did you forget to pass the parameter");
  44. }
  45. /// Construct a result with a certain type and a single parameter.
  46. ValidatorResult(SCEVType::TYPE Type, const SCEV *Expr) : Type(Type) {
  47. Parameters.insert(Expr);
  48. }
  49. /// Get the type of the ValidatorResult.
  50. SCEVType::TYPE getType() { return Type; }
  51. /// Is the analyzed SCEV constant during the execution of the SCoP.
  52. bool isConstant() { return Type == SCEVType::INT || Type == SCEVType::PARAM; }
  53. /// Is the analyzed SCEV valid.
  54. bool isValid() { return Type != SCEVType::INVALID; }
  55. /// Is the analyzed SCEV of Type IV.
  56. bool isIV() { return Type == SCEVType::IV; }
  57. /// Is the analyzed SCEV of Type INT.
  58. bool isINT() { return Type == SCEVType::INT; }
  59. /// Is the analyzed SCEV of Type PARAM.
  60. bool isPARAM() { return Type == SCEVType::PARAM; }
  61. /// Get the parameters of this validator result.
  62. const ParameterSetTy &getParameters() { return Parameters; }
  63. /// Add the parameters of Source to this result.
  64. void addParamsFrom(const ValidatorResult &Source) {
  65. Parameters.insert(Source.Parameters.begin(), Source.Parameters.end());
  66. }
  67. /// Merge a result.
  68. ///
  69. /// This means to merge the parameters and to set the Type to the most
  70. /// specific Type that matches both.
  71. void merge(const ValidatorResult &ToMerge) {
  72. Type = std::max(Type, ToMerge.Type);
  73. addParamsFrom(ToMerge);
  74. }
  75. void print(raw_ostream &OS) {
  76. switch (Type) {
  77. case SCEVType::INT:
  78. OS << "SCEVType::INT";
  79. break;
  80. case SCEVType::PARAM:
  81. OS << "SCEVType::PARAM";
  82. break;
  83. case SCEVType::IV:
  84. OS << "SCEVType::IV";
  85. break;
  86. case SCEVType::INVALID:
  87. OS << "SCEVType::INVALID";
  88. break;
  89. }
  90. }
  91. };
  92. raw_ostream &operator<<(raw_ostream &OS, ValidatorResult &VR) {
  93. VR.print(OS);
  94. return OS;
  95. }
  96. /// Check if a SCEV is valid in a SCoP.
  97. class SCEVValidator : public SCEVVisitor<SCEVValidator, ValidatorResult> {
  98. private:
  99. const Region *R;
  100. Loop *Scope;
  101. ScalarEvolution &SE;
  102. InvariantLoadsSetTy *ILS;
  103. public:
  104. SCEVValidator(const Region *R, Loop *Scope, ScalarEvolution &SE,
  105. InvariantLoadsSetTy *ILS)
  106. : R(R), Scope(Scope), SE(SE), ILS(ILS) {}
  107. ValidatorResult visitConstant(const SCEVConstant *Constant) {
  108. return ValidatorResult(SCEVType::INT);
  109. }
  110. ValidatorResult visitZeroExtendOrTruncateExpr(const SCEV *Expr,
  111. const SCEV *Operand) {
  112. ValidatorResult Op = visit(Operand);
  113. auto Type = Op.getType();
  114. // If unsigned operations are allowed return the operand, otherwise
  115. // check if we can model the expression without unsigned assumptions.
  116. if (PollyAllowUnsignedOperations || Type == SCEVType::INVALID)
  117. return Op;
  118. if (Type == SCEVType::IV)
  119. return ValidatorResult(SCEVType::INVALID);
  120. return ValidatorResult(SCEVType::PARAM, Expr);
  121. }
  122. ValidatorResult visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
  123. return visit(Expr->getOperand());
  124. }
  125. ValidatorResult visitTruncateExpr(const SCEVTruncateExpr *Expr) {
  126. return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
  127. }
  128. ValidatorResult visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
  129. return visitZeroExtendOrTruncateExpr(Expr, Expr->getOperand());
  130. }
  131. ValidatorResult visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
  132. return visit(Expr->getOperand());
  133. }
  134. ValidatorResult visitAddExpr(const SCEVAddExpr *Expr) {
  135. ValidatorResult Return(SCEVType::INT);
  136. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  137. ValidatorResult Op = visit(Expr->getOperand(i));
  138. Return.merge(Op);
  139. // Early exit.
  140. if (!Return.isValid())
  141. break;
  142. }
  143. return Return;
  144. }
  145. ValidatorResult visitMulExpr(const SCEVMulExpr *Expr) {
  146. ValidatorResult Return(SCEVType::INT);
  147. bool HasMultipleParams = false;
  148. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  149. ValidatorResult Op = visit(Expr->getOperand(i));
  150. if (Op.isINT())
  151. continue;
  152. if (Op.isPARAM() && Return.isPARAM()) {
  153. HasMultipleParams = true;
  154. continue;
  155. }
  156. if ((Op.isIV() || Op.isPARAM()) && !Return.isINT()) {
  157. LLVM_DEBUG(
  158. dbgs() << "INVALID: More than one non-int operand in MulExpr\n"
  159. << "\tExpr: " << *Expr << "\n"
  160. << "\tPrevious expression type: " << Return << "\n"
  161. << "\tNext operand (" << Op << "): " << *Expr->getOperand(i)
  162. << "\n");
  163. return ValidatorResult(SCEVType::INVALID);
  164. }
  165. Return.merge(Op);
  166. }
  167. if (HasMultipleParams && Return.isValid())
  168. return ValidatorResult(SCEVType::PARAM, Expr);
  169. return Return;
  170. }
  171. ValidatorResult visitAddRecExpr(const SCEVAddRecExpr *Expr) {
  172. if (!Expr->isAffine()) {
  173. LLVM_DEBUG(dbgs() << "INVALID: AddRec is not affine");
  174. return ValidatorResult(SCEVType::INVALID);
  175. }
  176. ValidatorResult Start = visit(Expr->getStart());
  177. ValidatorResult Recurrence = visit(Expr->getStepRecurrence(SE));
  178. if (!Start.isValid())
  179. return Start;
  180. if (!Recurrence.isValid())
  181. return Recurrence;
  182. auto *L = Expr->getLoop();
  183. if (R->contains(L) && (!Scope || !L->contains(Scope))) {
  184. LLVM_DEBUG(
  185. dbgs() << "INVALID: Loop of AddRec expression boxed in an a "
  186. "non-affine subregion or has a non-synthesizable exit "
  187. "value.");
  188. return ValidatorResult(SCEVType::INVALID);
  189. }
  190. if (R->contains(L)) {
  191. if (Recurrence.isINT()) {
  192. ValidatorResult Result(SCEVType::IV);
  193. Result.addParamsFrom(Start);
  194. return Result;
  195. }
  196. LLVM_DEBUG(dbgs() << "INVALID: AddRec within scop has non-int"
  197. "recurrence part");
  198. return ValidatorResult(SCEVType::INVALID);
  199. }
  200. assert(Recurrence.isConstant() && "Expected 'Recurrence' to be constant");
  201. // Directly generate ValidatorResult for Expr if 'start' is zero.
  202. if (Expr->getStart()->isZero())
  203. return ValidatorResult(SCEVType::PARAM, Expr);
  204. // Translate AddRecExpr from '{start, +, inc}' into 'start + {0, +, inc}'
  205. // if 'start' is not zero.
  206. const SCEV *ZeroStartExpr = SE.getAddRecExpr(
  207. SE.getConstant(Expr->getStart()->getType(), 0),
  208. Expr->getStepRecurrence(SE), Expr->getLoop(), Expr->getNoWrapFlags());
  209. ValidatorResult ZeroStartResult =
  210. ValidatorResult(SCEVType::PARAM, ZeroStartExpr);
  211. ZeroStartResult.addParamsFrom(Start);
  212. return ZeroStartResult;
  213. }
  214. ValidatorResult visitSMaxExpr(const SCEVSMaxExpr *Expr) {
  215. ValidatorResult Return(SCEVType::INT);
  216. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  217. ValidatorResult Op = visit(Expr->getOperand(i));
  218. if (!Op.isValid())
  219. return Op;
  220. Return.merge(Op);
  221. }
  222. return Return;
  223. }
  224. ValidatorResult visitSMinExpr(const SCEVSMinExpr *Expr) {
  225. ValidatorResult Return(SCEVType::INT);
  226. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  227. ValidatorResult Op = visit(Expr->getOperand(i));
  228. if (!Op.isValid())
  229. return Op;
  230. Return.merge(Op);
  231. }
  232. return Return;
  233. }
  234. ValidatorResult visitUMaxExpr(const SCEVUMaxExpr *Expr) {
  235. // We do not support unsigned max operations. If 'Expr' is constant during
  236. // Scop execution we treat this as a parameter, otherwise we bail out.
  237. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  238. ValidatorResult Op = visit(Expr->getOperand(i));
  239. if (!Op.isConstant()) {
  240. LLVM_DEBUG(dbgs() << "INVALID: UMaxExpr has a non-constant operand");
  241. return ValidatorResult(SCEVType::INVALID);
  242. }
  243. }
  244. return ValidatorResult(SCEVType::PARAM, Expr);
  245. }
  246. ValidatorResult visitUMinExpr(const SCEVUMinExpr *Expr) {
  247. // We do not support unsigned min operations. If 'Expr' is constant during
  248. // Scop execution we treat this as a parameter, otherwise we bail out.
  249. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  250. ValidatorResult Op = visit(Expr->getOperand(i));
  251. if (!Op.isConstant()) {
  252. LLVM_DEBUG(dbgs() << "INVALID: UMinExpr has a non-constant operand");
  253. return ValidatorResult(SCEVType::INVALID);
  254. }
  255. }
  256. return ValidatorResult(SCEVType::PARAM, Expr);
  257. }
  258. ValidatorResult visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
  259. // We do not support unsigned min operations. If 'Expr' is constant during
  260. // Scop execution we treat this as a parameter, otherwise we bail out.
  261. for (int i = 0, e = Expr->getNumOperands(); i < e; ++i) {
  262. ValidatorResult Op = visit(Expr->getOperand(i));
  263. if (!Op.isConstant()) {
  264. LLVM_DEBUG(
  265. dbgs()
  266. << "INVALID: SCEVSequentialUMinExpr has a non-constant operand");
  267. return ValidatorResult(SCEVType::INVALID);
  268. }
  269. }
  270. return ValidatorResult(SCEVType::PARAM, Expr);
  271. }
  272. ValidatorResult visitGenericInst(Instruction *I, const SCEV *S) {
  273. if (R->contains(I)) {
  274. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an instruction "
  275. "within the region\n");
  276. return ValidatorResult(SCEVType::INVALID);
  277. }
  278. return ValidatorResult(SCEVType::PARAM, S);
  279. }
  280. ValidatorResult visitLoadInstruction(Instruction *I, const SCEV *S) {
  281. if (R->contains(I) && ILS) {
  282. ILS->insert(cast<LoadInst>(I));
  283. return ValidatorResult(SCEVType::PARAM, S);
  284. }
  285. return visitGenericInst(I, S);
  286. }
  287. ValidatorResult visitDivision(const SCEV *Dividend, const SCEV *Divisor,
  288. const SCEV *DivExpr,
  289. Instruction *SDiv = nullptr) {
  290. // First check if we might be able to model the division, thus if the
  291. // divisor is constant. If so, check the dividend, otherwise check if
  292. // the whole division can be seen as a parameter.
  293. if (isa<SCEVConstant>(Divisor) && !Divisor->isZero())
  294. return visit(Dividend);
  295. // For signed divisions use the SDiv instruction to check for a parameter
  296. // division, for unsigned divisions check the operands.
  297. if (SDiv)
  298. return visitGenericInst(SDiv, DivExpr);
  299. ValidatorResult LHS = visit(Dividend);
  300. ValidatorResult RHS = visit(Divisor);
  301. if (LHS.isConstant() && RHS.isConstant())
  302. return ValidatorResult(SCEVType::PARAM, DivExpr);
  303. LLVM_DEBUG(
  304. dbgs() << "INVALID: unsigned division of non-constant expressions");
  305. return ValidatorResult(SCEVType::INVALID);
  306. }
  307. ValidatorResult visitUDivExpr(const SCEVUDivExpr *Expr) {
  308. if (!PollyAllowUnsignedOperations)
  309. return ValidatorResult(SCEVType::INVALID);
  310. auto *Dividend = Expr->getLHS();
  311. auto *Divisor = Expr->getRHS();
  312. return visitDivision(Dividend, Divisor, Expr);
  313. }
  314. ValidatorResult visitSDivInstruction(Instruction *SDiv, const SCEV *Expr) {
  315. assert(SDiv->getOpcode() == Instruction::SDiv &&
  316. "Assumed SDiv instruction!");
  317. auto *Dividend = SE.getSCEV(SDiv->getOperand(0));
  318. auto *Divisor = SE.getSCEV(SDiv->getOperand(1));
  319. return visitDivision(Dividend, Divisor, Expr, SDiv);
  320. }
  321. ValidatorResult visitSRemInstruction(Instruction *SRem, const SCEV *S) {
  322. assert(SRem->getOpcode() == Instruction::SRem &&
  323. "Assumed SRem instruction!");
  324. auto *Divisor = SRem->getOperand(1);
  325. auto *CI = dyn_cast<ConstantInt>(Divisor);
  326. if (!CI || CI->isZeroValue())
  327. return visitGenericInst(SRem, S);
  328. auto *Dividend = SRem->getOperand(0);
  329. auto *DividendSCEV = SE.getSCEV(Dividend);
  330. return visit(DividendSCEV);
  331. }
  332. ValidatorResult visitUnknown(const SCEVUnknown *Expr) {
  333. Value *V = Expr->getValue();
  334. if (!Expr->getType()->isIntegerTy() && !Expr->getType()->isPointerTy()) {
  335. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr is not an integer or pointer");
  336. return ValidatorResult(SCEVType::INVALID);
  337. }
  338. if (isa<UndefValue>(V)) {
  339. LLVM_DEBUG(dbgs() << "INVALID: UnknownExpr references an undef value");
  340. return ValidatorResult(SCEVType::INVALID);
  341. }
  342. if (Instruction *I = dyn_cast<Instruction>(Expr->getValue())) {
  343. switch (I->getOpcode()) {
  344. case Instruction::IntToPtr:
  345. return visit(SE.getSCEVAtScope(I->getOperand(0), Scope));
  346. case Instruction::Load:
  347. return visitLoadInstruction(I, Expr);
  348. case Instruction::SDiv:
  349. return visitSDivInstruction(I, Expr);
  350. case Instruction::SRem:
  351. return visitSRemInstruction(I, Expr);
  352. default:
  353. return visitGenericInst(I, Expr);
  354. }
  355. }
  356. if (Expr->getType()->isPointerTy()) {
  357. if (isa<ConstantPointerNull>(V))
  358. return ValidatorResult(SCEVType::INT); // "int"
  359. }
  360. return ValidatorResult(SCEVType::PARAM, Expr);
  361. }
  362. };
  363. /// Check whether a SCEV refers to an SSA name defined inside a region.
  364. class SCEVInRegionDependences final {
  365. const Region *R;
  366. Loop *Scope;
  367. const InvariantLoadsSetTy &ILS;
  368. bool AllowLoops;
  369. bool HasInRegionDeps = false;
  370. public:
  371. SCEVInRegionDependences(const Region *R, Loop *Scope, bool AllowLoops,
  372. const InvariantLoadsSetTy &ILS)
  373. : R(R), Scope(Scope), ILS(ILS), AllowLoops(AllowLoops) {}
  374. bool follow(const SCEV *S) {
  375. if (auto Unknown = dyn_cast<SCEVUnknown>(S)) {
  376. Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
  377. if (Inst) {
  378. // When we invariant load hoist a load, we first make sure that there
  379. // can be no dependences created by it in the Scop region. So, we should
  380. // not consider scalar dependences to `LoadInst`s that are invariant
  381. // load hoisted.
  382. //
  383. // If this check is not present, then we create data dependences which
  384. // are strictly not necessary by tracking the invariant load as a
  385. // scalar.
  386. LoadInst *LI = dyn_cast<LoadInst>(Inst);
  387. if (LI && ILS.contains(LI))
  388. return false;
  389. }
  390. // Return true when Inst is defined inside the region R.
  391. if (!Inst || !R->contains(Inst))
  392. return true;
  393. HasInRegionDeps = true;
  394. return false;
  395. }
  396. if (auto AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
  397. if (AllowLoops)
  398. return true;
  399. auto *L = AddRec->getLoop();
  400. if (R->contains(L) && !L->contains(Scope)) {
  401. HasInRegionDeps = true;
  402. return false;
  403. }
  404. }
  405. return true;
  406. }
  407. bool isDone() { return false; }
  408. bool hasDependences() { return HasInRegionDeps; }
  409. };
  410. /// Find all loops referenced in SCEVAddRecExprs.
  411. class SCEVFindLoops final {
  412. SetVector<const Loop *> &Loops;
  413. public:
  414. SCEVFindLoops(SetVector<const Loop *> &Loops) : Loops(Loops) {}
  415. bool follow(const SCEV *S) {
  416. if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S))
  417. Loops.insert(AddRec->getLoop());
  418. return true;
  419. }
  420. bool isDone() { return false; }
  421. };
  422. void polly::findLoops(const SCEV *Expr, SetVector<const Loop *> &Loops) {
  423. SCEVFindLoops FindLoops(Loops);
  424. SCEVTraversal<SCEVFindLoops> ST(FindLoops);
  425. ST.visitAll(Expr);
  426. }
  427. /// Find all values referenced in SCEVUnknowns.
  428. class SCEVFindValues final {
  429. ScalarEvolution &SE;
  430. SetVector<Value *> &Values;
  431. public:
  432. SCEVFindValues(ScalarEvolution &SE, SetVector<Value *> &Values)
  433. : SE(SE), Values(Values) {}
  434. bool follow(const SCEV *S) {
  435. const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(S);
  436. if (!Unknown)
  437. return true;
  438. Values.insert(Unknown->getValue());
  439. Instruction *Inst = dyn_cast<Instruction>(Unknown->getValue());
  440. if (!Inst || (Inst->getOpcode() != Instruction::SRem &&
  441. Inst->getOpcode() != Instruction::SDiv))
  442. return false;
  443. auto *Dividend = SE.getSCEV(Inst->getOperand(1));
  444. if (!isa<SCEVConstant>(Dividend))
  445. return false;
  446. auto *Divisor = SE.getSCEV(Inst->getOperand(0));
  447. SCEVFindValues FindValues(SE, Values);
  448. SCEVTraversal<SCEVFindValues> ST(FindValues);
  449. ST.visitAll(Dividend);
  450. ST.visitAll(Divisor);
  451. return false;
  452. }
  453. bool isDone() { return false; }
  454. };
  455. void polly::findValues(const SCEV *Expr, ScalarEvolution &SE,
  456. SetVector<Value *> &Values) {
  457. SCEVFindValues FindValues(SE, Values);
  458. SCEVTraversal<SCEVFindValues> ST(FindValues);
  459. ST.visitAll(Expr);
  460. }
  461. bool polly::hasScalarDepsInsideRegion(const SCEV *Expr, const Region *R,
  462. llvm::Loop *Scope, bool AllowLoops,
  463. const InvariantLoadsSetTy &ILS) {
  464. SCEVInRegionDependences InRegionDeps(R, Scope, AllowLoops, ILS);
  465. SCEVTraversal<SCEVInRegionDependences> ST(InRegionDeps);
  466. ST.visitAll(Expr);
  467. return InRegionDeps.hasDependences();
  468. }
  469. bool polly::isAffineExpr(const Region *R, llvm::Loop *Scope, const SCEV *Expr,
  470. ScalarEvolution &SE, InvariantLoadsSetTy *ILS) {
  471. if (isa<SCEVCouldNotCompute>(Expr))
  472. return false;
  473. SCEVValidator Validator(R, Scope, SE, ILS);
  474. LLVM_DEBUG({
  475. dbgs() << "\n";
  476. dbgs() << "Expr: " << *Expr << "\n";
  477. dbgs() << "Region: " << R->getNameStr() << "\n";
  478. dbgs() << " -> ";
  479. });
  480. ValidatorResult Result = Validator.visit(Expr);
  481. LLVM_DEBUG({
  482. if (Result.isValid())
  483. dbgs() << "VALID\n";
  484. dbgs() << "\n";
  485. });
  486. return Result.isValid();
  487. }
  488. static bool isAffineExpr(Value *V, const Region *R, Loop *Scope,
  489. ScalarEvolution &SE, ParameterSetTy &Params) {
  490. auto *E = SE.getSCEV(V);
  491. if (isa<SCEVCouldNotCompute>(E))
  492. return false;
  493. SCEVValidator Validator(R, Scope, SE, nullptr);
  494. ValidatorResult Result = Validator.visit(E);
  495. if (!Result.isValid())
  496. return false;
  497. auto ResultParams = Result.getParameters();
  498. Params.insert(ResultParams.begin(), ResultParams.end());
  499. return true;
  500. }
  501. bool polly::isAffineConstraint(Value *V, const Region *R, Loop *Scope,
  502. ScalarEvolution &SE, ParameterSetTy &Params,
  503. bool OrExpr) {
  504. if (auto *ICmp = dyn_cast<ICmpInst>(V)) {
  505. return isAffineConstraint(ICmp->getOperand(0), R, Scope, SE, Params,
  506. true) &&
  507. isAffineConstraint(ICmp->getOperand(1), R, Scope, SE, Params, true);
  508. } else if (auto *BinOp = dyn_cast<BinaryOperator>(V)) {
  509. auto Opcode = BinOp->getOpcode();
  510. if (Opcode == Instruction::And || Opcode == Instruction::Or)
  511. return isAffineConstraint(BinOp->getOperand(0), R, Scope, SE, Params,
  512. false) &&
  513. isAffineConstraint(BinOp->getOperand(1), R, Scope, SE, Params,
  514. false);
  515. /* Fall through */
  516. }
  517. if (!OrExpr)
  518. return false;
  519. return ::isAffineExpr(V, R, Scope, SE, Params);
  520. }
  521. ParameterSetTy polly::getParamsInAffineExpr(const Region *R, Loop *Scope,
  522. const SCEV *Expr,
  523. ScalarEvolution &SE) {
  524. if (isa<SCEVCouldNotCompute>(Expr))
  525. return ParameterSetTy();
  526. InvariantLoadsSetTy ILS;
  527. SCEVValidator Validator(R, Scope, SE, &ILS);
  528. ValidatorResult Result = Validator.visit(Expr);
  529. assert(Result.isValid() && "Requested parameters for an invalid SCEV!");
  530. return Result.getParameters();
  531. }
  532. std::pair<const SCEVConstant *, const SCEV *>
  533. polly::extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
  534. auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
  535. if (auto *Constant = dyn_cast<SCEVConstant>(S))
  536. return std::make_pair(Constant, SE.getConstant(S->getType(), 1));
  537. auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
  538. if (AddRec) {
  539. auto *StartExpr = AddRec->getStart();
  540. if (StartExpr->isZero()) {
  541. auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
  542. auto *LeftOverAddRec =
  543. SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
  544. AddRec->getNoWrapFlags());
  545. return std::make_pair(StepPair.first, LeftOverAddRec);
  546. }
  547. return std::make_pair(ConstPart, S);
  548. }
  549. if (auto *Add = dyn_cast<SCEVAddExpr>(S)) {
  550. SmallVector<const SCEV *, 4> LeftOvers;
  551. auto Op0Pair = extractConstantFactor(Add->getOperand(0), SE);
  552. auto *Factor = Op0Pair.first;
  553. if (SE.isKnownNegative(Factor)) {
  554. Factor = cast<SCEVConstant>(SE.getNegativeSCEV(Factor));
  555. LeftOvers.push_back(SE.getNegativeSCEV(Op0Pair.second));
  556. } else {
  557. LeftOvers.push_back(Op0Pair.second);
  558. }
  559. for (unsigned u = 1, e = Add->getNumOperands(); u < e; u++) {
  560. auto OpUPair = extractConstantFactor(Add->getOperand(u), SE);
  561. // TODO: Use something smarter than equality here, e.g., gcd.
  562. if (Factor == OpUPair.first)
  563. LeftOvers.push_back(OpUPair.second);
  564. else if (Factor == SE.getNegativeSCEV(OpUPair.first))
  565. LeftOvers.push_back(SE.getNegativeSCEV(OpUPair.second));
  566. else
  567. return std::make_pair(ConstPart, S);
  568. }
  569. auto *NewAdd = SE.getAddExpr(LeftOvers, Add->getNoWrapFlags());
  570. return std::make_pair(Factor, NewAdd);
  571. }
  572. auto *Mul = dyn_cast<SCEVMulExpr>(S);
  573. if (!Mul)
  574. return std::make_pair(ConstPart, S);
  575. SmallVector<const SCEV *, 4> LeftOvers;
  576. for (auto *Op : Mul->operands())
  577. if (isa<SCEVConstant>(Op))
  578. ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
  579. else
  580. LeftOvers.push_back(Op);
  581. return std::make_pair(ConstPart, SE.getMulExpr(LeftOvers));
  582. }
  583. const SCEV *polly::tryForwardThroughPHI(const SCEV *Expr, Region &R,
  584. ScalarEvolution &SE,
  585. ScopDetection *SD) {
  586. if (auto *Unknown = dyn_cast<SCEVUnknown>(Expr)) {
  587. Value *V = Unknown->getValue();
  588. auto *PHI = dyn_cast<PHINode>(V);
  589. if (!PHI)
  590. return Expr;
  591. Value *Final = nullptr;
  592. for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
  593. BasicBlock *Incoming = PHI->getIncomingBlock(i);
  594. if (SD->isErrorBlock(*Incoming, R) && R.contains(Incoming))
  595. continue;
  596. if (Final)
  597. return Expr;
  598. Final = PHI->getIncomingValue(i);
  599. }
  600. if (Final)
  601. return SE.getSCEV(Final);
  602. }
  603. return Expr;
  604. }
  605. Value *polly::getUniqueNonErrorValue(PHINode *PHI, Region *R,
  606. ScopDetection *SD) {
  607. Value *V = nullptr;
  608. for (unsigned i = 0; i < PHI->getNumIncomingValues(); i++) {
  609. BasicBlock *BB = PHI->getIncomingBlock(i);
  610. if (!SD->isErrorBlock(*BB, *R)) {
  611. if (V)
  612. return nullptr;
  613. V = PHI->getIncomingValue(i);
  614. }
  615. }
  616. return V;
  617. }