test_parser_tree.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # -*- coding: utf-8 # This file contains Unicode characters.
  2. from textwrap import dedent
  3. import pytest
  4. from parso import parse
  5. from parso.python import tree
  6. from parso.tree import search_ancestor
  7. class TestsFunctionAndLambdaParsing:
  8. FIXTURES = [
  9. ('def my_function(x, y, z) -> str:\n return x + y * z\n', {
  10. 'name': 'my_function',
  11. 'call_sig': 'my_function(x, y, z)',
  12. 'params': ['x', 'y', 'z'],
  13. 'annotation': "str",
  14. }),
  15. ('lambda x, y, z: x + y * z\n', {
  16. 'name': '<lambda>',
  17. 'call_sig': '<lambda>(x, y, z)',
  18. 'params': ['x', 'y', 'z'],
  19. }),
  20. ]
  21. @pytest.fixture(params=FIXTURES)
  22. def node(self, request):
  23. parsed = parse(dedent(request.param[0]), version='3.10')
  24. request.keywords['expected'] = request.param[1]
  25. child = parsed.children[0]
  26. if child.type == 'simple_stmt':
  27. child = child.children[0]
  28. return child
  29. @pytest.fixture()
  30. def expected(self, request, node):
  31. return request.keywords['expected']
  32. def test_name(self, node, expected):
  33. if node.type != 'lambdef':
  34. assert isinstance(node.name, tree.Name)
  35. assert node.name.value == expected['name']
  36. def test_params(self, node, expected):
  37. assert isinstance(node.get_params(), list)
  38. assert all(isinstance(x, tree.Param) for x in node.get_params())
  39. assert [str(x.name.value) for x in node.get_params()] == [x for x in expected['params']]
  40. def test_is_generator(self, node, expected):
  41. assert node.is_generator() is expected.get('is_generator', False)
  42. def test_yields(self, node, expected):
  43. assert node.is_generator() == expected.get('yields', False)
  44. def test_annotation(self, node, expected):
  45. expected_annotation = expected.get('annotation', None)
  46. if expected_annotation is None:
  47. assert node.annotation is None
  48. else:
  49. assert node.annotation.value == expected_annotation
  50. def test_end_pos_line(each_version):
  51. # jedi issue #150
  52. s = "x()\nx( )\nx( )\nx ( )\n"
  53. module = parse(s, version=each_version)
  54. for i, simple_stmt in enumerate(module.children[:-1]):
  55. expr_stmt = simple_stmt.children[0]
  56. assert expr_stmt.end_pos == (i + 1, i + 3)
  57. def test_default_param(each_version):
  58. func = parse('def x(foo=42): pass', version=each_version).children[0]
  59. param, = func.get_params()
  60. assert param.default.value == '42'
  61. assert param.annotation is None
  62. assert not param.star_count
  63. def test_annotation_param(each_version):
  64. func = parse('def x(foo: 3): pass', version=each_version).children[0]
  65. param, = func.get_params()
  66. assert param.default is None
  67. assert param.annotation.value == '3'
  68. assert not param.star_count
  69. def test_annotation_params(each_version):
  70. func = parse('def x(foo: 3, bar: 4): pass', version=each_version).children[0]
  71. param1, param2 = func.get_params()
  72. assert param1.default is None
  73. assert param1.annotation.value == '3'
  74. assert not param1.star_count
  75. assert param2.default is None
  76. assert param2.annotation.value == '4'
  77. assert not param2.star_count
  78. def test_default_and_annotation_param(each_version):
  79. func = parse('def x(foo:3=42): pass', version=each_version).children[0]
  80. param, = func.get_params()
  81. assert param.default.value == '42'
  82. assert param.annotation.value == '3'
  83. assert not param.star_count
  84. def get_yield_exprs(code, version):
  85. return list(parse(code, version=version).children[0].iter_yield_exprs())
  86. def get_return_stmts(code):
  87. return list(parse(code).children[0].iter_return_stmts())
  88. def get_raise_stmts(code, child):
  89. return list(parse(code).children[child].iter_raise_stmts())
  90. def test_yields(each_version):
  91. y, = get_yield_exprs('def x(): yield', each_version)
  92. assert y.value == 'yield'
  93. assert y.type == 'keyword'
  94. y, = get_yield_exprs('def x(): (yield 1)', each_version)
  95. assert y.type == 'yield_expr'
  96. y, = get_yield_exprs('def x(): [1, (yield)]', each_version)
  97. assert y.type == 'keyword'
  98. def test_yield_from():
  99. y, = get_yield_exprs('def x(): (yield from 1)', '3.8')
  100. assert y.type == 'yield_expr'
  101. def test_returns():
  102. r, = get_return_stmts('def x(): return')
  103. assert r.value == 'return'
  104. assert r.type == 'keyword'
  105. r, = get_return_stmts('def x(): return 1')
  106. assert r.type == 'return_stmt'
  107. def test_raises():
  108. code = """
  109. def single_function():
  110. raise Exception
  111. def top_function():
  112. def inner_function():
  113. raise NotImplementedError()
  114. inner_function()
  115. raise Exception
  116. def top_function_three():
  117. try:
  118. raise NotImplementedError()
  119. except NotImplementedError:
  120. pass
  121. raise Exception
  122. """
  123. r = get_raise_stmts(code, 0) # Lists in a simple Function
  124. assert len(list(r)) == 1
  125. r = get_raise_stmts(code, 1) # Doesn't Exceptions list in closures
  126. assert len(list(r)) == 1
  127. r = get_raise_stmts(code, 2) # Lists inside try-catch
  128. assert len(list(r)) == 2
  129. @pytest.mark.parametrize(
  130. 'code, name_index, is_definition, include_setitem', [
  131. ('x = 3', 0, True, False),
  132. ('x.y = 3', 0, False, False),
  133. ('x.y = 3', 1, True, False),
  134. ('x.y = u.v = z', 0, False, False),
  135. ('x.y = u.v = z', 1, True, False),
  136. ('x.y = u.v = z', 2, False, False),
  137. ('x.y = u.v, w = z', 3, True, False),
  138. ('x.y = u.v, w = z', 4, True, False),
  139. ('x.y = u.v, w = z', 5, False, False),
  140. ('x, y = z', 0, True, False),
  141. ('x, y = z', 1, True, False),
  142. ('x, y = z', 2, False, False),
  143. ('x, y = z', 2, False, False),
  144. ('x[0], y = z', 2, False, False),
  145. ('x[0] = z', 0, False, False),
  146. ('x[0], y = z', 0, False, False),
  147. ('x[0], y = z', 2, False, True),
  148. ('x[0] = z', 0, True, True),
  149. ('x[0], y = z', 0, True, True),
  150. ('x: int = z', 0, True, False),
  151. ('x: int = z', 1, False, False),
  152. ('x: int = z', 2, False, False),
  153. ('x: int', 0, True, False),
  154. ('x: int', 1, False, False),
  155. ]
  156. )
  157. def test_is_definition(code, name_index, is_definition, include_setitem):
  158. module = parse(code, version='3.8')
  159. name = module.get_first_leaf()
  160. while True:
  161. if name.type == 'name':
  162. if name_index == 0:
  163. break
  164. name_index -= 1
  165. name = name.get_next_leaf()
  166. assert name.is_definition(include_setitem=include_setitem) == is_definition
  167. def test_iter_funcdefs():
  168. code = dedent('''
  169. def normal(): ...
  170. async def asyn(): ...
  171. @dec
  172. def dec_normal(): ...
  173. @dec1
  174. @dec2
  175. async def dec_async(): ...
  176. def broken
  177. ''')
  178. module = parse(code, version='3.8')
  179. func_names = [f.name.value for f in module.iter_funcdefs()]
  180. assert func_names == ['normal', 'asyn', 'dec_normal', 'dec_async']
  181. def test_with_stmt_get_test_node_from_name():
  182. code = "with A as X.Y, B as (Z), C as Q[0], D as Q['foo']: pass"
  183. with_stmt = parse(code, version='3').children[0]
  184. tests = [
  185. with_stmt.get_test_node_from_name(name).value
  186. for name in with_stmt.get_defined_names(include_setitem=True)
  187. ]
  188. assert tests == ["A", "B", "C", "D"]
  189. sample_module = parse('x + y')
  190. sample_node = sample_module.children[0]
  191. sample_leaf = sample_node.children[0]
  192. @pytest.mark.parametrize(
  193. 'node,node_types,expected_ancestor', [
  194. (sample_module, ('file_input',), None),
  195. (sample_node, ('arith_expr',), None),
  196. (sample_node, ('file_input', 'eval_input'), sample_module),
  197. (sample_leaf, ('name',), None),
  198. (sample_leaf, ('arith_expr',), sample_node),
  199. (sample_leaf, ('file_input',), sample_module),
  200. (sample_leaf, ('file_input', 'arith_expr'), sample_node),
  201. (sample_leaf, ('shift_expr',), None),
  202. (sample_leaf, ('name', 'shift_expr',), None),
  203. (sample_leaf, (), None),
  204. ]
  205. )
  206. def test_search_ancestor(node, node_types, expected_ancestor):
  207. assert node.search_ancestor(*node_types) is expected_ancestor
  208. assert search_ancestor(node, *node_types) is expected_ancestor # deprecated