SMTAPI.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- SMTAPI.h -------------------------------------------------*- C++ -*-===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. //
  14. // This file defines a SMT generic Solver API, which will be the base class
  15. // for every SMT solver specific class.
  16. //
  17. //===----------------------------------------------------------------------===//
  18. #ifndef LLVM_SUPPORT_SMTAPI_H
  19. #define LLVM_SUPPORT_SMTAPI_H
  20. #include "llvm/ADT/APFloat.h"
  21. #include "llvm/ADT/APSInt.h"
  22. #include "llvm/ADT/FoldingSet.h"
  23. #include "llvm/Support/raw_ostream.h"
  24. #include <memory>
  25. namespace llvm {
  26. /// Generic base class for SMT sorts
  27. class SMTSort {
  28. public:
  29. SMTSort() = default;
  30. virtual ~SMTSort() = default;
  31. /// Returns true if the sort is a bitvector, calls isBitvectorSortImpl().
  32. virtual bool isBitvectorSort() const { return isBitvectorSortImpl(); }
  33. /// Returns true if the sort is a floating-point, calls isFloatSortImpl().
  34. virtual bool isFloatSort() const { return isFloatSortImpl(); }
  35. /// Returns true if the sort is a boolean, calls isBooleanSortImpl().
  36. virtual bool isBooleanSort() const { return isBooleanSortImpl(); }
  37. /// Returns the bitvector size, fails if the sort is not a bitvector
  38. /// Calls getBitvectorSortSizeImpl().
  39. virtual unsigned getBitvectorSortSize() const {
  40. assert(isBitvectorSort() && "Not a bitvector sort!");
  41. unsigned Size = getBitvectorSortSizeImpl();
  42. assert(Size && "Size is zero!");
  43. return Size;
  44. };
  45. /// Returns the floating-point size, fails if the sort is not a floating-point
  46. /// Calls getFloatSortSizeImpl().
  47. virtual unsigned getFloatSortSize() const {
  48. assert(isFloatSort() && "Not a floating-point sort!");
  49. unsigned Size = getFloatSortSizeImpl();
  50. assert(Size && "Size is zero!");
  51. return Size;
  52. };
  53. virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
  54. bool operator<(const SMTSort &Other) const {
  55. llvm::FoldingSetNodeID ID1, ID2;
  56. Profile(ID1);
  57. Other.Profile(ID2);
  58. return ID1 < ID2;
  59. }
  60. friend bool operator==(SMTSort const &LHS, SMTSort const &RHS) {
  61. return LHS.equal_to(RHS);
  62. }
  63. virtual void print(raw_ostream &OS) const = 0;
  64. LLVM_DUMP_METHOD void dump() const;
  65. protected:
  66. /// Query the SMT solver and returns true if two sorts are equal (same kind
  67. /// and bit width). This does not check if the two sorts are the same objects.
  68. virtual bool equal_to(SMTSort const &other) const = 0;
  69. /// Query the SMT solver and checks if a sort is bitvector.
  70. virtual bool isBitvectorSortImpl() const = 0;
  71. /// Query the SMT solver and checks if a sort is floating-point.
  72. virtual bool isFloatSortImpl() const = 0;
  73. /// Query the SMT solver and checks if a sort is boolean.
  74. virtual bool isBooleanSortImpl() const = 0;
  75. /// Query the SMT solver and returns the sort bit width.
  76. virtual unsigned getBitvectorSortSizeImpl() const = 0;
  77. /// Query the SMT solver and returns the sort bit width.
  78. virtual unsigned getFloatSortSizeImpl() const = 0;
  79. };
  80. /// Shared pointer for SMTSorts, used by SMTSolver API.
  81. using SMTSortRef = const SMTSort *;
  82. /// Generic base class for SMT exprs
  83. class SMTExpr {
  84. public:
  85. SMTExpr() = default;
  86. virtual ~SMTExpr() = default;
  87. bool operator<(const SMTExpr &Other) const {
  88. llvm::FoldingSetNodeID ID1, ID2;
  89. Profile(ID1);
  90. Other.Profile(ID2);
  91. return ID1 < ID2;
  92. }
  93. virtual void Profile(llvm::FoldingSetNodeID &ID) const = 0;
  94. friend bool operator==(SMTExpr const &LHS, SMTExpr const &RHS) {
  95. return LHS.equal_to(RHS);
  96. }
  97. virtual void print(raw_ostream &OS) const = 0;
  98. LLVM_DUMP_METHOD void dump() const;
  99. protected:
  100. /// Query the SMT solver and returns true if two sorts are equal (same kind
  101. /// and bit width). This does not check if the two sorts are the same objects.
  102. virtual bool equal_to(SMTExpr const &other) const = 0;
  103. };
  104. /// Shared pointer for SMTExprs, used by SMTSolver API.
  105. using SMTExprRef = const SMTExpr *;
  106. /// Generic base class for SMT Solvers
  107. ///
  108. /// This class is responsible for wrapping all sorts and expression generation,
  109. /// through the mk* methods. It also provides methods to create SMT expressions
  110. /// straight from clang's AST, through the from* methods.
  111. class SMTSolver {
  112. public:
  113. SMTSolver() = default;
  114. virtual ~SMTSolver() = default;
  115. LLVM_DUMP_METHOD void dump() const;
  116. // Returns an appropriate floating-point sort for the given bitwidth.
  117. SMTSortRef getFloatSort(unsigned BitWidth) {
  118. switch (BitWidth) {
  119. case 16:
  120. return getFloat16Sort();
  121. case 32:
  122. return getFloat32Sort();
  123. case 64:
  124. return getFloat64Sort();
  125. case 128:
  126. return getFloat128Sort();
  127. default:;
  128. }
  129. llvm_unreachable("Unsupported floating-point bitwidth!");
  130. }
  131. // Returns a boolean sort.
  132. virtual SMTSortRef getBoolSort() = 0;
  133. // Returns an appropriate bitvector sort for the given bitwidth.
  134. virtual SMTSortRef getBitvectorSort(const unsigned BitWidth) = 0;
  135. // Returns a floating-point sort of width 16
  136. virtual SMTSortRef getFloat16Sort() = 0;
  137. // Returns a floating-point sort of width 32
  138. virtual SMTSortRef getFloat32Sort() = 0;
  139. // Returns a floating-point sort of width 64
  140. virtual SMTSortRef getFloat64Sort() = 0;
  141. // Returns a floating-point sort of width 128
  142. virtual SMTSortRef getFloat128Sort() = 0;
  143. // Returns an appropriate sort for the given AST.
  144. virtual SMTSortRef getSort(const SMTExprRef &AST) = 0;
  145. /// Given a constraint, adds it to the solver
  146. virtual void addConstraint(const SMTExprRef &Exp) const = 0;
  147. /// Creates a bitvector addition operation
  148. virtual SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  149. /// Creates a bitvector subtraction operation
  150. virtual SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  151. /// Creates a bitvector multiplication operation
  152. virtual SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  153. /// Creates a bitvector signed modulus operation
  154. virtual SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  155. /// Creates a bitvector unsigned modulus operation
  156. virtual SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  157. /// Creates a bitvector signed division operation
  158. virtual SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  159. /// Creates a bitvector unsigned division operation
  160. virtual SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  161. /// Creates a bitvector logical shift left operation
  162. virtual SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  163. /// Creates a bitvector arithmetic shift right operation
  164. virtual SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  165. /// Creates a bitvector logical shift right operation
  166. virtual SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  167. /// Creates a bitvector negation operation
  168. virtual SMTExprRef mkBVNeg(const SMTExprRef &Exp) = 0;
  169. /// Creates a bitvector not operation
  170. virtual SMTExprRef mkBVNot(const SMTExprRef &Exp) = 0;
  171. /// Creates a bitvector xor operation
  172. virtual SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  173. /// Creates a bitvector or operation
  174. virtual SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  175. /// Creates a bitvector and operation
  176. virtual SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  177. /// Creates a bitvector unsigned less-than operation
  178. virtual SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  179. /// Creates a bitvector signed less-than operation
  180. virtual SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  181. /// Creates a bitvector unsigned greater-than operation
  182. virtual SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  183. /// Creates a bitvector signed greater-than operation
  184. virtual SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  185. /// Creates a bitvector unsigned less-equal-than operation
  186. virtual SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  187. /// Creates a bitvector signed less-equal-than operation
  188. virtual SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  189. /// Creates a bitvector unsigned greater-equal-than operation
  190. virtual SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  191. /// Creates a bitvector signed greater-equal-than operation
  192. virtual SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  193. /// Creates a boolean not operation
  194. virtual SMTExprRef mkNot(const SMTExprRef &Exp) = 0;
  195. /// Creates a boolean equality operation
  196. virtual SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  197. /// Creates a boolean and operation
  198. virtual SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  199. /// Creates a boolean or operation
  200. virtual SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  201. /// Creates a boolean ite operation
  202. virtual SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
  203. const SMTExprRef &F) = 0;
  204. /// Creates a bitvector sign extension operation
  205. virtual SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) = 0;
  206. /// Creates a bitvector zero extension operation
  207. virtual SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) = 0;
  208. /// Creates a bitvector extract operation
  209. virtual SMTExprRef mkBVExtract(unsigned High, unsigned Low,
  210. const SMTExprRef &Exp) = 0;
  211. /// Creates a bitvector concat operation
  212. virtual SMTExprRef mkBVConcat(const SMTExprRef &LHS,
  213. const SMTExprRef &RHS) = 0;
  214. /// Creates a predicate that checks for overflow in a bitvector addition
  215. /// operation
  216. virtual SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS,
  217. const SMTExprRef &RHS,
  218. bool isSigned) = 0;
  219. /// Creates a predicate that checks for underflow in a signed bitvector
  220. /// addition operation
  221. virtual SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
  222. const SMTExprRef &RHS) = 0;
  223. /// Creates a predicate that checks for overflow in a signed bitvector
  224. /// subtraction operation
  225. virtual SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
  226. const SMTExprRef &RHS) = 0;
  227. /// Creates a predicate that checks for underflow in a bitvector subtraction
  228. /// operation
  229. virtual SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS,
  230. const SMTExprRef &RHS,
  231. bool isSigned) = 0;
  232. /// Creates a predicate that checks for overflow in a signed bitvector
  233. /// division/modulus operation
  234. virtual SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
  235. const SMTExprRef &RHS) = 0;
  236. /// Creates a predicate that checks for overflow in a bitvector negation
  237. /// operation
  238. virtual SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) = 0;
  239. /// Creates a predicate that checks for overflow in a bitvector multiplication
  240. /// operation
  241. virtual SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS,
  242. const SMTExprRef &RHS,
  243. bool isSigned) = 0;
  244. /// Creates a predicate that checks for underflow in a signed bitvector
  245. /// multiplication operation
  246. virtual SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
  247. const SMTExprRef &RHS) = 0;
  248. /// Creates a floating-point negation operation
  249. virtual SMTExprRef mkFPNeg(const SMTExprRef &Exp) = 0;
  250. /// Creates a floating-point isInfinite operation
  251. virtual SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) = 0;
  252. /// Creates a floating-point isNaN operation
  253. virtual SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) = 0;
  254. /// Creates a floating-point isNormal operation
  255. virtual SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) = 0;
  256. /// Creates a floating-point isZero operation
  257. virtual SMTExprRef mkFPIsZero(const SMTExprRef &Exp) = 0;
  258. /// Creates a floating-point multiplication operation
  259. virtual SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  260. /// Creates a floating-point division operation
  261. virtual SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  262. /// Creates a floating-point remainder operation
  263. virtual SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  264. /// Creates a floating-point addition operation
  265. virtual SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  266. /// Creates a floating-point subtraction operation
  267. virtual SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  268. /// Creates a floating-point less-than operation
  269. virtual SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  270. /// Creates a floating-point greater-than operation
  271. virtual SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  272. /// Creates a floating-point less-than-or-equal operation
  273. virtual SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  274. /// Creates a floating-point greater-than-or-equal operation
  275. virtual SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) = 0;
  276. /// Creates a floating-point equality operation
  277. virtual SMTExprRef mkFPEqual(const SMTExprRef &LHS,
  278. const SMTExprRef &RHS) = 0;
  279. /// Creates a floating-point conversion from floatint-point to floating-point
  280. /// operation
  281. virtual SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) = 0;
  282. /// Creates a floating-point conversion from signed bitvector to
  283. /// floatint-point operation
  284. virtual SMTExprRef mkSBVtoFP(const SMTExprRef &From,
  285. const SMTSortRef &To) = 0;
  286. /// Creates a floating-point conversion from unsigned bitvector to
  287. /// floatint-point operation
  288. virtual SMTExprRef mkUBVtoFP(const SMTExprRef &From,
  289. const SMTSortRef &To) = 0;
  290. /// Creates a floating-point conversion from floatint-point to signed
  291. /// bitvector operation
  292. virtual SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) = 0;
  293. /// Creates a floating-point conversion from floatint-point to unsigned
  294. /// bitvector operation
  295. virtual SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) = 0;
  296. /// Creates a new symbol, given a name and a sort
  297. virtual SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) = 0;
  298. // Returns an appropriate floating-point rounding mode.
  299. virtual SMTExprRef getFloatRoundingMode() = 0;
  300. // If the a model is available, returns the value of a given bitvector symbol
  301. virtual llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
  302. bool isUnsigned) = 0;
  303. // If the a model is available, returns the value of a given boolean symbol
  304. virtual bool getBoolean(const SMTExprRef &Exp) = 0;
  305. /// Constructs an SMTExprRef from a boolean.
  306. virtual SMTExprRef mkBoolean(const bool b) = 0;
  307. /// Constructs an SMTExprRef from a finite APFloat.
  308. virtual SMTExprRef mkFloat(const llvm::APFloat Float) = 0;
  309. /// Constructs an SMTExprRef from an APSInt and its bit width
  310. virtual SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) = 0;
  311. /// Given an expression, extract the value of this operand in the model.
  312. virtual bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) = 0;
  313. /// Given an expression extract the value of this operand in the model.
  314. virtual bool getInterpretation(const SMTExprRef &Exp,
  315. llvm::APFloat &Float) = 0;
  316. /// Check if the constraints are satisfiable
  317. virtual std::optional<bool> check() const = 0;
  318. /// Push the current solver state
  319. virtual void push() = 0;
  320. /// Pop the previous solver state
  321. virtual void pop(unsigned NumStates = 1) = 0;
  322. /// Reset the solver and remove all constraints.
  323. virtual void reset() = 0;
  324. /// Checks if the solver supports floating-points.
  325. virtual bool isFPSupported() = 0;
  326. virtual void print(raw_ostream &OS) const = 0;
  327. };
  328. /// Shared pointer for SMTSolvers.
  329. using SMTSolverRef = std::shared_ptr<SMTSolver>;
  330. /// Convenience method to create and Z3Solver object
  331. SMTSolverRef CreateZ3Solver();
  332. } // namespace llvm
  333. #endif
  334. #ifdef __GNUC__
  335. #pragma GCC diagnostic pop
  336. #endif