ParseTreeTransforms.py 136 KB


  1. from __future__ import absolute_import
  2. import cython
  3. cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
  4. Options=object, UtilNodes=object, LetNode=object,
  5. LetRefNode=object, TreeFragment=object, EncodedString=object,
  6. error=object, warning=object, copy=object, hashlib=object, sys=object,
  7. _unicode=object)
  8. import copy
  9. import hashlib
  10. import sys
  11. from . import PyrexTypes
  12. from . import Naming
  13. from . import ExprNodes
  14. from . import Nodes
  15. from . import Options
  16. from . import Builtin
  17. from . import Errors
  18. from .Visitor import VisitorTransform, TreeVisitor
  19. from .Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
  20. from .UtilNodes import LetNode, LetRefNode
  21. from .TreeFragment import TreeFragment
  22. from .StringEncoding import EncodedString, _unicode
  23. from .Errors import error, warning, CompileError, InternalError
  24. from .Code import UtilityCode
  25. class SkipDeclarations(object):
  26. """
  27. Variable and function declarations can often have a deep tree structure,
  28. and yet most transformations don't need to descend to this depth.
  29. Declaration nodes are removed after AnalyseDeclarationsTransform, so there
  30. is no need to use this for transformations after that point.
  31. """
  32. def visit_CTypeDefNode(self, node):
  33. return node
  34. def visit_CVarDefNode(self, node):
  35. return node
  36. def visit_CDeclaratorNode(self, node):
  37. return node
  38. def visit_CBaseTypeNode(self, node):
  39. return node
  40. def visit_CEnumDefNode(self, node):
  41. return node
  42. def visit_CStructOrUnionDefNode(self, node):
  43. return node
  44. class NormalizeTree(CythonTransform):
  45. """
  46. This transform fixes up a few things after parsing
  47. in order to make the parse tree more suitable for
  48. transforms.
  49. a) After parsing, blocks with only one statement will
  50. be represented by that statement, not by a StatListNode.
  51. When doing transforms this is annoying and inconsistent,
  52. as one cannot in general remove a statement in a consistent
  53. way and so on. This transform wraps any single statements
  54. in a StatListNode containing a single statement.
  55. b) The PassStatNode is a noop and serves no purpose beyond
  56. plugging such one-statement blocks; i.e., once parsed a
  57. ` "pass" can just as well be represented using an empty
  58. StatListNode. This means less special cases to worry about
  59. in subsequent transforms (one always checks to see if a
  60. StatListNode has no children to see if the block is empty).
  61. """
  62. def __init__(self, context):
  63. super(NormalizeTree, self).__init__(context)
  64. self.is_in_statlist = False
  65. self.is_in_expr = False
  66. def visit_ExprNode(self, node):
  67. stacktmp = self.is_in_expr
  68. self.is_in_expr = True
  69. self.visitchildren(node)
  70. self.is_in_expr = stacktmp
  71. return node
  72. def visit_StatNode(self, node, is_listcontainer=False):
  73. stacktmp = self.is_in_statlist
  74. self.is_in_statlist = is_listcontainer
  75. self.visitchildren(node)
  76. self.is_in_statlist = stacktmp
  77. if not self.is_in_statlist and not self.is_in_expr:
  78. return Nodes.StatListNode(pos=node.pos, stats=[node])
  79. else:
  80. return node
  81. def visit_StatListNode(self, node):
  82. self.is_in_statlist = True
  83. self.visitchildren(node)
  84. self.is_in_statlist = False
  85. return node
  86. def visit_ParallelAssignmentNode(self, node):
  87. return self.visit_StatNode(node, True)
  88. def visit_CEnumDefNode(self, node):
  89. return self.visit_StatNode(node, True)
  90. def visit_CStructOrUnionDefNode(self, node):
  91. return self.visit_StatNode(node, True)
  92. def visit_PassStatNode(self, node):
  93. """Eliminate PassStatNode"""
  94. if not self.is_in_statlist:
  95. return Nodes.StatListNode(pos=node.pos, stats=[])
  96. else:
  97. return []
  98. def visit_ExprStatNode(self, node):
  99. """Eliminate useless string literals"""
  100. if node.expr.is_string_literal:
  101. return self.visit_PassStatNode(node)
  102. else:
  103. return self.visit_StatNode(node)
  104. def visit_CDeclaratorNode(self, node):
  105. return node
  106. class PostParseError(CompileError): pass
  107. # error strings checked by unit tests, so define them
  108. ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
  109. ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
  110. ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
  111. class PostParse(ScopeTrackingTransform):
  112. """
  113. Basic interpretation of the parse tree, as well as validity
  114. checking that can be done on a very basic level on the parse
  115. tree (while still not being a problem with the basic syntax,
  116. as such).
  117. Specifically:
  118. - Default values to cdef assignments are turned into single
  119. assignments following the declaration (everywhere but in class
  120. bodies, where they raise a compile error)
  121. - Interpret some node structures into Python runtime values.
  122. Some nodes take compile-time arguments (currently:
  123. TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
  124. which should be interpreted. This happens in a general way
  125. and other steps should be taken to ensure validity.
  126. Type arguments cannot be interpreted in this way.
  127. - For __cythonbufferdefaults__ the arguments are checked for
  128. validity.
  129. TemplatedTypeNode has its directives interpreted:
  130. Any first positional argument goes into the "dtype" attribute,
  131. any "ndim" keyword argument goes into the "ndim" attribute and
  132. so on. Also it is checked that the directive combination is valid.
  133. - __cythonbufferdefaults__ attributes are parsed and put into the
  134. type information.
  135. Note: Currently Parsing.py does a lot of interpretation and
  136. reorganization that can be refactored into this transform
  137. if a more pure Abstract Syntax Tree is wanted.
  138. """
  139. def __init__(self, context):
  140. super(PostParse, self).__init__(context)
  141. self.specialattribute_handlers = {
  142. '__cythonbufferdefaults__' : self.handle_bufferdefaults
  143. }
  144. def visit_LambdaNode(self, node):
  145. # unpack a lambda expression into the corresponding DefNode
  146. collector = YieldNodeCollector()
  147. collector.visitchildren(node.result_expr)
  148. if collector.has_yield or collector.has_await or isinstance(node.result_expr, ExprNodes.YieldExprNode):
  149. body = Nodes.ExprStatNode(
  150. node.result_expr.pos, expr=node.result_expr)
  151. else:
  152. body = Nodes.ReturnStatNode(
  153. node.result_expr.pos, value=node.result_expr)
  154. node.def_node = Nodes.DefNode(
  155. node.pos, name=node.name,
  156. args=node.args, star_arg=node.star_arg,
  157. starstar_arg=node.starstar_arg,
  158. body=body, doc=None)
  159. self.visitchildren(node)
  160. return node
  161. def visit_GeneratorExpressionNode(self, node):
  162. # unpack a generator expression into the corresponding DefNode
  163. collector = YieldNodeCollector()
  164. collector.visitchildren(node.loop)
  165. node.def_node = Nodes.DefNode(
  166. node.pos, name=node.name, doc=None,
  167. args=[], star_arg=None, starstar_arg=None,
  168. body=node.loop, is_async_def=collector.has_await)
  169. self.visitchildren(node)
  170. return node
  171. def visit_ComprehensionNode(self, node):
  172. # enforce local scope also in Py2 for async generators (seriously, that's a Py3.6 feature...)
  173. if not node.has_local_scope:
  174. collector = YieldNodeCollector()
  175. collector.visitchildren(node.loop)
  176. if collector.has_await:
  177. node.has_local_scope = True
  178. self.visitchildren(node)
  179. return node
  180. # cdef variables
  181. def handle_bufferdefaults(self, decl):
  182. if not isinstance(decl.default, ExprNodes.DictNode):
  183. raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
  184. self.scope_node.buffer_defaults_node = decl.default
  185. self.scope_node.buffer_defaults_pos = decl.pos
  186. def visit_CVarDefNode(self, node):
  187. # This assumes only plain names and pointers are assignable on
  188. # declaration. Also, it makes use of the fact that a cdef decl
  189. # must appear before the first use, so we don't have to deal with
  190. # "i = 3; cdef int i = i" and can simply move the nodes around.
  191. try:
  192. self.visitchildren(node)
  193. stats = [node]
  194. newdecls = []
  195. for decl in node.declarators:
  196. declbase = decl
  197. while isinstance(declbase, Nodes.CPtrDeclaratorNode):
  198. declbase = declbase.base
  199. if isinstance(declbase, Nodes.CNameDeclaratorNode):
  200. if declbase.default is not None:
  201. if self.scope_type in ('cclass', 'pyclass', 'struct'):
  202. if isinstance(self.scope_node, Nodes.CClassDefNode):
  203. handler = self.specialattribute_handlers.get(decl.name)
  204. if handler:
  205. if decl is not declbase:
  206. raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
  207. handler(decl)
  208. continue # Remove declaration
  209. raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
  210. first_assignment = self.scope_type != 'module'
  211. stats.append(Nodes.SingleAssignmentNode(node.pos,
  212. lhs=ExprNodes.NameNode(node.pos, name=declbase.name),
  213. rhs=declbase.default, first=first_assignment))
  214. declbase.default = None
  215. newdecls.append(decl)
  216. node.declarators = newdecls
  217. return stats
  218. except PostParseError as e:
  219. # An error in a cdef clause is ok, simply remove the declaration
  220. # and try to move on to report more errors
  221. self.context.nonfatal_error(e)
  222. return None
  223. # Split parallel assignments (a,b = b,a) into separate partial
  224. # assignments that are executed rhs-first using temps. This
  225. # restructuring must be applied before type analysis so that known
  226. # types on rhs and lhs can be matched directly. It is required in
  227. # the case that the types cannot be coerced to a Python type in
  228. # order to assign from a tuple.
  229. def visit_SingleAssignmentNode(self, node):
  230. self.visitchildren(node)
  231. return self._visit_assignment_node(node, [node.lhs, node.rhs])
  232. def visit_CascadedAssignmentNode(self, node):
  233. self.visitchildren(node)
  234. return self._visit_assignment_node(node, node.lhs_list + [node.rhs])
  235. def _visit_assignment_node(self, node, expr_list):
  236. """Flatten parallel assignments into separate single
  237. assignments or cascaded assignments.
  238. """
  239. if sum([ 1 for expr in expr_list
  240. if expr.is_sequence_constructor or expr.is_string_literal ]) < 2:
  241. # no parallel assignments => nothing to do
  242. return node
  243. expr_list_list = []
  244. flatten_parallel_assignments(expr_list, expr_list_list)
  245. temp_refs = []
  246. eliminate_rhs_duplicates(expr_list_list, temp_refs)
  247. nodes = []
  248. for expr_list in expr_list_list:
  249. lhs_list = expr_list[:-1]
  250. rhs = expr_list[-1]
  251. if len(lhs_list) == 1:
  252. node = Nodes.SingleAssignmentNode(rhs.pos,
  253. lhs = lhs_list[0], rhs = rhs)
  254. else:
  255. node = Nodes.CascadedAssignmentNode(rhs.pos,
  256. lhs_list = lhs_list, rhs = rhs)
  257. nodes.append(node)
  258. if len(nodes) == 1:
  259. assign_node = nodes[0]
  260. else:
  261. assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
  262. if temp_refs:
  263. duplicates_and_temps = [ (temp.expression, temp)
  264. for temp in temp_refs ]
  265. sort_common_subsequences(duplicates_and_temps)
  266. for _, temp_ref in duplicates_and_temps[::-1]:
  267. assign_node = LetNode(temp_ref, assign_node)
  268. return assign_node
  269. def _flatten_sequence(self, seq, result):
  270. for arg in seq.args:
  271. if arg.is_sequence_constructor:
  272. self._flatten_sequence(arg, result)
  273. else:
  274. result.append(arg)
  275. return result
  276. def visit_DelStatNode(self, node):
  277. self.visitchildren(node)
  278. node.args = self._flatten_sequence(node, [])
  279. return node
  280. def visit_ExceptClauseNode(self, node):
  281. if node.is_except_as:
  282. # except-as must delete NameNode target at the end
  283. del_target = Nodes.DelStatNode(
  284. node.pos,
  285. args=[ExprNodes.NameNode(
  286. node.target.pos, name=node.target.name)],
  287. ignore_nonexisting=True)
  288. node.body = Nodes.StatListNode(
  289. node.pos,
  290. stats=[Nodes.TryFinallyStatNode(
  291. node.pos,
  292. body=node.body,
  293. finally_clause=Nodes.StatListNode(
  294. node.pos,
  295. stats=[del_target]))])
  296. self.visitchildren(node)
  297. return node
  298. def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
  299. """Replace rhs items by LetRefNodes if they appear more than once.
  300. Creates a sequence of LetRefNodes that set up the required temps
  301. and appends them to ref_node_sequence. The input list is modified
  302. in-place.
  303. """
  304. seen_nodes = set()
  305. ref_nodes = {}
  306. def find_duplicates(node):
  307. if node.is_literal or node.is_name:
  308. # no need to replace those; can't include attributes here
  309. # as their access is not necessarily side-effect free
  310. return
  311. if node in seen_nodes:
  312. if node not in ref_nodes:
  313. ref_node = LetRefNode(node)
  314. ref_nodes[node] = ref_node
  315. ref_node_sequence.append(ref_node)
  316. else:
  317. seen_nodes.add(node)
  318. if node.is_sequence_constructor:
  319. for item in node.args:
  320. find_duplicates(item)
  321. for expr_list in expr_list_list:
  322. rhs = expr_list[-1]
  323. find_duplicates(rhs)
  324. if not ref_nodes:
  325. return
  326. def substitute_nodes(node):
  327. if node in ref_nodes:
  328. return ref_nodes[node]
  329. elif node.is_sequence_constructor:
  330. node.args = list(map(substitute_nodes, node.args))
  331. return node
  332. # replace nodes inside of the common subexpressions
  333. for node in ref_nodes:
  334. if node.is_sequence_constructor:
  335. node.args = list(map(substitute_nodes, node.args))
  336. # replace common subexpressions on all rhs items
  337. for expr_list in expr_list_list:
  338. expr_list[-1] = substitute_nodes(expr_list[-1])
  339. def sort_common_subsequences(items):
  340. """Sort items/subsequences so that all items and subsequences that
  341. an item contains appear before the item itself. This is needed
  342. because each rhs item must only be evaluated once, so its value
  343. must be evaluated first and then reused when packing sequences
  344. that contain it.
  345. This implies a partial order, and the sort must be stable to
  346. preserve the original order as much as possible, so we use a
  347. simple insertion sort (which is very fast for short sequences, the
  348. normal case in practice).
  349. """
  350. def contains(seq, x):
  351. for item in seq:
  352. if item is x:
  353. return True
  354. elif item.is_sequence_constructor and contains(item.args, x):
  355. return True
  356. return False
  357. def lower_than(a,b):
  358. return b.is_sequence_constructor and contains(b.args, a)
  359. for pos, item in enumerate(items):
  360. key = item[1] # the ResultRefNode which has already been injected into the sequences
  361. new_pos = pos
  362. for i in range(pos-1, -1, -1):
  363. if lower_than(key, items[i][0]):
  364. new_pos = i
  365. if new_pos != pos:
  366. for i in range(pos, new_pos, -1):
  367. items[i] = items[i-1]
  368. items[new_pos] = item
  369. def unpack_string_to_character_literals(literal):
  370. chars = []
  371. pos = literal.pos
  372. stype = literal.__class__
  373. sval = literal.value
  374. sval_type = sval.__class__
  375. for char in sval:
  376. cval = sval_type(char)
  377. chars.append(stype(pos, value=cval, constant_result=cval))
  378. return chars
  379. def flatten_parallel_assignments(input, output):
  380. # The input is a list of expression nodes, representing the LHSs
  381. # and RHS of one (possibly cascaded) assignment statement. For
  382. # sequence constructors, rearranges the matching parts of both
  383. # sides into a list of equivalent assignments between the
  384. # individual elements. This transformation is applied
  385. # recursively, so that nested structures get matched as well.
  386. rhs = input[-1]
  387. if (not (rhs.is_sequence_constructor or isinstance(rhs, ExprNodes.UnicodeNode))
  388. or not sum([lhs.is_sequence_constructor for lhs in input[:-1]])):
  389. output.append(input)
  390. return
  391. complete_assignments = []
  392. if rhs.is_sequence_constructor:
  393. rhs_args = rhs.args
  394. elif rhs.is_string_literal:
  395. rhs_args = unpack_string_to_character_literals(rhs)
  396. rhs_size = len(rhs_args)
  397. lhs_targets = [[] for _ in range(rhs_size)]
  398. starred_assignments = []
  399. for lhs in input[:-1]:
  400. if not lhs.is_sequence_constructor:
  401. if lhs.is_starred:
  402. error(lhs.pos, "starred assignment target must be in a list or tuple")
  403. complete_assignments.append(lhs)
  404. continue
  405. lhs_size = len(lhs.args)
  406. starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
  407. if starred_targets > 1:
  408. error(lhs.pos, "more than 1 starred expression in assignment")
  409. output.append([lhs,rhs])
  410. continue
  411. elif lhs_size - starred_targets > rhs_size:
  412. error(lhs.pos, "need more than %d value%s to unpack"
  413. % (rhs_size, (rhs_size != 1) and 's' or ''))
  414. output.append([lhs,rhs])
  415. continue
  416. elif starred_targets:
  417. map_starred_assignment(lhs_targets, starred_assignments,
  418. lhs.args, rhs_args)
  419. elif lhs_size < rhs_size:
  420. error(lhs.pos, "too many values to unpack (expected %d, got %d)"
  421. % (lhs_size, rhs_size))
  422. output.append([lhs,rhs])
  423. continue
  424. else:
  425. for targets, expr in zip(lhs_targets, lhs.args):
  426. targets.append(expr)
  427. if complete_assignments:
  428. complete_assignments.append(rhs)
  429. output.append(complete_assignments)
  430. # recursively flatten partial assignments
  431. for cascade, rhs in zip(lhs_targets, rhs_args):
  432. if cascade:
  433. cascade.append(rhs)
  434. flatten_parallel_assignments(cascade, output)
  435. # recursively flatten starred assignments
  436. for cascade in starred_assignments:
  437. if cascade[0].is_sequence_constructor:
  438. flatten_parallel_assignments(cascade, output)
  439. else:
  440. output.append(cascade)
  441. def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
  442. # Appends the fixed-position LHS targets to the target list that
  443. # appear left and right of the starred argument.
  444. #
  445. # The starred_assignments list receives a new tuple
  446. # (lhs_target, rhs_values_list) that maps the remaining arguments
  447. # (those that match the starred target) to a list.
  448. # left side of the starred target
  449. for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
  450. if expr.is_starred:
  451. starred = i
  452. lhs_remaining = len(lhs_args) - i - 1
  453. break
  454. targets.append(expr)
  455. else:
  456. raise InternalError("no starred arg found when splitting starred assignment")
  457. # right side of the starred target
  458. for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
  459. lhs_args[starred + 1:])):
  460. targets.append(expr)
  461. # the starred target itself, must be assigned a (potentially empty) list
  462. target = lhs_args[starred].target # unpack starred node
  463. starred_rhs = rhs_args[starred:]
  464. if lhs_remaining:
  465. starred_rhs = starred_rhs[:-lhs_remaining]
  466. if starred_rhs:
  467. pos = starred_rhs[0].pos
  468. else:
  469. pos = target.pos
  470. starred_assignments.append([
  471. target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
  472. class PxdPostParse(CythonTransform, SkipDeclarations):
  473. """
  474. Basic interpretation/validity checking that should only be
  475. done on pxd trees.
  476. A lot of this checking currently happens in the parser; but
  477. what is listed below happens here.
  478. - "def" functions are let through only if they fill the
  479. getbuffer/releasebuffer slots
  480. - cdef functions are let through only if they are on the
  481. top level and are declared "inline"
  482. """
  483. ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
  484. ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
  485. def __call__(self, node):
  486. self.scope_type = 'pxd'
  487. return super(PxdPostParse, self).__call__(node)
  488. def visit_CClassDefNode(self, node):
  489. old = self.scope_type
  490. self.scope_type = 'cclass'
  491. self.visitchildren(node)
  492. self.scope_type = old
  493. return node
  494. def visit_FuncDefNode(self, node):
  495. # FuncDefNode always come with an implementation (without
  496. # an imp they are CVarDefNodes..)
  497. err = self.ERR_INLINE_ONLY
  498. if (isinstance(node, Nodes.DefNode) and self.scope_type == 'cclass'
  499. and node.name in ('__getbuffer__', '__releasebuffer__')):
  500. err = None # allow these slots
  501. if isinstance(node, Nodes.CFuncDefNode):
  502. if (u'inline' in node.modifiers and
  503. self.scope_type in ('pxd', 'cclass')):
  504. node.inline_in_pxd = True
  505. if node.visibility != 'private':
  506. err = self.ERR_NOGO_WITH_INLINE % node.visibility
  507. elif node.api:
  508. err = self.ERR_NOGO_WITH_INLINE % 'api'
  509. else:
  510. err = None # allow inline function
  511. else:
  512. err = self.ERR_INLINE_ONLY
  513. if err:
  514. self.context.nonfatal_error(PostParseError(node.pos, err))
  515. return None
  516. else:
  517. return node
  518. class TrackNumpyAttributes(VisitorTransform, SkipDeclarations):
  519. # TODO: Make name handling as good as in InterpretCompilerDirectives() below - probably best to merge the two.
  520. def __init__(self):
  521. super(TrackNumpyAttributes, self).__init__()
  522. self.numpy_module_names = set()
  523. def visit_CImportStatNode(self, node):
  524. if node.module_name == u"numpy":
  525. self.numpy_module_names.add(node.as_name or u"numpy")
  526. return node
  527. def visit_AttributeNode(self, node):
  528. self.visitchildren(node)
  529. obj = node.obj
  530. if (obj.is_name and obj.name in self.numpy_module_names) or obj.is_numpy_attribute:
  531. node.is_numpy_attribute = True
  532. return node
  533. visit_Node = VisitorTransform.recurse_to_children
  534. class InterpretCompilerDirectives(CythonTransform):
  535. """
  536. After parsing, directives can be stored in a number of places:
  537. - #cython-comments at the top of the file (stored in ModuleNode)
  538. - Command-line arguments overriding these
  539. - @cython.directivename decorators
  540. - with cython.directivename: statements
  541. This transform is responsible for interpreting these various sources
  542. and store the directive in two ways:
  543. - Set the directives attribute of the ModuleNode for global directives.
  544. - Use a CompilerDirectivesNode to override directives for a subtree.
  545. (The first one is primarily to not have to modify with the tree
  546. structure, so that ModuleNode stay on top.)
  547. The directives are stored in dictionaries from name to value in effect.
  548. Each such dictionary is always filled in for all possible directives,
  549. using default values where no value is given by the user.
  550. The available directives are controlled in Options.py.
  551. Note that we have to run this prior to analysis, and so some minor
  552. duplication of functionality has to occur: We manually track cimports
  553. and which names the "cython" module may have been imported to.
  554. """
  555. unop_method_nodes = {
  556. 'typeof': ExprNodes.TypeofNode,
  557. 'operator.address': ExprNodes.AmpersandNode,
  558. 'operator.dereference': ExprNodes.DereferenceNode,
  559. 'operator.preincrement' : ExprNodes.inc_dec_constructor(True, '++'),
  560. 'operator.predecrement' : ExprNodes.inc_dec_constructor(True, '--'),
  561. 'operator.postincrement': ExprNodes.inc_dec_constructor(False, '++'),
  562. 'operator.postdecrement': ExprNodes.inc_dec_constructor(False, '--'),
  563. 'operator.typeid' : ExprNodes.TypeidNode,
  564. # For backwards compatibility.
  565. 'address': ExprNodes.AmpersandNode,
  566. }
  567. binop_method_nodes = {
  568. 'operator.comma' : ExprNodes.c_binop_constructor(','),
  569. }
  570. special_methods = set(['declare', 'union', 'struct', 'typedef',
  571. 'sizeof', 'cast', 'pointer', 'compiled',
  572. 'NULL', 'fused_type', 'parallel'])
  573. special_methods.update(unop_method_nodes)
  574. valid_parallel_directives = set([
  575. "parallel",
  576. "prange",
  577. "threadid",
  578. #"threadsavailable",
  579. ])
  580. def __init__(self, context, compilation_directive_defaults):
  581. super(InterpretCompilerDirectives, self).__init__(context)
  582. self.cython_module_names = set()
  583. self.directive_names = {'staticmethod': 'staticmethod'}
  584. self.parallel_directives = {}
  585. directives = copy.deepcopy(Options.get_directive_defaults())
  586. for key, value in compilation_directive_defaults.items():
  587. directives[_unicode(key)] = copy.deepcopy(value)
  588. self.directives = directives
  589. def check_directive_scope(self, pos, directive, scope):
  590. legal_scopes = Options.directive_scopes.get(directive, None)
  591. if legal_scopes and scope not in legal_scopes:
  592. self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
  593. 'is not allowed in %s scope' % (directive, scope)))
  594. return False
  595. else:
  596. if directive not in Options.directive_types:
  597. error(pos, "Invalid directive: '%s'." % (directive,))
  598. return True
  599. # Set up processing and handle the cython: comments.
  600. def visit_ModuleNode(self, node):
  601. for key in sorted(node.directive_comments):
  602. if not self.check_directive_scope(node.pos, key, 'module'):
  603. self.wrong_scope_error(node.pos, key, 'module')
  604. del node.directive_comments[key]
  605. self.module_scope = node.scope
  606. self.directives.update(node.directive_comments)
  607. node.directives = self.directives
  608. node.parallel_directives = self.parallel_directives
  609. self.visitchildren(node)
  610. node.cython_module_names = self.cython_module_names
  611. return node
  612. # The following four functions track imports and cimports that
  613. # begin with "cython"
  614. def is_cython_directive(self, name):
  615. return (name in Options.directive_types or
  616. name in self.special_methods or
  617. PyrexTypes.parse_basic_type(name))
  618. def is_parallel_directive(self, full_name, pos):
  619. """
  620. Checks to see if fullname (e.g. cython.parallel.prange) is a valid
  621. parallel directive. If it is a star import it also updates the
  622. parallel_directives.
  623. """
  624. result = (full_name + ".").startswith("cython.parallel.")
  625. if result:
  626. directive = full_name.split('.')
  627. if full_name == u"cython.parallel":
  628. self.parallel_directives[u"parallel"] = u"cython.parallel"
  629. elif full_name == u"cython.parallel.*":
  630. for name in self.valid_parallel_directives:
  631. self.parallel_directives[name] = u"cython.parallel.%s" % name
  632. elif (len(directive) != 3 or
  633. directive[-1] not in self.valid_parallel_directives):
  634. error(pos, "No such directive: %s" % full_name)
  635. self.module_scope.use_utility_code(
  636. UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
  637. return result
  638. def visit_CImportStatNode(self, node):
  639. if node.module_name == u"cython":
  640. self.cython_module_names.add(node.as_name or u"cython")
  641. elif node.module_name.startswith(u"cython."):
  642. if node.module_name.startswith(u"cython.parallel."):
  643. error(node.pos, node.module_name + " is not a module")
  644. if node.module_name == u"cython.parallel":
  645. if node.as_name and node.as_name != u"cython":
  646. self.parallel_directives[node.as_name] = node.module_name
  647. else:
  648. self.cython_module_names.add(u"cython")
  649. self.parallel_directives[
  650. u"cython.parallel"] = node.module_name
  651. self.module_scope.use_utility_code(
  652. UtilityCode.load_cached("InitThreads", "ModuleSetupCode.c"))
  653. elif node.as_name:
  654. self.directive_names[node.as_name] = node.module_name[7:]
  655. else:
  656. self.cython_module_names.add(u"cython")
  657. # if this cimport was a compiler directive, we don't
  658. # want to leave the cimport node sitting in the tree
  659. return None
  660. return node
  661. def visit_FromCImportStatNode(self, node):
  662. if not node.relative_level and (
  663. node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
  664. submodule = (node.module_name + u".")[7:]
  665. newimp = []
  666. for pos, name, as_name, kind in node.imported_names:
  667. full_name = submodule + name
  668. qualified_name = u"cython." + full_name
  669. if self.is_parallel_directive(qualified_name, node.pos):
  670. # from cython cimport parallel, or
  671. # from cython.parallel cimport parallel, prange, ...
  672. self.parallel_directives[as_name or name] = qualified_name
  673. elif self.is_cython_directive(full_name):
  674. self.directive_names[as_name or name] = full_name
  675. if kind is not None:
  676. self.context.nonfatal_error(PostParseError(pos,
  677. "Compiler directive imports must be plain imports"))
  678. else:
  679. newimp.append((pos, name, as_name, kind))
  680. if not newimp:
  681. return None
  682. node.imported_names = newimp
  683. return node
  684. def visit_FromImportStatNode(self, node):
  685. if (node.module.module_name.value == u"cython") or \
  686. node.module.module_name.value.startswith(u"cython."):
  687. submodule = (node.module.module_name.value + u".")[7:]
  688. newimp = []
  689. for name, name_node in node.items:
  690. full_name = submodule + name
  691. qualified_name = u"cython." + full_name
  692. if self.is_parallel_directive(qualified_name, node.pos):
  693. self.parallel_directives[name_node.name] = qualified_name
  694. elif self.is_cython_directive(full_name):
  695. self.directive_names[name_node.name] = full_name
  696. else:
  697. newimp.append((name, name_node))
  698. if not newimp:
  699. return None
  700. node.items = newimp
  701. return node
  702. def visit_SingleAssignmentNode(self, node):
  703. if isinstance(node.rhs, ExprNodes.ImportNode):
  704. module_name = node.rhs.module_name.value
  705. is_parallel = (module_name + u".").startswith(u"cython.parallel.")
  706. if module_name != u"cython" and not is_parallel:
  707. return node
  708. module_name = node.rhs.module_name.value
  709. as_name = node.lhs.name
  710. node = Nodes.CImportStatNode(node.pos,
  711. module_name = module_name,
  712. as_name = as_name)
  713. node = self.visit_CImportStatNode(node)
  714. else:
  715. self.visitchildren(node)
  716. return node
  717. def visit_NameNode(self, node):
  718. if node.name in self.cython_module_names:
  719. node.is_cython_module = True
  720. else:
  721. directive = self.directive_names.get(node.name)
  722. if directive is not None:
  723. node.cython_attribute = directive
  724. return node
  725. def visit_NewExprNode(self, node):
  726. self.visit(node.cppclass)
  727. self.visitchildren(node)
  728. return node
  729. def try_to_parse_directives(self, node):
  730. # If node is the contents of an directive (in a with statement or
  731. # decorator), returns a list of (directivename, value) pairs.
  732. # Otherwise, returns None
  733. if isinstance(node, ExprNodes.CallNode):
  734. self.visit(node.function)
  735. optname = node.function.as_cython_attribute()
  736. if optname:
  737. directivetype = Options.directive_types.get(optname)
  738. if directivetype:
  739. args, kwds = node.explicit_args_kwds()
  740. directives = []
  741. key_value_pairs = []
  742. if kwds is not None and directivetype is not dict:
  743. for keyvalue in kwds.key_value_pairs:
  744. key, value = keyvalue
  745. sub_optname = "%s.%s" % (optname, key.value)
  746. if Options.directive_types.get(sub_optname):
  747. directives.append(self.try_to_parse_directive(sub_optname, [value], None, keyvalue.pos))
  748. else:
  749. key_value_pairs.append(keyvalue)
  750. if not key_value_pairs:
  751. kwds = None
  752. else:
  753. kwds.key_value_pairs = key_value_pairs
  754. if directives and not kwds and not args:
  755. return directives
  756. directives.append(self.try_to_parse_directive(optname, args, kwds, node.function.pos))
  757. return directives
  758. elif isinstance(node, (ExprNodes.AttributeNode, ExprNodes.NameNode)):
  759. self.visit(node)
  760. optname = node.as_cython_attribute()
  761. if optname:
  762. directivetype = Options.directive_types.get(optname)
  763. if directivetype is bool:
  764. arg = ExprNodes.BoolNode(node.pos, value=True)
  765. return [self.try_to_parse_directive(optname, [arg], None, node.pos)]
  766. elif directivetype is None:
  767. return [(optname, None)]
  768. else:
  769. raise PostParseError(
  770. node.pos, "The '%s' directive should be used as a function call." % optname)
  771. return None
  772. def try_to_parse_directive(self, optname, args, kwds, pos):
  773. if optname == 'np_pythran' and not self.context.cpp:
  774. raise PostParseError(pos, 'The %s directive can only be used in C++ mode.' % optname)
  775. elif optname == 'exceptval':
  776. # default: exceptval(None, check=True)
  777. arg_error = len(args) > 1
  778. check = True
  779. if kwds and kwds.key_value_pairs:
  780. kw = kwds.key_value_pairs[0]
  781. if (len(kwds.key_value_pairs) == 1 and
  782. kw.key.is_string_literal and kw.key.value == 'check' and
  783. isinstance(kw.value, ExprNodes.BoolNode)):
  784. check = kw.value.value
  785. else:
  786. arg_error = True
  787. if arg_error:
  788. raise PostParseError(
  789. pos, 'The exceptval directive takes 0 or 1 positional arguments and the boolean keyword "check"')
  790. return ('exceptval', (args[0] if args else None, check))
  791. directivetype = Options.directive_types.get(optname)
  792. if len(args) == 1 and isinstance(args[0], ExprNodes.NoneNode):
  793. return optname, Options.get_directive_defaults()[optname]
  794. elif directivetype is bool:
  795. if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.BoolNode):
  796. raise PostParseError(pos,
  797. 'The %s directive takes one compile-time boolean argument' % optname)
  798. return (optname, args[0].value)
  799. elif directivetype is int:
  800. if kwds is not None or len(args) != 1 or not isinstance(args[0], ExprNodes.IntNode):
  801. raise PostParseError(pos,
  802. 'The %s directive takes one compile-time integer argument' % optname)
  803. return (optname, int(args[0].value))
  804. elif directivetype is str:
  805. if kwds is not None or len(args) != 1 or not isinstance(
  806. args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
  807. raise PostParseError(pos,
  808. 'The %s directive takes one compile-time string argument' % optname)
  809. return (optname, str(args[0].value))
  810. elif directivetype is type:
  811. if kwds is not None or len(args) != 1:
  812. raise PostParseError(pos,
  813. 'The %s directive takes one type argument' % optname)
  814. return (optname, args[0])
  815. elif directivetype is dict:
  816. if len(args) != 0:
  817. raise PostParseError(pos,
  818. 'The %s directive takes no prepositional arguments' % optname)
  819. return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
  820. elif directivetype is list:
  821. if kwds and len(kwds.key_value_pairs) != 0:
  822. raise PostParseError(pos,
  823. 'The %s directive takes no keyword arguments' % optname)
  824. return optname, [ str(arg.value) for arg in args ]
  825. elif callable(directivetype):
  826. if kwds is not None or len(args) != 1 or not isinstance(
  827. args[0], (ExprNodes.StringNode, ExprNodes.UnicodeNode)):
  828. raise PostParseError(pos,
  829. 'The %s directive takes one compile-time string argument' % optname)
  830. return (optname, directivetype(optname, str(args[0].value)))
  831. else:
  832. assert False
  833. def visit_with_directives(self, node, directives):
  834. if not directives:
  835. return self.visit_Node(node)
  836. old_directives = self.directives
  837. new_directives = dict(old_directives)
  838. new_directives.update(directives)
  839. if new_directives == old_directives:
  840. return self.visit_Node(node)
  841. self.directives = new_directives
  842. retbody = self.visit_Node(node)
  843. self.directives = old_directives
  844. if not isinstance(retbody, Nodes.StatListNode):
  845. retbody = Nodes.StatListNode(node.pos, stats=[retbody])
  846. return Nodes.CompilerDirectivesNode(
  847. pos=retbody.pos, body=retbody, directives=new_directives)
  848. # Handle decorators
  849. def visit_FuncDefNode(self, node):
  850. directives = self._extract_directives(node, 'function')
  851. return self.visit_with_directives(node, directives)
  852. def visit_CVarDefNode(self, node):
  853. directives = self._extract_directives(node, 'function')
  854. for name, value in directives.items():
  855. if name == 'locals':
  856. node.directive_locals = value
  857. elif name not in ('final', 'staticmethod'):
  858. self.context.nonfatal_error(PostParseError(
  859. node.pos,
  860. "Cdef functions can only take cython.locals(), "
  861. "staticmethod, or final decorators, got %s." % name))
  862. return self.visit_with_directives(node, directives)
  863. def visit_CClassDefNode(self, node):
  864. directives = self._extract_directives(node, 'cclass')
  865. return self.visit_with_directives(node, directives)
  866. def visit_CppClassNode(self, node):
  867. directives = self._extract_directives(node, 'cppclass')
  868. return self.visit_with_directives(node, directives)
  869. def visit_PyClassDefNode(self, node):
  870. directives = self._extract_directives(node, 'class')
  871. return self.visit_with_directives(node, directives)
  872. def _extract_directives(self, node, scope_name):
  873. if not node.decorators:
  874. return {}
  875. # Split the decorators into two lists -- real decorators and directives
  876. directives = []
  877. realdecs = []
  878. both = []
  879. # Decorators coming first take precedence.
  880. for dec in node.decorators[::-1]:
  881. new_directives = self.try_to_parse_directives(dec.decorator)
  882. if new_directives is not None:
  883. for directive in new_directives:
  884. if self.check_directive_scope(node.pos, directive[0], scope_name):
  885. name, value = directive
  886. if self.directives.get(name, object()) != value:
  887. directives.append(directive)
  888. if directive[0] == 'staticmethod':
  889. both.append(dec)
  890. # Adapt scope type based on decorators that change it.
  891. if directive[0] == 'cclass' and scope_name == 'class':
  892. scope_name = 'cclass'
  893. else:
  894. realdecs.append(dec)
  895. if realdecs and (scope_name == 'cclass' or
  896. isinstance(node, (Nodes.CFuncDefNode, Nodes.CClassDefNode, Nodes.CVarDefNode))):
  897. raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
  898. node.decorators = realdecs[::-1] + both[::-1]
  899. # merge or override repeated directives
  900. optdict = {}
  901. for directive in directives:
  902. name, value = directive
  903. if name in optdict:
  904. old_value = optdict[name]
  905. # keywords and arg lists can be merged, everything
  906. # else overrides completely
  907. if isinstance(old_value, dict):
  908. old_value.update(value)
  909. elif isinstance(old_value, list):
  910. old_value.extend(value)
  911. else:
  912. optdict[name] = value
  913. else:
  914. optdict[name] = value
  915. return optdict
  916. # Handle with-statements
  917. def visit_WithStatNode(self, node):
  918. directive_dict = {}
  919. for directive in self.try_to_parse_directives(node.manager) or []:
  920. if directive is not None:
  921. if node.target is not None:
  922. self.context.nonfatal_error(
  923. PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
  924. else:
  925. name, value = directive
  926. if name in ('nogil', 'gil'):
  927. # special case: in pure mode, "with nogil" spells "with cython.nogil"
  928. node = Nodes.GILStatNode(node.pos, state = name, body = node.body)
  929. return self.visit_Node(node)
  930. if self.check_directive_scope(node.pos, name, 'with statement'):
  931. directive_dict[name] = value
  932. if directive_dict:
  933. return self.visit_with_directives(node.body, directive_dict)
  934. return self.visit_Node(node)
  935. class ParallelRangeTransform(CythonTransform, SkipDeclarations):
  936. """
  937. Transform cython.parallel stuff. The parallel_directives come from the
  938. module node, set there by InterpretCompilerDirectives.
  939. x = cython.parallel.threadavailable() -> ParallelThreadAvailableNode
  940. with nogil, cython.parallel.parallel(): -> ParallelWithBlockNode
  941. print cython.parallel.threadid() -> ParallelThreadIdNode
  942. for i in cython.parallel.prange(...): -> ParallelRangeNode
  943. ...
  944. """
  945. # a list of names, maps 'cython.parallel.prange' in the code to
  946. # ['cython', 'parallel', 'prange']
  947. parallel_directive = None
  948. # Indicates whether a namenode in an expression is the cython module
  949. namenode_is_cython_module = False
  950. # Keep track of whether we are the context manager of a 'with' statement
  951. in_context_manager_section = False
  952. # One of 'prange' or 'with parallel'. This is used to disallow closely
  953. # nested 'with parallel:' blocks
  954. state = None
  955. directive_to_node = {
  956. u"cython.parallel.parallel": Nodes.ParallelWithBlockNode,
  957. # u"cython.parallel.threadsavailable": ExprNodes.ParallelThreadsAvailableNode,
  958. u"cython.parallel.threadid": ExprNodes.ParallelThreadIdNode,
  959. u"cython.parallel.prange": Nodes.ParallelRangeNode,
  960. }
  961. def node_is_parallel_directive(self, node):
  962. return node.name in self.parallel_directives or node.is_cython_module
  963. def get_directive_class_node(self, node):
  964. """
  965. Figure out which parallel directive was used and return the associated
  966. Node class.
  967. E.g. for a cython.parallel.prange() call we return ParallelRangeNode
  968. """
  969. if self.namenode_is_cython_module:
  970. directive = '.'.join(self.parallel_directive)
  971. else:
  972. directive = self.parallel_directives[self.parallel_directive[0]]
  973. directive = '%s.%s' % (directive,
  974. '.'.join(self.parallel_directive[1:]))
  975. directive = directive.rstrip('.')
  976. cls = self.directive_to_node.get(directive)
  977. if cls is None and not (self.namenode_is_cython_module and
  978. self.parallel_directive[0] != 'parallel'):
  979. error(node.pos, "Invalid directive: %s" % directive)
  980. self.namenode_is_cython_module = False
  981. self.parallel_directive = None
  982. return cls
  983. def visit_ModuleNode(self, node):
  984. """
  985. If any parallel directives were imported, copy them over and visit
  986. the AST
  987. """
  988. if node.parallel_directives:
  989. self.parallel_directives = node.parallel_directives
  990. return self.visit_Node(node)
  991. # No parallel directives were imported, so they can't be used :)
  992. return node
  993. def visit_NameNode(self, node):
  994. if self.node_is_parallel_directive(node):
  995. self.parallel_directive = [node.name]
  996. self.namenode_is_cython_module = node.is_cython_module
  997. return node
  998. def visit_AttributeNode(self, node):
  999. self.visitchildren(node)
  1000. if self.parallel_directive:
  1001. self.parallel_directive.append(node.attribute)
  1002. return node
  1003. def visit_CallNode(self, node):
  1004. self.visit(node.function)
  1005. if not self.parallel_directive:
  1006. self.visitchildren(node, exclude=('function',))
  1007. return node
  1008. # We are a parallel directive, replace this node with the
  1009. # corresponding ParallelSomethingSomething node
  1010. if isinstance(node, ExprNodes.GeneralCallNode):
  1011. args = node.positional_args.args
  1012. kwargs = node.keyword_args
  1013. else:
  1014. args = node.args
  1015. kwargs = {}
  1016. parallel_directive_class = self.get_directive_class_node(node)
  1017. if parallel_directive_class:
  1018. # Note: in case of a parallel() the body is set by
  1019. # visit_WithStatNode
  1020. node = parallel_directive_class(node.pos, args=args, kwargs=kwargs)
  1021. return node
  1022. def visit_WithStatNode(self, node):
  1023. "Rewrite with cython.parallel.parallel() blocks"
  1024. newnode = self.visit(node.manager)
  1025. if isinstance(newnode, Nodes.ParallelWithBlockNode):
  1026. if self.state == 'parallel with':
  1027. error(node.manager.pos,
  1028. "Nested parallel with blocks are disallowed")
  1029. self.state = 'parallel with'
  1030. body = self.visit(node.body)
  1031. self.state = None
  1032. newnode.body = body
  1033. return newnode
  1034. elif self.parallel_directive:
  1035. parallel_directive_class = self.get_directive_class_node(node)
  1036. if not parallel_directive_class:
  1037. # There was an error, stop here and now
  1038. return None
  1039. if parallel_directive_class is Nodes.ParallelWithBlockNode:
  1040. error(node.pos, "The parallel directive must be called")
  1041. return None
  1042. node.body = self.visit(node.body)
  1043. return node
  1044. def visit_ForInStatNode(self, node):
  1045. "Rewrite 'for i in cython.parallel.prange(...):'"
  1046. self.visit(node.iterator)
  1047. self.visit(node.target)
  1048. in_prange = isinstance(node.iterator.sequence,
  1049. Nodes.ParallelRangeNode)
  1050. previous_state = self.state
  1051. if in_prange:
  1052. # This will replace the entire ForInStatNode, so copy the
  1053. # attributes
  1054. parallel_range_node = node.iterator.sequence
  1055. parallel_range_node.target = node.target
  1056. parallel_range_node.body = node.body
  1057. parallel_range_node.else_clause = node.else_clause
  1058. node = parallel_range_node
  1059. if not isinstance(node.target, ExprNodes.NameNode):
  1060. error(node.target.pos,
  1061. "Can only iterate over an iteration variable")
  1062. self.state = 'prange'
  1063. self.visit(node.body)
  1064. self.state = previous_state
  1065. self.visit(node.else_clause)
  1066. return node
  1067. def visit(self, node):
  1068. "Visit a node that may be None"
  1069. if node is not None:
  1070. return super(ParallelRangeTransform, self).visit(node)
  1071. class WithTransform(CythonTransform, SkipDeclarations):
  1072. def visit_WithStatNode(self, node):
  1073. self.visitchildren(node, 'body')
  1074. pos = node.pos
  1075. is_async = node.is_async
  1076. body, target, manager = node.body, node.target, node.manager
  1077. node.enter_call = ExprNodes.SimpleCallNode(
  1078. pos, function=ExprNodes.AttributeNode(
  1079. pos, obj=ExprNodes.CloneNode(manager),
  1080. attribute=EncodedString('__aenter__' if is_async else '__enter__'),
  1081. is_special_lookup=True),
  1082. args=[],
  1083. is_temp=True)
  1084. if is_async:
  1085. node.enter_call = ExprNodes.AwaitExprNode(pos, arg=node.enter_call)
  1086. if target is not None:
  1087. body = Nodes.StatListNode(
  1088. pos, stats=[
  1089. Nodes.WithTargetAssignmentStatNode(
  1090. pos, lhs=target, with_node=node),
  1091. body])
  1092. excinfo_target = ExprNodes.TupleNode(pos, slow=True, args=[
  1093. ExprNodes.ExcValueNode(pos) for _ in range(3)])
  1094. except_clause = Nodes.ExceptClauseNode(
  1095. pos, body=Nodes.IfStatNode(
  1096. pos, if_clauses=[
  1097. Nodes.IfClauseNode(
  1098. pos, condition=ExprNodes.NotNode(
  1099. pos, operand=ExprNodes.WithExitCallNode(
  1100. pos, with_stat=node,
  1101. test_if_run=False,
  1102. args=excinfo_target,
  1103. await_expr=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
  1104. body=Nodes.ReraiseStatNode(pos),
  1105. ),
  1106. ],
  1107. else_clause=None),
  1108. pattern=None,
  1109. target=None,
  1110. excinfo_target=excinfo_target,
  1111. )
  1112. node.body = Nodes.TryFinallyStatNode(
  1113. pos, body=Nodes.TryExceptStatNode(
  1114. pos, body=body,
  1115. except_clauses=[except_clause],
  1116. else_clause=None,
  1117. ),
  1118. finally_clause=Nodes.ExprStatNode(
  1119. pos, expr=ExprNodes.WithExitCallNode(
  1120. pos, with_stat=node,
  1121. test_if_run=True,
  1122. args=ExprNodes.TupleNode(
  1123. pos, args=[ExprNodes.NoneNode(pos) for _ in range(3)]),
  1124. await_expr=ExprNodes.AwaitExprNode(pos, arg=None) if is_async else None)),
  1125. handle_error_case=False,
  1126. )
  1127. return node
  1128. def visit_ExprNode(self, node):
  1129. # With statements are never inside expressions.
  1130. return node
  1131. class DecoratorTransform(ScopeTrackingTransform, SkipDeclarations):
  1132. """
  1133. Transforms method decorators in cdef classes into nested calls or properties.
  1134. Python-style decorator properties are transformed into a PropertyNode
  1135. with up to the three getter, setter and deleter DefNodes.
  1136. The functional style isn't supported yet.
  1137. """
  1138. _properties = None
  1139. _map_property_attribute = {
  1140. 'getter': '__get__',
  1141. 'setter': '__set__',
  1142. 'deleter': '__del__',
  1143. }.get
  1144. def visit_CClassDefNode(self, node):
  1145. if self._properties is None:
  1146. self._properties = []
  1147. self._properties.append({})
  1148. super(DecoratorTransform, self).visit_CClassDefNode(node)
  1149. self._properties.pop()
  1150. return node
  1151. def visit_PropertyNode(self, node):
  1152. # Low-level warning for other code until we can convert all our uses over.
  1153. level = 2 if isinstance(node.pos[0], str) else 0
  1154. warning(node.pos, "'property %s:' syntax is deprecated, use '@property'" % node.name, level)
  1155. return node
  1156. def visit_DefNode(self, node):
  1157. scope_type = self.scope_type
  1158. node = self.visit_FuncDefNode(node)
  1159. if scope_type != 'cclass' or not node.decorators:
  1160. return node
  1161. # transform @property decorators
  1162. properties = self._properties[-1]
  1163. for decorator_node in node.decorators[::-1]:
  1164. decorator = decorator_node.decorator
  1165. if decorator.is_name and decorator.name == 'property':
  1166. if len(node.decorators) > 1:
  1167. return self._reject_decorated_property(node, decorator_node)
  1168. name = node.name
  1169. node.name = EncodedString('__get__')
  1170. node.decorators.remove(decorator_node)
  1171. stat_list = [node]
  1172. if name in properties:
  1173. prop = properties[name]
  1174. prop.pos = node.pos
  1175. prop.doc = node.doc
  1176. prop.body.stats = stat_list
  1177. return []
  1178. prop = Nodes.PropertyNode(node.pos, name=name)
  1179. prop.doc = node.doc
  1180. prop.body = Nodes.StatListNode(node.pos, stats=stat_list)
  1181. properties[name] = prop
  1182. return [prop]
  1183. elif decorator.is_attribute and decorator.obj.name in properties:
  1184. handler_name = self._map_property_attribute(decorator.attribute)
  1185. if handler_name:
  1186. if decorator.obj.name != node.name:
  1187. # CPython does not generate an error or warning, but not something useful either.
  1188. error(decorator_node.pos,
  1189. "Mismatching property names, expected '%s', got '%s'" % (
  1190. decorator.obj.name, node.name))
  1191. elif len(node.decorators) > 1:
  1192. return self._reject_decorated_property(node, decorator_node)
  1193. else:
  1194. return self._add_to_property(properties, node, handler_name, decorator_node)
  1195. # we clear node.decorators, so we need to set the
  1196. # is_staticmethod/is_classmethod attributes now
  1197. for decorator in node.decorators:
  1198. func = decorator.decorator
  1199. if func.is_name:
  1200. node.is_classmethod |= func.name == 'classmethod'
  1201. node.is_staticmethod |= func.name == 'staticmethod'
  1202. # transform normal decorators
  1203. decs = node.decorators
  1204. node.decorators = None
  1205. return self.chain_decorators(node, decs, node.name)
  1206. @staticmethod
  1207. def _reject_decorated_property(node, decorator_node):
  1208. # restrict transformation to outermost decorator as wrapped properties will probably not work
  1209. for deco in node.decorators:
  1210. if deco != decorator_node:
  1211. error(deco.pos, "Property methods with additional decorators are not supported")
  1212. return node
  1213. @staticmethod
  1214. def _add_to_property(properties, node, name, decorator):
  1215. prop = properties[node.name]
  1216. node.name = name
  1217. node.decorators.remove(decorator)
  1218. stats = prop.body.stats
  1219. for i, stat in enumerate(stats):
  1220. if stat.name == name:
  1221. stats[i] = node
  1222. break
  1223. else:
  1224. stats.append(node)
  1225. return []
  1226. @staticmethod
  1227. def chain_decorators(node, decorators, name):
  1228. """
  1229. Decorators are applied directly in DefNode and PyClassDefNode to avoid
  1230. reassignments to the function/class name - except for cdef class methods.
  1231. For those, the reassignment is required as methods are originally
  1232. defined in the PyMethodDef struct.
  1233. The IndirectionNode allows DefNode to override the decorator.
  1234. """
  1235. decorator_result = ExprNodes.NameNode(node.pos, name=name)
  1236. for decorator in decorators[::-1]:
  1237. decorator_result = ExprNodes.SimpleCallNode(
  1238. decorator.pos,
  1239. function=decorator.decorator,
  1240. args=[decorator_result])
  1241. name_node = ExprNodes.NameNode(node.pos, name=name)
  1242. reassignment = Nodes.SingleAssignmentNode(
  1243. node.pos,
  1244. lhs=name_node,
  1245. rhs=decorator_result)
  1246. reassignment = Nodes.IndirectionNode([reassignment])
  1247. node.decorator_indirection = reassignment
  1248. return [node, reassignment]
  1249. class CnameDirectivesTransform(CythonTransform, SkipDeclarations):
  1250. """
  1251. Only part of the CythonUtilityCode pipeline. Must be run before
  1252. DecoratorTransform in case this is a decorator for a cdef class.
  1253. It filters out @cname('my_cname') decorators and rewrites them to
  1254. CnameDecoratorNodes.
  1255. """
  1256. def handle_function(self, node):
  1257. if not getattr(node, 'decorators', None):
  1258. return self.visit_Node(node)
  1259. for i, decorator in enumerate(node.decorators):
  1260. decorator = decorator.decorator
  1261. if (isinstance(decorator, ExprNodes.CallNode) and
  1262. decorator.function.is_name and
  1263. decorator.function.name == 'cname'):
  1264. args, kwargs = decorator.explicit_args_kwds()
  1265. if kwargs:
  1266. raise AssertionError(
  1267. "cname decorator does not take keyword arguments")
  1268. if len(args) != 1:
  1269. raise AssertionError(
  1270. "cname decorator takes exactly one argument")
  1271. if not (args[0].is_literal and
  1272. args[0].type == Builtin.str_type):
  1273. raise AssertionError(
  1274. "argument to cname decorator must be a string literal")
  1275. cname = args[0].compile_time_value(None)
  1276. del node.decorators[i]
  1277. node = Nodes.CnameDecoratorNode(pos=node.pos, node=node,
  1278. cname=cname)
  1279. break
  1280. return self.visit_Node(node)
  1281. visit_FuncDefNode = handle_function
  1282. visit_CClassDefNode = handle_function
  1283. visit_CEnumDefNode = handle_function
  1284. visit_CStructOrUnionDefNode = handle_function
  1285. class ForwardDeclareTypes(CythonTransform):
  1286. def visit_CompilerDirectivesNode(self, node):
  1287. env = self.module_scope
  1288. old = env.directives
  1289. env.directives = node.directives
  1290. self.visitchildren(node)
  1291. env.directives = old
  1292. return node
  1293. def visit_ModuleNode(self, node):
  1294. self.module_scope = node.scope
  1295. self.module_scope.directives = node.directives
  1296. self.visitchildren(node)
  1297. return node
  1298. def visit_CDefExternNode(self, node):
  1299. old_cinclude_flag = self.module_scope.in_cinclude
  1300. self.module_scope.in_cinclude = 1
  1301. self.visitchildren(node)
  1302. self.module_scope.in_cinclude = old_cinclude_flag
  1303. return node
  1304. def visit_CEnumDefNode(self, node):
  1305. node.declare(self.module_scope)
  1306. return node
  1307. def visit_CStructOrUnionDefNode(self, node):
  1308. if node.name not in self.module_scope.entries:
  1309. node.declare(self.module_scope)
  1310. return node
  1311. def visit_CClassDefNode(self, node):
  1312. if node.class_name not in self.module_scope.entries:
  1313. node.declare(self.module_scope)
  1314. # Expand fused methods of .pxd declared types to construct the final vtable order.
  1315. type = self.module_scope.entries[node.class_name].type
  1316. if type is not None and type.is_extension_type and not type.is_builtin_type and type.scope:
  1317. scope = type.scope
  1318. for entry in scope.cfunc_entries:
  1319. if entry.type and entry.type.is_fused:
  1320. entry.type.get_all_specialized_function_types()
  1321. return node
  1322. class AnalyseDeclarationsTransform(EnvTransform):
  1323. basic_property = TreeFragment(u"""
  1324. property NAME:
  1325. def __get__(self):
  1326. return ATTR
  1327. def __set__(self, value):
  1328. ATTR = value
  1329. """, level='c_class', pipeline=[NormalizeTree(None)])
  1330. basic_pyobject_property = TreeFragment(u"""
  1331. property NAME:
  1332. def __get__(self):
  1333. return ATTR
  1334. def __set__(self, value):
  1335. ATTR = value
  1336. def __del__(self):
  1337. ATTR = None
  1338. """, level='c_class', pipeline=[NormalizeTree(None)])
  1339. basic_property_ro = TreeFragment(u"""
  1340. property NAME:
  1341. def __get__(self):
  1342. return ATTR
  1343. """, level='c_class', pipeline=[NormalizeTree(None)])
  1344. struct_or_union_wrapper = TreeFragment(u"""
  1345. cdef class NAME:
  1346. cdef TYPE value
  1347. def __init__(self, MEMBER=None):
  1348. cdef int count
  1349. count = 0
  1350. INIT_ASSIGNMENTS
  1351. if IS_UNION and count > 1:
  1352. raise ValueError, "At most one union member should be specified."
  1353. def __str__(self):
  1354. return STR_FORMAT % MEMBER_TUPLE
  1355. def __repr__(self):
  1356. return REPR_FORMAT % MEMBER_TUPLE
  1357. """, pipeline=[NormalizeTree(None)])
  1358. init_assignment = TreeFragment(u"""
  1359. if VALUE is not None:
  1360. ATTR = VALUE
  1361. count += 1
  1362. """, pipeline=[NormalizeTree(None)])
  1363. fused_function = None
  1364. in_lambda = 0
  1365. def __call__(self, root):
  1366. # needed to determine if a cdef var is declared after it's used.
  1367. self.seen_vars_stack = []
  1368. self.fused_error_funcs = set()
  1369. super_class = super(AnalyseDeclarationsTransform, self)
  1370. self._super_visit_FuncDefNode = super_class.visit_FuncDefNode
  1371. return super_class.__call__(root)
  1372. def visit_NameNode(self, node):
  1373. self.seen_vars_stack[-1].add(node.name)
  1374. return node
  1375. def visit_ModuleNode(self, node):
  1376. # Pickling support requires injecting module-level nodes.
  1377. self.extra_module_declarations = []
  1378. self.seen_vars_stack.append(set())
  1379. node.analyse_declarations(self.current_env())
  1380. self.visitchildren(node)
  1381. self.seen_vars_stack.pop()
  1382. node.body.stats.extend(self.extra_module_declarations)
  1383. return node
  1384. def visit_LambdaNode(self, node):
  1385. self.in_lambda += 1
  1386. node.analyse_declarations(self.current_env())
  1387. self.visitchildren(node)
  1388. self.in_lambda -= 1
  1389. return node
  1390. def visit_CClassDefNode(self, node):
  1391. node = self.visit_ClassDefNode(node)
  1392. if node.scope and node.scope.implemented and node.body:
  1393. stats = []
  1394. for entry in node.scope.var_entries:
  1395. if entry.needs_property:
  1396. property = self.create_Property(entry)
  1397. property.analyse_declarations(node.scope)
  1398. self.visit(property)
  1399. stats.append(property)
  1400. if stats:
  1401. node.body.stats += stats
  1402. if (node.visibility != 'extern'
  1403. and not node.scope.lookup('__reduce__')
  1404. and not node.scope.lookup('__reduce_ex__')):
  1405. self._inject_pickle_methods(node)
  1406. return node
  1407. def _inject_pickle_methods(self, node):
  1408. env = self.current_env()
  1409. if node.scope.directives['auto_pickle'] is False: # None means attempt it.
  1410. # Old behavior of not doing anything.
  1411. return
  1412. auto_pickle_forced = node.scope.directives['auto_pickle'] is True
  1413. all_members = []
  1414. cls = node.entry.type
  1415. cinit = None
  1416. inherited_reduce = None
  1417. while cls is not None:
  1418. all_members.extend(e for e in cls.scope.var_entries if e.name not in ('__weakref__', '__dict__'))
  1419. cinit = cinit or cls.scope.lookup('__cinit__')
  1420. inherited_reduce = inherited_reduce or cls.scope.lookup('__reduce__') or cls.scope.lookup('__reduce_ex__')
  1421. cls = cls.base_type
  1422. all_members.sort(key=lambda e: e.name)
  1423. if inherited_reduce:
  1424. # This is not failsafe, as we may not know whether a cimported class defines a __reduce__.
  1425. # This is why we define __reduce_cython__ and only replace __reduce__
  1426. # (via ExtensionTypes.SetupReduce utility code) at runtime on class creation.
  1427. return
  1428. non_py = [
  1429. e for e in all_members
  1430. if not e.type.is_pyobject and (not e.type.can_coerce_to_pyobject(env)
  1431. or not e.type.can_coerce_from_pyobject(env))
  1432. ]
  1433. structs = [e for e in all_members if e.type.is_struct_or_union]
  1434. if cinit or non_py or (structs and not auto_pickle_forced):
  1435. if cinit:
  1436. # TODO(robertwb): We could allow this if __cinit__ has no require arguments.
  1437. msg = 'no default __reduce__ due to non-trivial __cinit__'
  1438. elif non_py:
  1439. msg = "%s cannot be converted to a Python object for pickling" % ','.join("self.%s" % e.name for e in non_py)
  1440. else:
  1441. # Extern structs may be only partially defined.
  1442. # TODO(robertwb): Limit the restriction to extern
  1443. # (and recursively extern-containing) structs.
  1444. msg = ("Pickling of struct members such as %s must be explicitly requested "
  1445. "with @auto_pickle(True)" % ','.join("self.%s" % e.name for e in structs))
  1446. if auto_pickle_forced:
  1447. error(node.pos, msg)
  1448. pickle_func = TreeFragment(u"""
  1449. def __reduce_cython__(self):
  1450. raise TypeError("%(msg)s")
  1451. def __setstate_cython__(self, __pyx_state):
  1452. raise TypeError("%(msg)s")
  1453. """ % {'msg': msg},
  1454. level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
  1455. pickle_func.analyse_declarations(node.scope)
  1456. self.visit(pickle_func)
  1457. node.body.stats.append(pickle_func)
  1458. else:
  1459. for e in all_members:
  1460. if not e.type.is_pyobject:
  1461. e.type.create_to_py_utility_code(env)
  1462. e.type.create_from_py_utility_code(env)
  1463. all_members_names = [e.name for e in all_members]
  1464. checksums = _calculate_pickle_checksums(all_members_names)
  1465. unpickle_func_name = '__pyx_unpickle_%s' % node.class_name
  1466. # TODO(robertwb): Move the state into the third argument
  1467. # so it can be pickled *after* self is memoized.
  1468. unpickle_func = TreeFragment(u"""
  1469. def %(unpickle_func_name)s(__pyx_type, long __pyx_checksum, __pyx_state):
  1470. cdef object __pyx_PickleError
  1471. cdef object __pyx_result
  1472. if __pyx_checksum not in %(checksums)s:
  1473. from pickle import PickleError as __pyx_PickleError
  1474. raise __pyx_PickleError("Incompatible checksums (0x%%x vs %(checksums)s = (%(members)s))" %% __pyx_checksum)
  1475. __pyx_result = %(class_name)s.__new__(__pyx_type)
  1476. if __pyx_state is not None:
  1477. %(unpickle_func_name)s__set_state(<%(class_name)s> __pyx_result, __pyx_state)
  1478. return __pyx_result
  1479. cdef %(unpickle_func_name)s__set_state(%(class_name)s __pyx_result, tuple __pyx_state):
  1480. %(assignments)s
  1481. if len(__pyx_state) > %(num_members)d and hasattr(__pyx_result, '__dict__'):
  1482. __pyx_result.__dict__.update(__pyx_state[%(num_members)d])
  1483. """ % {
  1484. 'unpickle_func_name': unpickle_func_name,
  1485. 'checksums': "(%s)" % ', '.join(checksums),
  1486. 'members': ', '.join(all_members_names),
  1487. 'class_name': node.class_name,
  1488. 'assignments': '; '.join(
  1489. '__pyx_result.%s = __pyx_state[%s]' % (v, ix)
  1490. for ix, v in enumerate(all_members_names)),
  1491. 'num_members': len(all_members_names),
  1492. }, level='module', pipeline=[NormalizeTree(None)]).substitute({})
  1493. unpickle_func.analyse_declarations(node.entry.scope)
  1494. self.visit(unpickle_func)
  1495. self.extra_module_declarations.append(unpickle_func)
  1496. pickle_func = TreeFragment(u"""
  1497. def __reduce_cython__(self):
  1498. cdef tuple state
  1499. cdef object _dict
  1500. cdef bint use_setstate
  1501. state = (%(members)s)
  1502. _dict = getattr(self, '__dict__', None)
  1503. if _dict is not None:
  1504. state += (_dict,)
  1505. use_setstate = True
  1506. else:
  1507. use_setstate = %(any_notnone_members)s
  1508. if use_setstate:
  1509. return %(unpickle_func_name)s, (type(self), %(checksum)s, None), state
  1510. else:
  1511. return %(unpickle_func_name)s, (type(self), %(checksum)s, state)
  1512. def __setstate_cython__(self, __pyx_state):
  1513. %(unpickle_func_name)s__set_state(self, __pyx_state)
  1514. """ % {
  1515. 'unpickle_func_name': unpickle_func_name,
  1516. 'checksum': checksums[0],
  1517. 'members': ', '.join('self.%s' % v for v in all_members_names) + (',' if len(all_members_names) == 1 else ''),
  1518. # Even better, we could check PyType_IS_GC.
  1519. 'any_notnone_members' : ' or '.join(['self.%s is not None' % e.name for e in all_members if e.type.is_pyobject] or ['False']),
  1520. },
  1521. level='c_class', pipeline=[NormalizeTree(None)]).substitute({})
  1522. pickle_func.analyse_declarations(node.scope)
  1523. self.enter_scope(node, node.scope) # functions should be visited in the class scope
  1524. self.visit(pickle_func)
  1525. self.exit_scope()
  1526. node.body.stats.append(pickle_func)
  1527. def _handle_fused_def_decorators(self, old_decorators, env, node):
  1528. """
  1529. Create function calls to the decorators and reassignments to
  1530. the function.
  1531. """
  1532. # Delete staticmethod and classmethod decorators, this is
  1533. # handled directly by the fused function object.
  1534. decorators = []
  1535. for decorator in old_decorators:
  1536. func = decorator.decorator
  1537. if (not func.is_name or
  1538. func.name not in ('staticmethod', 'classmethod') or
  1539. env.lookup_here(func.name)):
  1540. # not a static or classmethod
  1541. decorators.append(decorator)
  1542. if decorators:
  1543. transform = DecoratorTransform(self.context)
  1544. def_node = node.node
  1545. _, reassignments = transform.chain_decorators(
  1546. def_node, decorators, def_node.name)
  1547. reassignments.analyse_declarations(env)
  1548. node = [node, reassignments]
  1549. return node
  1550. def _handle_def(self, decorators, env, node):
  1551. "Handle def or cpdef fused functions"
  1552. # Create PyCFunction nodes for each specialization
  1553. node.stats.insert(0, node.py_func)
  1554. node.py_func = self.visit(node.py_func)
  1555. node.update_fused_defnode_entry(env)
  1556. pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, binding=True)
  1557. pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
  1558. node.resulting_fused_function = pycfunc
  1559. # Create assignment node for our def function
  1560. node.fused_func_assignment = self._create_assignment(
  1561. node.py_func, ExprNodes.CloneNode(pycfunc), env)
  1562. if decorators:
  1563. node = self._handle_fused_def_decorators(decorators, env, node)
  1564. return node
  1565. def _create_fused_function(self, env, node):
  1566. "Create a fused function for a DefNode with fused arguments"
  1567. from . import FusedNode
  1568. if self.fused_function or self.in_lambda:
  1569. if self.fused_function not in self.fused_error_funcs:
  1570. if self.in_lambda:
  1571. error(node.pos, "Fused lambdas not allowed")
  1572. else:
  1573. error(node.pos, "Cannot nest fused functions")
  1574. self.fused_error_funcs.add(self.fused_function)
  1575. node.body = Nodes.PassStatNode(node.pos)
  1576. for arg in node.args:
  1577. if arg.type.is_fused:
  1578. arg.type = arg.type.get_fused_types()[0]
  1579. return node
  1580. decorators = getattr(node, 'decorators', None)
  1581. node = FusedNode.FusedCFuncDefNode(node, env)
  1582. self.fused_function = node
  1583. self.visitchildren(node)
  1584. self.fused_function = None
  1585. if node.py_func:
  1586. node = self._handle_def(decorators, env, node)
  1587. return node
  1588. def _handle_nogil_cleanup(self, lenv, node):
  1589. "Handle cleanup for 'with gil' blocks in nogil functions."
  1590. if lenv.nogil and lenv.has_with_gil_block:
  1591. # Acquire the GIL for cleanup in 'nogil' functions, by wrapping
  1592. # the entire function body in try/finally.
  1593. # The corresponding release will be taken care of by
  1594. # Nodes.FuncDefNode.generate_function_definitions()
  1595. node.body = Nodes.NogilTryFinallyStatNode(
  1596. node.body.pos,
  1597. body=node.body,
  1598. finally_clause=Nodes.EnsureGILNode(node.body.pos),
  1599. finally_except_clause=Nodes.EnsureGILNode(node.body.pos))
  1600. def _handle_fused(self, node):
  1601. if node.is_generator and node.has_fused_arguments:
  1602. node.has_fused_arguments = False
  1603. error(node.pos, "Fused generators not supported")
  1604. node.gbody = Nodes.StatListNode(node.pos,
  1605. stats=[],
  1606. body=Nodes.PassStatNode(node.pos))
  1607. return node.has_fused_arguments
  1608. def visit_FuncDefNode(self, node):
  1609. """
  1610. Analyse a function and its body, as that hasn't happened yet. Also
  1611. analyse the directive_locals set by @cython.locals().
  1612. Then, if we are a function with fused arguments, replace the function
  1613. (after it has declared itself in the symbol table!) with a
  1614. FusedCFuncDefNode, and analyse its children (which are in turn normal
  1615. functions). If we're a normal function, just analyse the body of the
  1616. function.
  1617. """
  1618. env = self.current_env()
  1619. self.seen_vars_stack.append(set())
  1620. lenv = node.local_scope
  1621. node.declare_arguments(lenv)
  1622. # @cython.locals(...)
  1623. for var, type_node in node.directive_locals.items():
  1624. if not lenv.lookup_here(var): # don't redeclare args
  1625. type = type_node.analyse_as_type(lenv)
  1626. if type:
  1627. lenv.declare_var(var, type, type_node.pos)
  1628. else:
  1629. error(type_node.pos, "Not a type")
  1630. if self._handle_fused(node):
  1631. node = self._create_fused_function(env, node)
  1632. else:
  1633. node.body.analyse_declarations(lenv)
  1634. self._handle_nogil_cleanup(lenv, node)
  1635. self._super_visit_FuncDefNode(node)
  1636. self.seen_vars_stack.pop()
  1637. return node
  1638. def visit_DefNode(self, node):
  1639. node = self.visit_FuncDefNode(node)
  1640. env = self.current_env()
  1641. if isinstance(node, Nodes.DefNode) and node.is_wrapper:
  1642. env = env.parent_scope
  1643. if (not isinstance(node, Nodes.DefNode) or
  1644. node.fused_py_func or node.is_generator_body or
  1645. not node.needs_assignment_synthesis(env)):
  1646. return node
  1647. return [node, self._synthesize_assignment(node, env)]
  1648. def visit_GeneratorBodyDefNode(self, node):
  1649. return self.visit_FuncDefNode(node)
  1650. def _synthesize_assignment(self, node, env):
  1651. # Synthesize assignment node and put it right after defnode
  1652. genv = env
  1653. while genv.is_py_class_scope or genv.is_c_class_scope:
  1654. genv = genv.outer_scope
  1655. if genv.is_closure_scope:
  1656. rhs = node.py_cfunc_node = ExprNodes.InnerFunctionNode(
  1657. node.pos, def_node=node,
  1658. pymethdef_cname=node.entry.pymethdef_cname,
  1659. code_object=ExprNodes.CodeObjectNode(node))
  1660. else:
  1661. binding = self.current_directives.get('binding')
  1662. rhs = ExprNodes.PyCFunctionNode.from_defnode(node, binding)
  1663. node.code_object = rhs.code_object
  1664. if node.is_generator:
  1665. node.gbody.code_object = node.code_object
  1666. if env.is_py_class_scope:
  1667. rhs.binding = True
  1668. node.is_cyfunction = rhs.binding
  1669. return self._create_assignment(node, rhs, env)
  1670. def _create_assignment(self, def_node, rhs, env):
  1671. if def_node.decorators:
  1672. for decorator in def_node.decorators[::-1]:
  1673. rhs = ExprNodes.SimpleCallNode(
  1674. decorator.pos,
  1675. function = decorator.decorator,
  1676. args = [rhs])
  1677. def_node.decorators = None
  1678. assmt = Nodes.SingleAssignmentNode(
  1679. def_node.pos,
  1680. lhs=ExprNodes.NameNode(def_node.pos, name=def_node.name),
  1681. rhs=rhs)
  1682. assmt.analyse_declarations(env)
  1683. return assmt
  1684. def visit_ScopedExprNode(self, node):
  1685. env = self.current_env()
  1686. node.analyse_declarations(env)
  1687. # the node may or may not have a local scope
  1688. if node.has_local_scope:
  1689. self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
  1690. self.enter_scope(node, node.expr_scope)
  1691. node.analyse_scoped_declarations(node.expr_scope)
  1692. self.visitchildren(node)
  1693. self.exit_scope()
  1694. self.seen_vars_stack.pop()
  1695. else:
  1696. node.analyse_scoped_declarations(env)
  1697. self.visitchildren(node)
  1698. return node
  1699. def visit_TempResultFromStatNode(self, node):
  1700. self.visitchildren(node)
  1701. node.analyse_declarations(self.current_env())
  1702. return node
  1703. def visit_CppClassNode(self, node):
  1704. if node.visibility == 'extern':
  1705. return None
  1706. else:
  1707. return self.visit_ClassDefNode(node)
  1708. def visit_CStructOrUnionDefNode(self, node):
  1709. # Create a wrapper node if needed.
  1710. # We want to use the struct type information (so it can't happen
  1711. # before this phase) but also create new objects to be declared
  1712. # (so it can't happen later).
  1713. # Note that we don't return the original node, as it is
  1714. # never used after this phase.
  1715. if True: # private (default)
  1716. return None
  1717. self_value = ExprNodes.AttributeNode(
  1718. pos = node.pos,
  1719. obj = ExprNodes.NameNode(pos=node.pos, name=u"self"),
  1720. attribute = EncodedString(u"value"))
  1721. var_entries = node.entry.type.scope.var_entries
  1722. attributes = []
  1723. for entry in var_entries:
  1724. attributes.append(ExprNodes.AttributeNode(pos = entry.pos,
  1725. obj = self_value,
  1726. attribute = entry.name))
  1727. # __init__ assignments
  1728. init_assignments = []
  1729. for entry, attr in zip(var_entries, attributes):
  1730. # TODO: branch on visibility
  1731. init_assignments.append(self.init_assignment.substitute({
  1732. u"VALUE": ExprNodes.NameNode(entry.pos, name = entry.name),
  1733. u"ATTR": attr,
  1734. }, pos = entry.pos))
  1735. # create the class
  1736. str_format = u"%s(%s)" % (node.entry.type.name, ("%s, " * len(attributes))[:-2])
  1737. wrapper_class = self.struct_or_union_wrapper.substitute({
  1738. u"INIT_ASSIGNMENTS": Nodes.StatListNode(node.pos, stats = init_assignments),
  1739. u"IS_UNION": ExprNodes.BoolNode(node.pos, value = not node.entry.type.is_struct),
  1740. u"MEMBER_TUPLE": ExprNodes.TupleNode(node.pos, args=attributes),
  1741. u"STR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format)),
  1742. u"REPR_FORMAT": ExprNodes.StringNode(node.pos, value = EncodedString(str_format.replace("%s", "%r"))),
  1743. }, pos = node.pos).stats[0]
  1744. wrapper_class.class_name = node.name
  1745. wrapper_class.shadow = True
  1746. class_body = wrapper_class.body.stats
  1747. # fix value type
  1748. assert isinstance(class_body[0].base_type, Nodes.CSimpleBaseTypeNode)
  1749. class_body[0].base_type.name = node.name
  1750. # fix __init__ arguments
  1751. init_method = class_body[1]
  1752. assert isinstance(init_method, Nodes.DefNode) and init_method.name == '__init__'
  1753. arg_template = init_method.args[1]
  1754. if not node.entry.type.is_struct:
  1755. arg_template.kw_only = True
  1756. del init_method.args[1]
  1757. for entry, attr in zip(var_entries, attributes):
  1758. arg = copy.deepcopy(arg_template)
  1759. arg.declarator.name = entry.name
  1760. init_method.args.append(arg)
  1761. # setters/getters
  1762. for entry, attr in zip(var_entries, attributes):
  1763. # TODO: branch on visibility
  1764. if entry.type.is_pyobject:
  1765. template = self.basic_pyobject_property
  1766. else:
  1767. template = self.basic_property
  1768. property = template.substitute({
  1769. u"ATTR": attr,
  1770. }, pos = entry.pos).stats[0]
  1771. property.name = entry.name
  1772. wrapper_class.body.stats.append(property)
  1773. wrapper_class.analyse_declarations(self.current_env())
  1774. return self.visit_CClassDefNode(wrapper_class)
  1775. # Some nodes are no longer needed after declaration
  1776. # analysis and can be dropped. The analysis was performed
  1777. # on these nodes in a separate recursive process from the
  1778. # enclosing function or module, so we can simply drop them.
  1779. def visit_CDeclaratorNode(self, node):
  1780. # necessary to ensure that all CNameDeclaratorNodes are visited.
  1781. self.visitchildren(node)
  1782. return node
  1783. def visit_CTypeDefNode(self, node):
  1784. return node
  1785. def visit_CBaseTypeNode(self, node):
  1786. return None
  1787. def visit_CEnumDefNode(self, node):
  1788. if node.visibility == 'public':
  1789. return node
  1790. else:
  1791. return None
  1792. def visit_CNameDeclaratorNode(self, node):
  1793. if node.name in self.seen_vars_stack[-1]:
  1794. entry = self.current_env().lookup(node.name)
  1795. if (entry is None or entry.visibility != 'extern'
  1796. and not entry.scope.is_c_class_scope):
  1797. warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
  1798. self.visitchildren(node)
  1799. return node
  1800. def visit_CVarDefNode(self, node):
  1801. # to ensure all CNameDeclaratorNodes are visited.
  1802. self.visitchildren(node)
  1803. return None
  1804. def visit_CnameDecoratorNode(self, node):
  1805. child_node = self.visit(node.node)
  1806. if not child_node:
  1807. return None
  1808. if type(child_node) is list: # Assignment synthesized
  1809. node.child_node = child_node[0]
  1810. return [node] + child_node[1:]
  1811. node.node = child_node
  1812. return node
  1813. def create_Property(self, entry):
  1814. if entry.visibility == 'public':
  1815. if entry.type.is_pyobject:
  1816. template = self.basic_pyobject_property
  1817. else:
  1818. template = self.basic_property
  1819. elif entry.visibility == 'readonly':
  1820. template = self.basic_property_ro
  1821. property = template.substitute({
  1822. u"ATTR": ExprNodes.AttributeNode(pos=entry.pos,
  1823. obj=ExprNodes.NameNode(pos=entry.pos, name="self"),
  1824. attribute=entry.name),
  1825. }, pos=entry.pos).stats[0]
  1826. property.name = entry.name
  1827. property.doc = entry.doc
  1828. return property
  1829. def _calculate_pickle_checksums(member_names):
  1830. # Cython 0.x used MD5 for the checksum, which a few Python installations remove for security reasons.
  1831. # SHA-256 should be ok for years to come, but early Cython 3.0 alpha releases used SHA-1,
  1832. # which may not be.
  1833. member_names_string = ' '.join(member_names).encode('utf-8')
  1834. hash_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
  1835. checksums = []
  1836. for algo_name in ['md5', 'sha256', 'sha1']:
  1837. try:
  1838. mkchecksum = getattr(hashlib, algo_name)
  1839. checksum = mkchecksum(member_names_string, **hash_kwargs).hexdigest()
  1840. except (AttributeError, ValueError):
  1841. # The algorithm (i.e. MD5) might not be there at all, or might be blocked at runtime.
  1842. continue
  1843. checksums.append('0x' + checksum[:7])
  1844. return checksums
  1845. class CalculateQualifiedNamesTransform(EnvTransform):
  1846. """
  1847. Calculate and store the '__qualname__' and the global
  1848. module name on some nodes.
  1849. """
  1850. def visit_ModuleNode(self, node):
  1851. self.module_name = self.global_scope().qualified_name
  1852. self.qualified_name = []
  1853. _super = super(CalculateQualifiedNamesTransform, self)
  1854. self._super_visit_FuncDefNode = _super.visit_FuncDefNode
  1855. self._super_visit_ClassDefNode = _super.visit_ClassDefNode
  1856. self.visitchildren(node)
  1857. return node
  1858. def _set_qualname(self, node, name=None):
  1859. if name:
  1860. qualname = self.qualified_name[:]
  1861. qualname.append(name)
  1862. else:
  1863. qualname = self.qualified_name
  1864. node.qualname = EncodedString('.'.join(qualname))
  1865. node.module_name = self.module_name
  1866. def _append_entry(self, entry):
  1867. if entry.is_pyglobal and not entry.is_pyclass_attr:
  1868. self.qualified_name = [entry.name]
  1869. else:
  1870. self.qualified_name.append(entry.name)
  1871. def visit_ClassNode(self, node):
  1872. self._set_qualname(node, node.name)
  1873. self.visitchildren(node)
  1874. return node
  1875. def visit_PyClassNamespaceNode(self, node):
  1876. # class name was already added by parent node
  1877. self._set_qualname(node)
  1878. self.visitchildren(node)
  1879. return node
  1880. def visit_PyCFunctionNode(self, node):
  1881. orig_qualified_name = self.qualified_name[:]
  1882. if node.def_node.is_wrapper and self.qualified_name and self.qualified_name[-1] == '<locals>':
  1883. self.qualified_name.pop()
  1884. self._set_qualname(node)
  1885. else:
  1886. self._set_qualname(node, node.def_node.name)
  1887. self.visitchildren(node)
  1888. self.qualified_name = orig_qualified_name
  1889. return node
  1890. def visit_DefNode(self, node):
  1891. if node.is_wrapper and self.qualified_name:
  1892. assert self.qualified_name[-1] == '<locals>', self.qualified_name
  1893. orig_qualified_name = self.qualified_name[:]
  1894. self.qualified_name.pop()
  1895. self._set_qualname(node)
  1896. self._super_visit_FuncDefNode(node)
  1897. self.qualified_name = orig_qualified_name
  1898. else:
  1899. self._set_qualname(node, node.name)
  1900. self.visit_FuncDefNode(node)
  1901. return node
  1902. def visit_FuncDefNode(self, node):
  1903. orig_qualified_name = self.qualified_name[:]
  1904. if getattr(node, 'name', None) == '<lambda>':
  1905. self.qualified_name.append('<lambda>')
  1906. else:
  1907. self._append_entry(node.entry)
  1908. self.qualified_name.append('<locals>')
  1909. self._super_visit_FuncDefNode(node)
  1910. self.qualified_name = orig_qualified_name
  1911. return node
  1912. def visit_ClassDefNode(self, node):
  1913. orig_qualified_name = self.qualified_name[:]
  1914. entry = (getattr(node, 'entry', None) or # PyClass
  1915. self.current_env().lookup_here(node.name)) # CClass
  1916. self._append_entry(entry)
  1917. self._super_visit_ClassDefNode(node)
  1918. self.qualified_name = orig_qualified_name
  1919. return node
  1920. class AnalyseExpressionsTransform(CythonTransform):
  1921. def visit_ModuleNode(self, node):
  1922. node.scope.infer_types()
  1923. node.body = node.body.analyse_expressions(node.scope)
  1924. self.visitchildren(node)
  1925. return node
  1926. def visit_FuncDefNode(self, node):
  1927. node.local_scope.infer_types()
  1928. node.body = node.body.analyse_expressions(node.local_scope)
  1929. self.visitchildren(node)
  1930. return node
  1931. def visit_ScopedExprNode(self, node):
  1932. if node.has_local_scope:
  1933. node.expr_scope.infer_types()
  1934. node = node.analyse_scoped_expressions(node.expr_scope)
  1935. self.visitchildren(node)
  1936. return node
  1937. def visit_IndexNode(self, node):
  1938. """
  1939. Replace index nodes used to specialize cdef functions with fused
  1940. argument types with the Attribute- or NameNode referring to the
  1941. function. We then need to copy over the specialization properties to
  1942. the attribute or name node.
  1943. Because the indexing might be a Python indexing operation on a fused
  1944. function, or (usually) a Cython indexing operation, we need to
  1945. re-analyse the types.
  1946. """
  1947. self.visit_Node(node)
  1948. if node.is_fused_index and not node.type.is_error:
  1949. node = node.base
  1950. return node
  1951. class FindInvalidUseOfFusedTypes(CythonTransform):
  1952. def visit_FuncDefNode(self, node):
  1953. # Errors related to use in functions with fused args will already
  1954. # have been detected
  1955. if not node.has_fused_arguments:
  1956. if not node.is_generator_body and node.return_type.is_fused:
  1957. error(node.pos, "Return type is not specified as argument type")
  1958. else:
  1959. self.visitchildren(node)
  1960. return node
  1961. def visit_ExprNode(self, node):
  1962. if node.type and node.type.is_fused:
  1963. error(node.pos, "Invalid use of fused types, type cannot be specialized")
  1964. else:
  1965. self.visitchildren(node)
  1966. return node
  1967. class ExpandInplaceOperators(EnvTransform):
  1968. def visit_InPlaceAssignmentNode(self, node):
  1969. lhs = node.lhs
  1970. rhs = node.rhs
  1971. if lhs.type.is_cpp_class:
  1972. # No getting around this exact operator here.
  1973. return node
  1974. if isinstance(lhs, ExprNodes.BufferIndexNode):
  1975. # There is code to handle this case in InPlaceAssignmentNode
  1976. return node
  1977. env = self.current_env()
  1978. def side_effect_free_reference(node, setting=False):
  1979. if node.is_name:
  1980. return node, []
  1981. elif node.type.is_pyobject and not setting:
  1982. node = LetRefNode(node)
  1983. return node, [node]
  1984. elif node.is_subscript:
  1985. base, temps = side_effect_free_reference(node.base)
  1986. index = LetRefNode(node.index)
  1987. return ExprNodes.IndexNode(node.pos, base=base, index=index), temps + [index]
  1988. elif node.is_attribute:
  1989. obj, temps = side_effect_free_reference(node.obj)
  1990. return ExprNodes.AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
  1991. elif isinstance(node, ExprNodes.BufferIndexNode):
  1992. raise ValueError("Don't allow things like attributes of buffer indexing operations")
  1993. else:
  1994. node = LetRefNode(node)
  1995. return node, [node]
  1996. try:
  1997. lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
  1998. except ValueError:
  1999. return node
  2000. dup = lhs.__class__(**lhs.__dict__)
  2001. binop = ExprNodes.binop_node(node.pos,
  2002. operator = node.operator,
  2003. operand1 = dup,
  2004. operand2 = rhs,
  2005. inplace=True)
  2006. # Manually analyse types for new node.
  2007. lhs.analyse_target_types(env)
  2008. dup.analyse_types(env)
  2009. binop.analyse_operation(env)
  2010. node = Nodes.SingleAssignmentNode(
  2011. node.pos,
  2012. lhs = lhs,
  2013. rhs=binop.coerce_to(lhs.type, env))
  2014. # Use LetRefNode to avoid side effects.
  2015. let_ref_nodes.reverse()
  2016. for t in let_ref_nodes:
  2017. node = LetNode(t, node)
  2018. return node
  2019. def visit_ExprNode(self, node):
  2020. # In-place assignments can't happen within an expression.
  2021. return node
  2022. class AdjustDefByDirectives(CythonTransform, SkipDeclarations):
  2023. """
  2024. Adjust function and class definitions by the decorator directives:
  2025. @cython.cfunc
  2026. @cython.cclass
  2027. @cython.ccall
  2028. @cython.inline
  2029. @cython.nogil
  2030. """
  2031. def visit_ModuleNode(self, node):
  2032. self.directives = node.directives
  2033. self.in_py_class = False
  2034. self.visitchildren(node)
  2035. return node
  2036. def visit_CompilerDirectivesNode(self, node):
  2037. old_directives = self.directives
  2038. self.directives = node.directives
  2039. self.visitchildren(node)
  2040. self.directives = old_directives
  2041. return node
  2042. def visit_DefNode(self, node):
  2043. modifiers = []
  2044. if 'inline' in self.directives:
  2045. modifiers.append('inline')
  2046. nogil = self.directives.get('nogil')
  2047. except_val = self.directives.get('exceptval')
  2048. return_type_node = self.directives.get('returns')
  2049. if return_type_node is None and self.directives['annotation_typing']:
  2050. return_type_node = node.return_type_annotation
  2051. # for Python anntations, prefer safe exception handling by default
  2052. if return_type_node is not None and except_val is None:
  2053. except_val = (None, True) # except *
  2054. elif except_val is None:
  2055. # backward compatible default: no exception check
  2056. except_val = (None, False)
  2057. if 'ccall' in self.directives:
  2058. node = node.as_cfunction(
  2059. overridable=True, modifiers=modifiers, nogil=nogil,
  2060. returns=return_type_node, except_val=except_val)
  2061. return self.visit(node)
  2062. if 'cfunc' in self.directives:
  2063. if self.in_py_class:
  2064. error(node.pos, "cfunc directive is not allowed here")
  2065. else:
  2066. node = node.as_cfunction(
  2067. overridable=False, modifiers=modifiers, nogil=nogil,
  2068. returns=return_type_node, except_val=except_val)
  2069. return self.visit(node)
  2070. if 'inline' in modifiers:
  2071. error(node.pos, "Python functions cannot be declared 'inline'")
  2072. if nogil:
  2073. # TODO: turn this into a "with gil" declaration.
  2074. error(node.pos, "Python functions cannot be declared 'nogil'")
  2075. self.visitchildren(node)
  2076. return node
  2077. def visit_LambdaNode(self, node):
  2078. # No directives should modify lambdas or generator expressions (and also nothing in them).
  2079. return node
  2080. def visit_PyClassDefNode(self, node):
  2081. if 'cclass' in self.directives:
  2082. node = node.as_cclass()
  2083. return self.visit(node)
  2084. else:
  2085. old_in_pyclass = self.in_py_class
  2086. self.in_py_class = True
  2087. self.visitchildren(node)
  2088. self.in_py_class = old_in_pyclass
  2089. return node
  2090. def visit_CClassDefNode(self, node):
  2091. old_in_pyclass = self.in_py_class
  2092. self.in_py_class = False
  2093. self.visitchildren(node)
  2094. self.in_py_class = old_in_pyclass
  2095. return node
  2096. class AlignFunctionDefinitions(CythonTransform):
  2097. """
  2098. This class takes the signatures from a .pxd file and applies them to
  2099. the def methods in a .py file.
  2100. """
  2101. def visit_ModuleNode(self, node):
  2102. self.scope = node.scope
  2103. self.directives = node.directives
  2104. self.imported_names = set() # hack, see visit_FromImportStatNode()
  2105. self.visitchildren(node)
  2106. return node
  2107. def visit_PyClassDefNode(self, node):
  2108. pxd_def = self.scope.lookup(node.name)
  2109. if pxd_def:
  2110. if pxd_def.is_cclass:
  2111. return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
  2112. elif not pxd_def.scope or not pxd_def.scope.is_builtin_scope:
  2113. error(node.pos, "'%s' redeclared" % node.name)
  2114. if pxd_def.pos:
  2115. error(pxd_def.pos, "previous declaration here")
  2116. return None
  2117. return node
  2118. def visit_CClassDefNode(self, node, pxd_def=None):
  2119. if pxd_def is None:
  2120. pxd_def = self.scope.lookup(node.class_name)
  2121. if pxd_def:
  2122. if not pxd_def.defined_in_pxd:
  2123. return node
  2124. outer_scope = self.scope
  2125. self.scope = pxd_def.type.scope
  2126. self.visitchildren(node)
  2127. if pxd_def:
  2128. self.scope = outer_scope
  2129. return node
  2130. def visit_DefNode(self, node):
  2131. pxd_def = self.scope.lookup(node.name)
  2132. if pxd_def and (not pxd_def.scope or not pxd_def.scope.is_builtin_scope):
  2133. if not pxd_def.is_cfunction:
  2134. error(node.pos, "'%s' redeclared" % node.name)
  2135. if pxd_def.pos:
  2136. error(pxd_def.pos, "previous declaration here")
  2137. return None
  2138. node = node.as_cfunction(pxd_def)
  2139. elif (self.scope.is_module_scope and self.directives['auto_cpdef']
  2140. and not node.name in self.imported_names
  2141. and node.is_cdef_func_compatible()):
  2142. # FIXME: cpdef-ing should be done in analyse_declarations()
  2143. node = node.as_cfunction(scope=self.scope)
  2144. # Enable this when nested cdef functions are allowed.
  2145. # self.visitchildren(node)
  2146. return node
  2147. def visit_FromImportStatNode(self, node):
  2148. # hack to prevent conditional import fallback functions from
  2149. # being cdpef-ed (global Python variables currently conflict
  2150. # with imports)
  2151. if self.scope.is_module_scope:
  2152. for name, _ in node.items:
  2153. self.imported_names.add(name)
  2154. return node
  2155. def visit_ExprNode(self, node):
  2156. # ignore lambdas and everything else that appears in expressions
  2157. return node
  2158. class RemoveUnreachableCode(CythonTransform):
  2159. def visit_StatListNode(self, node):
  2160. if not self.current_directives['remove_unreachable']:
  2161. return node
  2162. self.visitchildren(node)
  2163. for idx, stat in enumerate(node.stats):
  2164. idx += 1
  2165. if stat.is_terminator:
  2166. if idx < len(node.stats):
  2167. if self.current_directives['warn.unreachable']:
  2168. warning(node.stats[idx].pos, "Unreachable code", 2)
  2169. node.stats = node.stats[:idx]
  2170. node.is_terminator = True
  2171. break
  2172. return node
  2173. def visit_IfClauseNode(self, node):
  2174. self.visitchildren(node)
  2175. if node.body.is_terminator:
  2176. node.is_terminator = True
  2177. return node
  2178. def visit_IfStatNode(self, node):
  2179. self.visitchildren(node)
  2180. if node.else_clause and node.else_clause.is_terminator:
  2181. for clause in node.if_clauses:
  2182. if not clause.is_terminator:
  2183. break
  2184. else:
  2185. node.is_terminator = True
  2186. return node
  2187. def visit_TryExceptStatNode(self, node):
  2188. self.visitchildren(node)
  2189. if node.body.is_terminator and node.else_clause:
  2190. if self.current_directives['warn.unreachable']:
  2191. warning(node.else_clause.pos, "Unreachable code", 2)
  2192. node.else_clause = None
  2193. return node
  2194. def visit_TryFinallyStatNode(self, node):
  2195. self.visitchildren(node)
  2196. if node.finally_clause.is_terminator:
  2197. node.is_terminator = True
  2198. return node
  2199. class YieldNodeCollector(TreeVisitor):
  2200. def __init__(self):
  2201. super(YieldNodeCollector, self).__init__()
  2202. self.yields = []
  2203. self.returns = []
  2204. self.finallys = []
  2205. self.excepts = []
  2206. self.has_return_value = False
  2207. self.has_yield = False
  2208. self.has_await = False
  2209. def visit_Node(self, node):
  2210. self.visitchildren(node)
  2211. def visit_YieldExprNode(self, node):
  2212. self.yields.append(node)
  2213. self.has_yield = True
  2214. self.visitchildren(node)
  2215. def visit_AwaitExprNode(self, node):
  2216. self.yields.append(node)
  2217. self.has_await = True
  2218. self.visitchildren(node)
  2219. def visit_ReturnStatNode(self, node):
  2220. self.visitchildren(node)
  2221. if node.value:
  2222. self.has_return_value = True
  2223. self.returns.append(node)
  2224. def visit_TryFinallyStatNode(self, node):
  2225. self.visitchildren(node)
  2226. self.finallys.append(node)
  2227. def visit_TryExceptStatNode(self, node):
  2228. self.visitchildren(node)
  2229. self.excepts.append(node)
  2230. def visit_ClassDefNode(self, node):
  2231. pass
  2232. def visit_FuncDefNode(self, node):
  2233. pass
  2234. def visit_LambdaNode(self, node):
  2235. pass
  2236. def visit_GeneratorExpressionNode(self, node):
  2237. pass
  2238. def visit_CArgDeclNode(self, node):
  2239. # do not look into annotations
  2240. # FIXME: support (yield) in default arguments (currently crashes)
  2241. pass
  2242. class MarkClosureVisitor(CythonTransform):
  2243. def visit_ModuleNode(self, node):
  2244. self.needs_closure = False
  2245. self.visitchildren(node)
  2246. return node
  2247. def visit_FuncDefNode(self, node):
  2248. self.needs_closure = False
  2249. self.visitchildren(node)
  2250. node.needs_closure = self.needs_closure
  2251. self.needs_closure = True
  2252. collector = YieldNodeCollector()
  2253. collector.visitchildren(node)
  2254. if node.is_async_def:
  2255. coroutine_type = Nodes.AsyncDefNode
  2256. if collector.has_yield:
  2257. coroutine_type = Nodes.AsyncGenNode
  2258. for yield_expr in collector.yields + collector.returns:
  2259. yield_expr.in_async_gen = True
  2260. elif self.current_directives['iterable_coroutine']:
  2261. coroutine_type = Nodes.IterableAsyncDefNode
  2262. elif collector.has_await:
  2263. found = next(y for y in collector.yields if y.is_await)
  2264. error(found.pos, "'await' not allowed in generators (use 'yield')")
  2265. return node
  2266. elif collector.has_yield:
  2267. coroutine_type = Nodes.GeneratorDefNode
  2268. else:
  2269. return node
  2270. for i, yield_expr in enumerate(collector.yields, 1):
  2271. yield_expr.label_num = i
  2272. for retnode in collector.returns + collector.finallys + collector.excepts:
  2273. retnode.in_generator = True
  2274. gbody = Nodes.GeneratorBodyDefNode(
  2275. pos=node.pos, name=node.name, body=node.body,
  2276. is_async_gen_body=node.is_async_def and collector.has_yield)
  2277. coroutine = coroutine_type(
  2278. pos=node.pos, name=node.name, args=node.args,
  2279. star_arg=node.star_arg, starstar_arg=node.starstar_arg,
  2280. doc=node.doc, decorators=node.decorators,
  2281. gbody=gbody, lambda_name=node.lambda_name,
  2282. return_type_annotation=node.return_type_annotation)
  2283. return coroutine
  2284. def visit_CFuncDefNode(self, node):
  2285. self.needs_closure = False
  2286. self.visitchildren(node)
  2287. node.needs_closure = self.needs_closure
  2288. self.needs_closure = True
  2289. if node.needs_closure and node.overridable:
  2290. error(node.pos, "closures inside cpdef functions not yet supported")
  2291. return node
  2292. def visit_LambdaNode(self, node):
  2293. self.needs_closure = False
  2294. self.visitchildren(node)
  2295. node.needs_closure = self.needs_closure
  2296. self.needs_closure = True
  2297. return node
  2298. def visit_ClassDefNode(self, node):
  2299. self.visitchildren(node)
  2300. self.needs_closure = True
  2301. return node
  2302. class CreateClosureClasses(CythonTransform):
  2303. # Output closure classes in module scope for all functions
  2304. # that really need it.
  2305. def __init__(self, context):
  2306. super(CreateClosureClasses, self).__init__(context)
  2307. self.path = []
  2308. self.in_lambda = False
  2309. def visit_ModuleNode(self, node):
  2310. self.module_scope = node.scope
  2311. self.visitchildren(node)
  2312. return node
  2313. def find_entries_used_in_closures(self, node):
  2314. from_closure = []
  2315. in_closure = []
  2316. for scope in node.local_scope.iter_local_scopes():
  2317. for name, entry in scope.entries.items():
  2318. if not name:
  2319. continue
  2320. if entry.from_closure:
  2321. from_closure.append((name, entry))
  2322. elif entry.in_closure:
  2323. in_closure.append((name, entry))
  2324. return from_closure, in_closure
  2325. def create_class_from_scope(self, node, target_module_scope, inner_node=None):
  2326. # move local variables into closure
  2327. if node.is_generator:
  2328. for scope in node.local_scope.iter_local_scopes():
  2329. for entry in scope.entries.values():
  2330. if not (entry.from_closure or entry.is_pyglobal or entry.is_cglobal):
  2331. entry.in_closure = True
  2332. from_closure, in_closure = self.find_entries_used_in_closures(node)
  2333. in_closure.sort()
  2334. # Now from the beginning
  2335. node.needs_closure = False
  2336. node.needs_outer_scope = False
  2337. func_scope = node.local_scope
  2338. cscope = node.entry.scope
  2339. while cscope.is_py_class_scope or cscope.is_c_class_scope:
  2340. cscope = cscope.outer_scope
  2341. if not from_closure and (self.path or inner_node):
  2342. if not inner_node:
  2343. if not node.py_cfunc_node:
  2344. raise InternalError("DefNode does not have assignment node")
  2345. inner_node = node.py_cfunc_node
  2346. inner_node.needs_self_code = False
  2347. node.needs_outer_scope = False
  2348. if node.is_generator:
  2349. pass
  2350. elif not in_closure and not from_closure:
  2351. return
  2352. elif not in_closure:
  2353. func_scope.is_passthrough = True
  2354. func_scope.scope_class = cscope.scope_class
  2355. node.needs_outer_scope = True
  2356. return
  2357. # entry.cname can contain periods (eg. a derived C method of a class).
  2358. # We want to use the cname as part of a C struct name, so we replace
  2359. # periods with double underscores.
  2360. as_name = '%s_%s' % (
  2361. target_module_scope.next_id(Naming.closure_class_prefix),
  2362. node.entry.cname.replace('.','__'))
  2363. entry = target_module_scope.declare_c_class(
  2364. name=as_name, pos=node.pos, defining=True,
  2365. implementing=True)
  2366. entry.type.is_final_type = True
  2367. func_scope.scope_class = entry
  2368. class_scope = entry.type.scope
  2369. class_scope.is_internal = True
  2370. class_scope.is_closure_class_scope = True
  2371. if node.is_async_def or node.is_generator:
  2372. # Generators need their closure intact during cleanup as they resume to handle GeneratorExit
  2373. class_scope.directives['no_gc_clear'] = True
  2374. if Options.closure_freelist_size:
  2375. class_scope.directives['freelist'] = Options.closure_freelist_size
  2376. if from_closure:
  2377. assert cscope.is_closure_scope
  2378. class_scope.declare_var(pos=node.pos,
  2379. name=Naming.outer_scope_cname,
  2380. cname=Naming.outer_scope_cname,
  2381. type=cscope.scope_class.type,
  2382. is_cdef=True)
  2383. node.needs_outer_scope = True
  2384. for name, entry in in_closure:
  2385. closure_entry = class_scope.declare_var(
  2386. pos=entry.pos,
  2387. name=entry.name if not entry.in_subscope else None,
  2388. cname=entry.cname,
  2389. type=entry.type,
  2390. is_cdef=True)
  2391. if entry.is_declared_generic:
  2392. closure_entry.is_declared_generic = 1
  2393. node.needs_closure = True
  2394. # Do it here because other classes are already checked
  2395. target_module_scope.check_c_class(func_scope.scope_class)
  2396. def visit_LambdaNode(self, node):
  2397. if not isinstance(node.def_node, Nodes.DefNode):
  2398. # fused function, an error has been previously issued
  2399. return node
  2400. was_in_lambda = self.in_lambda
  2401. self.in_lambda = True
  2402. self.create_class_from_scope(node.def_node, self.module_scope, node)
  2403. self.visitchildren(node)
  2404. self.in_lambda = was_in_lambda
  2405. return node
  2406. def visit_FuncDefNode(self, node):
  2407. if self.in_lambda:
  2408. self.visitchildren(node)
  2409. return node
  2410. if node.needs_closure or self.path:
  2411. self.create_class_from_scope(node, self.module_scope)
  2412. self.path.append(node)
  2413. self.visitchildren(node)
  2414. self.path.pop()
  2415. return node
  2416. def visit_GeneratorBodyDefNode(self, node):
  2417. self.visitchildren(node)
  2418. return node
  2419. def visit_CFuncDefNode(self, node):
  2420. if not node.overridable:
  2421. return self.visit_FuncDefNode(node)
  2422. else:
  2423. self.visitchildren(node)
  2424. return node
  2425. class InjectGilHandling(VisitorTransform, SkipDeclarations):
  2426. """
  2427. Allow certain Python operations inside of nogil blocks by implicitly acquiring the GIL.
  2428. Must run before the AnalyseDeclarationsTransform to make sure the GILStatNodes get
  2429. set up, parallel sections know that the GIL is acquired inside of them, etc.
  2430. """
  2431. def __call__(self, root):
  2432. self.nogil = False
  2433. return super(InjectGilHandling, self).__call__(root)
  2434. # special node handling
  2435. def visit_RaiseStatNode(self, node):
  2436. """Allow raising exceptions in nogil sections by wrapping them in a 'with gil' block."""
  2437. if self.nogil:
  2438. node = Nodes.GILStatNode(node.pos, state='gil', body=node)
  2439. return node
  2440. # further candidates:
  2441. # def visit_AssertStatNode(self, node):
  2442. # def visit_ReraiseStatNode(self, node):
  2443. # nogil tracking
  2444. def visit_GILStatNode(self, node):
  2445. was_nogil = self.nogil
  2446. self.nogil = (node.state == 'nogil')
  2447. self.visitchildren(node)
  2448. self.nogil = was_nogil
  2449. return node
  2450. def visit_CFuncDefNode(self, node):
  2451. was_nogil = self.nogil
  2452. if isinstance(node.declarator, Nodes.CFuncDeclaratorNode):
  2453. self.nogil = node.declarator.nogil and not node.declarator.with_gil
  2454. self.visitchildren(node)
  2455. self.nogil = was_nogil
  2456. return node
  2457. def visit_ParallelRangeNode(self, node):
  2458. was_nogil = self.nogil
  2459. self.nogil = node.nogil
  2460. self.visitchildren(node)
  2461. self.nogil = was_nogil
  2462. return node
  2463. def visit_ExprNode(self, node):
  2464. # No special GIL handling inside of expressions for now.
  2465. return node
  2466. visit_Node = VisitorTransform.recurse_to_children
  2467. class GilCheck(VisitorTransform):
  2468. """
  2469. Call `node.gil_check(env)` on each node to make sure we hold the
  2470. GIL when we need it. Raise an error when on Python operations
  2471. inside a `nogil` environment.
  2472. Additionally, raise exceptions for closely nested with gil or with nogil
  2473. statements. The latter would abort Python.
  2474. """
  2475. def __call__(self, root):
  2476. self.env_stack = [root.scope]
  2477. self.nogil = False
  2478. # True for 'cdef func() nogil:' functions, as the GIL may be held while
  2479. # calling this function (thus contained 'nogil' blocks may be valid).
  2480. self.nogil_declarator_only = False
  2481. return super(GilCheck, self).__call__(root)
  2482. def _visit_scoped_children(self, node, gil_state):
  2483. was_nogil = self.nogil
  2484. outer_attrs = node.outer_attrs
  2485. if outer_attrs and len(self.env_stack) > 1:
  2486. self.nogil = self.env_stack[-2].nogil
  2487. self.visitchildren(node, outer_attrs)
  2488. self.nogil = gil_state
  2489. self.visitchildren(node, attrs=None, exclude=outer_attrs)
  2490. self.nogil = was_nogil
  2491. def visit_FuncDefNode(self, node):
  2492. self.env_stack.append(node.local_scope)
  2493. inner_nogil = node.local_scope.nogil
  2494. if inner_nogil:
  2495. self.nogil_declarator_only = True
  2496. if inner_nogil and node.nogil_check:
  2497. node.nogil_check(node.local_scope)
  2498. self._visit_scoped_children(node, inner_nogil)
  2499. # This cannot be nested, so it doesn't need backup/restore
  2500. self.nogil_declarator_only = False
  2501. self.env_stack.pop()
  2502. return node
  2503. def visit_GILStatNode(self, node):
  2504. if self.nogil and node.nogil_check:
  2505. node.nogil_check()
  2506. was_nogil = self.nogil
  2507. is_nogil = (node.state == 'nogil')
  2508. if was_nogil == is_nogil and not self.nogil_declarator_only:
  2509. if not was_nogil:
  2510. error(node.pos, "Trying to acquire the GIL while it is "
  2511. "already held.")
  2512. else:
  2513. error(node.pos, "Trying to release the GIL while it was "
  2514. "previously released.")
  2515. if isinstance(node.finally_clause, Nodes.StatListNode):
  2516. # The finally clause of the GILStatNode is a GILExitNode,
  2517. # which is wrapped in a StatListNode. Just unpack that.
  2518. node.finally_clause, = node.finally_clause.stats
  2519. self._visit_scoped_children(node, is_nogil)
  2520. return node
  2521. def visit_ParallelRangeNode(self, node):
  2522. if node.nogil:
  2523. node.nogil = False
  2524. node = Nodes.GILStatNode(node.pos, state='nogil', body=node)
  2525. return self.visit_GILStatNode(node)
  2526. if not self.nogil:
  2527. error(node.pos, "prange() can only be used without the GIL")
  2528. # Forget about any GIL-related errors that may occur in the body
  2529. return None
  2530. node.nogil_check(self.env_stack[-1])
  2531. self.visitchildren(node)
  2532. return node
  2533. def visit_ParallelWithBlockNode(self, node):
  2534. if not self.nogil:
  2535. error(node.pos, "The parallel section may only be used without "
  2536. "the GIL")
  2537. return None
  2538. if node.nogil_check:
  2539. # It does not currently implement this, but test for it anyway to
  2540. # avoid potential future surprises
  2541. node.nogil_check(self.env_stack[-1])
  2542. self.visitchildren(node)
  2543. return node
  2544. def visit_TryFinallyStatNode(self, node):
  2545. """
  2546. Take care of try/finally statements in nogil code sections.
  2547. """
  2548. if not self.nogil or isinstance(node, Nodes.GILStatNode):
  2549. return self.visit_Node(node)
  2550. node.nogil_check = None
  2551. node.is_try_finally_in_nogil = True
  2552. self.visitchildren(node)
  2553. return node
  2554. def visit_Node(self, node):
  2555. if self.env_stack and self.nogil and node.nogil_check:
  2556. node.nogil_check(self.env_stack[-1])
  2557. if node.outer_attrs:
  2558. self._visit_scoped_children(node, self.nogil)
  2559. else:
  2560. self.visitchildren(node)
  2561. if self.nogil:
  2562. node.in_nogil_context = True
  2563. return node
  2564. class TransformBuiltinMethods(EnvTransform):
  2565. """
  2566. Replace Cython's own cython.* builtins by the corresponding tree nodes.
  2567. """
  2568. def visit_SingleAssignmentNode(self, node):
  2569. if node.declaration_only:
  2570. return None
  2571. else:
  2572. self.visitchildren(node)
  2573. return node
  2574. def visit_AttributeNode(self, node):
  2575. self.visitchildren(node)
  2576. return self.visit_cython_attribute(node)
  2577. def visit_NameNode(self, node):
  2578. return self.visit_cython_attribute(node)
  2579. def visit_cython_attribute(self, node):
  2580. attribute = node.as_cython_attribute()
  2581. if attribute:
  2582. if attribute == u'compiled':
  2583. node = ExprNodes.BoolNode(node.pos, value=True)
  2584. elif attribute == u'__version__':
  2585. from .. import __version__ as version
  2586. node = ExprNodes.StringNode(node.pos, value=EncodedString(version))
  2587. elif attribute == u'NULL':
  2588. node = ExprNodes.NullNode(node.pos)
  2589. elif attribute in (u'set', u'frozenset', u'staticmethod'):
  2590. node = ExprNodes.NameNode(node.pos, name=EncodedString(attribute),
  2591. entry=self.current_env().builtin_scope().lookup_here(attribute))
  2592. elif PyrexTypes.parse_basic_type(attribute):
  2593. pass
  2594. elif self.context.cython_scope.lookup_qualified_name(attribute):
  2595. pass
  2596. else:
  2597. error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
  2598. return node
  2599. def visit_ExecStatNode(self, node):
  2600. lenv = self.current_env()
  2601. self.visitchildren(node)
  2602. if len(node.args) == 1:
  2603. node.args.append(ExprNodes.GlobalsExprNode(node.pos))
  2604. if not lenv.is_module_scope:
  2605. node.args.append(
  2606. ExprNodes.LocalsExprNode(
  2607. node.pos, self.current_scope_node(), lenv))
  2608. return node
  2609. def _inject_locals(self, node, func_name):
  2610. # locals()/dir()/vars() builtins
  2611. lenv = self.current_env()
  2612. entry = lenv.lookup_here(func_name)
  2613. if entry:
  2614. # not the builtin
  2615. return node
  2616. pos = node.pos
  2617. if func_name in ('locals', 'vars'):
  2618. if func_name == 'locals' and len(node.args) > 0:
  2619. error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d"
  2620. % len(node.args))
  2621. return node
  2622. elif func_name == 'vars':
  2623. if len(node.args) > 1:
  2624. error(self.pos, "Builtin 'vars()' called with wrong number of args, expected 0-1, got %d"
  2625. % len(node.args))
  2626. if len(node.args) > 0:
  2627. return node # nothing to do
  2628. return ExprNodes.LocalsExprNode(pos, self.current_scope_node(), lenv)
  2629. else: # dir()
  2630. if len(node.args) > 1:
  2631. error(self.pos, "Builtin 'dir()' called with wrong number of args, expected 0-1, got %d"
  2632. % len(node.args))
  2633. if len(node.args) > 0:
  2634. # optimised in Builtin.py
  2635. return node
  2636. if lenv.is_py_class_scope or lenv.is_module_scope:
  2637. if lenv.is_py_class_scope:
  2638. pyclass = self.current_scope_node()
  2639. locals_dict = ExprNodes.CloneNode(pyclass.dict)
  2640. else:
  2641. locals_dict = ExprNodes.GlobalsExprNode(pos)
  2642. return ExprNodes.SortedDictKeysNode(locals_dict)
  2643. local_names = sorted(var.name for var in lenv.entries.values() if var.name)
  2644. items = [ExprNodes.IdentifierStringNode(pos, value=var)
  2645. for var in local_names]
  2646. return ExprNodes.ListNode(pos, args=items)
  2647. def visit_PrimaryCmpNode(self, node):
  2648. # special case: for in/not-in test, we do not need to sort locals()
  2649. self.visitchildren(node)
  2650. if node.operator in 'not_in': # in/not_in
  2651. if isinstance(node.operand2, ExprNodes.SortedDictKeysNode):
  2652. arg = node.operand2.arg
  2653. if isinstance(arg, ExprNodes.NoneCheckNode):
  2654. arg = arg.arg
  2655. node.operand2 = arg
  2656. return node
  2657. def visit_CascadedCmpNode(self, node):
  2658. return self.visit_PrimaryCmpNode(node)
  2659. def _inject_eval(self, node, func_name):
  2660. lenv = self.current_env()
  2661. entry = lenv.lookup_here(func_name)
  2662. if entry or len(node.args) != 1:
  2663. return node
  2664. # Inject globals and locals
  2665. node.args.append(ExprNodes.GlobalsExprNode(node.pos))
  2666. if not lenv.is_module_scope:
  2667. node.args.append(
  2668. ExprNodes.LocalsExprNode(
  2669. node.pos, self.current_scope_node(), lenv))
  2670. return node
  2671. def _inject_super(self, node, func_name):
  2672. lenv = self.current_env()
  2673. entry = lenv.lookup_here(func_name)
  2674. if entry or node.args:
  2675. return node
  2676. # Inject no-args super
  2677. def_node = self.current_scope_node()
  2678. if (not isinstance(def_node, Nodes.DefNode) or not def_node.args or
  2679. len(self.env_stack) < 2):
  2680. return node
  2681. class_node, class_scope = self.env_stack[-2]
  2682. if class_scope.is_py_class_scope:
  2683. def_node.requires_classobj = True
  2684. class_node.class_cell.is_active = True
  2685. node.args = [
  2686. ExprNodes.ClassCellNode(
  2687. node.pos, is_generator=def_node.is_generator),
  2688. ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
  2689. ]
  2690. elif class_scope.is_c_class_scope:
  2691. node.args = [
  2692. ExprNodes.NameNode(
  2693. node.pos, name=class_node.scope.name,
  2694. entry=class_node.entry),
  2695. ExprNodes.NameNode(node.pos, name=def_node.args[0].name)
  2696. ]
  2697. return node
  2698. def visit_SimpleCallNode(self, node):
  2699. # cython.foo
  2700. function = node.function.as_cython_attribute()
  2701. if function:
  2702. if function in InterpretCompilerDirectives.unop_method_nodes:
  2703. if len(node.args) != 1:
  2704. error(node.function.pos, u"%s() takes exactly one argument" % function)
  2705. else:
  2706. node = InterpretCompilerDirectives.unop_method_nodes[function](
  2707. node.function.pos, operand=node.args[0])
  2708. elif function in InterpretCompilerDirectives.binop_method_nodes:
  2709. if len(node.args) != 2:
  2710. error(node.function.pos, u"%s() takes exactly two arguments" % function)
  2711. else:
  2712. node = InterpretCompilerDirectives.binop_method_nodes[function](
  2713. node.function.pos, operand1=node.args[0], operand2=node.args[1])
  2714. elif function == u'cast':
  2715. if len(node.args) != 2:
  2716. error(node.function.pos,
  2717. u"cast() takes exactly two arguments and an optional typecheck keyword")
  2718. else:
  2719. type = node.args[0].analyse_as_type(self.current_env())
  2720. if type:
  2721. node = ExprNodes.TypecastNode(
  2722. node.function.pos, type=type, operand=node.args[1], typecheck=False)
  2723. else:
  2724. error(node.args[0].pos, "Not a type")
  2725. elif function == u'sizeof':
  2726. if len(node.args) != 1:
  2727. error(node.function.pos, u"sizeof() takes exactly one argument")
  2728. else:
  2729. type = node.args[0].analyse_as_type(self.current_env())
  2730. if type:
  2731. node = ExprNodes.SizeofTypeNode(node.function.pos, arg_type=type)
  2732. else:
  2733. node = ExprNodes.SizeofVarNode(node.function.pos, operand=node.args[0])
  2734. elif function == 'cmod':
  2735. if len(node.args) != 2:
  2736. error(node.function.pos, u"cmod() takes exactly two arguments")
  2737. else:
  2738. node = ExprNodes.binop_node(node.function.pos, '%', node.args[0], node.args[1])
  2739. node.cdivision = True
  2740. elif function == 'cdiv':
  2741. if len(node.args) != 2:
  2742. error(node.function.pos, u"cdiv() takes exactly two arguments")
  2743. else:
  2744. node = ExprNodes.binop_node(node.function.pos, '/', node.args[0], node.args[1])
  2745. node.cdivision = True
  2746. elif function == u'set':
  2747. node.function = ExprNodes.NameNode(node.pos, name=EncodedString('set'))
  2748. elif function == u'staticmethod':
  2749. node.function = ExprNodes.NameNode(node.pos, name=EncodedString('staticmethod'))
  2750. elif self.context.cython_scope.lookup_qualified_name(function):
  2751. pass
  2752. else:
  2753. error(node.function.pos,
  2754. u"'%s' not a valid cython language construct" % function)
  2755. self.visitchildren(node)
  2756. if isinstance(node, ExprNodes.SimpleCallNode) and node.function.is_name:
  2757. func_name = node.function.name
  2758. if func_name in ('dir', 'locals', 'vars'):
  2759. return self._inject_locals(node, func_name)
  2760. if func_name == 'eval':
  2761. return self._inject_eval(node, func_name)
  2762. if func_name == 'super':
  2763. return self._inject_super(node, func_name)
  2764. return node
  2765. def visit_GeneralCallNode(self, node):
  2766. function = node.function.as_cython_attribute()
  2767. if function == u'cast':
  2768. # NOTE: assuming simple tuple/dict nodes for positional_args and keyword_args
  2769. args = node.positional_args.args
  2770. kwargs = node.keyword_args.compile_time_value(None)
  2771. if (len(args) != 2 or len(kwargs) > 1 or
  2772. (len(kwargs) == 1 and 'typecheck' not in kwargs)):
  2773. error(node.function.pos,
  2774. u"cast() takes exactly two arguments and an optional typecheck keyword")
  2775. else:
  2776. type = args[0].analyse_as_type(self.current_env())
  2777. if type:
  2778. typecheck = kwargs.get('typecheck', False)
  2779. node = ExprNodes.TypecastNode(
  2780. node.function.pos, type=type, operand=args[1], typecheck=typecheck)
  2781. else:
  2782. error(args[0].pos, "Not a type")
  2783. self.visitchildren(node)
  2784. return node
  2785. class ReplaceFusedTypeChecks(VisitorTransform):
  2786. """
  2787. This is not a transform in the pipeline. It is invoked on the specific
  2788. versions of a cdef function with fused argument types. It filters out any
  2789. type branches that don't match. e.g.
  2790. if fused_t is mytype:
  2791. ...
  2792. elif fused_t in other_fused_type:
  2793. ...
  2794. """
  2795. def __init__(self, local_scope):
  2796. super(ReplaceFusedTypeChecks, self).__init__()
  2797. self.local_scope = local_scope
  2798. # defer the import until now to avoid circular import time dependencies
  2799. from .Optimize import ConstantFolding
  2800. self.transform = ConstantFolding(reevaluate=True)
  2801. def visit_IfStatNode(self, node):
  2802. """
  2803. Filters out any if clauses with false compile time type check
  2804. expression.
  2805. """
  2806. self.visitchildren(node)
  2807. return self.transform(node)
  2808. def visit_PrimaryCmpNode(self, node):
  2809. with Errors.local_errors(ignore=True):
  2810. type1 = node.operand1.analyse_as_type(self.local_scope)
  2811. type2 = node.operand2.analyse_as_type(self.local_scope)
  2812. if type1 and type2:
  2813. false_node = ExprNodes.BoolNode(node.pos, value=False)
  2814. true_node = ExprNodes.BoolNode(node.pos, value=True)
  2815. type1 = self.specialize_type(type1, node.operand1.pos)
  2816. op = node.operator
  2817. if op in ('is', 'is_not', '==', '!='):
  2818. type2 = self.specialize_type(type2, node.operand2.pos)
  2819. is_same = type1.same_as(type2)
  2820. eq = op in ('is', '==')
  2821. if (is_same and eq) or (not is_same and not eq):
  2822. return true_node
  2823. elif op in ('in', 'not_in'):
  2824. # We have to do an instance check directly, as operand2
  2825. # needs to be a fused type and not a type with a subtype
  2826. # that is fused. First unpack the typedef
  2827. if isinstance(type2, PyrexTypes.CTypedefType):
  2828. type2 = type2.typedef_base_type
  2829. if type1.is_fused:
  2830. error(node.operand1.pos, "Type is fused")
  2831. elif not type2.is_fused:
  2832. error(node.operand2.pos,
  2833. "Can only use 'in' or 'not in' on a fused type")
  2834. else:
  2835. types = PyrexTypes.get_specialized_types(type2)
  2836. for specialized_type in types:
  2837. if type1.same_as(specialized_type):
  2838. if op == 'in':
  2839. return true_node
  2840. else:
  2841. return false_node
  2842. if op == 'not_in':
  2843. return true_node
  2844. return false_node
  2845. return node
  2846. def specialize_type(self, type, pos):
  2847. try:
  2848. return type.specialize(self.local_scope.fused_to_specific)
  2849. except KeyError:
  2850. error(pos, "Type is not specific")
  2851. return type
  2852. def visit_Node(self, node):
  2853. self.visitchildren(node)
  2854. return node
  2855. class DebugTransform(CythonTransform):
  2856. """
  2857. Write debug information for this Cython module.
  2858. """
  2859. def __init__(self, context, options, result):
  2860. super(DebugTransform, self).__init__(context)
  2861. self.visited = set()
  2862. # our treebuilder and debug output writer
  2863. # (see Cython.Debugger.debug_output.CythonDebugWriter)
  2864. self.tb = self.context.gdb_debug_outputwriter
  2865. #self.c_output_file = options.output_file
  2866. self.c_output_file = result.c_file
  2867. # Closure support, basically treat nested functions as if the AST were
  2868. # never nested
  2869. self.nested_funcdefs = []
  2870. # tells visit_NameNode whether it should register step-into functions
  2871. self.register_stepinto = False
  2872. def visit_ModuleNode(self, node):
  2873. self.tb.module_name = node.full_module_name
  2874. attrs = dict(
  2875. module_name=node.full_module_name,
  2876. filename=node.pos[0].filename,
  2877. c_filename=self.c_output_file)
  2878. self.tb.start('Module', attrs)
  2879. # serialize functions
  2880. self.tb.start('Functions')
  2881. # First, serialize functions normally...
  2882. self.visitchildren(node)
  2883. # ... then, serialize nested functions
  2884. for nested_funcdef in self.nested_funcdefs:
  2885. self.visit_FuncDefNode(nested_funcdef)
  2886. self.register_stepinto = True
  2887. self.serialize_modulenode_as_function(node)
  2888. self.register_stepinto = False
  2889. self.tb.end('Functions')
  2890. # 2.3 compatibility. Serialize global variables
  2891. self.tb.start('Globals')
  2892. entries = {}
  2893. for k, v in node.scope.entries.items():
  2894. if (v.qualified_name not in self.visited and not
  2895. v.name.startswith('__pyx_') and not
  2896. v.type.is_cfunction and not
  2897. v.type.is_extension_type):
  2898. entries[k]= v
  2899. self.serialize_local_variables(entries)
  2900. self.tb.end('Globals')
  2901. # self.tb.end('Module') # end Module after the line number mapping in
  2902. # Cython.Compiler.ModuleNode.ModuleNode._serialize_lineno_map
  2903. return node
  2904. def visit_FuncDefNode(self, node):
  2905. self.visited.add(node.local_scope.qualified_name)
  2906. if getattr(node, 'is_wrapper', False):
  2907. return node
  2908. if self.register_stepinto:
  2909. self.nested_funcdefs.append(node)
  2910. return node
  2911. # node.entry.visibility = 'extern'
  2912. if node.py_func is None:
  2913. pf_cname = ''
  2914. else:
  2915. pf_cname = node.py_func.entry.func_cname
  2916. attrs = dict(
  2917. name=node.entry.name or getattr(node, 'name', '<unknown>'),
  2918. cname=node.entry.func_cname,
  2919. pf_cname=pf_cname,
  2920. qualified_name=node.local_scope.qualified_name,
  2921. lineno=str(node.pos[1]))
  2922. self.tb.start('Function', attrs=attrs)
  2923. self.tb.start('Locals')
  2924. self.serialize_local_variables(node.local_scope.entries)
  2925. self.tb.end('Locals')
  2926. self.tb.start('Arguments')
  2927. for arg in node.local_scope.arg_entries:
  2928. self.tb.start(arg.name)
  2929. self.tb.end(arg.name)
  2930. self.tb.end('Arguments')
  2931. self.tb.start('StepIntoFunctions')
  2932. self.register_stepinto = True
  2933. self.visitchildren(node)
  2934. self.register_stepinto = False
  2935. self.tb.end('StepIntoFunctions')
  2936. self.tb.end('Function')
  2937. return node
  2938. def visit_NameNode(self, node):
  2939. if (self.register_stepinto and
  2940. node.type is not None and
  2941. node.type.is_cfunction and
  2942. getattr(node, 'is_called', False) and
  2943. node.entry.func_cname is not None):
  2944. # don't check node.entry.in_cinclude, as 'cdef extern: ...'
  2945. # declared functions are not 'in_cinclude'.
  2946. # This means we will list called 'cdef' functions as
  2947. # "step into functions", but this is not an issue as they will be
  2948. # recognized as Cython functions anyway.
  2949. attrs = dict(name=node.entry.func_cname)
  2950. self.tb.start('StepIntoFunction', attrs=attrs)
  2951. self.tb.end('StepIntoFunction')
  2952. self.visitchildren(node)
  2953. return node
  2954. def serialize_modulenode_as_function(self, node):
  2955. """
  2956. Serialize the module-level code as a function so the debugger will know
  2957. it's a "relevant frame" and it will know where to set the breakpoint
  2958. for 'break modulename'.
  2959. """
  2960. name = node.full_module_name.rpartition('.')[-1]
  2961. cname_py2 = 'init' + name
  2962. cname_py3 = 'PyInit_' + name
  2963. py2_attrs = dict(
  2964. name=name,
  2965. cname=cname_py2,
  2966. pf_cname='',
  2967. # Ignore the qualified_name, breakpoints should be set using
  2968. # `cy break modulename:lineno` for module-level breakpoints.
  2969. qualified_name='',
  2970. lineno='1',
  2971. is_initmodule_function="True",
  2972. )
  2973. py3_attrs = dict(py2_attrs, cname=cname_py3)
  2974. self._serialize_modulenode_as_function(node, py2_attrs)
  2975. self._serialize_modulenode_as_function(node, py3_attrs)
  2976. def _serialize_modulenode_as_function(self, node, attrs):
  2977. self.tb.start('Function', attrs=attrs)
  2978. self.tb.start('Locals')
  2979. self.serialize_local_variables(node.scope.entries)
  2980. self.tb.end('Locals')
  2981. self.tb.start('Arguments')
  2982. self.tb.end('Arguments')
  2983. self.tb.start('StepIntoFunctions')
  2984. self.register_stepinto = True
  2985. self.visitchildren(node)
  2986. self.register_stepinto = False
  2987. self.tb.end('StepIntoFunctions')
  2988. self.tb.end('Function')
  2989. def serialize_local_variables(self, entries):
  2990. for entry in entries.values():
  2991. if not entry.cname:
  2992. # not a local variable
  2993. continue
  2994. if entry.type.is_pyobject:
  2995. vartype = 'PythonObject'
  2996. else:
  2997. vartype = 'CObject'
  2998. if entry.from_closure:
  2999. # We're dealing with a closure where a variable from an outer
  3000. # scope is accessed, get it from the scope object.
  3001. cname = '%s->%s' % (Naming.cur_scope_cname,
  3002. entry.outer_entry.cname)
  3003. qname = '%s.%s.%s' % (entry.scope.outer_scope.qualified_name,
  3004. entry.scope.name,
  3005. entry.name)
  3006. elif entry.in_closure:
  3007. cname = '%s->%s' % (Naming.cur_scope_cname,
  3008. entry.cname)
  3009. qname = entry.qualified_name
  3010. else:
  3011. cname = entry.cname
  3012. qname = entry.qualified_name
  3013. if not entry.pos:
  3014. # this happens for variables that are not in the user's code,
  3015. # e.g. for the global __builtins__, __doc__, etc. We can just
  3016. # set the lineno to 0 for those.
  3017. lineno = '0'
  3018. else:
  3019. lineno = str(entry.pos[1])
  3020. attrs = dict(
  3021. name=entry.name,
  3022. cname=cname,
  3023. qualified_name=qname,
  3024. type=vartype,
  3025. lineno=lineno)
  3026. self.tb.start('LocalVar', attrs)
  3027. self.tb.end('LocalVar')