TestParseTreeTransforms.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. import os
  2. from Cython.TestUtils import TransformTest
  3. from Cython.Compiler.ParseTreeTransforms import *
  4. from Cython.Compiler.Nodes import *
  5. from Cython.Compiler import Main, Symtab
  6. class TestNormalizeTree(TransformTest):
  7. def test_parserbehaviour_is_what_we_coded_for(self):
  8. t = self.fragment(u"if x: y").root
  9. self.assertLines(u"""
  10. (root): StatListNode
  11. stats[0]: IfStatNode
  12. if_clauses[0]: IfClauseNode
  13. condition: NameNode
  14. body: ExprStatNode
  15. expr: NameNode
  16. """, self.treetypes(t))
  17. def test_wrap_singlestat(self):
  18. t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
  19. self.assertLines(u"""
  20. (root): StatListNode
  21. stats[0]: IfStatNode
  22. if_clauses[0]: IfClauseNode
  23. condition: NameNode
  24. body: StatListNode
  25. stats[0]: ExprStatNode
  26. expr: NameNode
  27. """, self.treetypes(t))
  28. def test_wrap_multistat(self):
  29. t = self.run_pipeline([NormalizeTree(None)], u"""
  30. if z:
  31. x
  32. y
  33. """)
  34. self.assertLines(u"""
  35. (root): StatListNode
  36. stats[0]: IfStatNode
  37. if_clauses[0]: IfClauseNode
  38. condition: NameNode
  39. body: StatListNode
  40. stats[0]: ExprStatNode
  41. expr: NameNode
  42. stats[1]: ExprStatNode
  43. expr: NameNode
  44. """, self.treetypes(t))
  45. def test_statinexpr(self):
  46. t = self.run_pipeline([NormalizeTree(None)], u"""
  47. a, b = x, y
  48. """)
  49. self.assertLines(u"""
  50. (root): StatListNode
  51. stats[0]: SingleAssignmentNode
  52. lhs: TupleNode
  53. args[0]: NameNode
  54. args[1]: NameNode
  55. rhs: TupleNode
  56. args[0]: NameNode
  57. args[1]: NameNode
  58. """, self.treetypes(t))
  59. def test_wrap_offagain(self):
  60. t = self.run_pipeline([NormalizeTree(None)], u"""
  61. x
  62. y
  63. if z:
  64. x
  65. """)
  66. self.assertLines(u"""
  67. (root): StatListNode
  68. stats[0]: ExprStatNode
  69. expr: NameNode
  70. stats[1]: ExprStatNode
  71. expr: NameNode
  72. stats[2]: IfStatNode
  73. if_clauses[0]: IfClauseNode
  74. condition: NameNode
  75. body: StatListNode
  76. stats[0]: ExprStatNode
  77. expr: NameNode
  78. """, self.treetypes(t))
  79. def test_pass_eliminated(self):
  80. t = self.run_pipeline([NormalizeTree(None)], u"pass")
  81. self.assertTrue(len(t.stats) == 0)
  82. class TestWithTransform(object): # (TransformTest): # Disabled!
  83. def test_simplified(self):
  84. t = self.run_pipeline([WithTransform(None)], u"""
  85. with x:
  86. y = z ** 3
  87. """)
  88. self.assertCode(u"""
  89. $0_0 = x
  90. $0_2 = $0_0.__exit__
  91. $0_0.__enter__()
  92. $0_1 = True
  93. try:
  94. try:
  95. $1_0 = None
  96. y = z ** 3
  97. except:
  98. $0_1 = False
  99. if (not $0_2($1_0)):
  100. raise
  101. finally:
  102. if $0_1:
  103. $0_2(None, None, None)
  104. """, t)
  105. def test_basic(self):
  106. t = self.run_pipeline([WithTransform(None)], u"""
  107. with x as y:
  108. y = z ** 3
  109. """)
  110. self.assertCode(u"""
  111. $0_0 = x
  112. $0_2 = $0_0.__exit__
  113. $0_3 = $0_0.__enter__()
  114. $0_1 = True
  115. try:
  116. try:
  117. $1_0 = None
  118. y = $0_3
  119. y = z ** 3
  120. except:
  121. $0_1 = False
  122. if (not $0_2($1_0)):
  123. raise
  124. finally:
  125. if $0_1:
  126. $0_2(None, None, None)
  127. """, t)
  128. class TestInterpretCompilerDirectives(TransformTest):
  129. """
  130. This class tests the parallel directives AST-rewriting and importing.
  131. """
  132. # Test the parallel directives (c)importing
  133. import_code = u"""
  134. cimport cython.parallel
  135. cimport cython.parallel as par
  136. from cython cimport parallel as par2
  137. from cython cimport parallel
  138. from cython.parallel cimport threadid as tid
  139. from cython.parallel cimport threadavailable as tavail
  140. from cython.parallel cimport prange
  141. """
  142. expected_directives_dict = {
  143. u'cython.parallel': u'cython.parallel',
  144. u'par': u'cython.parallel',
  145. u'par2': u'cython.parallel',
  146. u'parallel': u'cython.parallel',
  147. u"tid": u"cython.parallel.threadid",
  148. u"tavail": u"cython.parallel.threadavailable",
  149. u"prange": u"cython.parallel.prange",
  150. }
  151. def setUp(self):
  152. super(TestInterpretCompilerDirectives, self).setUp()
  153. compilation_options = Main.CompilationOptions(Main.default_options)
  154. ctx = compilation_options.create_context()
  155. transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
  156. transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
  157. self.pipeline = [transform]
  158. self.debug_exception_on_error = DebugFlags.debug_exception_on_error
  159. def tearDown(self):
  160. DebugFlags.debug_exception_on_error = self.debug_exception_on_error
  161. def test_parallel_directives_cimports(self):
  162. self.run_pipeline(self.pipeline, self.import_code)
  163. parallel_directives = self.pipeline[0].parallel_directives
  164. self.assertEqual(parallel_directives, self.expected_directives_dict)
  165. def test_parallel_directives_imports(self):
  166. self.run_pipeline(self.pipeline,
  167. self.import_code.replace(u'cimport', u'import'))
  168. parallel_directives = self.pipeline[0].parallel_directives
  169. self.assertEqual(parallel_directives, self.expected_directives_dict)
  170. # TODO: Re-enable once they're more robust.
  171. if False:
  172. from Cython.Debugger import DebugWriter
  173. from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
  174. else:
  175. # skip test, don't let it inherit unittest.TestCase
  176. DebuggerTestCase = object
  177. class TestDebugTransform(DebuggerTestCase):
  178. def elem_hasattrs(self, elem, attrs):
  179. return all(attr in elem.attrib for attr in attrs)
  180. def test_debug_info(self):
  181. try:
  182. assert os.path.exists(self.debug_dest)
  183. t = DebugWriter.etree.parse(self.debug_dest)
  184. # the xpath of the standard ElementTree is primitive, don't use
  185. # anything fancy
  186. L = list(t.find('/Module/Globals'))
  187. assert L
  188. xml_globals = dict((e.attrib['name'], e.attrib['type']) for e in L)
  189. self.assertEqual(len(L), len(xml_globals))
  190. L = list(t.find('/Module/Functions'))
  191. assert L
  192. xml_funcs = dict((e.attrib['qualified_name'], e) for e in L)
  193. self.assertEqual(len(L), len(xml_funcs))
  194. # test globals
  195. self.assertEqual('CObject', xml_globals.get('c_var'))
  196. self.assertEqual('PythonObject', xml_globals.get('python_var'))
  197. # test functions
  198. funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs',
  199. 'codefile.closure', 'codefile.inner')
  200. required_xml_attrs = 'name', 'cname', 'qualified_name'
  201. assert all(f in xml_funcs for f in funcnames)
  202. spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
  203. self.assertEqual(spam.attrib['name'], 'spam')
  204. self.assertNotEqual('spam', spam.attrib['cname'])
  205. assert self.elem_hasattrs(spam, required_xml_attrs)
  206. # test locals of functions
  207. spam_locals = list(spam.find('Locals'))
  208. assert spam_locals
  209. spam_locals.sort(key=lambda e: e.attrib['name'])
  210. names = [e.attrib['name'] for e in spam_locals]
  211. self.assertEqual(list('abcd'), names)
  212. assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
  213. # test arguments of functions
  214. spam_arguments = list(spam.find('Arguments'))
  215. assert spam_arguments
  216. self.assertEqual(1, len(list(spam_arguments)))
  217. # test step-into functions
  218. step_into = spam.find('StepIntoFunctions')
  219. spam_stepinto = [x.attrib['name'] for x in step_into]
  220. assert spam_stepinto
  221. self.assertEqual(2, len(spam_stepinto))
  222. assert 'puts' in spam_stepinto
  223. assert 'some_c_function' in spam_stepinto
  224. except:
  225. f = open(self.debug_dest)
  226. try:
  227. print(f.read())
  228. finally:
  229. f.close()
  230. raise
  231. if __name__ == "__main__":
  232. import unittest
  233. unittest.main()