mkql_ensure.cpp 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #include "mkql_ensure.h"
  2. #include <yql/essentials/minikql/computation/mkql_computation_node_codegen.h> // Y_IGNORE
  3. #include <yql/essentials/minikql/mkql_node_cast.h>
  4. #include <yql/essentials/minikql/mkql_program_builder.h>
  5. #include <yql/essentials/public/udf/udf_terminator.h>
  6. #include <yql/essentials/public/udf/udf_type_builder.h>
  7. namespace NKikimr {
  8. namespace NMiniKQL {
  9. namespace {
  10. class TEnsureWrapper : public TMutableCodegeneratorNode<TEnsureWrapper> {
  11. typedef TMutableCodegeneratorNode<TEnsureWrapper> TBaseComputation;
  12. public:
  13. TEnsureWrapper(TComputationMutables& mutables, IComputationNode* value, IComputationNode* predicate,
  14. IComputationNode* message, const NUdf::TSourcePosition& pos)
  15. : TBaseComputation(mutables, value->GetRepresentation())
  16. , Arg(value)
  17. , Predicate(predicate)
  18. , Message(message)
  19. , Pos(pos)
  20. {
  21. }
  22. NUdf::TUnboxedValuePod DoCalculate(TComputationContext& ctx) const {
  23. const auto& predicate = Predicate->GetValue(ctx);
  24. if (predicate && predicate.Get<bool>()) {
  25. return Arg->GetValue(ctx).Release();
  26. }
  27. Throw(this, &ctx);
  28. }
  29. #ifndef MKQL_DISABLE_CODEGEN
  30. Value* DoGenerateGetValue(const TCodegenContext& ctx, BasicBlock*& block) const {
  31. auto& context = ctx.Codegen.GetContext();
  32. const auto predicate = GetNodeValue(Predicate, ctx, block);
  33. const auto pass = CastInst::Create(Instruction::Trunc, predicate, Type::getInt1Ty(context), "bool", block);
  34. const auto kill = BasicBlock::Create(context, "kill", ctx.Func);
  35. const auto good = BasicBlock::Create(context, "good", ctx.Func);
  36. BranchInst::Create(good, kill, pass, block);
  37. block = kill;
  38. const auto doFunc = ConstantInt::get(Type::getInt64Ty(context), GetMethodPtr(&TEnsureWrapper::Throw));
  39. const auto doFuncArg = ConstantInt::get(Type::getInt64Ty(context), (ui64)this);
  40. const auto doFuncType = FunctionType::get(Type::getVoidTy(context), { Type::getInt64Ty(context), ctx.Ctx->getType() }, false);
  41. const auto doFuncPtr = CastInst::Create(Instruction::IntToPtr, doFunc, PointerType::getUnqual(doFuncType), "thrower", block);
  42. CallInst::Create(doFuncType, doFuncPtr, { doFuncArg, ctx.Ctx }, "", block)->setTailCall();
  43. new UnreachableInst(context, block);
  44. block = good;
  45. return GetNodeValue(Arg, ctx, block);;
  46. }
  47. #endif
  48. private:
  49. [[noreturn]] static void Throw(TEnsureWrapper const* thisPtr, TComputationContext* ctxPtr) {
  50. auto message = thisPtr->Message->GetValue(*ctxPtr);
  51. auto messageStr = message.AsStringRef();
  52. TStringBuilder res;
  53. res << thisPtr->Pos << " Condition violated";
  54. if (messageStr.Size() > 0) {
  55. res << ":\n\n" << TStringBuf(messageStr) << "\n\n";
  56. }
  57. UdfTerminate(res.data());
  58. }
  59. void RegisterDependencies() const final {
  60. DependsOn(Arg);
  61. DependsOn(Predicate);
  62. }
  63. IComputationNode* const Arg;
  64. IComputationNode* const Predicate;
  65. IComputationNode* const Message;
  66. const NUdf::TSourcePosition Pos;
  67. };
  68. }
  69. IComputationNode* WrapEnsure(TCallable& callable, const TComputationNodeFactoryContext& ctx) {
  70. MKQL_ENSURE(callable.GetInputsCount() == 6, "Expected 6 args");
  71. bool isOptional;
  72. auto unpackedType = UnpackOptionalData(callable.GetInput(1), isOptional);
  73. MKQL_ENSURE(unpackedType->GetSchemeType() == NUdf::TDataType<bool>::Id, "Expected bool");
  74. auto value = LocateNode(ctx.NodeLocator, callable, 0);
  75. auto predicate = LocateNode(ctx.NodeLocator, callable, 1);
  76. auto message = LocateNode(ctx.NodeLocator, callable, 2);
  77. const TStringBuf file = AS_VALUE(TDataLiteral, callable.GetInput(3))->AsValue().AsStringRef();
  78. const ui32 row = AS_VALUE(TDataLiteral, callable.GetInput(4))->AsValue().Get<ui32>();
  79. const ui32 column = AS_VALUE(TDataLiteral, callable.GetInput(5))->AsValue().Get<ui32>();
  80. return new TEnsureWrapper(ctx.Mutables, value, predicate, message, NUdf::TSourcePosition(row, column, file));
  81. }
  82. }
  83. }