codegen_ut.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. #include <yql/essentials/minikql/codegen/codegen.h>
  2. #include <codegen_ut_llvm_deps.h> // Y_IGNORE
  3. #include <library/cpp/testing/unittest/registar.h>
  4. #include <library/cpp/resource/resource.h>
  5. using namespace NYql::NCodegen;
  6. using namespace llvm;
  7. extern "C" int mul(int x, int y) {
  8. return x * y;
  9. }
  10. extern "C" int sum(int x, int y) {
  11. return x + y;
  12. }
  13. namespace {
  14. struct T128 {
  15. ui64 Lo;
  16. ui64 Hi;
  17. T128(ui64 x)
  18. : Lo(x)
  19. , Hi(0)
  20. {}
  21. bool operator==(const T128& other) const {
  22. return Lo == other.Lo && Hi == other.Hi;
  23. }
  24. };
  25. Function *CreateFibFunction(Module &M, LLVMContext &Context) {
  26. const auto funcType = FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context)}, false);
  27. // Create the fib function and insert it into module M. This function is said
  28. // to return an int and take an int parameter.
  29. Function *FibF = cast<Function>(M.getOrInsertFunction("fib", funcType).getCallee());
  30. // Add a basic block to the function.
  31. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", FibF);
  32. // Get pointers to the constants.
  33. Value *One = ConstantInt::get(Type::getInt32Ty(Context), 1);
  34. Value *Two = ConstantInt::get(Type::getInt32Ty(Context), 2);
  35. // Get pointer to the integer argument of the add1 function...
  36. auto ArgX = FibF->arg_begin(); // Get the arg.
  37. ArgX->setName("AnArg"); // Give it a nice symbolic name for fun.
  38. // Create the true_block.
  39. BasicBlock *RetBB = BasicBlock::Create(Context, "return", FibF);
  40. // Create an exit block.
  41. BasicBlock* RecurseBB = BasicBlock::Create(Context, "recurse", FibF);
  42. // Create the "if (arg <= 2) goto exitbb"
  43. Value *CondInst = CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_SLE, &*ArgX, Two, "cond", BB);
  44. BranchInst::Create(RetBB, RecurseBB, CondInst, BB);
  45. // Create: ret int 1
  46. ReturnInst::Create(Context, One, RetBB);
  47. // create fib(x-1)
  48. Value *Sub = BinaryOperator::CreateSub(&*ArgX, One, "arg", RecurseBB);
  49. CallInst *CallFibX1 = CallInst::Create(FibF, Sub, "fibx1", RecurseBB);
  50. CallFibX1->setTailCall();
  51. // create fib(x-2)
  52. Sub = BinaryOperator::CreateSub(&*ArgX, Two, "arg", RecurseBB);
  53. CallInst *CallFibX2 = CallInst::Create(FibF, Sub, "fibx2", RecurseBB);
  54. CallFibX2->setTailCall();
  55. // fib(x-1)+fib(x-2)
  56. Value *Sum = BinaryOperator::CreateAdd(CallFibX1, CallFibX2,
  57. "addresult", RecurseBB);
  58. // Create the return instruction and add it to the basic block
  59. ReturnInst::Create(Context, Sum, RecurseBB);
  60. return FibF;
  61. }
  62. Function *CreateBadFibFunction(Module &M, LLVMContext &Context) {
  63. const auto funcType = FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context)}, false);
  64. // Create the fib function and insert it into module M. This function is said
  65. // to return an int and take an int parameter.
  66. Function *FibF = cast<Function>(M.getOrInsertFunction("bad_fib", funcType).getCallee());
  67. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", FibF);
  68. // Get pointers to the constants.
  69. Value *One = ConstantInt::get(Type::getInt64Ty(Context), 1);
  70. ReturnInst::Create(Context, One, BB);
  71. return FibF;
  72. }
  73. Function *CreateMulFunction(Module &M, LLVMContext &Context) {
  74. const auto funcType = FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context), Type::getInt32Ty(Context)}, false);
  75. Function *MulF = cast<Function>(M.getOrInsertFunction("mul", funcType).getCallee());
  76. // Add a basic block to the function.
  77. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", MulF);
  78. auto args = MulF->arg_begin();
  79. auto ArgX = args; // Get the arg 1.
  80. ArgX->setName("x");
  81. auto ArgY = ++args; // Get the arg 2.
  82. ArgY->setName("y");
  83. // arg1 * arg2
  84. Value *Mul = BinaryOperator::CreateMul(&*ArgX, &*ArgY, "res", BB);
  85. // Create the return instruction and add it to the basic block
  86. ReturnInst::Create(Context, Mul, BB);
  87. return MulF;
  88. }
  89. Function *CreateUseNativeFunction(Module &M, LLVMContext &Context) {
  90. const auto funcType = FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context), Type::getInt32Ty(Context)}, false);
  91. Function *func = cast<Function>(M.getOrInsertFunction("add", funcType).getCallee());
  92. // Add a basic block to the function.
  93. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", func);
  94. auto args = func->arg_begin();
  95. auto ArgX = args; // Get the arg 1.
  96. ArgX->setName("x");
  97. auto ArgY = ++args; // Get the arg 2.
  98. ArgY->setName("y");
  99. Function* func_mul = M.getFunction("mul");
  100. if (!func_mul) {
  101. func_mul = Function::Create(
  102. /*Type=*/FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context), Type::getInt32Ty(Context)}, false),
  103. /*Linkage=*/GlobalValue::ExternalLinkage,
  104. /*Name=*/"mul", &M); // (external, no body)
  105. func_mul->setCallingConv(CallingConv::C);
  106. }
  107. // arg1 * arg2
  108. Value *Mul = CallInst::Create(func_mul, {&*ArgX, &*ArgY}, "res", BB);
  109. // Create the return instruction and add it to the basic block
  110. ReturnInst::Create(Context, Mul, BB);
  111. return func;
  112. }
  113. Function *CreateUseExternalFromGeneratedFunction(Module& main, LLVMContext &Context) {
  114. const auto funcType = FunctionType::get(Type::getInt32Ty(Context), {Type::getInt32Ty(Context), Type::getInt32Ty(Context), Type::getInt32Ty(Context)}, false);
  115. Function *func = cast<Function>(main.getOrInsertFunction("sum_sqr_3", funcType).getCallee());
  116. // Add a basic block to the function.
  117. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", func);
  118. auto args = func->arg_begin();
  119. auto ArgX = args; // Get the arg 1.
  120. ArgX->setName("x");
  121. auto ArgY = ++args; // Get the arg 2.
  122. ArgY->setName("y");
  123. auto ArgZ = ++args; // Get the arg 3.
  124. ArgZ->setName("z");
  125. Function* sum_sqr = main.getFunction("sum_sqr");
  126. Value *tmp = CallInst::Create(sum_sqr, {&*ArgX, &*ArgY}, "tmp", BB);
  127. Value *res = CallInst::Create(sum_sqr, {&*ArgZ, tmp}, "res", BB);
  128. // Create the return instruction and add it to the basic block
  129. ReturnInst::Create(Context, res, BB);
  130. return func;
  131. }
  132. Function *CreateUseExternalFromGeneratedFunction128(const ICodegen::TPtr& codegen, bool ir) {
  133. Module& main = codegen->GetModule();
  134. LLVMContext &Context = codegen->GetContext();
  135. auto typeInt128 = Type::getInt128Ty(Context);
  136. auto pointerInt128 = PointerType::getUnqual(typeInt128);
  137. const auto funcType = codegen->GetEffectiveTarget() != NYql::NCodegen::ETarget::Windows ?
  138. FunctionType::get(typeInt128, {typeInt128, typeInt128, typeInt128}, false):
  139. FunctionType::get(Type::getVoidTy(Context), {pointerInt128, pointerInt128, pointerInt128, pointerInt128}, false);
  140. Function *func = cast<Function>(main.getOrInsertFunction("sum_sqr_3", funcType).getCallee());
  141. auto args = func->arg_begin();
  142. // Add a basic block to the function.
  143. BasicBlock *BB = BasicBlock::Create(Context, "EntryBlock", func);
  144. llvm::Argument* retArg = nullptr;
  145. if (codegen->GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows) {
  146. retArg = &*args++;
  147. retArg->addAttr(Attribute::StructRet);
  148. retArg->addAttr(Attribute::NoAlias);
  149. }
  150. auto ArgX = args++; // Get the arg 1.
  151. ArgX->setName("x");
  152. auto ArgY = args++; // Get the arg 2.
  153. ArgY->setName("y");
  154. auto ArgZ = args++; // Get the arg 3.
  155. ArgZ->setName("z");
  156. const auto type = FunctionType::get(Type::getVoidTy(Context), { pointerInt128, pointerInt128, pointerInt128 }, false);
  157. const auto sum_sqr = main.getOrInsertFunction(ir ? "sum_sqr_128_ir" : "sum_sqr_128", type);
  158. if (codegen->GetEffectiveTarget() == NYql::NCodegen::ETarget::Windows) {
  159. Value* tmp1 = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "tmp1", BB);
  160. Value* tmp2 = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "tmp2", BB);
  161. CallInst::Create(sum_sqr, { &*tmp1, &*ArgX, &*ArgY }, "", BB);
  162. CallInst::Create(sum_sqr, { &*tmp2, &*ArgZ, &*tmp1 }, "", BB);
  163. auto res = new LoadInst(typeInt128, tmp2, "load_res", BB);
  164. new StoreInst(res, retArg, BB);
  165. // Create the return instruction and add it to the basic block
  166. ReturnInst::Create(Context, BB);
  167. } else {
  168. Value* tmp1 = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "tmp1", BB);
  169. Value* tmp2 = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "tmp2", BB);
  170. Value* argXPtr = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "argXptr", BB);
  171. Value* argYPtr = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "argYptr", BB);
  172. Value* argZPtr = new AllocaInst(typeInt128, 0U, nullptr, llvm::Align(16), "argZptr", BB);
  173. new StoreInst(&*ArgX, argXPtr, BB);
  174. new StoreInst(&*ArgY, argYPtr, BB);
  175. new StoreInst(&*ArgZ, argZPtr, BB);
  176. CallInst::Create(sum_sqr, { &*tmp1, &*argXPtr, &*argYPtr }, "", BB);
  177. CallInst::Create(sum_sqr, { &*tmp2, &*argZPtr, &*tmp1 }, "", BB);
  178. auto res = new LoadInst(typeInt128, tmp2, "load_res", BB);
  179. // Create the return instruction and add it to the basic block
  180. ReturnInst::Create(Context, res, BB);
  181. }
  182. return func;
  183. }
  184. }
  185. #if !defined(_ubsan_enabled_) && !defined(HAVE_VALGRIND)
  186. Y_UNIT_TEST_SUITE(TCodegenTests) {
  187. Y_UNIT_TEST(FibNative) {
  188. auto codegen = ICodegen::Make(ETarget::Native);
  189. auto func = CreateFibFunction(codegen->GetModule(), codegen->GetContext());
  190. codegen->Verify();
  191. codegen->Compile();
  192. typedef int(*TFunc)(int);
  193. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  194. UNIT_ASSERT_VALUES_EQUAL(funcPtr(24), 46368);
  195. }
  196. Y_UNIT_TEST(FibCurrentOS) {
  197. auto codegen = ICodegen::Make(ETarget::CurrentOS);
  198. auto func = CreateFibFunction(codegen->GetModule(), codegen->GetContext());
  199. codegen->Verify();
  200. codegen->Compile();
  201. typedef int(*TFunc)(int);
  202. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  203. UNIT_ASSERT_VALUES_EQUAL(funcPtr(24), 46368);
  204. }
  205. Y_UNIT_TEST(BadFib) {
  206. auto codegen = ICodegen::Make(ETarget::Native);
  207. auto func = CreateBadFibFunction(codegen->GetModule(), codegen->GetContext());
  208. Y_UNUSED(func);
  209. UNIT_ASSERT_EXCEPTION(codegen->Verify(), yexception);
  210. }
  211. Y_UNIT_TEST(FibFromBitCode) {
  212. auto codegen = ICodegen::Make(ETarget::Native);
  213. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  214. codegen->LoadBitCode(bitcode, "Funcs");
  215. auto func = codegen->GetModule().getFunction("fib");
  216. codegen->Verify();
  217. codegen->ExportSymbol(func);
  218. codegen->Compile();
  219. typedef int(*TFunc)(int);
  220. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  221. UNIT_ASSERT_VALUES_EQUAL(funcPtr(24), 46368);
  222. }
  223. Y_UNIT_TEST(LinkWithNativeFunction) {
  224. auto codegen = ICodegen::Make(ETarget::Native);
  225. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  226. codegen->LoadBitCode(bitcode, "Funcs");
  227. auto func = codegen->GetModule().getFunction("sum_sqr");
  228. codegen->AddGlobalMapping("mul", (void*)&sum);
  229. codegen->ExportSymbol(func);
  230. codegen->Verify();
  231. codegen->Compile();
  232. typedef int(*TFunc)(int, int);
  233. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  234. UNIT_ASSERT_VALUES_EQUAL(funcPtr(3, 4), 14);
  235. }
  236. Y_UNIT_TEST(LinkWithGeneratedFunction) {
  237. auto codegen = ICodegen::Make(ETarget::Native);
  238. auto mulFunc = CreateMulFunction(codegen->GetModule(), codegen->GetContext());
  239. Y_UNUSED(mulFunc);
  240. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  241. codegen->LoadBitCode(bitcode, "Funcs");
  242. auto func = codegen->GetModule().getFunction("sum_sqr");
  243. codegen->ExportSymbol(func);
  244. codegen->Verify();
  245. codegen->Compile();
  246. typedef int(*TFunc)(int, int);
  247. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  248. UNIT_ASSERT_VALUES_EQUAL(funcPtr(3, 4), 25);
  249. }
  250. Y_UNIT_TEST(ReuseExternalCode) {
  251. auto codegen = ICodegen::Make(ETarget::Native);
  252. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  253. codegen->LoadBitCode(bitcode, "Funcs");
  254. auto func = codegen->GetModule().getFunction("sum_sqr2");
  255. codegen->ExportSymbol(func);
  256. codegen->Verify();
  257. codegen->Compile();
  258. typedef int(*TFunc)(int, int);
  259. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  260. UNIT_ASSERT_VALUES_EQUAL(funcPtr(3, 4), 25);
  261. }
  262. Y_UNIT_TEST(UseObjectReference) {
  263. auto codegen = ICodegen::Make(ETarget::Native);
  264. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  265. codegen->LoadBitCode(bitcode, "Funcs");
  266. auto func = codegen->GetModule().getFunction("str_size");
  267. codegen->ExportSymbol(func);
  268. codegen->Verify();
  269. codegen->Compile();
  270. typedef size_t(*TFunc)(const std::string&);
  271. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  272. const std::string hw("Hello World!");
  273. UNIT_ASSERT_VALUES_EQUAL(funcPtr(hw), 12);
  274. }
  275. Y_UNIT_TEST(UseNativeFromGeneratedFunction) {
  276. auto codegen = ICodegen::Make(ETarget::Native);
  277. auto func = CreateUseNativeFunction(codegen->GetModule(), codegen->GetContext());
  278. codegen->AddGlobalMapping("mul", (void*)&mul);
  279. codegen->ExportSymbol(func);
  280. codegen->Verify();
  281. codegen->Compile();
  282. typedef int(*TFunc)(int, int);
  283. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  284. UNIT_ASSERT_VALUES_EQUAL(funcPtr(3, 4), 12);
  285. }
  286. Y_UNIT_TEST(UseExternalFromGeneratedFunction) {
  287. auto codegen = ICodegen::Make(ETarget::Native);
  288. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  289. codegen->LoadBitCode(bitcode, "Funcs");
  290. auto func = CreateUseExternalFromGeneratedFunction(codegen->GetModule(), codegen->GetContext());
  291. codegen->ExportSymbol(func);
  292. codegen->AddGlobalMapping("mul", (void*)&mul);
  293. codegen->Verify();
  294. codegen->Compile();
  295. typedef int(*TFunc)(int, int, int);
  296. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  297. UNIT_ASSERT_VALUES_EQUAL(funcPtr(7, 4, 8), 4289);
  298. }
  299. Y_UNIT_TEST(UseExternalFromGeneratedFunction_128bit_Compiled) {
  300. auto codegen = ICodegen::Make(ETarget::Native);
  301. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  302. codegen->LoadBitCode(bitcode, "Funcs");
  303. auto func = CreateUseExternalFromGeneratedFunction128(codegen, false);
  304. codegen->ExportSymbol(func);
  305. codegen->Verify();
  306. codegen->Compile();
  307. TStringStream str;
  308. codegen->ShowGeneratedFunctions(&str);
  309. #ifdef _win_
  310. typedef T128 (*TFunc)(T128, T128, T128);
  311. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  312. UNIT_ASSERT(funcPtr(T128(7), T128(4), T128(8)) == T128(4289));
  313. #else
  314. typedef unsigned __int128(*TFunc)(__int128, __int128, __int128);
  315. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  316. UNIT_ASSERT(funcPtr(7, 4, 8) == 4289);
  317. #endif
  318. #if !defined(_asan_enabled_) && !defined(_msan_enabled_) && !defined(_tsan_enabled_) && !defined(_hardening_enabled_)
  319. if (str.Str().Contains("call")) {
  320. UNIT_FAIL("Expected inline, disasm:\n" + str.Str());
  321. }
  322. #endif
  323. }
  324. Y_UNIT_TEST(UseExternalFromGeneratedFunction_128bit_Bitcode) {
  325. auto codegen = ICodegen::Make(ETarget::Native);
  326. auto bitcode = NResource::Find("/llvm_bc/Funcs");
  327. codegen->LoadBitCode(bitcode, "Funcs");
  328. auto func = CreateUseExternalFromGeneratedFunction128(codegen, true);
  329. codegen->ExportSymbol(func);
  330. codegen->Verify();
  331. codegen->Compile();
  332. TStringStream str;
  333. codegen->ShowGeneratedFunctions(&str);
  334. #ifdef _win_
  335. typedef T128(*TFunc)(T128, T128, T128);
  336. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  337. UNIT_ASSERT(funcPtr(T128(7), T128(4), T128(8)) == T128(4289));
  338. #else
  339. typedef unsigned __int128(*TFunc)(__int128, __int128, __int128);
  340. auto funcPtr = (TFunc)codegen->GetPointerToFunction(func);
  341. UNIT_ASSERT(funcPtr(7, 4, 8) == 4289);
  342. #endif
  343. #if !defined(_asan_enabled_) && !defined(_msan_enabled_) && !defined(_tsan_enabled_)
  344. if (str.Str().Contains("call")) {
  345. UNIT_FAIL("Expected inline, disasm:\n" + str.Str());
  346. }
  347. #endif
  348. }
  349. }
  350. #endif