RISCVVIntrinsicUtils.cpp 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087
  1. //===- RISCVVIntrinsicUtils.cpp - RISC-V Vector Intrinsic Utils -*- C++ -*-===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. #include "clang/Support/RISCVVIntrinsicUtils.h"
  9. #include "llvm/ADT/ArrayRef.h"
  10. #include "llvm/ADT/SmallSet.h"
  11. #include "llvm/ADT/StringExtras.h"
  12. #include "llvm/ADT/StringMap.h"
  13. #include "llvm/ADT/StringSet.h"
  14. #include "llvm/ADT/Twine.h"
  15. #include "llvm/Support/ErrorHandling.h"
  16. #include "llvm/Support/raw_ostream.h"
  17. #include <numeric>
  18. #include <optional>
  19. using namespace llvm;
  20. namespace clang {
  21. namespace RISCV {
  22. const PrototypeDescriptor PrototypeDescriptor::Mask = PrototypeDescriptor(
  23. BaseTypeModifier::Vector, VectorTypeModifier::MaskVector);
  24. const PrototypeDescriptor PrototypeDescriptor::VL =
  25. PrototypeDescriptor(BaseTypeModifier::SizeT);
  26. const PrototypeDescriptor PrototypeDescriptor::Vector =
  27. PrototypeDescriptor(BaseTypeModifier::Vector);
  28. //===----------------------------------------------------------------------===//
  29. // Type implementation
  30. //===----------------------------------------------------------------------===//
  31. LMULType::LMULType(int NewLog2LMUL) {
  32. // Check Log2LMUL is -3, -2, -1, 0, 1, 2, 3
  33. assert(NewLog2LMUL <= 3 && NewLog2LMUL >= -3 && "Bad LMUL number!");
  34. Log2LMUL = NewLog2LMUL;
  35. }
  36. std::string LMULType::str() const {
  37. if (Log2LMUL < 0)
  38. return "mf" + utostr(1ULL << (-Log2LMUL));
  39. return "m" + utostr(1ULL << Log2LMUL);
  40. }
  41. VScaleVal LMULType::getScale(unsigned ElementBitwidth) const {
  42. int Log2ScaleResult = 0;
  43. switch (ElementBitwidth) {
  44. default:
  45. break;
  46. case 8:
  47. Log2ScaleResult = Log2LMUL + 3;
  48. break;
  49. case 16:
  50. Log2ScaleResult = Log2LMUL + 2;
  51. break;
  52. case 32:
  53. Log2ScaleResult = Log2LMUL + 1;
  54. break;
  55. case 64:
  56. Log2ScaleResult = Log2LMUL;
  57. break;
  58. }
  59. // Illegal vscale result would be less than 1
  60. if (Log2ScaleResult < 0)
  61. return std::nullopt;
  62. return 1 << Log2ScaleResult;
  63. }
  64. void LMULType::MulLog2LMUL(int log2LMUL) { Log2LMUL += log2LMUL; }
  65. RVVType::RVVType(BasicType BT, int Log2LMUL,
  66. const PrototypeDescriptor &prototype)
  67. : BT(BT), LMUL(LMULType(Log2LMUL)) {
  68. applyBasicType();
  69. applyModifier(prototype);
  70. Valid = verifyType();
  71. if (Valid) {
  72. initBuiltinStr();
  73. initTypeStr();
  74. if (isVector()) {
  75. initClangBuiltinStr();
  76. }
  77. }
  78. }
  79. // clang-format off
  80. // boolean type are encoded the ratio of n (SEW/LMUL)
  81. // SEW/LMUL | 1 | 2 | 4 | 8 | 16 | 32 | 64
  82. // c type | vbool64_t | vbool32_t | vbool16_t | vbool8_t | vbool4_t | vbool2_t | vbool1_t
  83. // IR type | nxv1i1 | nxv2i1 | nxv4i1 | nxv8i1 | nxv16i1 | nxv32i1 | nxv64i1
  84. // type\lmul | 1/8 | 1/4 | 1/2 | 1 | 2 | 4 | 8
  85. // -------- |------ | -------- | ------- | ------- | -------- | -------- | --------
  86. // i64 | N/A | N/A | N/A | nxv1i64 | nxv2i64 | nxv4i64 | nxv8i64
  87. // i32 | N/A | N/A | nxv1i32 | nxv2i32 | nxv4i32 | nxv8i32 | nxv16i32
  88. // i16 | N/A | nxv1i16 | nxv2i16 | nxv4i16 | nxv8i16 | nxv16i16 | nxv32i16
  89. // i8 | nxv1i8 | nxv2i8 | nxv4i8 | nxv8i8 | nxv16i8 | nxv32i8 | nxv64i8
  90. // double | N/A | N/A | N/A | nxv1f64 | nxv2f64 | nxv4f64 | nxv8f64
  91. // float | N/A | N/A | nxv1f32 | nxv2f32 | nxv4f32 | nxv8f32 | nxv16f32
  92. // half | N/A | nxv1f16 | nxv2f16 | nxv4f16 | nxv8f16 | nxv16f16 | nxv32f16
  93. // clang-format on
  94. bool RVVType::verifyType() const {
  95. if (ScalarType == Invalid)
  96. return false;
  97. if (isScalar())
  98. return true;
  99. if (!Scale)
  100. return false;
  101. if (isFloat() && ElementBitwidth == 8)
  102. return false;
  103. unsigned V = *Scale;
  104. switch (ElementBitwidth) {
  105. case 1:
  106. case 8:
  107. // Check Scale is 1,2,4,8,16,32,64
  108. return (V <= 64 && isPowerOf2_32(V));
  109. case 16:
  110. // Check Scale is 1,2,4,8,16,32
  111. return (V <= 32 && isPowerOf2_32(V));
  112. case 32:
  113. // Check Scale is 1,2,4,8,16
  114. return (V <= 16 && isPowerOf2_32(V));
  115. case 64:
  116. // Check Scale is 1,2,4,8
  117. return (V <= 8 && isPowerOf2_32(V));
  118. }
  119. return false;
  120. }
  121. void RVVType::initBuiltinStr() {
  122. assert(isValid() && "RVVType is invalid");
  123. switch (ScalarType) {
  124. case ScalarTypeKind::Void:
  125. BuiltinStr = "v";
  126. return;
  127. case ScalarTypeKind::Size_t:
  128. BuiltinStr = "z";
  129. if (IsImmediate)
  130. BuiltinStr = "I" + BuiltinStr;
  131. if (IsPointer)
  132. BuiltinStr += "*";
  133. return;
  134. case ScalarTypeKind::Ptrdiff_t:
  135. BuiltinStr = "Y";
  136. return;
  137. case ScalarTypeKind::UnsignedLong:
  138. BuiltinStr = "ULi";
  139. return;
  140. case ScalarTypeKind::SignedLong:
  141. BuiltinStr = "Li";
  142. return;
  143. case ScalarTypeKind::Boolean:
  144. assert(ElementBitwidth == 1);
  145. BuiltinStr += "b";
  146. break;
  147. case ScalarTypeKind::SignedInteger:
  148. case ScalarTypeKind::UnsignedInteger:
  149. switch (ElementBitwidth) {
  150. case 8:
  151. BuiltinStr += "c";
  152. break;
  153. case 16:
  154. BuiltinStr += "s";
  155. break;
  156. case 32:
  157. BuiltinStr += "i";
  158. break;
  159. case 64:
  160. BuiltinStr += "Wi";
  161. break;
  162. default:
  163. llvm_unreachable("Unhandled ElementBitwidth!");
  164. }
  165. if (isSignedInteger())
  166. BuiltinStr = "S" + BuiltinStr;
  167. else
  168. BuiltinStr = "U" + BuiltinStr;
  169. break;
  170. case ScalarTypeKind::Float:
  171. switch (ElementBitwidth) {
  172. case 16:
  173. BuiltinStr += "x";
  174. break;
  175. case 32:
  176. BuiltinStr += "f";
  177. break;
  178. case 64:
  179. BuiltinStr += "d";
  180. break;
  181. default:
  182. llvm_unreachable("Unhandled ElementBitwidth!");
  183. }
  184. break;
  185. default:
  186. llvm_unreachable("ScalarType is invalid!");
  187. }
  188. if (IsImmediate)
  189. BuiltinStr = "I" + BuiltinStr;
  190. if (isScalar()) {
  191. if (IsConstant)
  192. BuiltinStr += "C";
  193. if (IsPointer)
  194. BuiltinStr += "*";
  195. return;
  196. }
  197. BuiltinStr = "q" + utostr(*Scale) + BuiltinStr;
  198. // Pointer to vector types. Defined for segment load intrinsics.
  199. // segment load intrinsics have pointer type arguments to store the loaded
  200. // vector values.
  201. if (IsPointer)
  202. BuiltinStr += "*";
  203. }
  204. void RVVType::initClangBuiltinStr() {
  205. assert(isValid() && "RVVType is invalid");
  206. assert(isVector() && "Handle Vector type only");
  207. ClangBuiltinStr = "__rvv_";
  208. switch (ScalarType) {
  209. case ScalarTypeKind::Boolean:
  210. ClangBuiltinStr += "bool" + utostr(64 / *Scale) + "_t";
  211. return;
  212. case ScalarTypeKind::Float:
  213. ClangBuiltinStr += "float";
  214. break;
  215. case ScalarTypeKind::SignedInteger:
  216. ClangBuiltinStr += "int";
  217. break;
  218. case ScalarTypeKind::UnsignedInteger:
  219. ClangBuiltinStr += "uint";
  220. break;
  221. default:
  222. llvm_unreachable("ScalarTypeKind is invalid");
  223. }
  224. ClangBuiltinStr += utostr(ElementBitwidth) + LMUL.str() + "_t";
  225. }
  226. void RVVType::initTypeStr() {
  227. assert(isValid() && "RVVType is invalid");
  228. if (IsConstant)
  229. Str += "const ";
  230. auto getTypeString = [&](StringRef TypeStr) {
  231. if (isScalar())
  232. return Twine(TypeStr + Twine(ElementBitwidth) + "_t").str();
  233. return Twine("v" + TypeStr + Twine(ElementBitwidth) + LMUL.str() + "_t")
  234. .str();
  235. };
  236. switch (ScalarType) {
  237. case ScalarTypeKind::Void:
  238. Str = "void";
  239. return;
  240. case ScalarTypeKind::Size_t:
  241. Str = "size_t";
  242. if (IsPointer)
  243. Str += " *";
  244. return;
  245. case ScalarTypeKind::Ptrdiff_t:
  246. Str = "ptrdiff_t";
  247. return;
  248. case ScalarTypeKind::UnsignedLong:
  249. Str = "unsigned long";
  250. return;
  251. case ScalarTypeKind::SignedLong:
  252. Str = "long";
  253. return;
  254. case ScalarTypeKind::Boolean:
  255. if (isScalar())
  256. Str += "bool";
  257. else
  258. // Vector bool is special case, the formulate is
  259. // `vbool<N>_t = MVT::nxv<64/N>i1` ex. vbool16_t = MVT::4i1
  260. Str += "vbool" + utostr(64 / *Scale) + "_t";
  261. break;
  262. case ScalarTypeKind::Float:
  263. if (isScalar()) {
  264. if (ElementBitwidth == 64)
  265. Str += "double";
  266. else if (ElementBitwidth == 32)
  267. Str += "float";
  268. else if (ElementBitwidth == 16)
  269. Str += "_Float16";
  270. else
  271. llvm_unreachable("Unhandled floating type.");
  272. } else
  273. Str += getTypeString("float");
  274. break;
  275. case ScalarTypeKind::SignedInteger:
  276. Str += getTypeString("int");
  277. break;
  278. case ScalarTypeKind::UnsignedInteger:
  279. Str += getTypeString("uint");
  280. break;
  281. default:
  282. llvm_unreachable("ScalarType is invalid!");
  283. }
  284. if (IsPointer)
  285. Str += " *";
  286. }
  287. void RVVType::initShortStr() {
  288. switch (ScalarType) {
  289. case ScalarTypeKind::Boolean:
  290. assert(isVector());
  291. ShortStr = "b" + utostr(64 / *Scale);
  292. return;
  293. case ScalarTypeKind::Float:
  294. ShortStr = "f" + utostr(ElementBitwidth);
  295. break;
  296. case ScalarTypeKind::SignedInteger:
  297. ShortStr = "i" + utostr(ElementBitwidth);
  298. break;
  299. case ScalarTypeKind::UnsignedInteger:
  300. ShortStr = "u" + utostr(ElementBitwidth);
  301. break;
  302. default:
  303. llvm_unreachable("Unhandled case!");
  304. }
  305. if (isVector())
  306. ShortStr += LMUL.str();
  307. }
  308. void RVVType::applyBasicType() {
  309. switch (BT) {
  310. case BasicType::Int8:
  311. ElementBitwidth = 8;
  312. ScalarType = ScalarTypeKind::SignedInteger;
  313. break;
  314. case BasicType::Int16:
  315. ElementBitwidth = 16;
  316. ScalarType = ScalarTypeKind::SignedInteger;
  317. break;
  318. case BasicType::Int32:
  319. ElementBitwidth = 32;
  320. ScalarType = ScalarTypeKind::SignedInteger;
  321. break;
  322. case BasicType::Int64:
  323. ElementBitwidth = 64;
  324. ScalarType = ScalarTypeKind::SignedInteger;
  325. break;
  326. case BasicType::Float16:
  327. ElementBitwidth = 16;
  328. ScalarType = ScalarTypeKind::Float;
  329. break;
  330. case BasicType::Float32:
  331. ElementBitwidth = 32;
  332. ScalarType = ScalarTypeKind::Float;
  333. break;
  334. case BasicType::Float64:
  335. ElementBitwidth = 64;
  336. ScalarType = ScalarTypeKind::Float;
  337. break;
  338. default:
  339. llvm_unreachable("Unhandled type code!");
  340. }
  341. assert(ElementBitwidth != 0 && "Bad element bitwidth!");
  342. }
  343. std::optional<PrototypeDescriptor>
  344. PrototypeDescriptor::parsePrototypeDescriptor(
  345. llvm::StringRef PrototypeDescriptorStr) {
  346. PrototypeDescriptor PD;
  347. BaseTypeModifier PT = BaseTypeModifier::Invalid;
  348. VectorTypeModifier VTM = VectorTypeModifier::NoModifier;
  349. if (PrototypeDescriptorStr.empty())
  350. return PD;
  351. // Handle base type modifier
  352. auto PType = PrototypeDescriptorStr.back();
  353. switch (PType) {
  354. case 'e':
  355. PT = BaseTypeModifier::Scalar;
  356. break;
  357. case 'v':
  358. PT = BaseTypeModifier::Vector;
  359. break;
  360. case 'w':
  361. PT = BaseTypeModifier::Vector;
  362. VTM = VectorTypeModifier::Widening2XVector;
  363. break;
  364. case 'q':
  365. PT = BaseTypeModifier::Vector;
  366. VTM = VectorTypeModifier::Widening4XVector;
  367. break;
  368. case 'o':
  369. PT = BaseTypeModifier::Vector;
  370. VTM = VectorTypeModifier::Widening8XVector;
  371. break;
  372. case 'm':
  373. PT = BaseTypeModifier::Vector;
  374. VTM = VectorTypeModifier::MaskVector;
  375. break;
  376. case '0':
  377. PT = BaseTypeModifier::Void;
  378. break;
  379. case 'z':
  380. PT = BaseTypeModifier::SizeT;
  381. break;
  382. case 't':
  383. PT = BaseTypeModifier::Ptrdiff;
  384. break;
  385. case 'u':
  386. PT = BaseTypeModifier::UnsignedLong;
  387. break;
  388. case 'l':
  389. PT = BaseTypeModifier::SignedLong;
  390. break;
  391. default:
  392. llvm_unreachable("Illegal primitive type transformers!");
  393. }
  394. PD.PT = static_cast<uint8_t>(PT);
  395. PrototypeDescriptorStr = PrototypeDescriptorStr.drop_back();
  396. // Compute the vector type transformers, it can only appear one time.
  397. if (PrototypeDescriptorStr.startswith("(")) {
  398. assert(VTM == VectorTypeModifier::NoModifier &&
  399. "VectorTypeModifier should only have one modifier");
  400. size_t Idx = PrototypeDescriptorStr.find(')');
  401. assert(Idx != StringRef::npos);
  402. StringRef ComplexType = PrototypeDescriptorStr.slice(1, Idx);
  403. PrototypeDescriptorStr = PrototypeDescriptorStr.drop_front(Idx + 1);
  404. assert(!PrototypeDescriptorStr.contains('(') &&
  405. "Only allow one vector type modifier");
  406. auto ComplexTT = ComplexType.split(":");
  407. if (ComplexTT.first == "Log2EEW") {
  408. uint32_t Log2EEW;
  409. if (ComplexTT.second.getAsInteger(10, Log2EEW)) {
  410. llvm_unreachable("Invalid Log2EEW value!");
  411. return std::nullopt;
  412. }
  413. switch (Log2EEW) {
  414. case 3:
  415. VTM = VectorTypeModifier::Log2EEW3;
  416. break;
  417. case 4:
  418. VTM = VectorTypeModifier::Log2EEW4;
  419. break;
  420. case 5:
  421. VTM = VectorTypeModifier::Log2EEW5;
  422. break;
  423. case 6:
  424. VTM = VectorTypeModifier::Log2EEW6;
  425. break;
  426. default:
  427. llvm_unreachable("Invalid Log2EEW value, should be [3-6]");
  428. return std::nullopt;
  429. }
  430. } else if (ComplexTT.first == "FixedSEW") {
  431. uint32_t NewSEW;
  432. if (ComplexTT.second.getAsInteger(10, NewSEW)) {
  433. llvm_unreachable("Invalid FixedSEW value!");
  434. return std::nullopt;
  435. }
  436. switch (NewSEW) {
  437. case 8:
  438. VTM = VectorTypeModifier::FixedSEW8;
  439. break;
  440. case 16:
  441. VTM = VectorTypeModifier::FixedSEW16;
  442. break;
  443. case 32:
  444. VTM = VectorTypeModifier::FixedSEW32;
  445. break;
  446. case 64:
  447. VTM = VectorTypeModifier::FixedSEW64;
  448. break;
  449. default:
  450. llvm_unreachable("Invalid FixedSEW value, should be 8, 16, 32 or 64");
  451. return std::nullopt;
  452. }
  453. } else if (ComplexTT.first == "LFixedLog2LMUL") {
  454. int32_t Log2LMUL;
  455. if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
  456. llvm_unreachable("Invalid LFixedLog2LMUL value!");
  457. return std::nullopt;
  458. }
  459. switch (Log2LMUL) {
  460. case -3:
  461. VTM = VectorTypeModifier::LFixedLog2LMULN3;
  462. break;
  463. case -2:
  464. VTM = VectorTypeModifier::LFixedLog2LMULN2;
  465. break;
  466. case -1:
  467. VTM = VectorTypeModifier::LFixedLog2LMULN1;
  468. break;
  469. case 0:
  470. VTM = VectorTypeModifier::LFixedLog2LMUL0;
  471. break;
  472. case 1:
  473. VTM = VectorTypeModifier::LFixedLog2LMUL1;
  474. break;
  475. case 2:
  476. VTM = VectorTypeModifier::LFixedLog2LMUL2;
  477. break;
  478. case 3:
  479. VTM = VectorTypeModifier::LFixedLog2LMUL3;
  480. break;
  481. default:
  482. llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
  483. return std::nullopt;
  484. }
  485. } else if (ComplexTT.first == "SFixedLog2LMUL") {
  486. int32_t Log2LMUL;
  487. if (ComplexTT.second.getAsInteger(10, Log2LMUL)) {
  488. llvm_unreachable("Invalid SFixedLog2LMUL value!");
  489. return std::nullopt;
  490. }
  491. switch (Log2LMUL) {
  492. case -3:
  493. VTM = VectorTypeModifier::SFixedLog2LMULN3;
  494. break;
  495. case -2:
  496. VTM = VectorTypeModifier::SFixedLog2LMULN2;
  497. break;
  498. case -1:
  499. VTM = VectorTypeModifier::SFixedLog2LMULN1;
  500. break;
  501. case 0:
  502. VTM = VectorTypeModifier::SFixedLog2LMUL0;
  503. break;
  504. case 1:
  505. VTM = VectorTypeModifier::SFixedLog2LMUL1;
  506. break;
  507. case 2:
  508. VTM = VectorTypeModifier::SFixedLog2LMUL2;
  509. break;
  510. case 3:
  511. VTM = VectorTypeModifier::SFixedLog2LMUL3;
  512. break;
  513. default:
  514. llvm_unreachable("Invalid LFixedLog2LMUL value, should be [-3, 3]");
  515. return std::nullopt;
  516. }
  517. } else {
  518. llvm_unreachable("Illegal complex type transformers!");
  519. }
  520. }
  521. PD.VTM = static_cast<uint8_t>(VTM);
  522. // Compute the remain type transformers
  523. TypeModifier TM = TypeModifier::NoModifier;
  524. for (char I : PrototypeDescriptorStr) {
  525. switch (I) {
  526. case 'P':
  527. if ((TM & TypeModifier::Const) == TypeModifier::Const)
  528. llvm_unreachable("'P' transformer cannot be used after 'C'");
  529. if ((TM & TypeModifier::Pointer) == TypeModifier::Pointer)
  530. llvm_unreachable("'P' transformer cannot be used twice");
  531. TM |= TypeModifier::Pointer;
  532. break;
  533. case 'C':
  534. TM |= TypeModifier::Const;
  535. break;
  536. case 'K':
  537. TM |= TypeModifier::Immediate;
  538. break;
  539. case 'U':
  540. TM |= TypeModifier::UnsignedInteger;
  541. break;
  542. case 'I':
  543. TM |= TypeModifier::SignedInteger;
  544. break;
  545. case 'F':
  546. TM |= TypeModifier::Float;
  547. break;
  548. case 'S':
  549. TM |= TypeModifier::LMUL1;
  550. break;
  551. default:
  552. llvm_unreachable("Illegal non-primitive type transformer!");
  553. }
  554. }
  555. PD.TM = static_cast<uint8_t>(TM);
  556. return PD;
  557. }
  558. void RVVType::applyModifier(const PrototypeDescriptor &Transformer) {
  559. // Handle primitive type transformer
  560. switch (static_cast<BaseTypeModifier>(Transformer.PT)) {
  561. case BaseTypeModifier::Scalar:
  562. Scale = 0;
  563. break;
  564. case BaseTypeModifier::Vector:
  565. Scale = LMUL.getScale(ElementBitwidth);
  566. break;
  567. case BaseTypeModifier::Void:
  568. ScalarType = ScalarTypeKind::Void;
  569. break;
  570. case BaseTypeModifier::SizeT:
  571. ScalarType = ScalarTypeKind::Size_t;
  572. break;
  573. case BaseTypeModifier::Ptrdiff:
  574. ScalarType = ScalarTypeKind::Ptrdiff_t;
  575. break;
  576. case BaseTypeModifier::UnsignedLong:
  577. ScalarType = ScalarTypeKind::UnsignedLong;
  578. break;
  579. case BaseTypeModifier::SignedLong:
  580. ScalarType = ScalarTypeKind::SignedLong;
  581. break;
  582. case BaseTypeModifier::Invalid:
  583. ScalarType = ScalarTypeKind::Invalid;
  584. return;
  585. }
  586. switch (static_cast<VectorTypeModifier>(Transformer.VTM)) {
  587. case VectorTypeModifier::Widening2XVector:
  588. ElementBitwidth *= 2;
  589. LMUL.MulLog2LMUL(1);
  590. Scale = LMUL.getScale(ElementBitwidth);
  591. break;
  592. case VectorTypeModifier::Widening4XVector:
  593. ElementBitwidth *= 4;
  594. LMUL.MulLog2LMUL(2);
  595. Scale = LMUL.getScale(ElementBitwidth);
  596. break;
  597. case VectorTypeModifier::Widening8XVector:
  598. ElementBitwidth *= 8;
  599. LMUL.MulLog2LMUL(3);
  600. Scale = LMUL.getScale(ElementBitwidth);
  601. break;
  602. case VectorTypeModifier::MaskVector:
  603. ScalarType = ScalarTypeKind::Boolean;
  604. Scale = LMUL.getScale(ElementBitwidth);
  605. ElementBitwidth = 1;
  606. break;
  607. case VectorTypeModifier::Log2EEW3:
  608. applyLog2EEW(3);
  609. break;
  610. case VectorTypeModifier::Log2EEW4:
  611. applyLog2EEW(4);
  612. break;
  613. case VectorTypeModifier::Log2EEW5:
  614. applyLog2EEW(5);
  615. break;
  616. case VectorTypeModifier::Log2EEW6:
  617. applyLog2EEW(6);
  618. break;
  619. case VectorTypeModifier::FixedSEW8:
  620. applyFixedSEW(8);
  621. break;
  622. case VectorTypeModifier::FixedSEW16:
  623. applyFixedSEW(16);
  624. break;
  625. case VectorTypeModifier::FixedSEW32:
  626. applyFixedSEW(32);
  627. break;
  628. case VectorTypeModifier::FixedSEW64:
  629. applyFixedSEW(64);
  630. break;
  631. case VectorTypeModifier::LFixedLog2LMULN3:
  632. applyFixedLog2LMUL(-3, FixedLMULType::LargerThan);
  633. break;
  634. case VectorTypeModifier::LFixedLog2LMULN2:
  635. applyFixedLog2LMUL(-2, FixedLMULType::LargerThan);
  636. break;
  637. case VectorTypeModifier::LFixedLog2LMULN1:
  638. applyFixedLog2LMUL(-1, FixedLMULType::LargerThan);
  639. break;
  640. case VectorTypeModifier::LFixedLog2LMUL0:
  641. applyFixedLog2LMUL(0, FixedLMULType::LargerThan);
  642. break;
  643. case VectorTypeModifier::LFixedLog2LMUL1:
  644. applyFixedLog2LMUL(1, FixedLMULType::LargerThan);
  645. break;
  646. case VectorTypeModifier::LFixedLog2LMUL2:
  647. applyFixedLog2LMUL(2, FixedLMULType::LargerThan);
  648. break;
  649. case VectorTypeModifier::LFixedLog2LMUL3:
  650. applyFixedLog2LMUL(3, FixedLMULType::LargerThan);
  651. break;
  652. case VectorTypeModifier::SFixedLog2LMULN3:
  653. applyFixedLog2LMUL(-3, FixedLMULType::SmallerThan);
  654. break;
  655. case VectorTypeModifier::SFixedLog2LMULN2:
  656. applyFixedLog2LMUL(-2, FixedLMULType::SmallerThan);
  657. break;
  658. case VectorTypeModifier::SFixedLog2LMULN1:
  659. applyFixedLog2LMUL(-1, FixedLMULType::SmallerThan);
  660. break;
  661. case VectorTypeModifier::SFixedLog2LMUL0:
  662. applyFixedLog2LMUL(0, FixedLMULType::SmallerThan);
  663. break;
  664. case VectorTypeModifier::SFixedLog2LMUL1:
  665. applyFixedLog2LMUL(1, FixedLMULType::SmallerThan);
  666. break;
  667. case VectorTypeModifier::SFixedLog2LMUL2:
  668. applyFixedLog2LMUL(2, FixedLMULType::SmallerThan);
  669. break;
  670. case VectorTypeModifier::SFixedLog2LMUL3:
  671. applyFixedLog2LMUL(3, FixedLMULType::SmallerThan);
  672. break;
  673. case VectorTypeModifier::NoModifier:
  674. break;
  675. }
  676. for (unsigned TypeModifierMaskShift = 0;
  677. TypeModifierMaskShift <= static_cast<unsigned>(TypeModifier::MaxOffset);
  678. ++TypeModifierMaskShift) {
  679. unsigned TypeModifierMask = 1 << TypeModifierMaskShift;
  680. if ((static_cast<unsigned>(Transformer.TM) & TypeModifierMask) !=
  681. TypeModifierMask)
  682. continue;
  683. switch (static_cast<TypeModifier>(TypeModifierMask)) {
  684. case TypeModifier::Pointer:
  685. IsPointer = true;
  686. break;
  687. case TypeModifier::Const:
  688. IsConstant = true;
  689. break;
  690. case TypeModifier::Immediate:
  691. IsImmediate = true;
  692. IsConstant = true;
  693. break;
  694. case TypeModifier::UnsignedInteger:
  695. ScalarType = ScalarTypeKind::UnsignedInteger;
  696. break;
  697. case TypeModifier::SignedInteger:
  698. ScalarType = ScalarTypeKind::SignedInteger;
  699. break;
  700. case TypeModifier::Float:
  701. ScalarType = ScalarTypeKind::Float;
  702. break;
  703. case TypeModifier::LMUL1:
  704. LMUL = LMULType(0);
  705. // Update ElementBitwidth need to update Scale too.
  706. Scale = LMUL.getScale(ElementBitwidth);
  707. break;
  708. default:
  709. llvm_unreachable("Unknown type modifier mask!");
  710. }
  711. }
  712. }
  713. void RVVType::applyLog2EEW(unsigned Log2EEW) {
  714. // update new elmul = (eew/sew) * lmul
  715. LMUL.MulLog2LMUL(Log2EEW - Log2_32(ElementBitwidth));
  716. // update new eew
  717. ElementBitwidth = 1 << Log2EEW;
  718. ScalarType = ScalarTypeKind::SignedInteger;
  719. Scale = LMUL.getScale(ElementBitwidth);
  720. }
  721. void RVVType::applyFixedSEW(unsigned NewSEW) {
  722. // Set invalid type if src and dst SEW are same.
  723. if (ElementBitwidth == NewSEW) {
  724. ScalarType = ScalarTypeKind::Invalid;
  725. return;
  726. }
  727. // Update new SEW
  728. ElementBitwidth = NewSEW;
  729. Scale = LMUL.getScale(ElementBitwidth);
  730. }
  731. void RVVType::applyFixedLog2LMUL(int Log2LMUL, enum FixedLMULType Type) {
  732. switch (Type) {
  733. case FixedLMULType::LargerThan:
  734. if (Log2LMUL < LMUL.Log2LMUL) {
  735. ScalarType = ScalarTypeKind::Invalid;
  736. return;
  737. }
  738. break;
  739. case FixedLMULType::SmallerThan:
  740. if (Log2LMUL > LMUL.Log2LMUL) {
  741. ScalarType = ScalarTypeKind::Invalid;
  742. return;
  743. }
  744. break;
  745. }
  746. // Update new LMUL
  747. LMUL = LMULType(Log2LMUL);
  748. Scale = LMUL.getScale(ElementBitwidth);
  749. }
  750. std::optional<RVVTypes>
  751. RVVTypeCache::computeTypes(BasicType BT, int Log2LMUL, unsigned NF,
  752. ArrayRef<PrototypeDescriptor> Prototype) {
  753. // LMUL x NF must be less than or equal to 8.
  754. if ((Log2LMUL >= 1) && (1 << Log2LMUL) * NF > 8)
  755. return std::nullopt;
  756. RVVTypes Types;
  757. for (const PrototypeDescriptor &Proto : Prototype) {
  758. auto T = computeType(BT, Log2LMUL, Proto);
  759. if (!T)
  760. return std::nullopt;
  761. // Record legal type index
  762. Types.push_back(*T);
  763. }
  764. return Types;
  765. }
  766. // Compute the hash value of RVVType, used for cache the result of computeType.
  767. static uint64_t computeRVVTypeHashValue(BasicType BT, int Log2LMUL,
  768. PrototypeDescriptor Proto) {
  769. // Layout of hash value:
  770. // 0 8 16 24 32 40
  771. // | Log2LMUL + 3 | BT | Proto.PT | Proto.TM | Proto.VTM |
  772. assert(Log2LMUL >= -3 && Log2LMUL <= 3);
  773. return (Log2LMUL + 3) | (static_cast<uint64_t>(BT) & 0xff) << 8 |
  774. ((uint64_t)(Proto.PT & 0xff) << 16) |
  775. ((uint64_t)(Proto.TM & 0xff) << 24) |
  776. ((uint64_t)(Proto.VTM & 0xff) << 32);
  777. }
  778. std::optional<RVVTypePtr> RVVTypeCache::computeType(BasicType BT, int Log2LMUL,
  779. PrototypeDescriptor Proto) {
  780. uint64_t Idx = computeRVVTypeHashValue(BT, Log2LMUL, Proto);
  781. // Search first
  782. auto It = LegalTypes.find(Idx);
  783. if (It != LegalTypes.end())
  784. return &(It->second);
  785. if (IllegalTypes.count(Idx))
  786. return std::nullopt;
  787. // Compute type and record the result.
  788. RVVType T(BT, Log2LMUL, Proto);
  789. if (T.isValid()) {
  790. // Record legal type index and value.
  791. std::pair<std::unordered_map<uint64_t, RVVType>::iterator, bool>
  792. InsertResult = LegalTypes.insert({Idx, T});
  793. return &(InsertResult.first->second);
  794. }
  795. // Record illegal type index.
  796. IllegalTypes.insert(Idx);
  797. return std::nullopt;
  798. }
  799. //===----------------------------------------------------------------------===//
  800. // RVVIntrinsic implementation
  801. //===----------------------------------------------------------------------===//
  802. RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix,
  803. StringRef NewOverloadedName,
  804. StringRef OverloadedSuffix, StringRef IRName,
  805. bool IsMasked, bool HasMaskedOffOperand, bool HasVL,
  806. PolicyScheme Scheme, bool SupportOverloading,
  807. bool HasBuiltinAlias, StringRef ManualCodegen,
  808. const RVVTypes &OutInTypes,
  809. const std::vector<int64_t> &NewIntrinsicTypes,
  810. const std::vector<StringRef> &RequiredFeatures,
  811. unsigned NF, Policy NewPolicyAttrs)
  812. : IRName(IRName), IsMasked(IsMasked),
  813. HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme),
  814. SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias),
  815. ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) {
  816. // Init BuiltinName, Name and OverloadedName
  817. BuiltinName = NewName.str();
  818. Name = BuiltinName;
  819. if (NewOverloadedName.empty())
  820. OverloadedName = NewName.split("_").first.str();
  821. else
  822. OverloadedName = NewOverloadedName.str();
  823. if (!Suffix.empty())
  824. Name += "_" + Suffix.str();
  825. if (!OverloadedSuffix.empty())
  826. OverloadedName += "_" + OverloadedSuffix.str();
  827. updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName,
  828. PolicyAttrs);
  829. // Init OutputType and InputTypes
  830. OutputType = OutInTypes[0];
  831. InputTypes.assign(OutInTypes.begin() + 1, OutInTypes.end());
  832. // IntrinsicTypes is unmasked TA version index. Need to update it
  833. // if there is merge operand (It is always in first operand).
  834. IntrinsicTypes = NewIntrinsicTypes;
  835. if ((IsMasked && hasMaskedOffOperand()) ||
  836. (!IsMasked && hasPassthruOperand())) {
  837. for (auto &I : IntrinsicTypes) {
  838. if (I >= 0)
  839. I += NF;
  840. }
  841. }
  842. }
  843. std::string RVVIntrinsic::getBuiltinTypeStr() const {
  844. std::string S;
  845. S += OutputType->getBuiltinStr();
  846. for (const auto &T : InputTypes) {
  847. S += T->getBuiltinStr();
  848. }
  849. return S;
  850. }
  851. std::string RVVIntrinsic::getSuffixStr(
  852. RVVTypeCache &TypeCache, BasicType Type, int Log2LMUL,
  853. llvm::ArrayRef<PrototypeDescriptor> PrototypeDescriptors) {
  854. SmallVector<std::string> SuffixStrs;
  855. for (auto PD : PrototypeDescriptors) {
  856. auto T = TypeCache.computeType(Type, Log2LMUL, PD);
  857. SuffixStrs.push_back((*T)->getShortStr());
  858. }
  859. return join(SuffixStrs, "_");
  860. }
  861. llvm::SmallVector<PrototypeDescriptor> RVVIntrinsic::computeBuiltinTypes(
  862. llvm::ArrayRef<PrototypeDescriptor> Prototype, bool IsMasked,
  863. bool HasMaskedOffOperand, bool HasVL, unsigned NF,
  864. PolicyScheme DefaultScheme, Policy PolicyAttrs) {
  865. SmallVector<PrototypeDescriptor> NewPrototype(Prototype.begin(),
  866. Prototype.end());
  867. bool HasPassthruOp = DefaultScheme == PolicyScheme::HasPassthruOperand;
  868. if (IsMasked) {
  869. // If HasMaskedOffOperand, insert result type as first input operand if
  870. // need.
  871. if (HasMaskedOffOperand && !PolicyAttrs.isTAMAPolicy()) {
  872. if (NF == 1) {
  873. NewPrototype.insert(NewPrototype.begin() + 1, NewPrototype[0]);
  874. } else if (NF > 1) {
  875. // Convert
  876. // (void, op0 address, op1 address, ...)
  877. // to
  878. // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
  879. PrototypeDescriptor MaskoffType = NewPrototype[1];
  880. MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
  881. NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
  882. }
  883. }
  884. if (HasMaskedOffOperand && NF > 1) {
  885. // Convert
  886. // (void, op0 address, op1 address, ..., maskedoff0, maskedoff1, ...)
  887. // to
  888. // (void, op0 address, op1 address, ..., mask, maskedoff0, maskedoff1,
  889. // ...)
  890. NewPrototype.insert(NewPrototype.begin() + NF + 1,
  891. PrototypeDescriptor::Mask);
  892. } else {
  893. // If IsMasked, insert PrototypeDescriptor:Mask as first input operand.
  894. NewPrototype.insert(NewPrototype.begin() + 1, PrototypeDescriptor::Mask);
  895. }
  896. } else {
  897. if (NF == 1) {
  898. if (PolicyAttrs.isTUPolicy() && HasPassthruOp)
  899. NewPrototype.insert(NewPrototype.begin(), NewPrototype[0]);
  900. } else if (PolicyAttrs.isTUPolicy() && HasPassthruOp) {
  901. // NF > 1 cases for segment load operations.
  902. // Convert
  903. // (void, op0 address, op1 address, ...)
  904. // to
  905. // (void, op0 address, op1 address, maskedoff0, maskedoff1, ...)
  906. PrototypeDescriptor MaskoffType = Prototype[1];
  907. MaskoffType.TM &= ~static_cast<uint8_t>(TypeModifier::Pointer);
  908. NewPrototype.insert(NewPrototype.begin() + NF + 1, NF, MaskoffType);
  909. }
  910. }
  911. // If HasVL, append PrototypeDescriptor:VL to last operand
  912. if (HasVL)
  913. NewPrototype.push_back(PrototypeDescriptor::VL);
  914. return NewPrototype;
  915. }
  916. llvm::SmallVector<Policy> RVVIntrinsic::getSupportedUnMaskedPolicies() {
  917. return {Policy(Policy::PolicyType::Undisturbed)}; // TU
  918. }
  919. llvm::SmallVector<Policy>
  920. RVVIntrinsic::getSupportedMaskedPolicies(bool HasTailPolicy,
  921. bool HasMaskPolicy) {
  922. if (HasTailPolicy && HasMaskPolicy)
  923. return {Policy(Policy::PolicyType::Undisturbed,
  924. Policy::PolicyType::Agnostic), // TUM
  925. Policy(Policy::PolicyType::Undisturbed,
  926. Policy::PolicyType::Undisturbed), // TUMU
  927. Policy(Policy::PolicyType::Agnostic,
  928. Policy::PolicyType::Undisturbed)}; // MU
  929. if (HasTailPolicy && !HasMaskPolicy)
  930. return {Policy(Policy::PolicyType::Undisturbed,
  931. Policy::PolicyType::Agnostic)}; // TU
  932. if (!HasTailPolicy && HasMaskPolicy)
  933. return {Policy(Policy::PolicyType::Agnostic,
  934. Policy::PolicyType::Undisturbed)}; // MU
  935. llvm_unreachable("An RVV instruction should not be without both tail policy "
  936. "and mask policy");
  937. }
  938. void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy,
  939. std::string &Name,
  940. std::string &BuiltinName,
  941. std::string &OverloadedName,
  942. Policy &PolicyAttrs) {
  943. auto appendPolicySuffix = [&](const std::string &suffix) {
  944. Name += suffix;
  945. BuiltinName += suffix;
  946. OverloadedName += suffix;
  947. };
  948. // This follows the naming guideline under riscv-c-api-doc to add the
  949. // `__riscv_` suffix for all RVV intrinsics.
  950. Name = "__riscv_" + Name;
  951. OverloadedName = "__riscv_" + OverloadedName;
  952. if (IsMasked) {
  953. if (PolicyAttrs.isTUMUPolicy())
  954. appendPolicySuffix("_tumu");
  955. else if (PolicyAttrs.isTUMAPolicy())
  956. appendPolicySuffix("_tum");
  957. else if (PolicyAttrs.isTAMUPolicy())
  958. appendPolicySuffix("_mu");
  959. else if (PolicyAttrs.isTAMAPolicy()) {
  960. Name += "_m";
  961. if (HasPolicy)
  962. BuiltinName += "_tama";
  963. else
  964. BuiltinName += "_m";
  965. } else
  966. llvm_unreachable("Unhandled policy condition");
  967. } else {
  968. if (PolicyAttrs.isTUPolicy())
  969. appendPolicySuffix("_tu");
  970. else if (PolicyAttrs.isTAPolicy()) {
  971. if (HasPolicy)
  972. BuiltinName += "_ta";
  973. } else
  974. llvm_unreachable("Unhandled policy condition");
  975. }
  976. }
  977. SmallVector<PrototypeDescriptor> parsePrototypes(StringRef Prototypes) {
  978. SmallVector<PrototypeDescriptor> PrototypeDescriptors;
  979. const StringRef Primaries("evwqom0ztul");
  980. while (!Prototypes.empty()) {
  981. size_t Idx = 0;
  982. // Skip over complex prototype because it could contain primitive type
  983. // character.
  984. if (Prototypes[0] == '(')
  985. Idx = Prototypes.find_first_of(')');
  986. Idx = Prototypes.find_first_of(Primaries, Idx);
  987. assert(Idx != StringRef::npos);
  988. auto PD = PrototypeDescriptor::parsePrototypeDescriptor(
  989. Prototypes.slice(0, Idx + 1));
  990. if (!PD)
  991. llvm_unreachable("Error during parsing prototype.");
  992. PrototypeDescriptors.push_back(*PD);
  993. Prototypes = Prototypes.drop_front(Idx + 1);
  994. }
  995. return PrototypeDescriptors;
  996. }
  997. raw_ostream &operator<<(raw_ostream &OS, const RVVIntrinsicRecord &Record) {
  998. OS << "{";
  999. OS << "\"" << Record.Name << "\",";
  1000. if (Record.OverloadedName == nullptr ||
  1001. StringRef(Record.OverloadedName).empty())
  1002. OS << "nullptr,";
  1003. else
  1004. OS << "\"" << Record.OverloadedName << "\",";
  1005. OS << Record.PrototypeIndex << ",";
  1006. OS << Record.SuffixIndex << ",";
  1007. OS << Record.OverloadedSuffixIndex << ",";
  1008. OS << (int)Record.PrototypeLength << ",";
  1009. OS << (int)Record.SuffixLength << ",";
  1010. OS << (int)Record.OverloadedSuffixSize << ",";
  1011. OS << (int)Record.RequiredExtensions << ",";
  1012. OS << (int)Record.TypeRangeMask << ",";
  1013. OS << (int)Record.Log2LMULMask << ",";
  1014. OS << (int)Record.NF << ",";
  1015. OS << (int)Record.HasMasked << ",";
  1016. OS << (int)Record.HasVL << ",";
  1017. OS << (int)Record.HasMaskedOffOperand << ",";
  1018. OS << (int)Record.HasTailPolicy << ",";
  1019. OS << (int)Record.HasMaskPolicy << ",";
  1020. OS << (int)Record.UnMaskedPolicyScheme << ",";
  1021. OS << (int)Record.MaskedPolicyScheme << ",";
  1022. OS << "},\n";
  1023. return OS;
  1024. }
  1025. } // end namespace RISCV
  1026. } // end namespace clang