FusedNode.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. from __future__ import absolute_import
  2. import copy
  3. from . import (ExprNodes, PyrexTypes, MemoryView,
  4. ParseTreeTransforms, StringEncoding, Errors,
  5. Naming)
  6. from .ExprNodes import CloneNode, ProxyNode, TupleNode
  7. from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode
  8. from ..Utils import OrderedSet
  9. class FusedCFuncDefNode(StatListNode):
  10. """
  11. This node replaces a function with fused arguments. It deep-copies the
  12. function for every permutation of fused types, and allocates a new local
  13. scope for it. It keeps track of the original function in self.node, and
  14. the entry of the original function in the symbol table is given the
  15. 'fused_cfunction' attribute which points back to us.
  16. Then when a function lookup occurs (to e.g. call it), the call can be
  17. dispatched to the right function.
  18. node FuncDefNode the original function
  19. nodes [FuncDefNode] list of copies of node with different specific types
  20. py_func DefNode the fused python function subscriptable from
  21. Python space
  22. __signatures__ A DictNode mapping signature specialization strings
  23. to PyCFunction nodes
  24. resulting_fused_function PyCFunction for the fused DefNode that delegates
  25. to specializations
  26. fused_func_assignment Assignment of the fused function to the function name
  27. defaults_tuple TupleNode of defaults (letting PyCFunctionNode build
  28. defaults would result in many different tuples)
  29. specialized_pycfuncs List of synthesized pycfunction nodes for the
  30. specializations
  31. code_object CodeObjectNode shared by all specializations and the
  32. fused function
  33. fused_compound_types All fused (compound) types (e.g. floating[:])
  34. """
  35. __signatures__ = None
  36. resulting_fused_function = None
  37. fused_func_assignment = None
  38. defaults_tuple = None
  39. decorators = None
  40. child_attrs = StatListNode.child_attrs + [
  41. '__signatures__', 'resulting_fused_function', 'fused_func_assignment']
  42. def __init__(self, node, env):
  43. super(FusedCFuncDefNode, self).__init__(node.pos)
  44. self.nodes = []
  45. self.node = node
  46. is_def = isinstance(self.node, DefNode)
  47. if is_def:
  48. # self.node.decorators = []
  49. self.copy_def(env)
  50. else:
  51. self.copy_cdef(env)
  52. # Perform some sanity checks. If anything fails, it's a bug
  53. for n in self.nodes:
  54. assert not n.entry.type.is_fused
  55. assert not n.local_scope.return_type.is_fused
  56. if node.return_type.is_fused:
  57. assert not n.return_type.is_fused
  58. if not is_def and n.cfunc_declarator.optional_arg_count:
  59. assert n.type.op_arg_struct
  60. node.entry.fused_cfunction = self
  61. # Copy the nodes as AnalyseDeclarationsTransform will prepend
  62. # self.py_func to self.stats, as we only want specialized
  63. # CFuncDefNodes in self.nodes
  64. self.stats = self.nodes[:]
  65. def copy_def(self, env):
  66. """
  67. Create a copy of the original def or lambda function for specialized
  68. versions.
  69. """
  70. fused_compound_types = PyrexTypes.unique(
  71. [arg.type for arg in self.node.args if arg.type.is_fused])
  72. fused_types = self._get_fused_base_types(fused_compound_types)
  73. permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
  74. self.fused_compound_types = fused_compound_types
  75. if self.node.entry in env.pyfunc_entries:
  76. env.pyfunc_entries.remove(self.node.entry)
  77. for cname, fused_to_specific in permutations:
  78. copied_node = copy.deepcopy(self.node)
  79. # keep signature object identity for special casing in DefNode.analyse_declarations()
  80. copied_node.entry.signature = self.node.entry.signature
  81. self._specialize_function_args(copied_node.args, fused_to_specific)
  82. copied_node.return_type = self.node.return_type.specialize(
  83. fused_to_specific)
  84. copied_node.analyse_declarations(env)
  85. # copied_node.is_staticmethod = self.node.is_staticmethod
  86. # copied_node.is_classmethod = self.node.is_classmethod
  87. self.create_new_local_scope(copied_node, env, fused_to_specific)
  88. self.specialize_copied_def(copied_node, cname, self.node.entry,
  89. fused_to_specific, fused_compound_types)
  90. PyrexTypes.specialize_entry(copied_node.entry, cname)
  91. copied_node.entry.used = True
  92. env.entries[copied_node.entry.name] = copied_node.entry
  93. if not self.replace_fused_typechecks(copied_node):
  94. break
  95. self.orig_py_func = self.node
  96. self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)
  97. def copy_cdef(self, env):
  98. """
  99. Create a copy of the original c(p)def function for all specialized
  100. versions.
  101. """
  102. permutations = self.node.type.get_all_specialized_permutations()
  103. # print 'Node %s has %d specializations:' % (self.node.entry.name,
  104. # len(permutations))
  105. # import pprint; pprint.pprint([d for cname, d in permutations])
  106. # Prevent copying of the python function
  107. self.orig_py_func = orig_py_func = self.node.py_func
  108. self.node.py_func = None
  109. if orig_py_func:
  110. env.pyfunc_entries.remove(orig_py_func.entry)
  111. fused_types = self.node.type.get_fused_types()
  112. self.fused_compound_types = fused_types
  113. new_cfunc_entries = []
  114. for cname, fused_to_specific in permutations:
  115. copied_node = copy.deepcopy(self.node)
  116. # Make the types in our CFuncType specific.
  117. type = copied_node.type.specialize(fused_to_specific)
  118. entry = copied_node.entry
  119. type.specialize_entry(entry, cname)
  120. # Reuse existing Entries (e.g. from .pxd files).
  121. for i, orig_entry in enumerate(env.cfunc_entries):
  122. if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type):
  123. copied_node.entry = env.cfunc_entries[i]
  124. if not copied_node.entry.func_cname:
  125. copied_node.entry.func_cname = entry.func_cname
  126. entry = copied_node.entry
  127. type = entry.type
  128. break
  129. else:
  130. new_cfunc_entries.append(entry)
  131. copied_node.type = type
  132. entry.type, type.entry = type, entry
  133. entry.used = (entry.used or
  134. self.node.entry.defined_in_pxd or
  135. env.is_c_class_scope or
  136. entry.is_cmethod)
  137. if self.node.cfunc_declarator.optional_arg_count:
  138. self.node.cfunc_declarator.declare_optional_arg_struct(
  139. type, env, fused_cname=cname)
  140. copied_node.return_type = type.return_type
  141. self.create_new_local_scope(copied_node, env, fused_to_specific)
  142. # Make the argument types in the CFuncDeclarator specific
  143. self._specialize_function_args(copied_node.cfunc_declarator.args,
  144. fused_to_specific)
  145. # If a cpdef, declare all specialized cpdefs (this
  146. # also calls analyse_declarations)
  147. copied_node.declare_cpdef_wrapper(env)
  148. if copied_node.py_func:
  149. env.pyfunc_entries.remove(copied_node.py_func.entry)
  150. self.specialize_copied_def(
  151. copied_node.py_func, cname, self.node.entry.as_variable,
  152. fused_to_specific, fused_types)
  153. if not self.replace_fused_typechecks(copied_node):
  154. break
  155. # replace old entry with new entries
  156. if self.node.entry in env.cfunc_entries:
  157. cindex = env.cfunc_entries.index(self.node.entry)
  158. env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries
  159. else:
  160. env.cfunc_entries.extend(new_cfunc_entries)
  161. if orig_py_func:
  162. self.py_func = self.make_fused_cpdef(orig_py_func, env,
  163. is_def=False)
  164. else:
  165. self.py_func = orig_py_func
  166. def _get_fused_base_types(self, fused_compound_types):
  167. """
  168. Get a list of unique basic fused types, from a list of
  169. (possibly) compound fused types.
  170. """
  171. base_types = []
  172. seen = set()
  173. for fused_type in fused_compound_types:
  174. fused_type.get_fused_types(result=base_types, seen=seen)
  175. return base_types
  176. def _specialize_function_args(self, args, fused_to_specific):
  177. for arg in args:
  178. if arg.type.is_fused:
  179. arg.type = arg.type.specialize(fused_to_specific)
  180. if arg.type.is_memoryviewslice:
  181. arg.type.validate_memslice_dtype(arg.pos)
  182. def create_new_local_scope(self, node, env, f2s):
  183. """
  184. Create a new local scope for the copied node and append it to
  185. self.nodes. A new local scope is needed because the arguments with the
  186. fused types are already in the local scope, and we need the specialized
  187. entries created after analyse_declarations on each specialized version
  188. of the (CFunc)DefNode.
  189. f2s is a dict mapping each fused type to its specialized version
  190. """
  191. node.create_local_scope(env)
  192. node.local_scope.fused_to_specific = f2s
  193. # This is copied from the original function, set it to false to
  194. # stop recursion
  195. node.has_fused_arguments = False
  196. self.nodes.append(node)
  197. def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
  198. """Specialize the copy of a DefNode given the copied node,
  199. the specialization cname and the original DefNode entry"""
  200. fused_types = self._get_fused_base_types(fused_compound_types)
  201. type_strings = [
  202. PyrexTypes.specialization_signature_string(fused_type, f2s)
  203. for fused_type in fused_types
  204. ]
  205. node.specialized_signature_string = '|'.join(type_strings)
  206. node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
  207. cname, node.entry.pymethdef_cname)
  208. node.entry.doc = py_entry.doc
  209. node.entry.doc_cname = py_entry.doc_cname
  210. def replace_fused_typechecks(self, copied_node):
  211. """
  212. Branch-prune fused type checks like
  213. if fused_t is int:
  214. ...
  215. Returns whether an error was issued and whether we should stop in
  216. in order to prevent a flood of errors.
  217. """
  218. num_errors = Errors.num_errors
  219. transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
  220. copied_node.local_scope)
  221. transform(copied_node)
  222. if Errors.num_errors > num_errors:
  223. return False
  224. return True
  225. def _fused_instance_checks(self, normal_types, pyx_code, env):
  226. """
  227. Generate Cython code for instance checks, matching an object to
  228. specialized types.
  229. """
  230. for specialized_type in normal_types:
  231. # all_numeric = all_numeric and specialized_type.is_numeric
  232. pyx_code.context.update(
  233. py_type_name=specialized_type.py_type_name(),
  234. specialized_type_name=specialized_type.specialization_string,
  235. )
  236. pyx_code.put_chunk(
  237. u"""
  238. if isinstance(arg, {{py_type_name}}):
  239. dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'; break
  240. """)
  241. def _dtype_name(self, dtype):
  242. name = str(dtype).replace('_', '__').replace(' ', '_')
  243. if dtype.is_typedef:
  244. name = Naming.fused_dtype_prefix + name
  245. return name
  246. def _dtype_type(self, dtype):
  247. if dtype.is_typedef:
  248. return self._dtype_name(dtype)
  249. return str(dtype)
  250. def _sizeof_dtype(self, dtype):
  251. if dtype.is_pyobject:
  252. return 'sizeof(void *)'
  253. else:
  254. return "sizeof(%s)" % self._dtype_type(dtype)
  255. def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
  256. "Setup some common cases to match dtypes against specializations"
  257. if pyx_code.indenter("if kind in b'iu':"):
  258. pyx_code.putln("pass")
  259. pyx_code.named_insertion_point("dtype_int")
  260. pyx_code.dedent()
  261. if pyx_code.indenter("elif kind == b'f':"):
  262. pyx_code.putln("pass")
  263. pyx_code.named_insertion_point("dtype_float")
  264. pyx_code.dedent()
  265. if pyx_code.indenter("elif kind == b'c':"):
  266. pyx_code.putln("pass")
  267. pyx_code.named_insertion_point("dtype_complex")
  268. pyx_code.dedent()
  269. if pyx_code.indenter("elif kind == b'O':"):
  270. pyx_code.putln("pass")
  271. pyx_code.named_insertion_point("dtype_object")
  272. pyx_code.dedent()
  273. match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
  274. no_match = "dest_sig[{{dest_sig_idx}}] = None"
  275. def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types, pythran_types):
  276. """
  277. Match a numpy dtype object to the individual specializations.
  278. """
  279. self._buffer_check_numpy_dtype_setup_cases(pyx_code)
  280. for specialized_type in pythran_types+specialized_buffer_types:
  281. final_type = specialized_type
  282. if specialized_type.is_pythran_expr:
  283. specialized_type = specialized_type.org_buffer
  284. dtype = specialized_type.dtype
  285. pyx_code.context.update(
  286. itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
  287. signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
  288. dtype=dtype,
  289. specialized_type_name=final_type.specialization_string)
  290. dtypes = [
  291. (dtype.is_int, pyx_code.dtype_int),
  292. (dtype.is_float, pyx_code.dtype_float),
  293. (dtype.is_complex, pyx_code.dtype_complex)
  294. ]
  295. for dtype_category, codewriter in dtypes:
  296. if dtype_category:
  297. cond = '{{itemsize_match}} and (<Py_ssize_t>arg.ndim) == %d' % (
  298. specialized_type.ndim,)
  299. if dtype.is_int:
  300. cond += ' and {{signed_match}}'
  301. if final_type.is_pythran_expr:
  302. cond += ' and arg_is_pythran_compatible'
  303. if codewriter.indenter("if %s:" % cond):
  304. #codewriter.putln("print 'buffer match found based on numpy dtype'")
  305. codewriter.putln(self.match)
  306. codewriter.putln("break")
  307. codewriter.dedent()
  308. def _buffer_parse_format_string_check(self, pyx_code, decl_code,
  309. specialized_type, env):
  310. """
  311. For each specialized type, try to coerce the object to a memoryview
  312. slice of that type. This means obtaining a buffer and parsing the
  313. format string.
  314. TODO: separate buffer acquisition from format parsing
  315. """
  316. dtype = specialized_type.dtype
  317. if specialized_type.is_buffer:
  318. axes = [('direct', 'strided')] * specialized_type.ndim
  319. else:
  320. axes = specialized_type.axes
  321. memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
  322. memslice_type.create_from_py_utility_code(env)
  323. pyx_code.context.update(
  324. coerce_from_py_func=memslice_type.from_py_function,
  325. dtype=dtype)
  326. decl_code.putln(
  327. "{{memviewslice_cname}} {{coerce_from_py_func}}(object, int)")
  328. pyx_code.context.update(
  329. specialized_type_name=specialized_type.specialization_string,
  330. sizeof_dtype=self._sizeof_dtype(dtype))
  331. pyx_code.put_chunk(
  332. u"""
  333. # try {{dtype}}
  334. if itemsize == -1 or itemsize == {{sizeof_dtype}}:
  335. memslice = {{coerce_from_py_func}}(arg, 0)
  336. if memslice.memview:
  337. __PYX_XDEC_MEMVIEW(&memslice, 1)
  338. # print 'found a match for the buffer through format parsing'
  339. %s
  340. break
  341. else:
  342. __pyx_PyErr_Clear()
  343. """ % self.match)
  344. def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env):
  345. """
  346. Generate Cython code to match objects to buffer specializations.
  347. First try to get a numpy dtype object and match it against the individual
  348. specializations. If that fails, try naively to coerce the object
  349. to each specialization, which obtains the buffer each time and tries
  350. to match the format string.
  351. """
  352. # The first thing to find a match in this loop breaks out of the loop
  353. pyx_code.put_chunk(
  354. u"""
  355. """ + (u"arg_is_pythran_compatible = False" if pythran_types else u"") + u"""
  356. if ndarray is not None:
  357. if isinstance(arg, ndarray):
  358. dtype = arg.dtype
  359. """ + (u"arg_is_pythran_compatible = True" if pythran_types else u"") + u"""
  360. elif __pyx_memoryview_check(arg):
  361. arg_base = arg.base
  362. if isinstance(arg_base, ndarray):
  363. dtype = arg_base.dtype
  364. else:
  365. dtype = None
  366. else:
  367. dtype = None
  368. itemsize = -1
  369. if dtype is not None:
  370. itemsize = dtype.itemsize
  371. kind = ord(dtype.kind)
  372. dtype_signed = kind == 'i'
  373. """)
  374. pyx_code.indent(2)
  375. if pythran_types:
  376. pyx_code.put_chunk(
  377. u"""
  378. # Pythran only supports the endianness of the current compiler
  379. byteorder = dtype.byteorder
  380. if byteorder == "<" and not __Pyx_Is_Little_Endian():
  381. arg_is_pythran_compatible = False
  382. elif byteorder == ">" and __Pyx_Is_Little_Endian():
  383. arg_is_pythran_compatible = False
  384. if arg_is_pythran_compatible:
  385. cur_stride = itemsize
  386. shape = arg.shape
  387. strides = arg.strides
  388. for i in range(arg.ndim-1, -1, -1):
  389. if (<Py_ssize_t>strides[i]) != cur_stride:
  390. arg_is_pythran_compatible = False
  391. break
  392. cur_stride *= <Py_ssize_t> shape[i]
  393. else:
  394. arg_is_pythran_compatible = not (arg.flags.f_contiguous and (<Py_ssize_t>arg.ndim) > 1)
  395. """)
  396. pyx_code.named_insertion_point("numpy_dtype_checks")
  397. self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types)
  398. pyx_code.dedent(2)
  399. for specialized_type in buffer_types:
  400. self._buffer_parse_format_string_check(
  401. pyx_code, decl_code, specialized_type, env)
  402. def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types):
  403. """
  404. If we have any buffer specializations, write out some variable
  405. declarations and imports.
  406. """
  407. decl_code.put_chunk(
  408. u"""
  409. ctypedef struct {{memviewslice_cname}}:
  410. void *memview
  411. void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
  412. bint __pyx_memoryview_check(object)
  413. """)
  414. pyx_code.local_variable_declarations.put_chunk(
  415. u"""
  416. cdef {{memviewslice_cname}} memslice
  417. cdef Py_ssize_t itemsize
  418. cdef bint dtype_signed
  419. cdef char kind
  420. itemsize = -1
  421. """)
  422. if pythran_types:
  423. pyx_code.local_variable_declarations.put_chunk(u"""
  424. cdef bint arg_is_pythran_compatible
  425. cdef Py_ssize_t cur_stride
  426. """)
  427. pyx_code.imports.put_chunk(
  428. u"""
  429. cdef type ndarray
  430. ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable()
  431. """)
  432. seen_typedefs = set()
  433. seen_int_dtypes = set()
  434. for buffer_type in all_buffer_types:
  435. dtype = buffer_type.dtype
  436. dtype_name = self._dtype_name(dtype)
  437. if dtype.is_typedef:
  438. if dtype_name not in seen_typedefs:
  439. seen_typedefs.add(dtype_name)
  440. decl_code.putln(
  441. 'ctypedef %s %s "%s"' % (dtype.resolve(), dtype_name,
  442. dtype.empty_declaration_code()))
  443. if buffer_type.dtype.is_int:
  444. if str(dtype) not in seen_int_dtypes:
  445. seen_int_dtypes.add(str(dtype))
  446. pyx_code.context.update(dtype_name=dtype_name,
  447. dtype_type=self._dtype_type(dtype))
  448. pyx_code.local_variable_declarations.put_chunk(
  449. u"""
  450. cdef bint {{dtype_name}}_is_signed
  451. {{dtype_name}}_is_signed = not (<{{dtype_type}}> -1 > 0)
  452. """)
  453. def _split_fused_types(self, arg):
  454. """
  455. Specialize fused types and split into normal types and buffer types.
  456. """
  457. specialized_types = PyrexTypes.get_specialized_types(arg.type)
  458. # Prefer long over int, etc by sorting (see type classes in PyrexTypes.py)
  459. specialized_types.sort()
  460. seen_py_type_names = set()
  461. normal_types, buffer_types, pythran_types = [], [], []
  462. has_object_fallback = False
  463. for specialized_type in specialized_types:
  464. py_type_name = specialized_type.py_type_name()
  465. if py_type_name:
  466. if py_type_name in seen_py_type_names:
  467. continue
  468. seen_py_type_names.add(py_type_name)
  469. if py_type_name == 'object':
  470. has_object_fallback = True
  471. else:
  472. normal_types.append(specialized_type)
  473. elif specialized_type.is_pythran_expr:
  474. pythran_types.append(specialized_type)
  475. elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
  476. buffer_types.append(specialized_type)
  477. return normal_types, buffer_types, pythran_types, has_object_fallback
  478. def _unpack_argument(self, pyx_code):
  479. pyx_code.put_chunk(
  480. u"""
  481. # PROCESSING ARGUMENT {{arg_tuple_idx}}
  482. if {{arg_tuple_idx}} < len(<tuple>args):
  483. arg = (<tuple>args)[{{arg_tuple_idx}}]
  484. elif kwargs is not None and '{{arg.name}}' in <dict>kwargs:
  485. arg = (<dict>kwargs)['{{arg.name}}']
  486. else:
  487. {{if arg.default}}
  488. arg = (<tuple>defaults)[{{default_idx}}]
  489. {{else}}
  490. {{if arg_tuple_idx < min_positional_args}}
  491. raise TypeError("Expected at least %d argument%s, got %d" % (
  492. {{min_positional_args}}, {{'"s"' if min_positional_args != 1 else '""'}}, len(<tuple>args)))
  493. {{else}}
  494. raise TypeError("Missing keyword-only argument: '%s'" % "{{arg.default}}")
  495. {{endif}}
  496. {{endif}}
  497. """)
  498. def make_fused_cpdef(self, orig_py_func, env, is_def):
  499. """
  500. This creates the function that is indexable from Python and does
  501. runtime dispatch based on the argument types. The function gets the
  502. arg tuple and kwargs dict (or None) and the defaults tuple
  503. as arguments from the Binding Fused Function's tp_call.
  504. """
  505. from . import TreeFragment, Code, UtilityCode
  506. fused_types = self._get_fused_base_types([
  507. arg.type for arg in self.node.args if arg.type.is_fused])
  508. context = {
  509. 'memviewslice_cname': MemoryView.memviewslice_cname,
  510. 'func_args': self.node.args,
  511. 'n_fused': len(fused_types),
  512. 'min_positional_args':
  513. self.node.num_required_args - self.node.num_required_kw_args
  514. if is_def else
  515. sum(1 for arg in self.node.args if arg.default is None),
  516. 'name': orig_py_func.entry.name,
  517. }
  518. pyx_code = Code.PyxCodeWriter(context=context)
  519. decl_code = Code.PyxCodeWriter(context=context)
  520. decl_code.put_chunk(
  521. u"""
  522. cdef extern from *:
  523. void __pyx_PyErr_Clear "PyErr_Clear" ()
  524. type __Pyx_ImportNumPyArrayTypeIfAvailable()
  525. int __Pyx_Is_Little_Endian()
  526. """)
  527. decl_code.indent()
  528. pyx_code.put_chunk(
  529. u"""
  530. def __pyx_fused_cpdef(signatures, args, kwargs, defaults):
  531. # FIXME: use a typed signature - currently fails badly because
  532. # default arguments inherit the types we specify here!
  533. dest_sig = [None] * {{n_fused}}
  534. if kwargs is not None and not kwargs:
  535. kwargs = None
  536. cdef Py_ssize_t i
  537. # instance check body
  538. """)
  539. pyx_code.indent() # indent following code to function body
  540. pyx_code.named_insertion_point("imports")
  541. pyx_code.named_insertion_point("func_defs")
  542. pyx_code.named_insertion_point("local_variable_declarations")
  543. fused_index = 0
  544. default_idx = 0
  545. all_buffer_types = OrderedSet()
  546. seen_fused_types = set()
  547. for i, arg in enumerate(self.node.args):
  548. if arg.type.is_fused:
  549. arg_fused_types = arg.type.get_fused_types()
  550. if len(arg_fused_types) > 1:
  551. raise NotImplementedError("Determination of more than one fused base "
  552. "type per argument is not implemented.")
  553. fused_type = arg_fused_types[0]
  554. if arg.type.is_fused and fused_type not in seen_fused_types:
  555. seen_fused_types.add(fused_type)
  556. context.update(
  557. arg_tuple_idx=i,
  558. arg=arg,
  559. dest_sig_idx=fused_index,
  560. default_idx=default_idx,
  561. )
  562. normal_types, buffer_types, pythran_types, has_object_fallback = self._split_fused_types(arg)
  563. self._unpack_argument(pyx_code)
  564. # 'unrolled' loop, first match breaks out of it
  565. if pyx_code.indenter("while 1:"):
  566. if normal_types:
  567. self._fused_instance_checks(normal_types, pyx_code, env)
  568. if buffer_types or pythran_types:
  569. env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c"))
  570. self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env)
  571. if has_object_fallback:
  572. pyx_code.context.update(specialized_type_name='object')
  573. pyx_code.putln(self.match)
  574. else:
  575. pyx_code.putln(self.no_match)
  576. pyx_code.putln("break")
  577. pyx_code.dedent()
  578. fused_index += 1
  579. all_buffer_types.update(buffer_types)
  580. all_buffer_types.update(ty.org_buffer for ty in pythran_types)
  581. if arg.default:
  582. default_idx += 1
  583. if all_buffer_types:
  584. self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types)
  585. env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
  586. env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
  587. pyx_code.put_chunk(
  588. u"""
  589. candidates = []
  590. for sig in <dict>signatures:
  591. match_found = False
  592. src_sig = sig.strip('()').split('|')
  593. for i in range(len(dest_sig)):
  594. dst_type = dest_sig[i]
  595. if dst_type is not None:
  596. if src_sig[i] == dst_type:
  597. match_found = True
  598. else:
  599. match_found = False
  600. break
  601. if match_found:
  602. candidates.append(sig)
  603. if not candidates:
  604. raise TypeError("No matching signature found")
  605. elif len(candidates) > 1:
  606. raise TypeError("Function call with ambiguous argument types")
  607. else:
  608. return (<dict>signatures)[candidates[0]]
  609. """)
  610. fragment_code = pyx_code.getvalue()
  611. # print decl_code.getvalue()
  612. # print fragment_code
  613. from .Optimize import ConstantFolding
  614. fragment = TreeFragment.TreeFragment(
  615. fragment_code, level='module', pipeline=[ConstantFolding()])
  616. ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
  617. UtilityCode.declare_declarations_in_scope(
  618. decl_code.getvalue(), env.global_scope())
  619. ast.scope = env
  620. # FIXME: for static methods of cdef classes, we build the wrong signature here: first arg becomes 'self'
  621. ast.analyse_declarations(env)
  622. py_func = ast.stats[-1] # the DefNode
  623. self.fragment_scope = ast.scope
  624. if isinstance(self.node, DefNode):
  625. py_func.specialized_cpdefs = self.nodes[:]
  626. else:
  627. py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
  628. return py_func
  629. def update_fused_defnode_entry(self, env):
  630. copy_attributes = (
  631. 'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
  632. 'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
  633. 'scope'
  634. )
  635. entry = self.py_func.entry
  636. for attr in copy_attributes:
  637. setattr(entry, attr,
  638. getattr(self.orig_py_func.entry, attr))
  639. self.py_func.name = self.orig_py_func.name
  640. self.py_func.doc = self.orig_py_func.doc
  641. env.entries.pop('__pyx_fused_cpdef', None)
  642. if isinstance(self.node, DefNode):
  643. env.entries[entry.name] = entry
  644. else:
  645. env.entries[entry.name].as_variable = entry
  646. env.pyfunc_entries.append(entry)
  647. self.py_func.entry.fused_cfunction = self
  648. for node in self.nodes:
  649. if isinstance(self.node, DefNode):
  650. node.fused_py_func = self.py_func
  651. else:
  652. node.py_func.fused_py_func = self.py_func
  653. node.entry.as_variable = entry
  654. self.synthesize_defnodes()
  655. self.stats.append(self.__signatures__)
  656. def analyse_expressions(self, env):
  657. """
  658. Analyse the expressions. Take care to only evaluate default arguments
  659. once and clone the result for all specializations
  660. """
  661. for fused_compound_type in self.fused_compound_types:
  662. for fused_type in fused_compound_type.get_fused_types():
  663. for specialization_type in fused_type.types:
  664. if specialization_type.is_complex:
  665. specialization_type.create_declaration_utility_code(env)
  666. if self.py_func:
  667. self.__signatures__ = self.__signatures__.analyse_expressions(env)
  668. self.py_func = self.py_func.analyse_expressions(env)
  669. self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
  670. self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
  671. self.defaults = defaults = []
  672. for arg in self.node.args:
  673. if arg.default:
  674. arg.default = arg.default.analyse_expressions(env)
  675. defaults.append(ProxyNode(arg.default))
  676. else:
  677. defaults.append(None)
  678. for i, stat in enumerate(self.stats):
  679. stat = self.stats[i] = stat.analyse_expressions(env)
  680. if isinstance(stat, FuncDefNode):
  681. for arg, default in zip(stat.args, defaults):
  682. if default is not None:
  683. arg.default = CloneNode(default).coerce_to(arg.type, env)
  684. if self.py_func:
  685. args = [CloneNode(default) for default in defaults if default]
  686. self.defaults_tuple = TupleNode(self.pos, args=args)
  687. self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True).coerce_to_pyobject(env)
  688. self.defaults_tuple = ProxyNode(self.defaults_tuple)
  689. self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)
  690. fused_func = self.resulting_fused_function.arg
  691. fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
  692. fused_func.code_object = CloneNode(self.code_object)
  693. for i, pycfunc in enumerate(self.specialized_pycfuncs):
  694. pycfunc.code_object = CloneNode(self.code_object)
  695. pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
  696. pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
  697. return self
  698. def synthesize_defnodes(self):
  699. """
  700. Create the __signatures__ dict of PyCFunctionNode specializations.
  701. """
  702. if isinstance(self.nodes[0], CFuncDefNode):
  703. nodes = [node.py_func for node in self.nodes]
  704. else:
  705. nodes = self.nodes
  706. signatures = [StringEncoding.EncodedString(node.specialized_signature_string)
  707. for node in nodes]
  708. keys = [ExprNodes.StringNode(node.pos, value=sig)
  709. for node, sig in zip(nodes, signatures)]
  710. values = [ExprNodes.PyCFunctionNode.from_defnode(node, binding=True)
  711. for node in nodes]
  712. self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, zip(keys, values))
  713. self.specialized_pycfuncs = values
  714. for pycfuncnode in values:
  715. pycfuncnode.is_specialization = True
  716. def generate_function_definitions(self, env, code):
  717. if self.py_func:
  718. self.py_func.pymethdef_required = True
  719. self.fused_func_assignment.generate_function_definitions(env, code)
  720. for stat in self.stats:
  721. if isinstance(stat, FuncDefNode) and stat.entry.used:
  722. code.mark_pos(stat.pos)
  723. stat.generate_function_definitions(env, code)
  724. def generate_execution_code(self, code):
  725. # Note: all def function specialization are wrapped in PyCFunction
  726. # nodes in the self.__signatures__ dictnode.
  727. for default in self.defaults:
  728. if default is not None:
  729. default.generate_evaluation_code(code)
  730. if self.py_func:
  731. self.defaults_tuple.generate_evaluation_code(code)
  732. self.code_object.generate_evaluation_code(code)
  733. for stat in self.stats:
  734. code.mark_pos(stat.pos)
  735. if isinstance(stat, ExprNodes.ExprNode):
  736. stat.generate_evaluation_code(code)
  737. else:
  738. stat.generate_execution_code(code)
  739. if self.__signatures__:
  740. self.resulting_fused_function.generate_evaluation_code(code)
  741. code.putln(
  742. "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
  743. (self.resulting_fused_function.result(),
  744. self.__signatures__.result()))
  745. code.put_giveref(self.__signatures__.result())
  746. self.__signatures__.generate_post_assignment_code(code)
  747. self.__signatures__.free_temps(code)
  748. self.fused_func_assignment.generate_execution_code(code)
  749. # Dispose of results
  750. self.resulting_fused_function.generate_disposal_code(code)
  751. self.resulting_fused_function.free_temps(code)
  752. self.defaults_tuple.generate_disposal_code(code)
  753. self.defaults_tuple.free_temps(code)
  754. self.code_object.generate_disposal_code(code)
  755. self.code_object.free_temps(code)
  756. for default in self.defaults:
  757. if default is not None:
  758. default.generate_disposal_code(code)
  759. default.free_temps(code)
  760. def annotate(self, code):
  761. for stat in self.stats:
  762. stat.annotate(code)