ast_opt.c 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137
  1. /* AST Optimizer */
  2. #include "Python.h"
  3. #include "pycore_ast.h" // _PyAST_GetDocString()
  4. #include "pycore_compile.h" // _PyASTOptimizeState
  5. #include "pycore_long.h" // _PyLong
  6. #include "pycore_pystate.h" // _PyThreadState_GET()
  7. #include "pycore_format.h" // F_LJUST
  8. static int
  9. make_const(expr_ty node, PyObject *val, PyArena *arena)
  10. {
  11. // Even if no new value was calculated, make_const may still
  12. // need to clear an error (e.g. for division by zero)
  13. if (val == NULL) {
  14. if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
  15. return 0;
  16. }
  17. PyErr_Clear();
  18. return 1;
  19. }
  20. if (_PyArena_AddPyObject(arena, val) < 0) {
  21. Py_DECREF(val);
  22. return 0;
  23. }
  24. node->kind = Constant_kind;
  25. node->v.Constant.kind = NULL;
  26. node->v.Constant.value = val;
  27. return 1;
  28. }
  29. #define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
  30. static int
  31. has_starred(asdl_expr_seq *elts)
  32. {
  33. Py_ssize_t n = asdl_seq_LEN(elts);
  34. for (Py_ssize_t i = 0; i < n; i++) {
  35. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  36. if (e->kind == Starred_kind) {
  37. return 1;
  38. }
  39. }
  40. return 0;
  41. }
  42. static PyObject*
  43. unary_not(PyObject *v)
  44. {
  45. int r = PyObject_IsTrue(v);
  46. if (r < 0)
  47. return NULL;
  48. return PyBool_FromLong(!r);
  49. }
  50. static int
  51. fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
  52. {
  53. expr_ty arg = node->v.UnaryOp.operand;
  54. if (arg->kind != Constant_kind) {
  55. /* Fold not into comparison */
  56. if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
  57. asdl_seq_LEN(arg->v.Compare.ops) == 1) {
  58. /* Eq and NotEq are often implemented in terms of one another, so
  59. folding not (self == other) into self != other breaks implementation
  60. of !=. Detecting such cases doesn't seem worthwhile.
  61. Python uses </> for 'is subset'/'is superset' operations on sets.
  62. They don't satisfy not folding laws. */
  63. cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
  64. switch (op) {
  65. case Is:
  66. op = IsNot;
  67. break;
  68. case IsNot:
  69. op = Is;
  70. break;
  71. case In:
  72. op = NotIn;
  73. break;
  74. case NotIn:
  75. op = In;
  76. break;
  77. // The remaining comparison operators can't be safely inverted
  78. case Eq:
  79. case NotEq:
  80. case Lt:
  81. case LtE:
  82. case Gt:
  83. case GtE:
  84. op = 0; // The AST enums leave "0" free as an "unused" marker
  85. break;
  86. // No default case, so the compiler will emit a warning if new
  87. // comparison operators are added without being handled here
  88. }
  89. if (op) {
  90. asdl_seq_SET(arg->v.Compare.ops, 0, op);
  91. COPY_NODE(node, arg);
  92. return 1;
  93. }
  94. }
  95. return 1;
  96. }
  97. typedef PyObject *(*unary_op)(PyObject*);
  98. static const unary_op ops[] = {
  99. [Invert] = PyNumber_Invert,
  100. [Not] = unary_not,
  101. [UAdd] = PyNumber_Positive,
  102. [USub] = PyNumber_Negative,
  103. };
  104. PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
  105. return make_const(node, newval, arena);
  106. }
  107. /* Check whether a collection doesn't containing too much items (including
  108. subcollections). This protects from creating a constant that needs
  109. too much time for calculating a hash.
  110. "limit" is the maximal number of items.
  111. Returns the negative number if the total number of items exceeds the
  112. limit. Otherwise returns the limit minus the total number of items.
  113. */
  114. static Py_ssize_t
  115. check_complexity(PyObject *obj, Py_ssize_t limit)
  116. {
  117. if (PyTuple_Check(obj)) {
  118. Py_ssize_t i;
  119. limit -= PyTuple_GET_SIZE(obj);
  120. for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
  121. limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
  122. }
  123. return limit;
  124. }
  125. else if (PyFrozenSet_Check(obj)) {
  126. Py_ssize_t i = 0;
  127. PyObject *item;
  128. Py_hash_t hash;
  129. limit -= PySet_GET_SIZE(obj);
  130. while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
  131. limit = check_complexity(item, limit);
  132. }
  133. }
  134. return limit;
  135. }
  136. #define MAX_INT_SIZE 128 /* bits */
  137. #define MAX_COLLECTION_SIZE 256 /* items */
  138. #define MAX_STR_SIZE 4096 /* characters */
  139. #define MAX_TOTAL_ITEMS 1024 /* including nested collections */
  140. static PyObject *
  141. safe_multiply(PyObject *v, PyObject *w)
  142. {
  143. if (PyLong_Check(v) && PyLong_Check(w) &&
  144. !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
  145. ) {
  146. size_t vbits = _PyLong_NumBits(v);
  147. size_t wbits = _PyLong_NumBits(w);
  148. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  149. return NULL;
  150. }
  151. if (vbits + wbits > MAX_INT_SIZE) {
  152. return NULL;
  153. }
  154. }
  155. else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
  156. Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
  157. PySet_GET_SIZE(w);
  158. if (size) {
  159. long n = PyLong_AsLong(v);
  160. if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
  161. return NULL;
  162. }
  163. if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
  164. return NULL;
  165. }
  166. }
  167. }
  168. else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
  169. Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
  170. PyBytes_GET_SIZE(w);
  171. if (size) {
  172. long n = PyLong_AsLong(v);
  173. if (n < 0 || n > MAX_STR_SIZE / size) {
  174. return NULL;
  175. }
  176. }
  177. }
  178. else if (PyLong_Check(w) &&
  179. (PyTuple_Check(v) || PyFrozenSet_Check(v) ||
  180. PyUnicode_Check(v) || PyBytes_Check(v)))
  181. {
  182. return safe_multiply(w, v);
  183. }
  184. return PyNumber_Multiply(v, w);
  185. }
  186. static PyObject *
  187. safe_power(PyObject *v, PyObject *w)
  188. {
  189. if (PyLong_Check(v) && PyLong_Check(w) &&
  190. !_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
  191. ) {
  192. size_t vbits = _PyLong_NumBits(v);
  193. size_t wbits = PyLong_AsSize_t(w);
  194. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  195. return NULL;
  196. }
  197. if (vbits > MAX_INT_SIZE / wbits) {
  198. return NULL;
  199. }
  200. }
  201. return PyNumber_Power(v, w, Py_None);
  202. }
  203. static PyObject *
  204. safe_lshift(PyObject *v, PyObject *w)
  205. {
  206. if (PyLong_Check(v) && PyLong_Check(w) &&
  207. !_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
  208. ) {
  209. size_t vbits = _PyLong_NumBits(v);
  210. size_t wbits = PyLong_AsSize_t(w);
  211. if (vbits == (size_t)-1 || wbits == (size_t)-1) {
  212. return NULL;
  213. }
  214. if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
  215. return NULL;
  216. }
  217. }
  218. return PyNumber_Lshift(v, w);
  219. }
  220. static PyObject *
  221. safe_mod(PyObject *v, PyObject *w)
  222. {
  223. if (PyUnicode_Check(v) || PyBytes_Check(v)) {
  224. return NULL;
  225. }
  226. return PyNumber_Remainder(v, w);
  227. }
  228. static expr_ty
  229. parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
  230. {
  231. const void *data = PyUnicode_DATA(fmt);
  232. int kind = PyUnicode_KIND(fmt);
  233. Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
  234. Py_ssize_t start, pos;
  235. int has_percents = 0;
  236. start = pos = *ppos;
  237. while (pos < size) {
  238. if (PyUnicode_READ(kind, data, pos) != '%') {
  239. pos++;
  240. }
  241. else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
  242. has_percents = 1;
  243. pos += 2;
  244. }
  245. else {
  246. break;
  247. }
  248. }
  249. *ppos = pos;
  250. if (pos == start) {
  251. return NULL;
  252. }
  253. PyObject *str = PyUnicode_Substring(fmt, start, pos);
  254. /* str = str.replace('%%', '%') */
  255. if (str && has_percents) {
  256. _Py_DECLARE_STR(percent, "%");
  257. _Py_DECLARE_STR(dbl_percent, "%%");
  258. Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
  259. &_Py_STR(percent), -1));
  260. }
  261. if (!str) {
  262. return NULL;
  263. }
  264. if (_PyArena_AddPyObject(arena, str) < 0) {
  265. Py_DECREF(str);
  266. return NULL;
  267. }
  268. return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
  269. }
  270. #define MAXDIGITS 3
  271. static int
  272. simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
  273. int *spec, int *flags, int *width, int *prec)
  274. {
  275. Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
  276. Py_UCS4 ch;
  277. #define NEXTC do { \
  278. if (pos >= len) { \
  279. return 0; \
  280. } \
  281. ch = PyUnicode_READ_CHAR(fmt, pos); \
  282. pos++; \
  283. } while (0)
  284. *flags = 0;
  285. while (1) {
  286. NEXTC;
  287. switch (ch) {
  288. case '-': *flags |= F_LJUST; continue;
  289. case '+': *flags |= F_SIGN; continue;
  290. case ' ': *flags |= F_BLANK; continue;
  291. case '#': *flags |= F_ALT; continue;
  292. case '0': *flags |= F_ZERO; continue;
  293. }
  294. break;
  295. }
  296. if ('0' <= ch && ch <= '9') {
  297. *width = 0;
  298. int digits = 0;
  299. while ('0' <= ch && ch <= '9') {
  300. *width = *width * 10 + (ch - '0');
  301. NEXTC;
  302. if (++digits >= MAXDIGITS) {
  303. return 0;
  304. }
  305. }
  306. }
  307. if (ch == '.') {
  308. NEXTC;
  309. *prec = 0;
  310. if ('0' <= ch && ch <= '9') {
  311. int digits = 0;
  312. while ('0' <= ch && ch <= '9') {
  313. *prec = *prec * 10 + (ch - '0');
  314. NEXTC;
  315. if (++digits >= MAXDIGITS) {
  316. return 0;
  317. }
  318. }
  319. }
  320. }
  321. *spec = ch;
  322. *ppos = pos;
  323. return 1;
  324. #undef NEXTC
  325. }
  326. static expr_ty
  327. parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
  328. {
  329. int spec, flags, width = -1, prec = -1;
  330. if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
  331. // Unsupported format.
  332. return NULL;
  333. }
  334. if (spec == 's' || spec == 'r' || spec == 'a') {
  335. char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
  336. if (!(flags & F_LJUST) && width > 0) {
  337. *p++ = '>';
  338. }
  339. if (width >= 0) {
  340. p += snprintf(p, MAXDIGITS + 1, "%d", width);
  341. }
  342. if (prec >= 0) {
  343. p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
  344. }
  345. expr_ty format_spec = NULL;
  346. if (p != buf) {
  347. PyObject *str = PyUnicode_FromString(buf);
  348. if (str == NULL) {
  349. return NULL;
  350. }
  351. if (_PyArena_AddPyObject(arena, str) < 0) {
  352. Py_DECREF(str);
  353. return NULL;
  354. }
  355. format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
  356. if (format_spec == NULL) {
  357. return NULL;
  358. }
  359. }
  360. return _PyAST_FormattedValue(arg, spec, format_spec,
  361. arg->lineno, arg->col_offset,
  362. arg->end_lineno, arg->end_col_offset,
  363. arena);
  364. }
  365. // Unsupported format.
  366. return NULL;
  367. }
  368. static int
  369. optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
  370. {
  371. Py_ssize_t pos = 0;
  372. Py_ssize_t cnt = 0;
  373. asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
  374. if (!seq) {
  375. return 0;
  376. }
  377. seq->size = 0;
  378. while (1) {
  379. expr_ty lit = parse_literal(fmt, &pos, arena);
  380. if (lit) {
  381. asdl_seq_SET(seq, seq->size++, lit);
  382. }
  383. else if (PyErr_Occurred()) {
  384. return 0;
  385. }
  386. if (pos >= PyUnicode_GET_LENGTH(fmt)) {
  387. break;
  388. }
  389. if (cnt >= asdl_seq_LEN(elts)) {
  390. // More format units than items.
  391. return 1;
  392. }
  393. assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
  394. pos++;
  395. expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
  396. cnt++;
  397. if (!expr) {
  398. return !PyErr_Occurred();
  399. }
  400. asdl_seq_SET(seq, seq->size++, expr);
  401. }
  402. if (cnt < asdl_seq_LEN(elts)) {
  403. // More items than format units.
  404. return 1;
  405. }
  406. expr_ty res = _PyAST_JoinedStr(seq,
  407. node->lineno, node->col_offset,
  408. node->end_lineno, node->end_col_offset,
  409. arena);
  410. if (!res) {
  411. return 0;
  412. }
  413. COPY_NODE(node, res);
  414. // PySys_FormatStderr("format = %R\n", fmt);
  415. return 1;
  416. }
  417. static int
  418. fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
  419. {
  420. expr_ty lhs, rhs;
  421. lhs = node->v.BinOp.left;
  422. rhs = node->v.BinOp.right;
  423. if (lhs->kind != Constant_kind) {
  424. return 1;
  425. }
  426. PyObject *lv = lhs->v.Constant.value;
  427. if (node->v.BinOp.op == Mod &&
  428. rhs->kind == Tuple_kind &&
  429. PyUnicode_Check(lv) &&
  430. !has_starred(rhs->v.Tuple.elts))
  431. {
  432. return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
  433. }
  434. if (rhs->kind != Constant_kind) {
  435. return 1;
  436. }
  437. PyObject *rv = rhs->v.Constant.value;
  438. PyObject *newval = NULL;
  439. switch (node->v.BinOp.op) {
  440. case Add:
  441. newval = PyNumber_Add(lv, rv);
  442. break;
  443. case Sub:
  444. newval = PyNumber_Subtract(lv, rv);
  445. break;
  446. case Mult:
  447. newval = safe_multiply(lv, rv);
  448. break;
  449. case Div:
  450. newval = PyNumber_TrueDivide(lv, rv);
  451. break;
  452. case FloorDiv:
  453. newval = PyNumber_FloorDivide(lv, rv);
  454. break;
  455. case Mod:
  456. newval = safe_mod(lv, rv);
  457. break;
  458. case Pow:
  459. newval = safe_power(lv, rv);
  460. break;
  461. case LShift:
  462. newval = safe_lshift(lv, rv);
  463. break;
  464. case RShift:
  465. newval = PyNumber_Rshift(lv, rv);
  466. break;
  467. case BitOr:
  468. newval = PyNumber_Or(lv, rv);
  469. break;
  470. case BitXor:
  471. newval = PyNumber_Xor(lv, rv);
  472. break;
  473. case BitAnd:
  474. newval = PyNumber_And(lv, rv);
  475. break;
  476. // No builtin constants implement the following operators
  477. case MatMult:
  478. return 1;
  479. // No default case, so the compiler will emit a warning if new binary
  480. // operators are added without being handled here
  481. }
  482. return make_const(node, newval, arena);
  483. }
  484. static PyObject*
  485. make_const_tuple(asdl_expr_seq *elts)
  486. {
  487. for (int i = 0; i < asdl_seq_LEN(elts); i++) {
  488. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  489. if (e->kind != Constant_kind) {
  490. return NULL;
  491. }
  492. }
  493. PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
  494. if (newval == NULL) {
  495. return NULL;
  496. }
  497. for (int i = 0; i < asdl_seq_LEN(elts); i++) {
  498. expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
  499. PyObject *v = e->v.Constant.value;
  500. PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
  501. }
  502. return newval;
  503. }
  504. static int
  505. fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
  506. {
  507. PyObject *newval;
  508. if (node->v.Tuple.ctx != Load)
  509. return 1;
  510. newval = make_const_tuple(node->v.Tuple.elts);
  511. return make_const(node, newval, arena);
  512. }
  513. static int
  514. fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
  515. {
  516. PyObject *newval;
  517. expr_ty arg, idx;
  518. arg = node->v.Subscript.value;
  519. idx = node->v.Subscript.slice;
  520. if (node->v.Subscript.ctx != Load ||
  521. arg->kind != Constant_kind ||
  522. idx->kind != Constant_kind)
  523. {
  524. return 1;
  525. }
  526. newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
  527. return make_const(node, newval, arena);
  528. }
  529. /* Change literal list or set of constants into constant
  530. tuple or frozenset respectively. Change literal list of
  531. non-constants into tuple.
  532. Used for right operand of "in" and "not in" tests and for iterable
  533. in "for" loop and comprehensions.
  534. */
  535. static int
  536. fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
  537. {
  538. PyObject *newval;
  539. if (arg->kind == List_kind) {
  540. /* First change a list into tuple. */
  541. asdl_expr_seq *elts = arg->v.List.elts;
  542. if (has_starred(elts)) {
  543. return 1;
  544. }
  545. expr_context_ty ctx = arg->v.List.ctx;
  546. arg->kind = Tuple_kind;
  547. arg->v.Tuple.elts = elts;
  548. arg->v.Tuple.ctx = ctx;
  549. /* Try to create a constant tuple. */
  550. newval = make_const_tuple(elts);
  551. }
  552. else if (arg->kind == Set_kind) {
  553. newval = make_const_tuple(arg->v.Set.elts);
  554. if (newval) {
  555. Py_SETREF(newval, PyFrozenSet_New(newval));
  556. }
  557. }
  558. else {
  559. return 1;
  560. }
  561. return make_const(arg, newval, arena);
  562. }
  563. static int
  564. fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
  565. {
  566. asdl_int_seq *ops;
  567. asdl_expr_seq *args;
  568. Py_ssize_t i;
  569. ops = node->v.Compare.ops;
  570. args = node->v.Compare.comparators;
  571. /* Change literal list or set in 'in' or 'not in' into
  572. tuple or frozenset respectively. */
  573. i = asdl_seq_LEN(ops) - 1;
  574. int op = asdl_seq_GET(ops, i);
  575. if (op == In || op == NotIn) {
  576. if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
  577. return 0;
  578. }
  579. }
  580. return 1;
  581. }
  582. static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  583. static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  584. static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  585. static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  586. static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  587. static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  588. static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  589. static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  590. static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  591. static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  592. static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  593. static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
  594. #define CALL(FUNC, TYPE, ARG) \
  595. if (!FUNC((ARG), ctx_, state)) \
  596. return 0;
  597. #define CALL_OPT(FUNC, TYPE, ARG) \
  598. if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
  599. return 0;
  600. #define CALL_SEQ(FUNC, TYPE, ARG) { \
  601. int i; \
  602. asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
  603. for (i = 0; i < asdl_seq_LEN(seq); i++) { \
  604. TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
  605. if (elt != NULL && !FUNC(elt, ctx_, state)) \
  606. return 0; \
  607. } \
  608. }
  609. static int
  610. astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
  611. {
  612. int docstring = _PyAST_GetDocString(stmts) != NULL;
  613. CALL_SEQ(astfold_stmt, stmt, stmts);
  614. if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
  615. stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
  616. asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
  617. if (!values) {
  618. return 0;
  619. }
  620. asdl_seq_SET(values, 0, st->v.Expr.value);
  621. expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
  622. st->end_lineno, st->end_col_offset,
  623. ctx_);
  624. if (!expr) {
  625. return 0;
  626. }
  627. st->v.Expr.value = expr;
  628. }
  629. return 1;
  630. }
  631. static int
  632. astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  633. {
  634. switch (node_->kind) {
  635. case Module_kind:
  636. CALL(astfold_body, asdl_seq, node_->v.Module.body);
  637. break;
  638. case Interactive_kind:
  639. CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
  640. break;
  641. case Expression_kind:
  642. CALL(astfold_expr, expr_ty, node_->v.Expression.body);
  643. break;
  644. // The following top level nodes don't participate in constant folding
  645. case FunctionType_kind:
  646. break;
  647. // No default case, so the compiler will emit a warning if new top level
  648. // compilation nodes are added without being handled here
  649. }
  650. return 1;
  651. }
  652. static int
  653. astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  654. {
  655. if (++state->recursion_depth > state->recursion_limit) {
  656. PyErr_SetString(PyExc_RecursionError,
  657. "maximum recursion depth exceeded during compilation");
  658. return 0;
  659. }
  660. switch (node_->kind) {
  661. case BoolOp_kind:
  662. CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
  663. break;
  664. case BinOp_kind:
  665. CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
  666. CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
  667. CALL(fold_binop, expr_ty, node_);
  668. break;
  669. case UnaryOp_kind:
  670. CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
  671. CALL(fold_unaryop, expr_ty, node_);
  672. break;
  673. case Lambda_kind:
  674. CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
  675. CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
  676. break;
  677. case IfExp_kind:
  678. CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
  679. CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
  680. CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
  681. break;
  682. case Dict_kind:
  683. CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
  684. CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
  685. break;
  686. case Set_kind:
  687. CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
  688. break;
  689. case ListComp_kind:
  690. CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
  691. CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
  692. break;
  693. case SetComp_kind:
  694. CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
  695. CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
  696. break;
  697. case DictComp_kind:
  698. CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
  699. CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
  700. CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
  701. break;
  702. case GeneratorExp_kind:
  703. CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
  704. CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
  705. break;
  706. case Await_kind:
  707. CALL(astfold_expr, expr_ty, node_->v.Await.value);
  708. break;
  709. case Yield_kind:
  710. CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
  711. break;
  712. case YieldFrom_kind:
  713. CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
  714. break;
  715. case Compare_kind:
  716. CALL(astfold_expr, expr_ty, node_->v.Compare.left);
  717. CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
  718. CALL(fold_compare, expr_ty, node_);
  719. break;
  720. case Call_kind:
  721. CALL(astfold_expr, expr_ty, node_->v.Call.func);
  722. CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
  723. CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
  724. break;
  725. case FormattedValue_kind:
  726. CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
  727. CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
  728. break;
  729. case JoinedStr_kind:
  730. CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
  731. break;
  732. case Attribute_kind:
  733. CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
  734. break;
  735. case Subscript_kind:
  736. CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
  737. CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
  738. CALL(fold_subscr, expr_ty, node_);
  739. break;
  740. case Starred_kind:
  741. CALL(astfold_expr, expr_ty, node_->v.Starred.value);
  742. break;
  743. case Slice_kind:
  744. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
  745. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
  746. CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
  747. break;
  748. case List_kind:
  749. CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
  750. break;
  751. case Tuple_kind:
  752. CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
  753. CALL(fold_tuple, expr_ty, node_);
  754. break;
  755. case Name_kind:
  756. if (node_->v.Name.ctx == Load &&
  757. _PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
  758. state->recursion_depth--;
  759. return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
  760. }
  761. break;
  762. case NamedExpr_kind:
  763. CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
  764. break;
  765. case Constant_kind:
  766. // Already a constant, nothing further to do
  767. break;
  768. // No default case, so the compiler will emit a warning if new expression
  769. // kinds are added without being handled here
  770. }
  771. state->recursion_depth--;
  772. return 1;
  773. }
  774. static int
  775. astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  776. {
  777. CALL(astfold_expr, expr_ty, node_->value);
  778. return 1;
  779. }
  780. static int
  781. astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  782. {
  783. CALL(astfold_expr, expr_ty, node_->target);
  784. CALL(astfold_expr, expr_ty, node_->iter);
  785. CALL_SEQ(astfold_expr, expr, node_->ifs);
  786. CALL(fold_iter, expr_ty, node_->iter);
  787. return 1;
  788. }
  789. static int
  790. astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  791. {
  792. CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
  793. CALL_SEQ(astfold_arg, arg, node_->args);
  794. CALL_OPT(astfold_arg, arg_ty, node_->vararg);
  795. CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
  796. CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
  797. CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
  798. CALL_SEQ(astfold_expr, expr, node_->defaults);
  799. return 1;
  800. }
  801. static int
  802. astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  803. {
  804. if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
  805. CALL_OPT(astfold_expr, expr_ty, node_->annotation);
  806. }
  807. return 1;
  808. }
  809. static int
  810. astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  811. {
  812. if (++state->recursion_depth > state->recursion_limit) {
  813. PyErr_SetString(PyExc_RecursionError,
  814. "maximum recursion depth exceeded during compilation");
  815. return 0;
  816. }
  817. switch (node_->kind) {
  818. case FunctionDef_kind:
  819. CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
  820. CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
  821. CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
  822. CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
  823. if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
  824. CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
  825. }
  826. break;
  827. case AsyncFunctionDef_kind:
  828. CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
  829. CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
  830. CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
  831. CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
  832. if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
  833. CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
  834. }
  835. break;
  836. case ClassDef_kind:
  837. CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
  838. CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
  839. CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
  840. CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
  841. CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
  842. break;
  843. case Return_kind:
  844. CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
  845. break;
  846. case Delete_kind:
  847. CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
  848. break;
  849. case Assign_kind:
  850. CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
  851. CALL(astfold_expr, expr_ty, node_->v.Assign.value);
  852. break;
  853. case AugAssign_kind:
  854. CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
  855. CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
  856. break;
  857. case AnnAssign_kind:
  858. CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
  859. if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
  860. CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
  861. }
  862. CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
  863. break;
  864. case TypeAlias_kind:
  865. CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
  866. CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
  867. CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
  868. break;
  869. case For_kind:
  870. CALL(astfold_expr, expr_ty, node_->v.For.target);
  871. CALL(astfold_expr, expr_ty, node_->v.For.iter);
  872. CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
  873. CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
  874. CALL(fold_iter, expr_ty, node_->v.For.iter);
  875. break;
  876. case AsyncFor_kind:
  877. CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
  878. CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
  879. CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
  880. CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
  881. break;
  882. case While_kind:
  883. CALL(astfold_expr, expr_ty, node_->v.While.test);
  884. CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
  885. CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
  886. break;
  887. case If_kind:
  888. CALL(astfold_expr, expr_ty, node_->v.If.test);
  889. CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
  890. CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
  891. break;
  892. case With_kind:
  893. CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
  894. CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
  895. break;
  896. case AsyncWith_kind:
  897. CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
  898. CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
  899. break;
  900. case Raise_kind:
  901. CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
  902. CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
  903. break;
  904. case Try_kind:
  905. CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
  906. CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
  907. CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
  908. CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
  909. break;
  910. case TryStar_kind:
  911. CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
  912. CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
  913. CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
  914. CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
  915. break;
  916. case Assert_kind:
  917. CALL(astfold_expr, expr_ty, node_->v.Assert.test);
  918. CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
  919. break;
  920. case Expr_kind:
  921. CALL(astfold_expr, expr_ty, node_->v.Expr.value);
  922. break;
  923. case Match_kind:
  924. CALL(astfold_expr, expr_ty, node_->v.Match.subject);
  925. CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
  926. break;
  927. // The following statements don't contain any subexpressions to be folded
  928. case Import_kind:
  929. case ImportFrom_kind:
  930. case Global_kind:
  931. case Nonlocal_kind:
  932. case Pass_kind:
  933. case Break_kind:
  934. case Continue_kind:
  935. break;
  936. // No default case, so the compiler will emit a warning if new statement
  937. // kinds are added without being handled here
  938. }
  939. state->recursion_depth--;
  940. return 1;
  941. }
  942. static int
  943. astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  944. {
  945. switch (node_->kind) {
  946. case ExceptHandler_kind:
  947. CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
  948. CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
  949. break;
  950. // No default case, so the compiler will emit a warning if new handler
  951. // kinds are added without being handled here
  952. }
  953. return 1;
  954. }
  955. static int
  956. astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  957. {
  958. CALL(astfold_expr, expr_ty, node_->context_expr);
  959. CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
  960. return 1;
  961. }
  962. static int
  963. astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  964. {
  965. // Currently, this is really only used to form complex/negative numeric
  966. // constants in MatchValue and MatchMapping nodes
  967. // We still recurse into all subexpressions and subpatterns anyway
  968. if (++state->recursion_depth > state->recursion_limit) {
  969. PyErr_SetString(PyExc_RecursionError,
  970. "maximum recursion depth exceeded during compilation");
  971. return 0;
  972. }
  973. switch (node_->kind) {
  974. case MatchValue_kind:
  975. CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
  976. break;
  977. case MatchSingleton_kind:
  978. break;
  979. case MatchSequence_kind:
  980. CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
  981. break;
  982. case MatchMapping_kind:
  983. CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
  984. CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
  985. break;
  986. case MatchClass_kind:
  987. CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
  988. CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
  989. CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
  990. break;
  991. case MatchStar_kind:
  992. break;
  993. case MatchAs_kind:
  994. if (node_->v.MatchAs.pattern) {
  995. CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
  996. }
  997. break;
  998. case MatchOr_kind:
  999. CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
  1000. break;
  1001. // No default case, so the compiler will emit a warning if new pattern
  1002. // kinds are added without being handled here
  1003. }
  1004. state->recursion_depth--;
  1005. return 1;
  1006. }
  1007. static int
  1008. astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  1009. {
  1010. CALL(astfold_pattern, expr_ty, node_->pattern);
  1011. CALL_OPT(astfold_expr, expr_ty, node_->guard);
  1012. CALL_SEQ(astfold_stmt, stmt, node_->body);
  1013. return 1;
  1014. }
  1015. static int
  1016. astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
  1017. {
  1018. switch (node_->kind) {
  1019. case TypeVar_kind:
  1020. CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
  1021. break;
  1022. case ParamSpec_kind:
  1023. break;
  1024. case TypeVarTuple_kind:
  1025. break;
  1026. }
  1027. return 1;
  1028. }
  1029. #undef CALL
  1030. #undef CALL_OPT
  1031. #undef CALL_SEQ
  1032. /* See comments in symtable.c. */
  1033. #define COMPILER_STACK_FRAME_SCALE 2
  1034. int
  1035. _PyAST_Optimize(mod_ty mod, PyArena *arena, _PyASTOptimizeState *state)
  1036. {
  1037. PyThreadState *tstate;
  1038. int starting_recursion_depth;
  1039. /* Setup recursion depth check counters */
  1040. tstate = _PyThreadState_GET();
  1041. if (!tstate) {
  1042. return 0;
  1043. }
  1044. /* Be careful here to prevent overflow. */
  1045. int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
  1046. starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
  1047. state->recursion_depth = starting_recursion_depth;
  1048. state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
  1049. int ret = astfold_mod(mod, arena, state);
  1050. assert(ret || PyErr_Occurred());
  1051. /* Check that the recursion depth counting balanced correctly */
  1052. if (ret && state->recursion_depth != starting_recursion_depth) {
  1053. PyErr_Format(PyExc_SystemError,
  1054. "AST optimizer recursion depth mismatch (before=%d, after=%d)",
  1055. starting_recursion_depth, state->recursion_depth);
  1056. return 0;
  1057. }
  1058. return ret;
  1059. }