VectorUtils.h 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/Analysis/VectorUtils.h - Vector utilities -----------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file defines some vectorizer utilities.
  15. //
  16. //===----------------------------------------------------------------------===//
  17. #ifndef LLVM_ANALYSIS_VECTORUTILS_H
  18. #define LLVM_ANALYSIS_VECTORUTILS_H
  19. #include "llvm/ADT/MapVector.h"
  20. #include "llvm/ADT/SmallVector.h"
  21. #include "llvm/Analysis/LoopAccessAnalysis.h"
  22. #include "llvm/Support/CheckedArithmetic.h"
  23. namespace llvm {
  24. class TargetLibraryInfo;
  25. /// Describes the type of Parameters
  26. enum class VFParamKind {
  27. Vector, // No semantic information.
  28. OMP_Linear, // declare simd linear(i)
  29. OMP_LinearRef, // declare simd linear(ref(i))
  30. OMP_LinearVal, // declare simd linear(val(i))
  31. OMP_LinearUVal, // declare simd linear(uval(i))
  32. OMP_LinearPos, // declare simd linear(i:c) uniform(c)
  33. OMP_LinearValPos, // declare simd linear(val(i:c)) uniform(c)
  34. OMP_LinearRefPos, // declare simd linear(ref(i:c)) uniform(c)
  35. OMP_LinearUValPos, // declare simd linear(uval(i:c)) uniform(c)
  36. OMP_Uniform, // declare simd uniform(i)
  37. GlobalPredicate, // Global logical predicate that acts on all lanes
  38. // of the input and output mask concurrently. For
  39. // example, it is implied by the `M` token in the
  40. // Vector Function ABI mangled name.
  41. Unknown
  42. };
  43. /// Describes the type of Instruction Set Architecture
  44. enum class VFISAKind {
  45. AdvancedSIMD, // AArch64 Advanced SIMD (NEON)
  46. SVE, // AArch64 Scalable Vector Extension
  47. SSE, // x86 SSE
  48. AVX, // x86 AVX
  49. AVX2, // x86 AVX2
  50. AVX512, // x86 AVX512
  51. LLVM, // LLVM internal ISA for functions that are not
  52. // attached to an existing ABI via name mangling.
  53. Unknown // Unknown ISA
  54. };
  55. /// Encapsulates information needed to describe a parameter.
  56. ///
  57. /// The description of the parameter is not linked directly to
  58. /// OpenMP or any other vector function description. This structure
  59. /// is extendible to handle other paradigms that describe vector
  60. /// functions and their parameters.
  61. struct VFParameter {
  62. unsigned ParamPos; // Parameter Position in Scalar Function.
  63. VFParamKind ParamKind; // Kind of Parameter.
  64. int LinearStepOrPos = 0; // Step or Position of the Parameter.
  65. Align Alignment = Align(); // Optional alignment in bytes, defaulted to 1.
  66. // Comparison operator.
  67. bool operator==(const VFParameter &Other) const {
  68. return std::tie(ParamPos, ParamKind, LinearStepOrPos, Alignment) ==
  69. std::tie(Other.ParamPos, Other.ParamKind, Other.LinearStepOrPos,
  70. Other.Alignment);
  71. }
  72. };
  73. /// Contains the information about the kind of vectorization
  74. /// available.
  75. ///
  76. /// This object in independent on the paradigm used to
  77. /// represent vector functions. in particular, it is not attached to
  78. /// any target-specific ABI.
  79. struct VFShape {
  80. ElementCount VF; // Vectorization factor.
  81. SmallVector<VFParameter, 8> Parameters; // List of parameter information.
  82. // Comparison operator.
  83. bool operator==(const VFShape &Other) const {
  84. return std::tie(VF, Parameters) == std::tie(Other.VF, Other.Parameters);
  85. }
  86. /// Update the parameter in position P.ParamPos to P.
  87. void updateParam(VFParameter P) {
  88. assert(P.ParamPos < Parameters.size() && "Invalid parameter position.");
  89. Parameters[P.ParamPos] = P;
  90. assert(hasValidParameterList() && "Invalid parameter list");
  91. }
  92. // Retrieve the VFShape that can be used to map a (scalar) function to itself,
  93. // with VF = 1.
  94. static VFShape getScalarShape(const CallInst &CI) {
  95. return VFShape::get(CI, ElementCount::getFixed(1),
  96. /*HasGlobalPredicate*/ false);
  97. }
  98. // Retrieve the basic vectorization shape of the function, where all
  99. // parameters are mapped to VFParamKind::Vector with \p EC
  100. // lanes. Specifies whether the function has a Global Predicate
  101. // argument via \p HasGlobalPred.
  102. static VFShape get(const CallInst &CI, ElementCount EC, bool HasGlobalPred) {
  103. SmallVector<VFParameter, 8> Parameters;
  104. for (unsigned I = 0; I < CI.arg_size(); ++I)
  105. Parameters.push_back(VFParameter({I, VFParamKind::Vector}));
  106. if (HasGlobalPred)
  107. Parameters.push_back(
  108. VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate}));
  109. return {EC, Parameters};
  110. }
  111. /// Validation check on the Parameters in the VFShape.
  112. bool hasValidParameterList() const;
  113. };
  114. /// Holds the VFShape for a specific scalar to vector function mapping.
  115. struct VFInfo {
  116. VFShape Shape; /// Classification of the vector function.
  117. std::string ScalarName; /// Scalar Function Name.
  118. std::string VectorName; /// Vector Function Name associated to this VFInfo.
  119. VFISAKind ISA; /// Instruction Set Architecture.
  120. };
  121. namespace VFABI {
  122. /// LLVM Internal VFABI ISA token for vector functions.
  123. static constexpr char const *_LLVM_ = "_LLVM_";
  124. /// Prefix for internal name redirection for vector function that
  125. /// tells the compiler to scalarize the call using the scalar name
  126. /// of the function. For example, a mangled name like
  127. /// `_ZGV_LLVM_N2v_foo(_LLVM_Scalarize_foo)` would tell the
  128. /// vectorizer to vectorize the scalar call `foo`, and to scalarize
  129. /// it once vectorization is done.
  130. static constexpr char const *_LLVM_Scalarize_ = "_LLVM_Scalarize_";
  131. /// Function to construct a VFInfo out of a mangled names in the
  132. /// following format:
  133. ///
  134. /// <VFABI_name>{(<redirection>)}
  135. ///
  136. /// where <VFABI_name> is the name of the vector function, mangled according
  137. /// to the rules described in the Vector Function ABI of the target vector
  138. /// extension (or <isa> from now on). The <VFABI_name> is in the following
  139. /// format:
  140. ///
  141. /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)]
  142. ///
  143. /// This methods support demangling rules for the following <isa>:
  144. ///
  145. /// * AArch64: https://developer.arm.com/docs/101129/latest
  146. ///
  147. /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and
  148. /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt
  149. ///
  150. /// \param MangledName -> input string in the format
  151. /// _ZGV<isa><mask><vlen><parameters>_<scalarname>[(<redirection>)].
  152. /// \param M -> Module used to retrieve informations about the vector
  153. /// function that are not possible to retrieve from the mangled
  154. /// name. At the moment, this parameter is needed only to retrieve the
  155. /// Vectorization Factor of scalable vector functions from their
  156. /// respective IR declarations.
  157. std::optional<VFInfo> tryDemangleForVFABI(StringRef MangledName,
  158. const Module &M);
  159. /// This routine mangles the given VectorName according to the LangRef
  160. /// specification for vector-function-abi-variant attribute and is specific to
  161. /// the TLI mappings. It is the responsibility of the caller to make sure that
  162. /// this is only used if all parameters in the vector function are vector type.
  163. /// This returned string holds scalar-to-vector mapping:
  164. /// _ZGV<isa><mask><vlen><vparams>_<scalarname>(<vectorname>)
  165. ///
  166. /// where:
  167. ///
  168. /// <isa> = "_LLVM_"
  169. /// <mask> = "N". Note: TLI does not support masked interfaces.
  170. /// <vlen> = Number of concurrent lanes, stored in the `VectorizationFactor`
  171. /// field of the `VecDesc` struct. If the number of lanes is scalable
  172. /// then 'x' is printed instead.
  173. /// <vparams> = "v", as many as are the numArgs.
  174. /// <scalarname> = the name of the scalar function.
  175. /// <vectorname> = the name of the vector function.
  176. std::string mangleTLIVectorName(StringRef VectorName, StringRef ScalarName,
  177. unsigned numArgs, ElementCount VF);
  178. /// Retrieve the `VFParamKind` from a string token.
  179. VFParamKind getVFParamKindFromString(const StringRef Token);
  180. // Name of the attribute where the variant mappings are stored.
  181. static constexpr char const *MappingsAttrName = "vector-function-abi-variant";
  182. /// Populates a set of strings representing the Vector Function ABI variants
  183. /// associated to the CallInst CI. If the CI does not contain the
  184. /// vector-function-abi-variant attribute, we return without populating
  185. /// VariantMappings, i.e. callers of getVectorVariantNames need not check for
  186. /// the presence of the attribute (see InjectTLIMappings).
  187. void getVectorVariantNames(const CallInst &CI,
  188. SmallVectorImpl<std::string> &VariantMappings);
  189. } // end namespace VFABI
  190. /// The Vector Function Database.
  191. ///
  192. /// Helper class used to find the vector functions associated to a
  193. /// scalar CallInst.
  194. class VFDatabase {
  195. /// The Module of the CallInst CI.
  196. const Module *M;
  197. /// The CallInst instance being queried for scalar to vector mappings.
  198. const CallInst &CI;
  199. /// List of vector functions descriptors associated to the call
  200. /// instruction.
  201. const SmallVector<VFInfo, 8> ScalarToVectorMappings;
  202. /// Retrieve the scalar-to-vector mappings associated to the rule of
  203. /// a vector Function ABI.
  204. static void getVFABIMappings(const CallInst &CI,
  205. SmallVectorImpl<VFInfo> &Mappings) {
  206. if (!CI.getCalledFunction())
  207. return;
  208. const StringRef ScalarName = CI.getCalledFunction()->getName();
  209. SmallVector<std::string, 8> ListOfStrings;
  210. // The check for the vector-function-abi-variant attribute is done when
  211. // retrieving the vector variant names here.
  212. VFABI::getVectorVariantNames(CI, ListOfStrings);
  213. if (ListOfStrings.empty())
  214. return;
  215. for (const auto &MangledName : ListOfStrings) {
  216. const std::optional<VFInfo> Shape =
  217. VFABI::tryDemangleForVFABI(MangledName, *(CI.getModule()));
  218. // A match is found via scalar and vector names, and also by
  219. // ensuring that the variant described in the attribute has a
  220. // corresponding definition or declaration of the vector
  221. // function in the Module M.
  222. if (Shape && (Shape->ScalarName == ScalarName)) {
  223. assert(CI.getModule()->getFunction(Shape->VectorName) &&
  224. "Vector function is missing.");
  225. Mappings.push_back(*Shape);
  226. }
  227. }
  228. }
  229. public:
  230. /// Retrieve all the VFInfo instances associated to the CallInst CI.
  231. static SmallVector<VFInfo, 8> getMappings(const CallInst &CI) {
  232. SmallVector<VFInfo, 8> Ret;
  233. // Get mappings from the Vector Function ABI variants.
  234. getVFABIMappings(CI, Ret);
  235. // Other non-VFABI variants should be retrieved here.
  236. return Ret;
  237. }
  238. /// Constructor, requires a CallInst instance.
  239. VFDatabase(CallInst &CI)
  240. : M(CI.getModule()), CI(CI),
  241. ScalarToVectorMappings(VFDatabase::getMappings(CI)) {}
  242. /// \defgroup VFDatabase query interface.
  243. ///
  244. /// @{
  245. /// Retrieve the Function with VFShape \p Shape.
  246. Function *getVectorizedFunction(const VFShape &Shape) const {
  247. if (Shape == VFShape::getScalarShape(CI))
  248. return CI.getCalledFunction();
  249. for (const auto &Info : ScalarToVectorMappings)
  250. if (Info.Shape == Shape)
  251. return M->getFunction(Info.VectorName);
  252. return nullptr;
  253. }
  254. /// @}
  255. };
  256. template <typename T> class ArrayRef;
  257. class DemandedBits;
  258. class GetElementPtrInst;
  259. template <typename InstTy> class InterleaveGroup;
  260. class IRBuilderBase;
  261. class Loop;
  262. class ScalarEvolution;
  263. class TargetTransformInfo;
  264. class Type;
  265. class Value;
  266. namespace Intrinsic {
  267. typedef unsigned ID;
  268. }
  269. /// A helper function for converting Scalar types to vector types. If
  270. /// the incoming type is void, we return void. If the EC represents a
  271. /// scalar, we return the scalar type.
  272. inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
  273. if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
  274. return Scalar;
  275. return VectorType::get(Scalar, EC);
  276. }
  277. inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
  278. return ToVectorTy(Scalar, ElementCount::getFixed(VF));
  279. }
  280. /// Identify if the intrinsic is trivially vectorizable.
  281. /// This method returns true if the intrinsic's argument types are all scalars
  282. /// for the scalar form of the intrinsic and all vectors (or scalars handled by
  283. /// isVectorIntrinsicWithScalarOpAtArg) for the vector form of the intrinsic.
  284. bool isTriviallyVectorizable(Intrinsic::ID ID);
  285. /// Identifies if the vector form of the intrinsic has a scalar operand.
  286. bool isVectorIntrinsicWithScalarOpAtArg(Intrinsic::ID ID,
  287. unsigned ScalarOpdIdx);
  288. /// Identifies if the vector form of the intrinsic has a operand that has
  289. /// an overloaded type.
  290. bool isVectorIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID, unsigned OpdIdx);
  291. /// Returns intrinsic ID for call.
  292. /// For the input call instruction it finds mapping intrinsic and returns
  293. /// its intrinsic ID, in case it does not found it return not_intrinsic.
  294. Intrinsic::ID getVectorIntrinsicIDForCall(const CallInst *CI,
  295. const TargetLibraryInfo *TLI);
  296. /// Find the operand of the GEP that should be checked for consecutive
  297. /// stores. This ignores trailing indices that have no effect on the final
  298. /// pointer.
  299. unsigned getGEPInductionOperand(const GetElementPtrInst *Gep);
  300. /// If the argument is a GEP, then returns the operand identified by
  301. /// getGEPInductionOperand. However, if there is some other non-loop-invariant
  302. /// operand, it returns that instead.
  303. Value *stripGetElementPtr(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
  304. /// If a value has only one user that is a CastInst, return it.
  305. Value *getUniqueCastUse(Value *Ptr, Loop *Lp, Type *Ty);
  306. /// Get the stride of a pointer access in a loop. Looks for symbolic
  307. /// strides "a[i*stride]". Returns the symbolic stride, or null otherwise.
  308. Value *getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp);
  309. /// Given a vector and an element number, see if the scalar value is
  310. /// already around as a register, for example if it were inserted then extracted
  311. /// from the vector.
  312. Value *findScalarElement(Value *V, unsigned EltNo);
  313. /// If all non-negative \p Mask elements are the same value, return that value.
  314. /// If all elements are negative (undefined) or \p Mask contains different
  315. /// non-negative values, return -1.
  316. int getSplatIndex(ArrayRef<int> Mask);
  317. /// Get splat value if the input is a splat vector or return nullptr.
  318. /// The value may be extracted from a splat constants vector or from
  319. /// a sequence of instructions that broadcast a single value into a vector.
  320. Value *getSplatValue(const Value *V);
  321. /// Return true if each element of the vector value \p V is poisoned or equal to
  322. /// every other non-poisoned element. If an index element is specified, either
  323. /// every element of the vector is poisoned or the element at that index is not
  324. /// poisoned and equal to every other non-poisoned element.
  325. /// This may be more powerful than the related getSplatValue() because it is
  326. /// not limited by finding a scalar source value to a splatted vector.
  327. bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0);
  328. /// Transform a shuffle mask's output demanded element mask into demanded
  329. /// element masks for the 2 operands, returns false if the mask isn't valid.
  330. /// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth].
  331. /// \p AllowUndefElts permits "-1" indices to be treated as undef.
  332. bool getShuffleDemandedElts(int SrcWidth, ArrayRef<int> Mask,
  333. const APInt &DemandedElts, APInt &DemandedLHS,
  334. APInt &DemandedRHS, bool AllowUndefElts = false);
  335. /// Replace each shuffle mask index with the scaled sequential indices for an
  336. /// equivalent mask of narrowed elements. Mask elements that are less than 0
  337. /// (sentinel values) are repeated in the output mask.
  338. ///
  339. /// Example with Scale = 4:
  340. /// <4 x i32> <3, 2, 0, -1> -->
  341. /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1>
  342. ///
  343. /// This is the reverse process of widening shuffle mask elements, but it always
  344. /// succeeds because the indexes can always be multiplied (scaled up) to map to
  345. /// narrower vector elements.
  346. void narrowShuffleMaskElts(int Scale, ArrayRef<int> Mask,
  347. SmallVectorImpl<int> &ScaledMask);
  348. /// Try to transform a shuffle mask by replacing elements with the scaled index
  349. /// for an equivalent mask of widened elements. If all mask elements that would
  350. /// map to a wider element of the new mask are the same negative number
  351. /// (sentinel value), that element of the new mask is the same value. If any
  352. /// element in a given slice is negative and some other element in that slice is
  353. /// not the same value, return false (partial matches with sentinel values are
  354. /// not allowed).
  355. ///
  356. /// Example with Scale = 4:
  357. /// <16 x i8> <12, 13, 14, 15, 8, 9, 10, 11, 0, 1, 2, 3, -1, -1, -1, -1> -->
  358. /// <4 x i32> <3, 2, 0, -1>
  359. ///
  360. /// This is the reverse process of narrowing shuffle mask elements if it
  361. /// succeeds. This transform is not always possible because indexes may not
  362. /// divide evenly (scale down) to map to wider vector elements.
  363. bool widenShuffleMaskElts(int Scale, ArrayRef<int> Mask,
  364. SmallVectorImpl<int> &ScaledMask);
  365. /// Repetitively apply `widenShuffleMaskElts()` for as long as it succeeds,
  366. /// to get the shuffle mask with widest possible elements.
  367. void getShuffleMaskWithWidestElts(ArrayRef<int> Mask,
  368. SmallVectorImpl<int> &ScaledMask);
  369. /// Splits and processes shuffle mask depending on the number of input and
  370. /// output registers. The function does 2 main things: 1) splits the
  371. /// source/destination vectors into real registers; 2) do the mask analysis to
  372. /// identify which real registers are permuted. Then the function processes
  373. /// resulting registers mask using provided action items. If no input register
  374. /// is defined, \p NoInputAction action is used. If only 1 input register is
  375. /// used, \p SingleInputAction is used, otherwise \p ManyInputsAction is used to
  376. /// process > 2 input registers and masks.
  377. /// \param Mask Original shuffle mask.
  378. /// \param NumOfSrcRegs Number of source registers.
  379. /// \param NumOfDestRegs Number of destination registers.
  380. /// \param NumOfUsedRegs Number of actually used destination registers.
  381. void processShuffleMasks(
  382. ArrayRef<int> Mask, unsigned NumOfSrcRegs, unsigned NumOfDestRegs,
  383. unsigned NumOfUsedRegs, function_ref<void()> NoInputAction,
  384. function_ref<void(ArrayRef<int>, unsigned, unsigned)> SingleInputAction,
  385. function_ref<void(ArrayRef<int>, unsigned, unsigned)> ManyInputsAction);
  386. /// Compute a map of integer instructions to their minimum legal type
  387. /// size.
  388. ///
  389. /// C semantics force sub-int-sized values (e.g. i8, i16) to be promoted to int
  390. /// type (e.g. i32) whenever arithmetic is performed on them.
  391. ///
  392. /// For targets with native i8 or i16 operations, usually InstCombine can shrink
  393. /// the arithmetic type down again. However InstCombine refuses to create
  394. /// illegal types, so for targets without i8 or i16 registers, the lengthening
  395. /// and shrinking remains.
  396. ///
  397. /// Most SIMD ISAs (e.g. NEON) however support vectors of i8 or i16 even when
  398. /// their scalar equivalents do not, so during vectorization it is important to
  399. /// remove these lengthens and truncates when deciding the profitability of
  400. /// vectorization.
  401. ///
  402. /// This function analyzes the given range of instructions and determines the
  403. /// minimum type size each can be converted to. It attempts to remove or
  404. /// minimize type size changes across each def-use chain, so for example in the
  405. /// following code:
  406. ///
  407. /// %1 = load i8, i8*
  408. /// %2 = add i8 %1, 2
  409. /// %3 = load i16, i16*
  410. /// %4 = zext i8 %2 to i32
  411. /// %5 = zext i16 %3 to i32
  412. /// %6 = add i32 %4, %5
  413. /// %7 = trunc i32 %6 to i16
  414. ///
  415. /// Instruction %6 must be done at least in i16, so computeMinimumValueSizes
  416. /// will return: {%1: 16, %2: 16, %3: 16, %4: 16, %5: 16, %6: 16, %7: 16}.
  417. ///
  418. /// If the optional TargetTransformInfo is provided, this function tries harder
  419. /// to do less work by only looking at illegal types.
  420. MapVector<Instruction*, uint64_t>
  421. computeMinimumValueSizes(ArrayRef<BasicBlock*> Blocks,
  422. DemandedBits &DB,
  423. const TargetTransformInfo *TTI=nullptr);
  424. /// Compute the union of two access-group lists.
  425. ///
  426. /// If the list contains just one access group, it is returned directly. If the
  427. /// list is empty, returns nullptr.
  428. MDNode *uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2);
  429. /// Compute the access-group list of access groups that @p Inst1 and @p Inst2
  430. /// are both in. If either instruction does not access memory at all, it is
  431. /// considered to be in every list.
  432. ///
  433. /// If the list contains just one access group, it is returned directly. If the
  434. /// list is empty, returns nullptr.
  435. MDNode *intersectAccessGroups(const Instruction *Inst1,
  436. const Instruction *Inst2);
  437. /// Specifically, let Kinds = [MD_tbaa, MD_alias_scope, MD_noalias, MD_fpmath,
  438. /// MD_nontemporal, MD_access_group].
  439. /// For K in Kinds, we get the MDNode for K from each of the
  440. /// elements of VL, compute their "intersection" (i.e., the most generic
  441. /// metadata value that covers all of the individual values), and set I's
  442. /// metadata for M equal to the intersection value.
  443. ///
  444. /// This function always sets a (possibly null) value for each K in Kinds.
  445. Instruction *propagateMetadata(Instruction *I, ArrayRef<Value *> VL);
  446. /// Create a mask that filters the members of an interleave group where there
  447. /// are gaps.
  448. ///
  449. /// For example, the mask for \p Group with interleave-factor 3
  450. /// and \p VF 4, that has only its first member present is:
  451. ///
  452. /// <1,0,0,1,0,0,1,0,0,1,0,0>
  453. ///
  454. /// Note: The result is a mask of 0's and 1's, as opposed to the other
  455. /// create[*]Mask() utilities which create a shuffle mask (mask that
  456. /// consists of indices).
  457. Constant *createBitMaskForGaps(IRBuilderBase &Builder, unsigned VF,
  458. const InterleaveGroup<Instruction> &Group);
  459. /// Create a mask with replicated elements.
  460. ///
  461. /// This function creates a shuffle mask for replicating each of the \p VF
  462. /// elements in a vector \p ReplicationFactor times. It can be used to
  463. /// transform a mask of \p VF elements into a mask of
  464. /// \p VF * \p ReplicationFactor elements used by a predicated
  465. /// interleaved-group of loads/stores whose Interleaved-factor ==
  466. /// \p ReplicationFactor.
  467. ///
  468. /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is:
  469. ///
  470. /// <0,0,0,1,1,1,2,2,2,3,3,3>
  471. llvm::SmallVector<int, 16> createReplicatedMask(unsigned ReplicationFactor,
  472. unsigned VF);
  473. /// Create an interleave shuffle mask.
  474. ///
  475. /// This function creates a shuffle mask for interleaving \p NumVecs vectors of
  476. /// vectorization factor \p VF into a single wide vector. The mask is of the
  477. /// form:
  478. ///
  479. /// <0, VF, VF * 2, ..., VF * (NumVecs - 1), 1, VF + 1, VF * 2 + 1, ...>
  480. ///
  481. /// For example, the mask for VF = 4 and NumVecs = 2 is:
  482. ///
  483. /// <0, 4, 1, 5, 2, 6, 3, 7>.
  484. llvm::SmallVector<int, 16> createInterleaveMask(unsigned VF, unsigned NumVecs);
  485. /// Create a stride shuffle mask.
  486. ///
  487. /// This function creates a shuffle mask whose elements begin at \p Start and
  488. /// are incremented by \p Stride. The mask can be used to deinterleave an
  489. /// interleaved vector into separate vectors of vectorization factor \p VF. The
  490. /// mask is of the form:
  491. ///
  492. /// <Start, Start + Stride, ..., Start + Stride * (VF - 1)>
  493. ///
  494. /// For example, the mask for Start = 0, Stride = 2, and VF = 4 is:
  495. ///
  496. /// <0, 2, 4, 6>
  497. llvm::SmallVector<int, 16> createStrideMask(unsigned Start, unsigned Stride,
  498. unsigned VF);
  499. /// Create a sequential shuffle mask.
  500. ///
  501. /// This function creates shuffle mask whose elements are sequential and begin
  502. /// at \p Start. The mask contains \p NumInts integers and is padded with \p
  503. /// NumUndefs undef values. The mask is of the form:
  504. ///
  505. /// <Start, Start + 1, ... Start + NumInts - 1, undef_1, ... undef_NumUndefs>
  506. ///
  507. /// For example, the mask for Start = 0, NumInsts = 4, and NumUndefs = 4 is:
  508. ///
  509. /// <0, 1, 2, 3, undef, undef, undef, undef>
  510. llvm::SmallVector<int, 16>
  511. createSequentialMask(unsigned Start, unsigned NumInts, unsigned NumUndefs);
  512. /// Given a shuffle mask for a binary shuffle, create the equivalent shuffle
  513. /// mask assuming both operands are identical. This assumes that the unary
  514. /// shuffle will use elements from operand 0 (operand 1 will be unused).
  515. llvm::SmallVector<int, 16> createUnaryMask(ArrayRef<int> Mask,
  516. unsigned NumElts);
  517. /// Concatenate a list of vectors.
  518. ///
  519. /// This function generates code that concatenate the vectors in \p Vecs into a
  520. /// single large vector. The number of vectors should be greater than one, and
  521. /// their element types should be the same. The number of elements in the
  522. /// vectors should also be the same; however, if the last vector has fewer
  523. /// elements, it will be padded with undefs.
  524. Value *concatenateVectors(IRBuilderBase &Builder, ArrayRef<Value *> Vecs);
  525. /// Given a mask vector of i1, Return true if all of the elements of this
  526. /// predicate mask are known to be false or undef. That is, return true if all
  527. /// lanes can be assumed inactive.
  528. bool maskIsAllZeroOrUndef(Value *Mask);
  529. /// Given a mask vector of i1, Return true if all of the elements of this
  530. /// predicate mask are known to be true or undef. That is, return true if all
  531. /// lanes can be assumed active.
  532. bool maskIsAllOneOrUndef(Value *Mask);
  533. /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
  534. /// for each lane which may be active.
  535. APInt possiblyDemandedEltsInMask(Value *Mask);
  536. /// The group of interleaved loads/stores sharing the same stride and
  537. /// close to each other.
  538. ///
  539. /// Each member in this group has an index starting from 0, and the largest
  540. /// index should be less than interleaved factor, which is equal to the absolute
  541. /// value of the access's stride.
  542. ///
  543. /// E.g. An interleaved load group of factor 4:
  544. /// for (unsigned i = 0; i < 1024; i+=4) {
  545. /// a = A[i]; // Member of index 0
  546. /// b = A[i+1]; // Member of index 1
  547. /// d = A[i+3]; // Member of index 3
  548. /// ...
  549. /// }
  550. ///
  551. /// An interleaved store group of factor 4:
  552. /// for (unsigned i = 0; i < 1024; i+=4) {
  553. /// ...
  554. /// A[i] = a; // Member of index 0
  555. /// A[i+1] = b; // Member of index 1
  556. /// A[i+2] = c; // Member of index 2
  557. /// A[i+3] = d; // Member of index 3
  558. /// }
  559. ///
  560. /// Note: the interleaved load group could have gaps (missing members), but
  561. /// the interleaved store group doesn't allow gaps.
  562. template <typename InstTy> class InterleaveGroup {
  563. public:
  564. InterleaveGroup(uint32_t Factor, bool Reverse, Align Alignment)
  565. : Factor(Factor), Reverse(Reverse), Alignment(Alignment),
  566. InsertPos(nullptr) {}
  567. InterleaveGroup(InstTy *Instr, int32_t Stride, Align Alignment)
  568. : Alignment(Alignment), InsertPos(Instr) {
  569. Factor = std::abs(Stride);
  570. assert(Factor > 1 && "Invalid interleave factor");
  571. Reverse = Stride < 0;
  572. Members[0] = Instr;
  573. }
  574. bool isReverse() const { return Reverse; }
  575. uint32_t getFactor() const { return Factor; }
  576. Align getAlign() const { return Alignment; }
  577. uint32_t getNumMembers() const { return Members.size(); }
  578. /// Try to insert a new member \p Instr with index \p Index and
  579. /// alignment \p NewAlign. The index is related to the leader and it could be
  580. /// negative if it is the new leader.
  581. ///
  582. /// \returns false if the instruction doesn't belong to the group.
  583. bool insertMember(InstTy *Instr, int32_t Index, Align NewAlign) {
  584. // Make sure the key fits in an int32_t.
  585. std::optional<int32_t> MaybeKey = checkedAdd(Index, SmallestKey);
  586. if (!MaybeKey)
  587. return false;
  588. int32_t Key = *MaybeKey;
  589. // Skip if the key is used for either the tombstone or empty special values.
  590. if (DenseMapInfo<int32_t>::getTombstoneKey() == Key ||
  591. DenseMapInfo<int32_t>::getEmptyKey() == Key)
  592. return false;
  593. // Skip if there is already a member with the same index.
  594. if (Members.find(Key) != Members.end())
  595. return false;
  596. if (Key > LargestKey) {
  597. // The largest index is always less than the interleave factor.
  598. if (Index >= static_cast<int32_t>(Factor))
  599. return false;
  600. LargestKey = Key;
  601. } else if (Key < SmallestKey) {
  602. // Make sure the largest index fits in an int32_t.
  603. std::optional<int32_t> MaybeLargestIndex = checkedSub(LargestKey, Key);
  604. if (!MaybeLargestIndex)
  605. return false;
  606. // The largest index is always less than the interleave factor.
  607. if (*MaybeLargestIndex >= static_cast<int64_t>(Factor))
  608. return false;
  609. SmallestKey = Key;
  610. }
  611. // It's always safe to select the minimum alignment.
  612. Alignment = std::min(Alignment, NewAlign);
  613. Members[Key] = Instr;
  614. return true;
  615. }
  616. /// Get the member with the given index \p Index
  617. ///
  618. /// \returns nullptr if contains no such member.
  619. InstTy *getMember(uint32_t Index) const {
  620. int32_t Key = SmallestKey + Index;
  621. return Members.lookup(Key);
  622. }
  623. /// Get the index for the given member. Unlike the key in the member
  624. /// map, the index starts from 0.
  625. uint32_t getIndex(const InstTy *Instr) const {
  626. for (auto I : Members) {
  627. if (I.second == Instr)
  628. return I.first - SmallestKey;
  629. }
  630. llvm_unreachable("InterleaveGroup contains no such member");
  631. }
  632. InstTy *getInsertPos() const { return InsertPos; }
  633. void setInsertPos(InstTy *Inst) { InsertPos = Inst; }
  634. /// Add metadata (e.g. alias info) from the instructions in this group to \p
  635. /// NewInst.
  636. ///
  637. /// FIXME: this function currently does not add noalias metadata a'la
  638. /// addNewMedata. To do that we need to compute the intersection of the
  639. /// noalias info from all members.
  640. void addMetadata(InstTy *NewInst) const;
  641. /// Returns true if this Group requires a scalar iteration to handle gaps.
  642. bool requiresScalarEpilogue() const {
  643. // If the last member of the Group exists, then a scalar epilog is not
  644. // needed for this group.
  645. if (getMember(getFactor() - 1))
  646. return false;
  647. // We have a group with gaps. It therefore can't be a reversed access,
  648. // because such groups get invalidated (TODO).
  649. assert(!isReverse() && "Group should have been invalidated");
  650. // This is a group of loads, with gaps, and without a last-member
  651. return true;
  652. }
  653. private:
  654. uint32_t Factor; // Interleave Factor.
  655. bool Reverse;
  656. Align Alignment;
  657. DenseMap<int32_t, InstTy *> Members;
  658. int32_t SmallestKey = 0;
  659. int32_t LargestKey = 0;
  660. // To avoid breaking dependences, vectorized instructions of an interleave
  661. // group should be inserted at either the first load or the last store in
  662. // program order.
  663. //
  664. // E.g. %even = load i32 // Insert Position
  665. // %add = add i32 %even // Use of %even
  666. // %odd = load i32
  667. //
  668. // store i32 %even
  669. // %odd = add i32 // Def of %odd
  670. // store i32 %odd // Insert Position
  671. InstTy *InsertPos;
  672. };
  673. /// Drive the analysis of interleaved memory accesses in the loop.
  674. ///
  675. /// Use this class to analyze interleaved accesses only when we can vectorize
  676. /// a loop. Otherwise it's meaningless to do analysis as the vectorization
  677. /// on interleaved accesses is unsafe.
  678. ///
  679. /// The analysis collects interleave groups and records the relationships
  680. /// between the member and the group in a map.
  681. class InterleavedAccessInfo {
  682. public:
  683. InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L,
  684. DominatorTree *DT, LoopInfo *LI,
  685. const LoopAccessInfo *LAI)
  686. : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(LAI) {}
  687. ~InterleavedAccessInfo() { invalidateGroups(); }
  688. /// Analyze the interleaved accesses and collect them in interleave
  689. /// groups. Substitute symbolic strides using \p Strides.
  690. /// Consider also predicated loads/stores in the analysis if
  691. /// \p EnableMaskedInterleavedGroup is true.
  692. void analyzeInterleaving(bool EnableMaskedInterleavedGroup);
  693. /// Invalidate groups, e.g., in case all blocks in loop will be predicated
  694. /// contrary to original assumption. Although we currently prevent group
  695. /// formation for predicated accesses, we may be able to relax this limitation
  696. /// in the future once we handle more complicated blocks. Returns true if any
  697. /// groups were invalidated.
  698. bool invalidateGroups() {
  699. if (InterleaveGroups.empty()) {
  700. assert(
  701. !RequiresScalarEpilogue &&
  702. "RequiresScalarEpilog should not be set without interleave groups");
  703. return false;
  704. }
  705. InterleaveGroupMap.clear();
  706. for (auto *Ptr : InterleaveGroups)
  707. delete Ptr;
  708. InterleaveGroups.clear();
  709. RequiresScalarEpilogue = false;
  710. return true;
  711. }
  712. /// Check if \p Instr belongs to any interleave group.
  713. bool isInterleaved(Instruction *Instr) const {
  714. return InterleaveGroupMap.find(Instr) != InterleaveGroupMap.end();
  715. }
  716. /// Get the interleave group that \p Instr belongs to.
  717. ///
  718. /// \returns nullptr if doesn't have such group.
  719. InterleaveGroup<Instruction> *
  720. getInterleaveGroup(const Instruction *Instr) const {
  721. return InterleaveGroupMap.lookup(Instr);
  722. }
  723. iterator_range<SmallPtrSetIterator<llvm::InterleaveGroup<Instruction> *>>
  724. getInterleaveGroups() {
  725. return make_range(InterleaveGroups.begin(), InterleaveGroups.end());
  726. }
  727. /// Returns true if an interleaved group that may access memory
  728. /// out-of-bounds requires a scalar epilogue iteration for correctness.
  729. bool requiresScalarEpilogue() const { return RequiresScalarEpilogue; }
  730. /// Invalidate groups that require a scalar epilogue (due to gaps). This can
  731. /// happen when optimizing for size forbids a scalar epilogue, and the gap
  732. /// cannot be filtered by masking the load/store.
  733. void invalidateGroupsRequiringScalarEpilogue();
  734. /// Returns true if we have any interleave groups.
  735. bool hasGroups() const { return !InterleaveGroups.empty(); }
  736. private:
  737. /// A wrapper around ScalarEvolution, used to add runtime SCEV checks.
  738. /// Simplifies SCEV expressions in the context of existing SCEV assumptions.
  739. /// The interleaved access analysis can also add new predicates (for example
  740. /// by versioning strides of pointers).
  741. PredicatedScalarEvolution &PSE;
  742. Loop *TheLoop;
  743. DominatorTree *DT;
  744. LoopInfo *LI;
  745. const LoopAccessInfo *LAI;
  746. /// True if the loop may contain non-reversed interleaved groups with
  747. /// out-of-bounds accesses. We ensure we don't speculatively access memory
  748. /// out-of-bounds by executing at least one scalar epilogue iteration.
  749. bool RequiresScalarEpilogue = false;
  750. /// Holds the relationships between the members and the interleave group.
  751. DenseMap<Instruction *, InterleaveGroup<Instruction> *> InterleaveGroupMap;
  752. SmallPtrSet<InterleaveGroup<Instruction> *, 4> InterleaveGroups;
  753. /// Holds dependences among the memory accesses in the loop. It maps a source
  754. /// access to a set of dependent sink accesses.
  755. DenseMap<Instruction *, SmallPtrSet<Instruction *, 2>> Dependences;
  756. /// The descriptor for a strided memory access.
  757. struct StrideDescriptor {
  758. StrideDescriptor() = default;
  759. StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size,
  760. Align Alignment)
  761. : Stride(Stride), Scev(Scev), Size(Size), Alignment(Alignment) {}
  762. // The access's stride. It is negative for a reverse access.
  763. int64_t Stride = 0;
  764. // The scalar expression of this access.
  765. const SCEV *Scev = nullptr;
  766. // The size of the memory object.
  767. uint64_t Size = 0;
  768. // The alignment of this access.
  769. Align Alignment;
  770. };
  771. /// A type for holding instructions and their stride descriptors.
  772. using StrideEntry = std::pair<Instruction *, StrideDescriptor>;
  773. /// Create a new interleave group with the given instruction \p Instr,
  774. /// stride \p Stride and alignment \p Align.
  775. ///
  776. /// \returns the newly created interleave group.
  777. InterleaveGroup<Instruction> *
  778. createInterleaveGroup(Instruction *Instr, int Stride, Align Alignment) {
  779. assert(!InterleaveGroupMap.count(Instr) &&
  780. "Already in an interleaved access group");
  781. InterleaveGroupMap[Instr] =
  782. new InterleaveGroup<Instruction>(Instr, Stride, Alignment);
  783. InterleaveGroups.insert(InterleaveGroupMap[Instr]);
  784. return InterleaveGroupMap[Instr];
  785. }
  786. /// Release the group and remove all the relationships.
  787. void releaseGroup(InterleaveGroup<Instruction> *Group) {
  788. for (unsigned i = 0; i < Group->getFactor(); i++)
  789. if (Instruction *Member = Group->getMember(i))
  790. InterleaveGroupMap.erase(Member);
  791. InterleaveGroups.erase(Group);
  792. delete Group;
  793. }
  794. /// Collect all the accesses with a constant stride in program order.
  795. void collectConstStrideAccesses(
  796. MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo,
  797. const ValueToValueMap &Strides);
  798. /// Returns true if \p Stride is allowed in an interleaved group.
  799. static bool isStrided(int Stride);
  800. /// Returns true if \p BB is a predicated block.
  801. bool isPredicated(BasicBlock *BB) const {
  802. return LoopAccessInfo::blockNeedsPredication(BB, TheLoop, DT);
  803. }
  804. /// Returns true if LoopAccessInfo can be used for dependence queries.
  805. bool areDependencesValid() const {
  806. return LAI && LAI->getDepChecker().getDependences();
  807. }
  808. /// Returns true if memory accesses \p A and \p B can be reordered, if
  809. /// necessary, when constructing interleaved groups.
  810. ///
  811. /// \p A must precede \p B in program order. We return false if reordering is
  812. /// not necessary or is prevented because \p A and \p B may be dependent.
  813. bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A,
  814. StrideEntry *B) const {
  815. // Code motion for interleaved accesses can potentially hoist strided loads
  816. // and sink strided stores. The code below checks the legality of the
  817. // following two conditions:
  818. //
  819. // 1. Potentially moving a strided load (B) before any store (A) that
  820. // precedes B, or
  821. //
  822. // 2. Potentially moving a strided store (A) after any load or store (B)
  823. // that A precedes.
  824. //
  825. // It's legal to reorder A and B if we know there isn't a dependence from A
  826. // to B. Note that this determination is conservative since some
  827. // dependences could potentially be reordered safely.
  828. // A is potentially the source of a dependence.
  829. auto *Src = A->first;
  830. auto SrcDes = A->second;
  831. // B is potentially the sink of a dependence.
  832. auto *Sink = B->first;
  833. auto SinkDes = B->second;
  834. // Code motion for interleaved accesses can't violate WAR dependences.
  835. // Thus, reordering is legal if the source isn't a write.
  836. if (!Src->mayWriteToMemory())
  837. return true;
  838. // At least one of the accesses must be strided.
  839. if (!isStrided(SrcDes.Stride) && !isStrided(SinkDes.Stride))
  840. return true;
  841. // If dependence information is not available from LoopAccessInfo,
  842. // conservatively assume the instructions can't be reordered.
  843. if (!areDependencesValid())
  844. return false;
  845. // If we know there is a dependence from source to sink, assume the
  846. // instructions can't be reordered. Otherwise, reordering is legal.
  847. return Dependences.find(Src) == Dependences.end() ||
  848. !Dependences.lookup(Src).count(Sink);
  849. }
  850. /// Collect the dependences from LoopAccessInfo.
  851. ///
  852. /// We process the dependences once during the interleaved access analysis to
  853. /// enable constant-time dependence queries.
  854. void collectDependences() {
  855. if (!areDependencesValid())
  856. return;
  857. auto *Deps = LAI->getDepChecker().getDependences();
  858. for (auto Dep : *Deps)
  859. Dependences[Dep.getSource(*LAI)].insert(Dep.getDestination(*LAI));
  860. }
  861. };
  862. } // llvm namespace
  863. #endif
  864. #ifdef __GNUC__
  865. #pragma GCC diagnostic pop
  866. #endif