TestParseTreeTransforms.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import os.path
  2. import unittest
  3. from Cython.TestUtils import TransformTest
  4. from Cython.Compiler.ParseTreeTransforms import *
  5. from Cython.Compiler.ParseTreeTransforms import _calculate_pickle_checksums
  6. from Cython.Compiler.Nodes import *
  7. from Cython.Compiler import Main, Symtab
  8. class TestNormalizeTree(TransformTest):
  9. def test_parserbehaviour_is_what_we_coded_for(self):
  10. t = self.fragment(u"if x: y").root
  11. self.assertLines(u"""
  12. (root): StatListNode
  13. stats[0]: IfStatNode
  14. if_clauses[0]: IfClauseNode
  15. condition: NameNode
  16. body: ExprStatNode
  17. expr: NameNode
  18. """, self.treetypes(t))
  19. def test_wrap_singlestat(self):
  20. t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
  21. self.assertLines(u"""
  22. (root): StatListNode
  23. stats[0]: IfStatNode
  24. if_clauses[0]: IfClauseNode
  25. condition: NameNode
  26. body: StatListNode
  27. stats[0]: ExprStatNode
  28. expr: NameNode
  29. """, self.treetypes(t))
  30. def test_wrap_multistat(self):
  31. t = self.run_pipeline([NormalizeTree(None)], u"""
  32. if z:
  33. x
  34. y
  35. """)
  36. self.assertLines(u"""
  37. (root): StatListNode
  38. stats[0]: IfStatNode
  39. if_clauses[0]: IfClauseNode
  40. condition: NameNode
  41. body: StatListNode
  42. stats[0]: ExprStatNode
  43. expr: NameNode
  44. stats[1]: ExprStatNode
  45. expr: NameNode
  46. """, self.treetypes(t))
  47. def test_statinexpr(self):
  48. t = self.run_pipeline([NormalizeTree(None)], u"""
  49. a, b = x, y
  50. """)
  51. self.assertLines(u"""
  52. (root): StatListNode
  53. stats[0]: SingleAssignmentNode
  54. lhs: TupleNode
  55. args[0]: NameNode
  56. args[1]: NameNode
  57. rhs: TupleNode
  58. args[0]: NameNode
  59. args[1]: NameNode
  60. """, self.treetypes(t))
  61. def test_wrap_offagain(self):
  62. t = self.run_pipeline([NormalizeTree(None)], u"""
  63. x
  64. y
  65. if z:
  66. x
  67. """)
  68. self.assertLines(u"""
  69. (root): StatListNode
  70. stats[0]: ExprStatNode
  71. expr: NameNode
  72. stats[1]: ExprStatNode
  73. expr: NameNode
  74. stats[2]: IfStatNode
  75. if_clauses[0]: IfClauseNode
  76. condition: NameNode
  77. body: StatListNode
  78. stats[0]: ExprStatNode
  79. expr: NameNode
  80. """, self.treetypes(t))
  81. def test_pass_eliminated(self):
  82. t = self.run_pipeline([NormalizeTree(None)], u"pass")
  83. self.assertTrue(len(t.stats) == 0)
  84. class TestWithTransform(object): # (TransformTest): # Disabled!
  85. def test_simplified(self):
  86. t = self.run_pipeline([WithTransform(None)], u"""
  87. with x:
  88. y = z ** 3
  89. """)
  90. self.assertCode(u"""
  91. $0_0 = x
  92. $0_2 = $0_0.__exit__
  93. $0_0.__enter__()
  94. $0_1 = True
  95. try:
  96. try:
  97. $1_0 = None
  98. y = z ** 3
  99. except:
  100. $0_1 = False
  101. if (not $0_2($1_0)):
  102. raise
  103. finally:
  104. if $0_1:
  105. $0_2(None, None, None)
  106. """, t)
  107. def test_basic(self):
  108. t = self.run_pipeline([WithTransform(None)], u"""
  109. with x as y:
  110. y = z ** 3
  111. """)
  112. self.assertCode(u"""
  113. $0_0 = x
  114. $0_2 = $0_0.__exit__
  115. $0_3 = $0_0.__enter__()
  116. $0_1 = True
  117. try:
  118. try:
  119. $1_0 = None
  120. y = $0_3
  121. y = z ** 3
  122. except:
  123. $0_1 = False
  124. if (not $0_2($1_0)):
  125. raise
  126. finally:
  127. if $0_1:
  128. $0_2(None, None, None)
  129. """, t)
  130. class TestInterpretCompilerDirectives(TransformTest):
  131. """
  132. This class tests the parallel directives AST-rewriting and importing.
  133. """
  134. # Test the parallel directives (c)importing
  135. import_code = u"""
  136. cimport cython.parallel
  137. cimport cython.parallel as par
  138. from cython cimport parallel as par2
  139. from cython cimport parallel
  140. from cython.parallel cimport threadid as tid
  141. from cython.parallel cimport threadavailable as tavail
  142. from cython.parallel cimport prange
  143. """
  144. expected_directives_dict = {
  145. u'cython.parallel': u'cython.parallel',
  146. u'par': u'cython.parallel',
  147. u'par2': u'cython.parallel',
  148. u'parallel': u'cython.parallel',
  149. u"tid": u"cython.parallel.threadid",
  150. u"tavail": u"cython.parallel.threadavailable",
  151. u"prange": u"cython.parallel.prange",
  152. }
  153. def setUp(self):
  154. super(TestInterpretCompilerDirectives, self).setUp()
  155. compilation_options = Main.CompilationOptions(Main.default_options)
  156. ctx = compilation_options.create_context()
  157. transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
  158. transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
  159. self.pipeline = [transform]
  160. self.debug_exception_on_error = DebugFlags.debug_exception_on_error
  161. def tearDown(self):
  162. DebugFlags.debug_exception_on_error = self.debug_exception_on_error
  163. def test_parallel_directives_cimports(self):
  164. self.run_pipeline(self.pipeline, self.import_code)
  165. parallel_directives = self.pipeline[0].parallel_directives
  166. self.assertEqual(parallel_directives, self.expected_directives_dict)
  167. def test_parallel_directives_imports(self):
  168. self.run_pipeline(self.pipeline,
  169. self.import_code.replace(u'cimport', u'import'))
  170. parallel_directives = self.pipeline[0].parallel_directives
  171. self.assertEqual(parallel_directives, self.expected_directives_dict)
  172. # TODO: Re-enable once they're more robust.
  173. if False:
  174. from Cython.Debugger import DebugWriter
  175. from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
  176. else:
  177. # skip test, don't let it inherit unittest.TestCase
  178. DebuggerTestCase = object
  179. class TestDebugTransform(DebuggerTestCase):
  180. def elem_hasattrs(self, elem, attrs):
  181. return all(attr in elem.attrib for attr in attrs)
  182. def test_debug_info(self):
  183. try:
  184. assert os.path.exists(self.debug_dest)
  185. t = DebugWriter.etree.parse(self.debug_dest)
  186. # the xpath of the standard ElementTree is primitive, don't use
  187. # anything fancy
  188. L = list(t.find('/Module/Globals'))
  189. assert L
  190. xml_globals = dict((e.attrib['name'], e.attrib['type']) for e in L)
  191. self.assertEqual(len(L), len(xml_globals))
  192. L = list(t.find('/Module/Functions'))
  193. assert L
  194. xml_funcs = dict((e.attrib['qualified_name'], e) for e in L)
  195. self.assertEqual(len(L), len(xml_funcs))
  196. # test globals
  197. self.assertEqual('CObject', xml_globals.get('c_var'))
  198. self.assertEqual('PythonObject', xml_globals.get('python_var'))
  199. # test functions
  200. funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs',
  201. 'codefile.closure', 'codefile.inner')
  202. required_xml_attrs = 'name', 'cname', 'qualified_name'
  203. assert all(f in xml_funcs for f in funcnames)
  204. spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
  205. self.assertEqual(spam.attrib['name'], 'spam')
  206. self.assertNotEqual('spam', spam.attrib['cname'])
  207. assert self.elem_hasattrs(spam, required_xml_attrs)
  208. # test locals of functions
  209. spam_locals = list(spam.find('Locals'))
  210. assert spam_locals
  211. spam_locals.sort(key=lambda e: e.attrib['name'])
  212. names = [e.attrib['name'] for e in spam_locals]
  213. self.assertEqual(list('abcd'), names)
  214. assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
  215. # test arguments of functions
  216. spam_arguments = list(spam.find('Arguments'))
  217. assert spam_arguments
  218. self.assertEqual(1, len(list(spam_arguments)))
  219. # test step-into functions
  220. step_into = spam.find('StepIntoFunctions')
  221. spam_stepinto = [x.attrib['name'] for x in step_into]
  222. assert spam_stepinto
  223. self.assertEqual(2, len(spam_stepinto))
  224. assert 'puts' in spam_stepinto
  225. assert 'some_c_function' in spam_stepinto
  226. except:
  227. f = open(self.debug_dest)
  228. try:
  229. print(f.read())
  230. finally:
  231. f.close()
  232. raise
  233. class TestAnalyseDeclarationsTransform(unittest.TestCase):
  234. def test_calculate_pickle_checksums(self):
  235. checksums = _calculate_pickle_checksums(['member1', 'member2', 'member3'])
  236. assert 2 <= len(checksums) <= 3, checksums # expecting ['0xc0af380' (MD5), '0x0c75bd4', '0xa7a7b94']
  237. if __name__ == "__main__":
  238. import unittest
  239. unittest.main()