RISCVTargetTransformInfo.h 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. /// \file
  9. /// This file defines a TargetTransformInfo::Concept conforming object specific
  10. /// to the RISC-V target machine. It uses the target's detailed information to
  11. /// provide more precise answers to certain TTI queries, while letting the
  12. /// target independent and default TTI implementations handle the rest.
  13. ///
  14. //===----------------------------------------------------------------------===//
  15. #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
  16. #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
  17. #include "RISCVSubtarget.h"
  18. #include "RISCVTargetMachine.h"
  19. #include "llvm/Analysis/IVDescriptors.h"
  20. #include "llvm/Analysis/TargetTransformInfo.h"
  21. #include "llvm/CodeGen/BasicTTIImpl.h"
  22. #include "llvm/IR/Function.h"
  23. #include <optional>
  24. namespace llvm {
  25. class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
  26. using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
  27. using TTI = TargetTransformInfo;
  28. friend BaseT;
  29. const RISCVSubtarget *ST;
  30. const RISCVTargetLowering *TLI;
  31. const RISCVSubtarget *getST() const { return ST; }
  32. const RISCVTargetLowering *getTLI() const { return TLI; }
  33. /// This function returns an estimate for VL to be used in VL based terms
  34. /// of the cost model. For fixed length vectors, this is simply the
  35. /// vector length. For scalable vectors, we return results consistent
  36. /// with getVScaleForTuning under the assumption that clients are also
  37. /// using that when comparing costs between scalar and vector representation.
  38. /// This does unfortunately mean that we can both undershoot and overshot
  39. /// the true cost significantly if getVScaleForTuning is wildly off for the
  40. /// actual target hardware.
  41. unsigned getEstimatedVLFor(VectorType *Ty);
  42. /// Return the cost of LMUL. The larger the LMUL, the higher the cost.
  43. InstructionCost getLMULCost(MVT VT);
  44. public:
  45. explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
  46. : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)),
  47. TLI(ST->getTargetLowering()) {}
  48. /// Return the cost of materializing an immediate for a value operand of
  49. /// a store instruction.
  50. InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
  51. TTI::TargetCostKind CostKind);
  52. InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
  53. TTI::TargetCostKind CostKind);
  54. InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
  55. const APInt &Imm, Type *Ty,
  56. TTI::TargetCostKind CostKind,
  57. Instruction *Inst = nullptr);
  58. InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
  59. const APInt &Imm, Type *Ty,
  60. TTI::TargetCostKind CostKind);
  61. TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
  62. bool shouldExpandReduction(const IntrinsicInst *II) const;
  63. bool supportsScalableVectors() const { return ST->hasVInstructions(); }
  64. bool enableScalableVectorization() const { return ST->hasVInstructions(); }
  65. PredicationStyle emitGetActiveLaneMask() const {
  66. return ST->hasVInstructions() ? PredicationStyle::Data
  67. : PredicationStyle::None;
  68. }
  69. std::optional<unsigned> getMaxVScale() const;
  70. std::optional<unsigned> getVScaleForTuning() const;
  71. TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
  72. unsigned getRegUsageForType(Type *Ty);
  73. unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const;
  74. bool preferEpilogueVectorization() const {
  75. // Epilogue vectorization is usually unprofitable - tail folding or
  76. // a smaller VF would have been better. This a blunt hammer - we
  77. // should re-examine this once vectorization is better tuned.
  78. return false;
  79. }
  80. InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
  81. Align Alignment, unsigned AddressSpace,
  82. TTI::TargetCostKind CostKind);
  83. void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
  84. TTI::UnrollingPreferences &UP,
  85. OptimizationRemarkEmitter *ORE);
  86. void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
  87. TTI::PeelingPreferences &PP);
  88. unsigned getMinVectorRegisterBitWidth() const {
  89. return ST->useRVVForFixedLengthVectors() ? 16 : 0;
  90. }
  91. InstructionCost getSpliceCost(VectorType *Tp, int Index);
  92. InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
  93. ArrayRef<int> Mask,
  94. TTI::TargetCostKind CostKind, int Index,
  95. VectorType *SubTp,
  96. ArrayRef<const Value *> Args = std::nullopt);
  97. InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
  98. TTI::TargetCostKind CostKind);
  99. InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
  100. const Value *Ptr, bool VariableMask,
  101. Align Alignment,
  102. TTI::TargetCostKind CostKind,
  103. const Instruction *I);
  104. InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
  105. TTI::CastContextHint CCH,
  106. TTI::TargetCostKind CostKind,
  107. const Instruction *I = nullptr);
  108. InstructionCost getMinMaxReductionCost(VectorType *Ty, VectorType *CondTy,
  109. bool IsUnsigned,
  110. TTI::TargetCostKind CostKind);
  111. InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
  112. std::optional<FastMathFlags> FMF,
  113. TTI::TargetCostKind CostKind);
  114. InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
  115. Type *ResTy, VectorType *ValTy,
  116. std::optional<FastMathFlags> FMF,
  117. TTI::TargetCostKind CostKind);
  118. InstructionCost
  119. getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
  120. unsigned AddressSpace, TTI::TargetCostKind CostKind,
  121. TTI::OperandValueInfo OpdInfo = {TTI::OK_AnyValue, TTI::OP_None},
  122. const Instruction *I = nullptr);
  123. InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
  124. CmpInst::Predicate VecPred,
  125. TTI::TargetCostKind CostKind,
  126. const Instruction *I = nullptr);
  127. using BaseT::getVectorInstrCost;
  128. InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
  129. TTI::TargetCostKind CostKind,
  130. unsigned Index, Value *Op0, Value *Op1);
  131. InstructionCost getArithmeticInstrCost(
  132. unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
  133. TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
  134. TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
  135. ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
  136. const Instruction *CxtI = nullptr);
  137. bool isElementTypeLegalForScalableVector(Type *Ty) const {
  138. return TLI->isLegalElementTypeForRVV(Ty);
  139. }
  140. bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
  141. if (!ST->hasVInstructions())
  142. return false;
  143. // Only support fixed vectors if we know the minimum vector size.
  144. if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
  145. return false;
  146. // Don't allow elements larger than the ELEN.
  147. // FIXME: How to limit for scalable vectors?
  148. if (isa<FixedVectorType>(DataType) &&
  149. DataType->getScalarSizeInBits() > ST->getELEN())
  150. return false;
  151. if (Alignment <
  152. DL.getTypeStoreSize(DataType->getScalarType()).getFixedValue())
  153. return false;
  154. return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
  155. }
  156. bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
  157. return isLegalMaskedLoadStore(DataType, Alignment);
  158. }
  159. bool isLegalMaskedStore(Type *DataType, Align Alignment) {
  160. return isLegalMaskedLoadStore(DataType, Alignment);
  161. }
  162. bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
  163. if (!ST->hasVInstructions())
  164. return false;
  165. // Only support fixed vectors if we know the minimum vector size.
  166. if (isa<FixedVectorType>(DataType) && !ST->useRVVForFixedLengthVectors())
  167. return false;
  168. // Don't allow elements larger than the ELEN.
  169. // FIXME: How to limit for scalable vectors?
  170. if (isa<FixedVectorType>(DataType) &&
  171. DataType->getScalarSizeInBits() > ST->getELEN())
  172. return false;
  173. if (Alignment <
  174. DL.getTypeStoreSize(DataType->getScalarType()).getFixedValue())
  175. return false;
  176. return TLI->isLegalElementTypeForRVV(DataType->getScalarType());
  177. }
  178. bool isLegalMaskedGather(Type *DataType, Align Alignment) {
  179. return isLegalMaskedGatherScatter(DataType, Alignment);
  180. }
  181. bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
  182. return isLegalMaskedGatherScatter(DataType, Alignment);
  183. }
  184. bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
  185. // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
  186. return ST->is64Bit() && !ST->hasVInstructionsI64();
  187. }
  188. bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
  189. // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
  190. return ST->is64Bit() && !ST->hasVInstructionsI64();
  191. }
  192. /// \returns How the target needs this vector-predicated operation to be
  193. /// transformed.
  194. TargetTransformInfo::VPLegalization
  195. getVPLegalizationStrategy(const VPIntrinsic &PI) const {
  196. using VPLegalization = TargetTransformInfo::VPLegalization;
  197. if (!ST->hasVInstructions() ||
  198. (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
  199. cast<VectorType>(PI.getArgOperand(1)->getType())
  200. ->getElementType()
  201. ->getIntegerBitWidth() != 1))
  202. return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
  203. return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
  204. }
  205. bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
  206. ElementCount VF) const {
  207. if (!VF.isScalable())
  208. return true;
  209. Type *Ty = RdxDesc.getRecurrenceType();
  210. if (!TLI->isLegalElementTypeForRVV(Ty))
  211. return false;
  212. switch (RdxDesc.getRecurrenceKind()) {
  213. case RecurKind::Add:
  214. case RecurKind::FAdd:
  215. case RecurKind::And:
  216. case RecurKind::Or:
  217. case RecurKind::Xor:
  218. case RecurKind::SMin:
  219. case RecurKind::SMax:
  220. case RecurKind::UMin:
  221. case RecurKind::UMax:
  222. case RecurKind::FMin:
  223. case RecurKind::FMax:
  224. case RecurKind::SelectICmp:
  225. case RecurKind::SelectFCmp:
  226. case RecurKind::FMulAdd:
  227. return true;
  228. default:
  229. return false;
  230. }
  231. }
  232. unsigned getMaxInterleaveFactor(unsigned VF) {
  233. // If the loop will not be vectorized, don't interleave the loop.
  234. // Let regular unroll to unroll the loop.
  235. return VF == 1 ? 1 : ST->getMaxInterleaveFactor();
  236. }
  237. enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
  238. unsigned getNumberOfRegisters(unsigned ClassID) const {
  239. switch (ClassID) {
  240. case RISCVRegisterClass::GPRRC:
  241. // 31 = 32 GPR - x0 (zero register)
  242. // FIXME: Should we exclude fixed registers like SP, TP or GP?
  243. return 31;
  244. case RISCVRegisterClass::FPRRC:
  245. if (ST->hasStdExtF())
  246. return 32;
  247. return 0;
  248. case RISCVRegisterClass::VRRC:
  249. // Although there are 32 vector registers, v0 is special in that it is the
  250. // only register that can be used to hold a mask.
  251. // FIXME: Should we conservatively return 31 as the number of usable
  252. // vector registers?
  253. return ST->hasVInstructions() ? 32 : 0;
  254. }
  255. llvm_unreachable("unknown register class");
  256. }
  257. unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
  258. if (Vector)
  259. return RISCVRegisterClass::VRRC;
  260. if (!Ty)
  261. return RISCVRegisterClass::GPRRC;
  262. Type *ScalarTy = Ty->getScalarType();
  263. if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhOrZfhmin()) ||
  264. (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
  265. (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
  266. return RISCVRegisterClass::FPRRC;
  267. }
  268. return RISCVRegisterClass::GPRRC;
  269. }
  270. const char *getRegisterClassName(unsigned ClassID) const {
  271. switch (ClassID) {
  272. case RISCVRegisterClass::GPRRC:
  273. return "RISCV::GPRRC";
  274. case RISCVRegisterClass::FPRRC:
  275. return "RISCV::FPRRC";
  276. case RISCVRegisterClass::VRRC:
  277. return "RISCV::VRRC";
  278. }
  279. llvm_unreachable("unknown register class");
  280. }
  281. bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
  282. const TargetTransformInfo::LSRCost &C2);
  283. };
  284. } // end namespace llvm
  285. #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H