SemaRISCVVectorLookup.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. //==- SemaRISCVVectorLookup.cpp - Name Lookup for RISC-V Vector Intrinsic -==//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // This file implements name lookup for RISC-V vector intrinsic.
  10. //
  11. //===----------------------------------------------------------------------===//
  12. #include "clang/AST/ASTContext.h"
  13. #include "clang/AST/Decl.h"
  14. #include "clang/Basic/Builtins.h"
  15. #include "clang/Basic/TargetInfo.h"
  16. #include "clang/Lex/Preprocessor.h"
  17. #include "clang/Sema/Lookup.h"
  18. #include "clang/Sema/RISCVIntrinsicManager.h"
  19. #include "clang/Sema/Sema.h"
  20. #include "clang/Support/RISCVVIntrinsicUtils.h"
  21. #include "llvm/ADT/SmallVector.h"
  22. #include <optional>
  23. #include <string>
  24. #include <vector>
  25. using namespace llvm;
  26. using namespace clang;
  27. using namespace clang::RISCV;
  28. namespace {
  29. // Function definition of a RVV intrinsic.
  30. struct RVVIntrinsicDef {
  31. /// Full function name with suffix, e.g. vadd_vv_i32m1.
  32. std::string Name;
  33. /// Overloaded function name, e.g. vadd.
  34. std::string OverloadName;
  35. /// Mapping to which clang built-in function, e.g. __builtin_rvv_vadd.
  36. std::string BuiltinName;
  37. /// Function signature, first element is return type.
  38. RVVTypes Signature;
  39. };
  40. struct RVVOverloadIntrinsicDef {
  41. // Indexes of RISCVIntrinsicManagerImpl::IntrinsicList.
  42. SmallVector<size_t, 8> Indexes;
  43. };
  44. } // namespace
  45. static const PrototypeDescriptor RVVSignatureTable[] = {
  46. #define DECL_SIGNATURE_TABLE
  47. #include "clang/Basic/riscv_vector_builtin_sema.inc"
  48. #undef DECL_SIGNATURE_TABLE
  49. };
  50. static const RVVIntrinsicRecord RVVIntrinsicRecords[] = {
  51. #define DECL_INTRINSIC_RECORDS
  52. #include "clang/Basic/riscv_vector_builtin_sema.inc"
  53. #undef DECL_INTRINSIC_RECORDS
  54. };
  55. // Get subsequence of signature table.
  56. static ArrayRef<PrototypeDescriptor> ProtoSeq2ArrayRef(uint16_t Index,
  57. uint8_t Length) {
  58. return ArrayRef(&RVVSignatureTable[Index], Length);
  59. }
  60. static QualType RVVType2Qual(ASTContext &Context, const RVVType *Type) {
  61. QualType QT;
  62. switch (Type->getScalarType()) {
  63. case ScalarTypeKind::Void:
  64. QT = Context.VoidTy;
  65. break;
  66. case ScalarTypeKind::Size_t:
  67. QT = Context.getSizeType();
  68. break;
  69. case ScalarTypeKind::Ptrdiff_t:
  70. QT = Context.getPointerDiffType();
  71. break;
  72. case ScalarTypeKind::UnsignedLong:
  73. QT = Context.UnsignedLongTy;
  74. break;
  75. case ScalarTypeKind::SignedLong:
  76. QT = Context.LongTy;
  77. break;
  78. case ScalarTypeKind::Boolean:
  79. QT = Context.BoolTy;
  80. break;
  81. case ScalarTypeKind::SignedInteger:
  82. QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), true);
  83. break;
  84. case ScalarTypeKind::UnsignedInteger:
  85. QT = Context.getIntTypeForBitwidth(Type->getElementBitwidth(), false);
  86. break;
  87. case ScalarTypeKind::Float:
  88. switch (Type->getElementBitwidth()) {
  89. case 64:
  90. QT = Context.DoubleTy;
  91. break;
  92. case 32:
  93. QT = Context.FloatTy;
  94. break;
  95. case 16:
  96. QT = Context.Float16Ty;
  97. break;
  98. default:
  99. llvm_unreachable("Unsupported floating point width.");
  100. }
  101. break;
  102. case Invalid:
  103. llvm_unreachable("Unhandled type.");
  104. }
  105. if (Type->isVector())
  106. QT = Context.getScalableVectorType(QT, *Type->getScale());
  107. if (Type->isConstant())
  108. QT = Context.getConstType(QT);
  109. // Transform the type to a pointer as the last step, if necessary.
  110. if (Type->isPointer())
  111. QT = Context.getPointerType(QT);
  112. return QT;
  113. }
  114. namespace {
  115. class RISCVIntrinsicManagerImpl : public sema::RISCVIntrinsicManager {
  116. private:
  117. Sema &S;
  118. ASTContext &Context;
  119. RVVTypeCache TypeCache;
  120. // List of all RVV intrinsic.
  121. std::vector<RVVIntrinsicDef> IntrinsicList;
  122. // Mapping function name to index of IntrinsicList.
  123. StringMap<size_t> Intrinsics;
  124. // Mapping function name to RVVOverloadIntrinsicDef.
  125. StringMap<RVVOverloadIntrinsicDef> OverloadIntrinsics;
  126. // Create IntrinsicList
  127. void InitIntrinsicList();
  128. // Create RVVIntrinsicDef.
  129. void InitRVVIntrinsic(const RVVIntrinsicRecord &Record, StringRef SuffixStr,
  130. StringRef OverloadedSuffixStr, bool IsMask,
  131. RVVTypes &Types, bool HasPolicy, Policy PolicyAttrs);
  132. // Create FunctionDecl for a vector intrinsic.
  133. void CreateRVVIntrinsicDecl(LookupResult &LR, IdentifierInfo *II,
  134. Preprocessor &PP, unsigned Index,
  135. bool IsOverload);
  136. public:
  137. RISCVIntrinsicManagerImpl(clang::Sema &S) : S(S), Context(S.Context) {
  138. InitIntrinsicList();
  139. }
  140. // Create RISC-V vector intrinsic and insert into symbol table if found, and
  141. // return true, otherwise return false.
  142. bool CreateIntrinsicIfFound(LookupResult &LR, IdentifierInfo *II,
  143. Preprocessor &PP) override;
  144. };
  145. } // namespace
  146. void RISCVIntrinsicManagerImpl::InitIntrinsicList() {
  147. const TargetInfo &TI = Context.getTargetInfo();
  148. bool HasVectorFloat32 = TI.hasFeature("zve32f");
  149. bool HasVectorFloat64 = TI.hasFeature("zve64d");
  150. bool HasZvfh = TI.hasFeature("experimental-zvfh");
  151. bool HasRV64 = TI.hasFeature("64bit");
  152. bool HasFullMultiply = TI.hasFeature("v");
  153. // Construction of RVVIntrinsicRecords need to sync with createRVVIntrinsics
  154. // in RISCVVEmitter.cpp.
  155. for (auto &Record : RVVIntrinsicRecords) {
  156. // Create Intrinsics for each type and LMUL.
  157. BasicType BaseType = BasicType::Unknown;
  158. ArrayRef<PrototypeDescriptor> BasicProtoSeq =
  159. ProtoSeq2ArrayRef(Record.PrototypeIndex, Record.PrototypeLength);
  160. ArrayRef<PrototypeDescriptor> SuffixProto =
  161. ProtoSeq2ArrayRef(Record.SuffixIndex, Record.SuffixLength);
  162. ArrayRef<PrototypeDescriptor> OverloadedSuffixProto = ProtoSeq2ArrayRef(
  163. Record.OverloadedSuffixIndex, Record.OverloadedSuffixSize);
  164. PolicyScheme UnMaskedPolicyScheme =
  165. static_cast<PolicyScheme>(Record.UnMaskedPolicyScheme);
  166. PolicyScheme MaskedPolicyScheme =
  167. static_cast<PolicyScheme>(Record.MaskedPolicyScheme);
  168. const Policy DefaultPolicy;
  169. llvm::SmallVector<PrototypeDescriptor> ProtoSeq =
  170. RVVIntrinsic::computeBuiltinTypes(BasicProtoSeq, /*IsMasked=*/false,
  171. /*HasMaskedOffOperand=*/false,
  172. Record.HasVL, Record.NF,
  173. UnMaskedPolicyScheme, DefaultPolicy);
  174. llvm::SmallVector<PrototypeDescriptor> ProtoMaskSeq =
  175. RVVIntrinsic::computeBuiltinTypes(
  176. BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
  177. Record.HasVL, Record.NF, MaskedPolicyScheme, DefaultPolicy);
  178. bool UnMaskedHasPolicy = UnMaskedPolicyScheme != PolicyScheme::SchemeNone;
  179. bool MaskedHasPolicy = MaskedPolicyScheme != PolicyScheme::SchemeNone;
  180. SmallVector<Policy> SupportedUnMaskedPolicies =
  181. RVVIntrinsic::getSupportedUnMaskedPolicies();
  182. SmallVector<Policy> SupportedMaskedPolicies =
  183. RVVIntrinsic::getSupportedMaskedPolicies(Record.HasTailPolicy,
  184. Record.HasMaskPolicy);
  185. for (unsigned int TypeRangeMaskShift = 0;
  186. TypeRangeMaskShift <= static_cast<unsigned int>(BasicType::MaxOffset);
  187. ++TypeRangeMaskShift) {
  188. unsigned int BaseTypeI = 1 << TypeRangeMaskShift;
  189. BaseType = static_cast<BasicType>(BaseTypeI);
  190. if ((BaseTypeI & Record.TypeRangeMask) != BaseTypeI)
  191. continue;
  192. // Check requirement.
  193. if (BaseType == BasicType::Float16 && !HasZvfh)
  194. continue;
  195. if (BaseType == BasicType::Float32 && !HasVectorFloat32)
  196. continue;
  197. if (BaseType == BasicType::Float64 && !HasVectorFloat64)
  198. continue;
  199. if (((Record.RequiredExtensions & RVV_REQ_RV64) == RVV_REQ_RV64) &&
  200. !HasRV64)
  201. continue;
  202. if ((BaseType == BasicType::Int64) &&
  203. ((Record.RequiredExtensions & RVV_REQ_FullMultiply) ==
  204. RVV_REQ_FullMultiply) &&
  205. !HasFullMultiply)
  206. continue;
  207. // Expanded with different LMUL.
  208. for (int Log2LMUL = -3; Log2LMUL <= 3; Log2LMUL++) {
  209. if (!(Record.Log2LMULMask & (1 << (Log2LMUL + 3))))
  210. continue;
  211. std::optional<RVVTypes> Types =
  212. TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoSeq);
  213. // Ignored to create new intrinsic if there are any illegal types.
  214. if (!Types.has_value())
  215. continue;
  216. std::string SuffixStr = RVVIntrinsic::getSuffixStr(
  217. TypeCache, BaseType, Log2LMUL, SuffixProto);
  218. std::string OverloadedSuffixStr = RVVIntrinsic::getSuffixStr(
  219. TypeCache, BaseType, Log2LMUL, OverloadedSuffixProto);
  220. // Create non-masked intrinsic.
  221. InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, false, *Types,
  222. UnMaskedHasPolicy, DefaultPolicy);
  223. // Create non-masked policy intrinsic.
  224. if (Record.UnMaskedPolicyScheme != PolicyScheme::SchemeNone) {
  225. for (auto P : SupportedUnMaskedPolicies) {
  226. llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
  227. RVVIntrinsic::computeBuiltinTypes(
  228. BasicProtoSeq, /*IsMasked=*/false,
  229. /*HasMaskedOffOperand=*/false, Record.HasVL, Record.NF,
  230. UnMaskedPolicyScheme, P);
  231. std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
  232. BaseType, Log2LMUL, Record.NF, PolicyPrototype);
  233. InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
  234. /*IsMask=*/false, *PolicyTypes, UnMaskedHasPolicy,
  235. P);
  236. }
  237. }
  238. if (!Record.HasMasked)
  239. continue;
  240. // Create masked intrinsic.
  241. std::optional<RVVTypes> MaskTypes =
  242. TypeCache.computeTypes(BaseType, Log2LMUL, Record.NF, ProtoMaskSeq);
  243. InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr, true,
  244. *MaskTypes, MaskedHasPolicy, DefaultPolicy);
  245. if (Record.MaskedPolicyScheme == PolicyScheme::SchemeNone)
  246. continue;
  247. // Create masked policy intrinsic.
  248. for (auto P : SupportedMaskedPolicies) {
  249. llvm::SmallVector<PrototypeDescriptor> PolicyPrototype =
  250. RVVIntrinsic::computeBuiltinTypes(
  251. BasicProtoSeq, /*IsMasked=*/true, Record.HasMaskedOffOperand,
  252. Record.HasVL, Record.NF, MaskedPolicyScheme, P);
  253. std::optional<RVVTypes> PolicyTypes = TypeCache.computeTypes(
  254. BaseType, Log2LMUL, Record.NF, PolicyPrototype);
  255. InitRVVIntrinsic(Record, SuffixStr, OverloadedSuffixStr,
  256. /*IsMask=*/true, *PolicyTypes, MaskedHasPolicy, P);
  257. }
  258. } // End for different LMUL
  259. } // End for different TypeRange
  260. }
  261. }
  262. // Compute name and signatures for intrinsic with practical types.
  263. void RISCVIntrinsicManagerImpl::InitRVVIntrinsic(
  264. const RVVIntrinsicRecord &Record, StringRef SuffixStr,
  265. StringRef OverloadedSuffixStr, bool IsMasked, RVVTypes &Signature,
  266. bool HasPolicy, Policy PolicyAttrs) {
  267. // Function name, e.g. vadd_vv_i32m1.
  268. std::string Name = Record.Name;
  269. if (!SuffixStr.empty())
  270. Name += "_" + SuffixStr.str();
  271. // Overloaded function name, e.g. vadd.
  272. std::string OverloadedName;
  273. if (!Record.OverloadedName)
  274. OverloadedName = StringRef(Record.Name).split("_").first.str();
  275. else
  276. OverloadedName = Record.OverloadedName;
  277. if (!OverloadedSuffixStr.empty())
  278. OverloadedName += "_" + OverloadedSuffixStr.str();
  279. // clang built-in function name, e.g. __builtin_rvv_vadd.
  280. std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name);
  281. RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName,
  282. OverloadedName, PolicyAttrs);
  283. // Put into IntrinsicList.
  284. size_t Index = IntrinsicList.size();
  285. IntrinsicList.push_back({Name, OverloadedName, BuiltinName, Signature});
  286. // Creating mapping to Intrinsics.
  287. Intrinsics.insert({Name, Index});
  288. // Get the RVVOverloadIntrinsicDef.
  289. RVVOverloadIntrinsicDef &OverloadIntrinsicDef =
  290. OverloadIntrinsics[OverloadedName];
  291. // And added the index.
  292. OverloadIntrinsicDef.Indexes.push_back(Index);
  293. }
  294. void RISCVIntrinsicManagerImpl::CreateRVVIntrinsicDecl(LookupResult &LR,
  295. IdentifierInfo *II,
  296. Preprocessor &PP,
  297. unsigned Index,
  298. bool IsOverload) {
  299. ASTContext &Context = S.Context;
  300. RVVIntrinsicDef &IDef = IntrinsicList[Index];
  301. RVVTypes Sigs = IDef.Signature;
  302. size_t SigLength = Sigs.size();
  303. RVVType *ReturnType = Sigs[0];
  304. QualType RetType = RVVType2Qual(Context, ReturnType);
  305. SmallVector<QualType, 8> ArgTypes;
  306. QualType BuiltinFuncType;
  307. // Skip return type, and convert RVVType to QualType for arguments.
  308. for (size_t i = 1; i < SigLength; ++i)
  309. ArgTypes.push_back(RVVType2Qual(Context, Sigs[i]));
  310. FunctionProtoType::ExtProtoInfo PI(
  311. Context.getDefaultCallingConvention(false, false, true));
  312. PI.Variadic = false;
  313. SourceLocation Loc = LR.getNameLoc();
  314. BuiltinFuncType = Context.getFunctionType(RetType, ArgTypes, PI);
  315. DeclContext *Parent = Context.getTranslationUnitDecl();
  316. FunctionDecl *RVVIntrinsicDecl = FunctionDecl::Create(
  317. Context, Parent, Loc, Loc, II, BuiltinFuncType, /*TInfo=*/nullptr,
  318. SC_Extern, S.getCurFPFeatures().isFPConstrained(),
  319. /*isInlineSpecified*/ false,
  320. /*hasWrittenPrototype*/ true);
  321. // Create Decl objects for each parameter, adding them to the
  322. // FunctionDecl.
  323. const auto *FP = cast<FunctionProtoType>(BuiltinFuncType);
  324. SmallVector<ParmVarDecl *, 8> ParmList;
  325. for (unsigned IParm = 0, E = FP->getNumParams(); IParm != E; ++IParm) {
  326. ParmVarDecl *Parm =
  327. ParmVarDecl::Create(Context, RVVIntrinsicDecl, Loc, Loc, nullptr,
  328. FP->getParamType(IParm), nullptr, SC_None, nullptr);
  329. Parm->setScopeInfo(0, IParm);
  330. ParmList.push_back(Parm);
  331. }
  332. RVVIntrinsicDecl->setParams(ParmList);
  333. // Add function attributes.
  334. if (IsOverload)
  335. RVVIntrinsicDecl->addAttr(OverloadableAttr::CreateImplicit(Context));
  336. // Setup alias to __builtin_rvv_*
  337. IdentifierInfo &IntrinsicII = PP.getIdentifierTable().get(IDef.BuiltinName);
  338. RVVIntrinsicDecl->addAttr(
  339. BuiltinAliasAttr::CreateImplicit(S.Context, &IntrinsicII));
  340. // Add to symbol table.
  341. LR.addDecl(RVVIntrinsicDecl);
  342. }
  343. bool RISCVIntrinsicManagerImpl::CreateIntrinsicIfFound(LookupResult &LR,
  344. IdentifierInfo *II,
  345. Preprocessor &PP) {
  346. StringRef Name = II->getName();
  347. // Lookup the function name from the overload intrinsics first.
  348. auto OvIItr = OverloadIntrinsics.find(Name);
  349. if (OvIItr != OverloadIntrinsics.end()) {
  350. const RVVOverloadIntrinsicDef &OvIntrinsicDef = OvIItr->second;
  351. for (auto Index : OvIntrinsicDef.Indexes)
  352. CreateRVVIntrinsicDecl(LR, II, PP, Index,
  353. /*IsOverload*/ true);
  354. // If we added overloads, need to resolve the lookup result.
  355. LR.resolveKind();
  356. return true;
  357. }
  358. // Lookup the function name from the intrinsics.
  359. auto Itr = Intrinsics.find(Name);
  360. if (Itr != Intrinsics.end()) {
  361. CreateRVVIntrinsicDecl(LR, II, PP, Itr->second,
  362. /*IsOverload*/ false);
  363. return true;
  364. }
  365. // It's not an RVV intrinsics.
  366. return false;
  367. }
  368. namespace clang {
  369. std::unique_ptr<clang::sema::RISCVIntrinsicManager>
  370. CreateRISCVIntrinsicManager(Sema &S) {
  371. return std::make_unique<RISCVIntrinsicManagerImpl>(S);
  372. }
  373. } // namespace clang