Optimize.py 205 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857
  1. from __future__ import absolute_import
  2. import re
  3. import sys
  4. import copy
  5. import codecs
  6. import itertools
  7. from . import TypeSlots
  8. from .ExprNodes import not_a_constant
  9. import cython
  10. cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object,
  11. Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
  12. UtilNodes=object, _py_int_types=object)
  13. if sys.version_info[0] >= 3:
  14. _py_int_types = int
  15. _py_string_types = (bytes, str)
  16. else:
  17. _py_int_types = (int, long)
  18. _py_string_types = (bytes, unicode)
  19. from . import Nodes
  20. from . import ExprNodes
  21. from . import PyrexTypes
  22. from . import Visitor
  23. from . import Builtin
  24. from . import UtilNodes
  25. from . import Options
  26. from .Code import UtilityCode, TempitaUtilityCode
  27. from .StringEncoding import EncodedString, bytes_literal, encoded_string
  28. from .Errors import error, warning
  29. from .ParseTreeTransforms import SkipDeclarations
  30. try:
  31. from __builtin__ import reduce
  32. except ImportError:
  33. from functools import reduce
  34. try:
  35. from __builtin__ import basestring
  36. except ImportError:
  37. basestring = str # Python 3
  38. def load_c_utility(name):
  39. return UtilityCode.load_cached(name, "Optimize.c")
  40. def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
  41. if isinstance(node, coercion_nodes):
  42. return node.arg
  43. return node
  44. def unwrap_node(node):
  45. while isinstance(node, UtilNodes.ResultRefNode):
  46. node = node.expression
  47. return node
  48. def is_common_value(a, b):
  49. a = unwrap_node(a)
  50. b = unwrap_node(b)
  51. if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
  52. return a.name == b.name
  53. if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
  54. return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
  55. return False
  56. def filter_none_node(node):
  57. if node is not None and node.constant_result is None:
  58. return None
  59. return node
  60. class _YieldNodeCollector(Visitor.TreeVisitor):
  61. """
  62. YieldExprNode finder for generator expressions.
  63. """
  64. def __init__(self):
  65. Visitor.TreeVisitor.__init__(self)
  66. self.yield_stat_nodes = {}
  67. self.yield_nodes = []
  68. visit_Node = Visitor.TreeVisitor.visitchildren
  69. def visit_YieldExprNode(self, node):
  70. self.yield_nodes.append(node)
  71. self.visitchildren(node)
  72. def visit_ExprStatNode(self, node):
  73. self.visitchildren(node)
  74. if node.expr in self.yield_nodes:
  75. self.yield_stat_nodes[node.expr] = node
  76. # everything below these nodes is out of scope:
  77. def visit_GeneratorExpressionNode(self, node):
  78. pass
  79. def visit_LambdaNode(self, node):
  80. pass
  81. def visit_FuncDefNode(self, node):
  82. pass
  83. def _find_single_yield_expression(node):
  84. yield_statements = _find_yield_statements(node)
  85. if len(yield_statements) != 1:
  86. return None, None
  87. return yield_statements[0]
  88. def _find_yield_statements(node):
  89. collector = _YieldNodeCollector()
  90. collector.visitchildren(node)
  91. try:
  92. yield_statements = [
  93. (yield_node.arg, collector.yield_stat_nodes[yield_node])
  94. for yield_node in collector.yield_nodes
  95. ]
  96. except KeyError:
  97. # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
  98. yield_statements = []
  99. return yield_statements
  100. class IterationTransform(Visitor.EnvTransform):
  101. """Transform some common for-in loop patterns into efficient C loops:
  102. - for-in-dict loop becomes a while loop calling PyDict_Next()
  103. - for-in-enumerate is replaced by an external counter variable
  104. - for-in-range loop becomes a plain C for loop
  105. """
  106. def visit_PrimaryCmpNode(self, node):
  107. if node.is_ptr_contains():
  108. # for t in operand2:
  109. # if operand1 == t:
  110. # res = True
  111. # break
  112. # else:
  113. # res = False
  114. pos = node.pos
  115. result_ref = UtilNodes.ResultRefNode(node)
  116. if node.operand2.is_subscript:
  117. base_type = node.operand2.base.type.base_type
  118. else:
  119. base_type = node.operand2.type.base_type
  120. target_handle = UtilNodes.TempHandle(base_type)
  121. target = target_handle.ref(pos)
  122. cmp_node = ExprNodes.PrimaryCmpNode(
  123. pos, operator=u'==', operand1=node.operand1, operand2=target)
  124. if_body = Nodes.StatListNode(
  125. pos,
  126. stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
  127. Nodes.BreakStatNode(pos)])
  128. if_node = Nodes.IfStatNode(
  129. pos,
  130. if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
  131. else_clause=None)
  132. for_loop = UtilNodes.TempsBlockNode(
  133. pos,
  134. temps = [target_handle],
  135. body = Nodes.ForInStatNode(
  136. pos,
  137. target=target,
  138. iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
  139. body=if_node,
  140. else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
  141. for_loop = for_loop.analyse_expressions(self.current_env())
  142. for_loop = self.visit(for_loop)
  143. new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
  144. if node.operator == 'not_in':
  145. new_node = ExprNodes.NotNode(pos, operand=new_node)
  146. return new_node
  147. else:
  148. self.visitchildren(node)
  149. return node
  150. def visit_ForInStatNode(self, node):
  151. self.visitchildren(node)
  152. return self._optimise_for_loop(node, node.iterator.sequence)
  153. def _optimise_for_loop(self, node, iterable, reversed=False):
  154. annotation_type = None
  155. if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
  156. annotation = iterable.entry.annotation
  157. if annotation.is_subscript:
  158. annotation = annotation.base # container base type
  159. # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
  160. if annotation.is_name:
  161. if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
  162. annotation_type = Builtin.dict_type
  163. elif annotation.name == 'Dict':
  164. annotation_type = Builtin.dict_type
  165. if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
  166. annotation_type = Builtin.set_type
  167. elif annotation.name in ('Set', 'FrozenSet'):
  168. annotation_type = Builtin.set_type
  169. if Builtin.dict_type in (iterable.type, annotation_type):
  170. # like iterating over dict.keys()
  171. if reversed:
  172. # CPython raises an error here: not a sequence
  173. return node
  174. return self._transform_dict_iteration(
  175. node, dict_obj=iterable, method=None, keys=True, values=False)
  176. if (Builtin.set_type in (iterable.type, annotation_type) or
  177. Builtin.frozenset_type in (iterable.type, annotation_type)):
  178. if reversed:
  179. # CPython raises an error here: not a sequence
  180. return node
  181. return self._transform_set_iteration(node, iterable)
  182. # C array (slice) iteration?
  183. if iterable.type.is_ptr or iterable.type.is_array:
  184. return self._transform_carray_iteration(node, iterable, reversed=reversed)
  185. if iterable.type is Builtin.bytes_type:
  186. return self._transform_bytes_iteration(node, iterable, reversed=reversed)
  187. if iterable.type is Builtin.unicode_type:
  188. return self._transform_unicode_iteration(node, iterable, reversed=reversed)
  189. # the rest is based on function calls
  190. if not isinstance(iterable, ExprNodes.SimpleCallNode):
  191. return node
  192. if iterable.args is None:
  193. arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0
  194. else:
  195. arg_count = len(iterable.args)
  196. if arg_count and iterable.self is not None:
  197. arg_count -= 1
  198. function = iterable.function
  199. # dict iteration?
  200. if function.is_attribute and not reversed and not arg_count:
  201. base_obj = iterable.self or function.obj
  202. method = function.attribute
  203. # in Py3, items() is equivalent to Py2's iteritems()
  204. is_safe_iter = self.global_scope().context.language_level >= 3
  205. if not is_safe_iter and method in ('keys', 'values', 'items'):
  206. # try to reduce this to the corresponding .iter*() methods
  207. if isinstance(base_obj, ExprNodes.CallNode):
  208. inner_function = base_obj.function
  209. if (inner_function.is_name and inner_function.name == 'dict'
  210. and inner_function.entry
  211. and inner_function.entry.is_builtin):
  212. # e.g. dict(something).items() => safe to use .iter*()
  213. is_safe_iter = True
  214. keys = values = False
  215. if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
  216. keys = True
  217. elif method == 'itervalues' or (is_safe_iter and method == 'values'):
  218. values = True
  219. elif method == 'iteritems' or (is_safe_iter and method == 'items'):
  220. keys = values = True
  221. if keys or values:
  222. return self._transform_dict_iteration(
  223. node, base_obj, method, keys, values)
  224. # enumerate/reversed ?
  225. if iterable.self is None and function.is_name and \
  226. function.entry and function.entry.is_builtin:
  227. if function.name == 'enumerate':
  228. if reversed:
  229. # CPython raises an error here: not a sequence
  230. return node
  231. return self._transform_enumerate_iteration(node, iterable)
  232. elif function.name == 'reversed':
  233. if reversed:
  234. # CPython raises an error here: not a sequence
  235. return node
  236. return self._transform_reversed_iteration(node, iterable)
  237. # range() iteration?
  238. if Options.convert_range and 1 <= arg_count <= 3 and (
  239. iterable.self is None and
  240. function.is_name and function.name in ('range', 'xrange') and
  241. function.entry and function.entry.is_builtin):
  242. if node.target.type.is_int or node.target.type.is_enum:
  243. return self._transform_range_iteration(node, iterable, reversed=reversed)
  244. if node.target.type.is_pyobject:
  245. # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
  246. for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args):
  247. if isinstance(arg, ExprNodes.IntNode):
  248. if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
  249. continue
  250. break
  251. else:
  252. return self._transform_range_iteration(node, iterable, reversed=reversed)
  253. return node
  254. def _transform_reversed_iteration(self, node, reversed_function):
  255. args = reversed_function.arg_tuple.args
  256. if len(args) == 0:
  257. error(reversed_function.pos,
  258. "reversed() requires an iterable argument")
  259. return node
  260. elif len(args) > 1:
  261. error(reversed_function.pos,
  262. "reversed() takes exactly 1 argument")
  263. return node
  264. arg = args[0]
  265. # reversed(list/tuple) ?
  266. if arg.type in (Builtin.tuple_type, Builtin.list_type):
  267. node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
  268. node.iterator.reversed = True
  269. return node
  270. return self._optimise_for_loop(node, arg, reversed=True)
  271. PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
  272. PyrexTypes.c_char_ptr_type, [
  273. PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
  274. ])
  275. PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
  276. PyrexTypes.c_py_ssize_t_type, [
  277. PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
  278. ])
  279. def _transform_bytes_iteration(self, node, slice_node, reversed=False):
  280. target_type = node.target.type
  281. if not target_type.is_int and target_type is not Builtin.bytes_type:
  282. # bytes iteration returns bytes objects in Py2, but
  283. # integers in Py3
  284. return node
  285. unpack_temp_node = UtilNodes.LetRefNode(
  286. slice_node.as_none_safe_node("'NoneType' is not iterable"))
  287. slice_base_node = ExprNodes.PythonCapiCallNode(
  288. slice_node.pos, "PyBytes_AS_STRING",
  289. self.PyBytes_AS_STRING_func_type,
  290. args = [unpack_temp_node],
  291. is_temp = 0,
  292. )
  293. len_node = ExprNodes.PythonCapiCallNode(
  294. slice_node.pos, "PyBytes_GET_SIZE",
  295. self.PyBytes_GET_SIZE_func_type,
  296. args = [unpack_temp_node],
  297. is_temp = 0,
  298. )
  299. return UtilNodes.LetNode(
  300. unpack_temp_node,
  301. self._transform_carray_iteration(
  302. node,
  303. ExprNodes.SliceIndexNode(
  304. slice_node.pos,
  305. base = slice_base_node,
  306. start = None,
  307. step = None,
  308. stop = len_node,
  309. type = slice_base_node.type,
  310. is_temp = 1,
  311. ),
  312. reversed = reversed))
  313. PyUnicode_READ_func_type = PyrexTypes.CFuncType(
  314. PyrexTypes.c_py_ucs4_type, [
  315. PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
  316. PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
  317. PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
  318. ])
  319. init_unicode_iteration_func_type = PyrexTypes.CFuncType(
  320. PyrexTypes.c_int_type, [
  321. PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
  322. PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  323. PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
  324. PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
  325. ],
  326. exception_value = '-1')
  327. def _transform_unicode_iteration(self, node, slice_node, reversed=False):
  328. if slice_node.is_literal:
  329. # try to reduce to byte iteration for plain Latin-1 strings
  330. try:
  331. bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
  332. except UnicodeEncodeError:
  333. pass
  334. else:
  335. bytes_slice = ExprNodes.SliceIndexNode(
  336. slice_node.pos,
  337. base=ExprNodes.BytesNode(
  338. slice_node.pos, value=bytes_value,
  339. constant_result=bytes_value,
  340. type=PyrexTypes.c_const_char_ptr_type).coerce_to(
  341. PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
  342. start=None,
  343. stop=ExprNodes.IntNode(
  344. slice_node.pos, value=str(len(bytes_value)),
  345. constant_result=len(bytes_value),
  346. type=PyrexTypes.c_py_ssize_t_type),
  347. type=Builtin.unicode_type, # hint for Python conversion
  348. )
  349. return self._transform_carray_iteration(node, bytes_slice, reversed)
  350. unpack_temp_node = UtilNodes.LetRefNode(
  351. slice_node.as_none_safe_node("'NoneType' is not iterable"))
  352. start_node = ExprNodes.IntNode(
  353. node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
  354. length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  355. end_node = length_temp.ref(node.pos)
  356. if reversed:
  357. relation1, relation2 = '>', '>='
  358. start_node, end_node = end_node, start_node
  359. else:
  360. relation1, relation2 = '<=', '<'
  361. kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  362. data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
  363. counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  364. target_value = ExprNodes.PythonCapiCallNode(
  365. slice_node.pos, "__Pyx_PyUnicode_READ",
  366. self.PyUnicode_READ_func_type,
  367. args = [kind_temp.ref(slice_node.pos),
  368. data_temp.ref(slice_node.pos),
  369. counter_temp.ref(node.target.pos)],
  370. is_temp = False,
  371. )
  372. if target_value.type != node.target.type:
  373. target_value = target_value.coerce_to(node.target.type,
  374. self.current_env())
  375. target_assign = Nodes.SingleAssignmentNode(
  376. pos = node.target.pos,
  377. lhs = node.target,
  378. rhs = target_value)
  379. body = Nodes.StatListNode(
  380. node.pos,
  381. stats = [target_assign, node.body])
  382. loop_node = Nodes.ForFromStatNode(
  383. node.pos,
  384. bound1=start_node, relation1=relation1,
  385. target=counter_temp.ref(node.target.pos),
  386. relation2=relation2, bound2=end_node,
  387. step=None, body=body,
  388. else_clause=node.else_clause,
  389. from_range=True)
  390. setup_node = Nodes.ExprStatNode(
  391. node.pos,
  392. expr = ExprNodes.PythonCapiCallNode(
  393. slice_node.pos, "__Pyx_init_unicode_iteration",
  394. self.init_unicode_iteration_func_type,
  395. args = [unpack_temp_node,
  396. ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
  397. type=PyrexTypes.c_py_ssize_t_ptr_type),
  398. ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
  399. type=PyrexTypes.c_void_ptr_ptr_type),
  400. ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
  401. type=PyrexTypes.c_int_ptr_type),
  402. ],
  403. is_temp = True,
  404. result_is_used = False,
  405. utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
  406. ))
  407. return UtilNodes.LetNode(
  408. unpack_temp_node,
  409. UtilNodes.TempsBlockNode(
  410. node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
  411. body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
  412. def _transform_carray_iteration(self, node, slice_node, reversed=False):
  413. neg_step = False
  414. if isinstance(slice_node, ExprNodes.SliceIndexNode):
  415. slice_base = slice_node.base
  416. start = filter_none_node(slice_node.start)
  417. stop = filter_none_node(slice_node.stop)
  418. step = None
  419. if not stop:
  420. if not slice_base.type.is_pyobject:
  421. error(slice_node.pos, "C array iteration requires known end index")
  422. return node
  423. elif slice_node.is_subscript:
  424. assert isinstance(slice_node.index, ExprNodes.SliceNode)
  425. slice_base = slice_node.base
  426. index = slice_node.index
  427. start = filter_none_node(index.start)
  428. stop = filter_none_node(index.stop)
  429. step = filter_none_node(index.step)
  430. if step:
  431. if not isinstance(step.constant_result, _py_int_types) \
  432. or step.constant_result == 0 \
  433. or step.constant_result > 0 and not stop \
  434. or step.constant_result < 0 and not start:
  435. if not slice_base.type.is_pyobject:
  436. error(step.pos, "C array iteration requires known step size and end index")
  437. return node
  438. else:
  439. # step sign is handled internally by ForFromStatNode
  440. step_value = step.constant_result
  441. if reversed:
  442. step_value = -step_value
  443. neg_step = step_value < 0
  444. step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
  445. value=str(abs(step_value)),
  446. constant_result=abs(step_value))
  447. elif slice_node.type.is_array:
  448. if slice_node.type.size is None:
  449. error(slice_node.pos, "C array iteration requires known end index")
  450. return node
  451. slice_base = slice_node
  452. start = None
  453. stop = ExprNodes.IntNode(
  454. slice_node.pos, value=str(slice_node.type.size),
  455. type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
  456. step = None
  457. else:
  458. if not slice_node.type.is_pyobject:
  459. error(slice_node.pos, "C array iteration requires known end index")
  460. return node
  461. if start:
  462. start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  463. if stop:
  464. stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  465. if stop is None:
  466. if neg_step:
  467. stop = ExprNodes.IntNode(
  468. slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
  469. else:
  470. error(slice_node.pos, "C array iteration requires known step size and end index")
  471. return node
  472. if reversed:
  473. if not start:
  474. start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
  475. type=PyrexTypes.c_py_ssize_t_type)
  476. # if step was provided, it was already negated above
  477. start, stop = stop, start
  478. ptr_type = slice_base.type
  479. if ptr_type.is_array:
  480. ptr_type = ptr_type.element_ptr_type()
  481. carray_ptr = slice_base.coerce_to_simple(self.current_env())
  482. if start and start.constant_result != 0:
  483. start_ptr_node = ExprNodes.AddNode(
  484. start.pos,
  485. operand1=carray_ptr,
  486. operator='+',
  487. operand2=start,
  488. type=ptr_type)
  489. else:
  490. start_ptr_node = carray_ptr
  491. if stop and stop.constant_result != 0:
  492. stop_ptr_node = ExprNodes.AddNode(
  493. stop.pos,
  494. operand1=ExprNodes.CloneNode(carray_ptr),
  495. operator='+',
  496. operand2=stop,
  497. type=ptr_type
  498. ).coerce_to_simple(self.current_env())
  499. else:
  500. stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
  501. counter = UtilNodes.TempHandle(ptr_type)
  502. counter_temp = counter.ref(node.target.pos)
  503. if slice_base.type.is_string and node.target.type.is_pyobject:
  504. # special case: char* -> bytes/unicode
  505. if slice_node.type is Builtin.unicode_type:
  506. target_value = ExprNodes.CastNode(
  507. ExprNodes.DereferenceNode(
  508. node.target.pos, operand=counter_temp,
  509. type=ptr_type.base_type),
  510. PyrexTypes.c_py_ucs4_type).coerce_to(
  511. node.target.type, self.current_env())
  512. else:
  513. # char* -> bytes coercion requires slicing, not indexing
  514. target_value = ExprNodes.SliceIndexNode(
  515. node.target.pos,
  516. start=ExprNodes.IntNode(node.target.pos, value='0',
  517. constant_result=0,
  518. type=PyrexTypes.c_int_type),
  519. stop=ExprNodes.IntNode(node.target.pos, value='1',
  520. constant_result=1,
  521. type=PyrexTypes.c_int_type),
  522. base=counter_temp,
  523. type=Builtin.bytes_type,
  524. is_temp=1)
  525. elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
  526. # Allow iteration with pointer target to avoid copy.
  527. target_value = counter_temp
  528. else:
  529. # TODO: can this safely be replaced with DereferenceNode() as above?
  530. target_value = ExprNodes.IndexNode(
  531. node.target.pos,
  532. index=ExprNodes.IntNode(node.target.pos, value='0',
  533. constant_result=0,
  534. type=PyrexTypes.c_int_type),
  535. base=counter_temp,
  536. type=ptr_type.base_type)
  537. if target_value.type != node.target.type:
  538. target_value = target_value.coerce_to(node.target.type,
  539. self.current_env())
  540. target_assign = Nodes.SingleAssignmentNode(
  541. pos = node.target.pos,
  542. lhs = node.target,
  543. rhs = target_value)
  544. body = Nodes.StatListNode(
  545. node.pos,
  546. stats = [target_assign, node.body])
  547. relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
  548. for_node = Nodes.ForFromStatNode(
  549. node.pos,
  550. bound1=start_ptr_node, relation1=relation1,
  551. target=counter_temp,
  552. relation2=relation2, bound2=stop_ptr_node,
  553. step=step, body=body,
  554. else_clause=node.else_clause,
  555. from_range=True)
  556. return UtilNodes.TempsBlockNode(
  557. node.pos, temps=[counter],
  558. body=for_node)
  559. def _transform_enumerate_iteration(self, node, enumerate_function):
  560. args = enumerate_function.arg_tuple.args
  561. if len(args) == 0:
  562. error(enumerate_function.pos,
  563. "enumerate() requires an iterable argument")
  564. return node
  565. elif len(args) > 2:
  566. error(enumerate_function.pos,
  567. "enumerate() takes at most 2 arguments")
  568. return node
  569. if not node.target.is_sequence_constructor:
  570. # leave this untouched for now
  571. return node
  572. targets = node.target.args
  573. if len(targets) != 2:
  574. # leave this untouched for now
  575. return node
  576. enumerate_target, iterable_target = targets
  577. counter_type = enumerate_target.type
  578. if not counter_type.is_pyobject and not counter_type.is_int:
  579. # nothing we can do here, I guess
  580. return node
  581. if len(args) == 2:
  582. start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
  583. else:
  584. start = ExprNodes.IntNode(enumerate_function.pos,
  585. value='0',
  586. type=counter_type,
  587. constant_result=0)
  588. temp = UtilNodes.LetRefNode(start)
  589. inc_expression = ExprNodes.AddNode(
  590. enumerate_function.pos,
  591. operand1 = temp,
  592. operand2 = ExprNodes.IntNode(node.pos, value='1',
  593. type=counter_type,
  594. constant_result=1),
  595. operator = '+',
  596. type = counter_type,
  597. #inplace = True, # not worth using in-place operation for Py ints
  598. is_temp = counter_type.is_pyobject
  599. )
  600. loop_body = [
  601. Nodes.SingleAssignmentNode(
  602. pos = enumerate_target.pos,
  603. lhs = enumerate_target,
  604. rhs = temp),
  605. Nodes.SingleAssignmentNode(
  606. pos = enumerate_target.pos,
  607. lhs = temp,
  608. rhs = inc_expression)
  609. ]
  610. if isinstance(node.body, Nodes.StatListNode):
  611. node.body.stats = loop_body + node.body.stats
  612. else:
  613. loop_body.append(node.body)
  614. node.body = Nodes.StatListNode(
  615. node.body.pos,
  616. stats = loop_body)
  617. node.target = iterable_target
  618. node.item = node.item.coerce_to(iterable_target.type, self.current_env())
  619. node.iterator.sequence = args[0]
  620. # recurse into loop to check for further optimisations
  621. return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
  622. def _find_for_from_node_relations(self, neg_step_value, reversed):
  623. if reversed:
  624. if neg_step_value:
  625. return '<', '<='
  626. else:
  627. return '>', '>='
  628. else:
  629. if neg_step_value:
  630. return '>=', '>'
  631. else:
  632. return '<=', '<'
  633. def _transform_range_iteration(self, node, range_function, reversed=False):
  634. args = range_function.arg_tuple.args
  635. if len(args) < 3:
  636. step_pos = range_function.pos
  637. step_value = 1
  638. step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
  639. else:
  640. step = args[2]
  641. step_pos = step.pos
  642. if not isinstance(step.constant_result, _py_int_types):
  643. # cannot determine step direction
  644. return node
  645. step_value = step.constant_result
  646. if step_value == 0:
  647. # will lead to an error elsewhere
  648. return node
  649. step = ExprNodes.IntNode(step_pos, value=str(step_value),
  650. constant_result=step_value)
  651. if len(args) == 1:
  652. bound1 = ExprNodes.IntNode(range_function.pos, value='0',
  653. constant_result=0)
  654. bound2 = args[0].coerce_to_integer(self.current_env())
  655. else:
  656. bound1 = args[0].coerce_to_integer(self.current_env())
  657. bound2 = args[1].coerce_to_integer(self.current_env())
  658. relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
  659. bound2_ref_node = None
  660. if reversed:
  661. bound1, bound2 = bound2, bound1
  662. abs_step = abs(step_value)
  663. if abs_step != 1:
  664. if (isinstance(bound1.constant_result, _py_int_types) and
  665. isinstance(bound2.constant_result, _py_int_types)):
  666. # calculate final bounds now
  667. if step_value < 0:
  668. begin_value = bound2.constant_result
  669. end_value = bound1.constant_result
  670. bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
  671. else:
  672. begin_value = bound1.constant_result
  673. end_value = bound2.constant_result
  674. bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1
  675. bound1 = ExprNodes.IntNode(
  676. bound1.pos, value=str(bound1_value), constant_result=bound1_value,
  677. type=PyrexTypes.spanning_type(bound1.type, bound2.type))
  678. else:
  679. # evaluate the same expression as above at runtime
  680. bound2_ref_node = UtilNodes.LetRefNode(bound2)
  681. bound1 = self._build_range_step_calculation(
  682. bound1, bound2_ref_node, step, step_value)
  683. if step_value < 0:
  684. step_value = -step_value
  685. step.value = str(step_value)
  686. step.constant_result = step_value
  687. step = step.coerce_to_integer(self.current_env())
  688. if not bound2.is_literal:
  689. # stop bound must be immutable => keep it in a temp var
  690. bound2_is_temp = True
  691. bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
  692. else:
  693. bound2_is_temp = False
  694. for_node = Nodes.ForFromStatNode(
  695. node.pos,
  696. target=node.target,
  697. bound1=bound1, relation1=relation1,
  698. relation2=relation2, bound2=bound2,
  699. step=step, body=node.body,
  700. else_clause=node.else_clause,
  701. from_range=True)
  702. for_node.set_up_loop(self.current_env())
  703. if bound2_is_temp:
  704. for_node = UtilNodes.LetNode(bound2, for_node)
  705. return for_node
  706. def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
  707. abs_step = abs(step_value)
  708. spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
  709. if step.type.is_int and abs_step < 0x7FFF:
  710. # Avoid loss of integer precision warnings.
  711. spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
  712. else:
  713. spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
  714. if step_value < 0:
  715. begin_value = bound2_ref_node
  716. end_value = bound1
  717. final_op = '-'
  718. else:
  719. begin_value = bound1
  720. end_value = bound2_ref_node
  721. final_op = '+'
  722. step_calculation_node = ExprNodes.binop_node(
  723. bound1.pos,
  724. operand1=ExprNodes.binop_node(
  725. bound1.pos,
  726. operand1=bound2_ref_node,
  727. operator=final_op, # +/-
  728. operand2=ExprNodes.MulNode(
  729. bound1.pos,
  730. operand1=ExprNodes.IntNode(
  731. bound1.pos,
  732. value=str(abs_step),
  733. constant_result=abs_step,
  734. type=spanning_step_type),
  735. operator='*',
  736. operand2=ExprNodes.DivNode(
  737. bound1.pos,
  738. operand1=ExprNodes.SubNode(
  739. bound1.pos,
  740. operand1=ExprNodes.SubNode(
  741. bound1.pos,
  742. operand1=begin_value,
  743. operator='-',
  744. operand2=end_value,
  745. type=spanning_type),
  746. operator='-',
  747. operand2=ExprNodes.IntNode(
  748. bound1.pos,
  749. value='1',
  750. constant_result=1),
  751. type=spanning_step_type),
  752. operator='//',
  753. operand2=ExprNodes.IntNode(
  754. bound1.pos,
  755. value=str(abs_step),
  756. constant_result=abs_step,
  757. type=spanning_step_type),
  758. type=spanning_step_type),
  759. type=spanning_step_type),
  760. type=spanning_step_type),
  761. operator=final_op, # +/-
  762. operand2=ExprNodes.IntNode(
  763. bound1.pos,
  764. value='1',
  765. constant_result=1),
  766. type=spanning_type)
  767. return step_calculation_node
  768. def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
  769. temps = []
  770. temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
  771. temps.append(temp)
  772. dict_temp = temp.ref(dict_obj.pos)
  773. temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  774. temps.append(temp)
  775. pos_temp = temp.ref(node.pos)
  776. key_target = value_target = tuple_target = None
  777. if keys and values:
  778. if node.target.is_sequence_constructor:
  779. if len(node.target.args) == 2:
  780. key_target, value_target = node.target.args
  781. else:
  782. # unusual case that may or may not lead to an error
  783. return node
  784. else:
  785. tuple_target = node.target
  786. elif keys:
  787. key_target = node.target
  788. else:
  789. value_target = node.target
  790. if isinstance(node.body, Nodes.StatListNode):
  791. body = node.body
  792. else:
  793. body = Nodes.StatListNode(pos = node.body.pos,
  794. stats = [node.body])
  795. # keep original length to guard against dict modification
  796. dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  797. temps.append(dict_len_temp)
  798. dict_len_temp_addr = ExprNodes.AmpersandNode(
  799. node.pos, operand=dict_len_temp.ref(dict_obj.pos),
  800. type=PyrexTypes.c_ptr_type(dict_len_temp.type))
  801. temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  802. temps.append(temp)
  803. is_dict_temp = temp.ref(node.pos)
  804. is_dict_temp_addr = ExprNodes.AmpersandNode(
  805. node.pos, operand=is_dict_temp,
  806. type=PyrexTypes.c_ptr_type(temp.type))
  807. iter_next_node = Nodes.DictIterationNextNode(
  808. dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
  809. key_target, value_target, tuple_target,
  810. is_dict_temp)
  811. iter_next_node = iter_next_node.analyse_expressions(self.current_env())
  812. body.stats[0:0] = [iter_next_node]
  813. if method:
  814. method_node = ExprNodes.StringNode(
  815. dict_obj.pos, is_identifier=True, value=method)
  816. dict_obj = dict_obj.as_none_safe_node(
  817. "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''),
  818. error = "PyExc_AttributeError",
  819. format_args = [method])
  820. else:
  821. method_node = ExprNodes.NullNode(dict_obj.pos)
  822. dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
  823. def flag_node(value):
  824. value = value and 1 or 0
  825. return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
  826. result_code = [
  827. Nodes.SingleAssignmentNode(
  828. node.pos,
  829. lhs = pos_temp,
  830. rhs = ExprNodes.IntNode(node.pos, value='0',
  831. constant_result=0)),
  832. Nodes.SingleAssignmentNode(
  833. dict_obj.pos,
  834. lhs = dict_temp,
  835. rhs = ExprNodes.PythonCapiCallNode(
  836. dict_obj.pos,
  837. "__Pyx_dict_iterator",
  838. self.PyDict_Iterator_func_type,
  839. utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
  840. args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
  841. method_node, dict_len_temp_addr, is_dict_temp_addr,
  842. ],
  843. is_temp=True,
  844. )),
  845. Nodes.WhileStatNode(
  846. node.pos,
  847. condition = None,
  848. body = body,
  849. else_clause = node.else_clause
  850. )
  851. ]
  852. return UtilNodes.TempsBlockNode(
  853. node.pos, temps=temps,
  854. body=Nodes.StatListNode(
  855. node.pos,
  856. stats = result_code
  857. ))
  858. PyDict_Iterator_func_type = PyrexTypes.CFuncType(
  859. PyrexTypes.py_object_type, [
  860. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  861. PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None),
  862. PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None),
  863. PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  864. PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None),
  865. ])
  866. PySet_Iterator_func_type = PyrexTypes.CFuncType(
  867. PyrexTypes.py_object_type, [
  868. PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None),
  869. PyrexTypes.CFuncTypeArg("is_set", PyrexTypes.c_int_type, None),
  870. PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  871. PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None),
  872. ])
  873. def _transform_set_iteration(self, node, set_obj):
  874. temps = []
  875. temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
  876. temps.append(temp)
  877. set_temp = temp.ref(set_obj.pos)
  878. temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  879. temps.append(temp)
  880. pos_temp = temp.ref(node.pos)
  881. if isinstance(node.body, Nodes.StatListNode):
  882. body = node.body
  883. else:
  884. body = Nodes.StatListNode(pos = node.body.pos,
  885. stats = [node.body])
  886. # keep original length to guard against set modification
  887. set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  888. temps.append(set_len_temp)
  889. set_len_temp_addr = ExprNodes.AmpersandNode(
  890. node.pos, operand=set_len_temp.ref(set_obj.pos),
  891. type=PyrexTypes.c_ptr_type(set_len_temp.type))
  892. temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  893. temps.append(temp)
  894. is_set_temp = temp.ref(node.pos)
  895. is_set_temp_addr = ExprNodes.AmpersandNode(
  896. node.pos, operand=is_set_temp,
  897. type=PyrexTypes.c_ptr_type(temp.type))
  898. value_target = node.target
  899. iter_next_node = Nodes.SetIterationNextNode(
  900. set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
  901. iter_next_node = iter_next_node.analyse_expressions(self.current_env())
  902. body.stats[0:0] = [iter_next_node]
  903. def flag_node(value):
  904. value = value and 1 or 0
  905. return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
  906. result_code = [
  907. Nodes.SingleAssignmentNode(
  908. node.pos,
  909. lhs=pos_temp,
  910. rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
  911. Nodes.SingleAssignmentNode(
  912. set_obj.pos,
  913. lhs=set_temp,
  914. rhs=ExprNodes.PythonCapiCallNode(
  915. set_obj.pos,
  916. "__Pyx_set_iterator",
  917. self.PySet_Iterator_func_type,
  918. utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
  919. args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
  920. set_len_temp_addr, is_set_temp_addr,
  921. ],
  922. is_temp=True,
  923. )),
  924. Nodes.WhileStatNode(
  925. node.pos,
  926. condition=None,
  927. body=body,
  928. else_clause=node.else_clause,
  929. )
  930. ]
  931. return UtilNodes.TempsBlockNode(
  932. node.pos, temps=temps,
  933. body=Nodes.StatListNode(
  934. node.pos,
  935. stats = result_code
  936. ))
  937. class SwitchTransform(Visitor.EnvTransform):
  938. """
  939. This transformation tries to turn long if statements into C switch statements.
  940. The requirement is that every clause be an (or of) var == value, where the var
  941. is common among all clauses and both var and value are ints.
  942. """
  943. NO_MATCH = (None, None, None)
  944. def extract_conditions(self, cond, allow_not_in):
  945. while True:
  946. if isinstance(cond, (ExprNodes.CoerceToTempNode,
  947. ExprNodes.CoerceToBooleanNode)):
  948. cond = cond.arg
  949. elif isinstance(cond, ExprNodes.BoolBinopResultNode):
  950. cond = cond.arg.arg
  951. elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
  952. # this is what we get from the FlattenInListTransform
  953. cond = cond.subexpression
  954. elif isinstance(cond, ExprNodes.TypecastNode):
  955. cond = cond.operand
  956. else:
  957. break
  958. if isinstance(cond, ExprNodes.PrimaryCmpNode):
  959. if cond.cascade is not None:
  960. return self.NO_MATCH
  961. elif cond.is_c_string_contains() and \
  962. isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
  963. not_in = cond.operator == 'not_in'
  964. if not_in and not allow_not_in:
  965. return self.NO_MATCH
  966. if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
  967. cond.operand2.contains_surrogates():
  968. # dealing with surrogates leads to different
  969. # behaviour on wide and narrow Unicode
  970. # platforms => refuse to optimise this case
  971. return self.NO_MATCH
  972. return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
  973. elif not cond.is_python_comparison():
  974. if cond.operator == '==':
  975. not_in = False
  976. elif allow_not_in and cond.operator == '!=':
  977. not_in = True
  978. else:
  979. return self.NO_MATCH
  980. # this looks somewhat silly, but it does the right
  981. # checks for NameNode and AttributeNode
  982. if is_common_value(cond.operand1, cond.operand1):
  983. if cond.operand2.is_literal:
  984. return not_in, cond.operand1, [cond.operand2]
  985. elif getattr(cond.operand2, 'entry', None) \
  986. and cond.operand2.entry.is_const:
  987. return not_in, cond.operand1, [cond.operand2]
  988. if is_common_value(cond.operand2, cond.operand2):
  989. if cond.operand1.is_literal:
  990. return not_in, cond.operand2, [cond.operand1]
  991. elif getattr(cond.operand1, 'entry', None) \
  992. and cond.operand1.entry.is_const:
  993. return not_in, cond.operand2, [cond.operand1]
  994. elif isinstance(cond, ExprNodes.BoolBinopNode):
  995. if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
  996. allow_not_in = (cond.operator == 'and')
  997. not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
  998. not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
  999. if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
  1000. if (not not_in_1) or allow_not_in:
  1001. return not_in_1, t1, c1+c2
  1002. return self.NO_MATCH
  1003. def extract_in_string_conditions(self, string_literal):
  1004. if isinstance(string_literal, ExprNodes.UnicodeNode):
  1005. charvals = list(map(ord, set(string_literal.value)))
  1006. charvals.sort()
  1007. return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
  1008. constant_result=charval)
  1009. for charval in charvals ]
  1010. else:
  1011. # this is a bit tricky as Py3's bytes type returns
  1012. # integers on iteration, whereas Py2 returns 1-char byte
  1013. # strings
  1014. characters = string_literal.value
  1015. characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
  1016. characters.sort()
  1017. return [ ExprNodes.CharNode(string_literal.pos, value=charval,
  1018. constant_result=charval)
  1019. for charval in characters ]
  1020. def extract_common_conditions(self, common_var, condition, allow_not_in):
  1021. not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
  1022. if var is None:
  1023. return self.NO_MATCH
  1024. elif common_var is not None and not is_common_value(var, common_var):
  1025. return self.NO_MATCH
  1026. elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
  1027. return self.NO_MATCH
  1028. return not_in, var, conditions
  1029. def has_duplicate_values(self, condition_values):
  1030. # duplicated values don't work in a switch statement
  1031. seen = set()
  1032. for value in condition_values:
  1033. if value.has_constant_result():
  1034. if value.constant_result in seen:
  1035. return True
  1036. seen.add(value.constant_result)
  1037. else:
  1038. # this isn't completely safe as we don't know the
  1039. # final C value, but this is about the best we can do
  1040. try:
  1041. if value.entry.cname in seen:
  1042. return True
  1043. except AttributeError:
  1044. return True # play safe
  1045. seen.add(value.entry.cname)
  1046. return False
  1047. def visit_IfStatNode(self, node):
  1048. if not self.current_directives.get('optimize.use_switch'):
  1049. self.visitchildren(node)
  1050. return node
  1051. common_var = None
  1052. cases = []
  1053. for if_clause in node.if_clauses:
  1054. _, common_var, conditions = self.extract_common_conditions(
  1055. common_var, if_clause.condition, False)
  1056. if common_var is None:
  1057. self.visitchildren(node)
  1058. return node
  1059. cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos,
  1060. conditions=conditions,
  1061. body=if_clause.body))
  1062. condition_values = [
  1063. cond for case in cases for cond in case.conditions]
  1064. if len(condition_values) < 2:
  1065. self.visitchildren(node)
  1066. return node
  1067. if self.has_duplicate_values(condition_values):
  1068. self.visitchildren(node)
  1069. return node
  1070. # Recurse into body subtrees that we left untouched so far.
  1071. self.visitchildren(node, 'else_clause')
  1072. for case in cases:
  1073. self.visitchildren(case, 'body')
  1074. common_var = unwrap_node(common_var)
  1075. switch_node = Nodes.SwitchStatNode(pos=node.pos,
  1076. test=common_var,
  1077. cases=cases,
  1078. else_clause=node.else_clause)
  1079. return switch_node
  1080. def visit_CondExprNode(self, node):
  1081. if not self.current_directives.get('optimize.use_switch'):
  1082. self.visitchildren(node)
  1083. return node
  1084. not_in, common_var, conditions = self.extract_common_conditions(
  1085. None, node.test, True)
  1086. if common_var is None \
  1087. or len(conditions) < 2 \
  1088. or self.has_duplicate_values(conditions):
  1089. self.visitchildren(node)
  1090. return node
  1091. return self.build_simple_switch_statement(
  1092. node, common_var, conditions, not_in,
  1093. node.true_val, node.false_val)
  1094. def visit_BoolBinopNode(self, node):
  1095. if not self.current_directives.get('optimize.use_switch'):
  1096. self.visitchildren(node)
  1097. return node
  1098. not_in, common_var, conditions = self.extract_common_conditions(
  1099. None, node, True)
  1100. if common_var is None \
  1101. or len(conditions) < 2 \
  1102. or self.has_duplicate_values(conditions):
  1103. self.visitchildren(node)
  1104. node.wrap_operands(self.current_env()) # in case we changed the operands
  1105. return node
  1106. return self.build_simple_switch_statement(
  1107. node, common_var, conditions, not_in,
  1108. ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
  1109. ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
  1110. def visit_PrimaryCmpNode(self, node):
  1111. if not self.current_directives.get('optimize.use_switch'):
  1112. self.visitchildren(node)
  1113. return node
  1114. not_in, common_var, conditions = self.extract_common_conditions(
  1115. None, node, True)
  1116. if common_var is None \
  1117. or len(conditions) < 2 \
  1118. or self.has_duplicate_values(conditions):
  1119. self.visitchildren(node)
  1120. return node
  1121. return self.build_simple_switch_statement(
  1122. node, common_var, conditions, not_in,
  1123. ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
  1124. ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
  1125. def build_simple_switch_statement(self, node, common_var, conditions,
  1126. not_in, true_val, false_val):
  1127. result_ref = UtilNodes.ResultRefNode(node)
  1128. true_body = Nodes.SingleAssignmentNode(
  1129. node.pos,
  1130. lhs=result_ref,
  1131. rhs=true_val.coerce_to(node.type, self.current_env()),
  1132. first=True)
  1133. false_body = Nodes.SingleAssignmentNode(
  1134. node.pos,
  1135. lhs=result_ref,
  1136. rhs=false_val.coerce_to(node.type, self.current_env()),
  1137. first=True)
  1138. if not_in:
  1139. true_body, false_body = false_body, true_body
  1140. cases = [Nodes.SwitchCaseNode(pos = node.pos,
  1141. conditions = conditions,
  1142. body = true_body)]
  1143. common_var = unwrap_node(common_var)
  1144. switch_node = Nodes.SwitchStatNode(pos = node.pos,
  1145. test = common_var,
  1146. cases = cases,
  1147. else_clause = false_body)
  1148. replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
  1149. return replacement
  1150. def visit_EvalWithTempExprNode(self, node):
  1151. if not self.current_directives.get('optimize.use_switch'):
  1152. self.visitchildren(node)
  1153. return node
  1154. # drop unused expression temp from FlattenInListTransform
  1155. orig_expr = node.subexpression
  1156. temp_ref = node.lazy_temp
  1157. self.visitchildren(node)
  1158. if node.subexpression is not orig_expr:
  1159. # node was restructured => check if temp is still used
  1160. if not Visitor.tree_contains(node.subexpression, temp_ref):
  1161. return node.subexpression
  1162. return node
  1163. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1164. class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
  1165. """
  1166. This transformation flattens "x in [val1, ..., valn]" into a sequential list
  1167. of comparisons.
  1168. """
  1169. def visit_PrimaryCmpNode(self, node):
  1170. self.visitchildren(node)
  1171. if node.cascade is not None:
  1172. return node
  1173. elif node.operator == 'in':
  1174. conjunction = 'or'
  1175. eq_or_neq = '=='
  1176. elif node.operator == 'not_in':
  1177. conjunction = 'and'
  1178. eq_or_neq = '!='
  1179. else:
  1180. return node
  1181. if not isinstance(node.operand2, (ExprNodes.TupleNode,
  1182. ExprNodes.ListNode,
  1183. ExprNodes.SetNode)):
  1184. return node
  1185. args = node.operand2.args
  1186. if len(args) == 0:
  1187. # note: lhs may have side effects
  1188. return node
  1189. if any([arg.is_starred for arg in args]):
  1190. # Starred arguments do not directly translate to comparisons or "in" tests.
  1191. return node
  1192. lhs = UtilNodes.ResultRefNode(node.operand1)
  1193. conds = []
  1194. temps = []
  1195. for arg in args:
  1196. try:
  1197. # Trial optimisation to avoid redundant temp
  1198. # assignments. However, since is_simple() is meant to
  1199. # be called after type analysis, we ignore any errors
  1200. # and just play safe in that case.
  1201. is_simple_arg = arg.is_simple()
  1202. except Exception:
  1203. is_simple_arg = False
  1204. if not is_simple_arg:
  1205. # must evaluate all non-simple RHS before doing the comparisons
  1206. arg = UtilNodes.LetRefNode(arg)
  1207. temps.append(arg)
  1208. cond = ExprNodes.PrimaryCmpNode(
  1209. pos = node.pos,
  1210. operand1 = lhs,
  1211. operator = eq_or_neq,
  1212. operand2 = arg,
  1213. cascade = None)
  1214. conds.append(ExprNodes.TypecastNode(
  1215. pos = node.pos,
  1216. operand = cond,
  1217. type = PyrexTypes.c_bint_type))
  1218. def concat(left, right):
  1219. return ExprNodes.BoolBinopNode(
  1220. pos = node.pos,
  1221. operator = conjunction,
  1222. operand1 = left,
  1223. operand2 = right)
  1224. condition = reduce(concat, conds)
  1225. new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
  1226. for temp in temps[::-1]:
  1227. new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
  1228. return new_node
  1229. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1230. class DropRefcountingTransform(Visitor.VisitorTransform):
  1231. """Drop ref-counting in safe places.
  1232. """
  1233. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1234. def visit_ParallelAssignmentNode(self, node):
  1235. """
  1236. Parallel swap assignments like 'a,b = b,a' are safe.
  1237. """
  1238. left_names, right_names = [], []
  1239. left_indices, right_indices = [], []
  1240. temps = []
  1241. for stat in node.stats:
  1242. if isinstance(stat, Nodes.SingleAssignmentNode):
  1243. if not self._extract_operand(stat.lhs, left_names,
  1244. left_indices, temps):
  1245. return node
  1246. if not self._extract_operand(stat.rhs, right_names,
  1247. right_indices, temps):
  1248. return node
  1249. elif isinstance(stat, Nodes.CascadedAssignmentNode):
  1250. # FIXME
  1251. return node
  1252. else:
  1253. return node
  1254. if left_names or right_names:
  1255. # lhs/rhs names must be a non-redundant permutation
  1256. lnames = [ path for path, n in left_names ]
  1257. rnames = [ path for path, n in right_names ]
  1258. if set(lnames) != set(rnames):
  1259. return node
  1260. if len(set(lnames)) != len(right_names):
  1261. return node
  1262. if left_indices or right_indices:
  1263. # base name and index of index nodes must be a
  1264. # non-redundant permutation
  1265. lindices = []
  1266. for lhs_node in left_indices:
  1267. index_id = self._extract_index_id(lhs_node)
  1268. if not index_id:
  1269. return node
  1270. lindices.append(index_id)
  1271. rindices = []
  1272. for rhs_node in right_indices:
  1273. index_id = self._extract_index_id(rhs_node)
  1274. if not index_id:
  1275. return node
  1276. rindices.append(index_id)
  1277. if set(lindices) != set(rindices):
  1278. return node
  1279. if len(set(lindices)) != len(right_indices):
  1280. return node
  1281. # really supporting IndexNode requires support in
  1282. # __Pyx_GetItemInt(), so let's stop short for now
  1283. return node
  1284. temp_args = [t.arg for t in temps]
  1285. for temp in temps:
  1286. temp.use_managed_ref = False
  1287. for _, name_node in left_names + right_names:
  1288. if name_node not in temp_args:
  1289. name_node.use_managed_ref = False
  1290. for index_node in left_indices + right_indices:
  1291. index_node.use_managed_ref = False
  1292. return node
  1293. def _extract_operand(self, node, names, indices, temps):
  1294. node = unwrap_node(node)
  1295. if not node.type.is_pyobject:
  1296. return False
  1297. if isinstance(node, ExprNodes.CoerceToTempNode):
  1298. temps.append(node)
  1299. node = node.arg
  1300. name_path = []
  1301. obj_node = node
  1302. while obj_node.is_attribute:
  1303. if obj_node.is_py_attr:
  1304. return False
  1305. name_path.append(obj_node.member)
  1306. obj_node = obj_node.obj
  1307. if obj_node.is_name:
  1308. name_path.append(obj_node.name)
  1309. names.append( ('.'.join(name_path[::-1]), node) )
  1310. elif node.is_subscript:
  1311. if node.base.type != Builtin.list_type:
  1312. return False
  1313. if not node.index.type.is_int:
  1314. return False
  1315. if not node.base.is_name:
  1316. return False
  1317. indices.append(node)
  1318. else:
  1319. return False
  1320. return True
  1321. def _extract_index_id(self, index_node):
  1322. base = index_node.base
  1323. index = index_node.index
  1324. if isinstance(index, ExprNodes.NameNode):
  1325. index_val = index.name
  1326. elif isinstance(index, ExprNodes.ConstNode):
  1327. # FIXME:
  1328. return None
  1329. else:
  1330. return None
  1331. return (base.name, index_val)
  1332. class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
  1333. """Optimize some common calls to builtin types *before* the type
  1334. analysis phase and *after* the declarations analysis phase.
  1335. This transform cannot make use of any argument types, but it can
  1336. restructure the tree in a way that the type analysis phase can
  1337. respond to.
  1338. Introducing C function calls here may not be a good idea. Move
  1339. them to the OptimizeBuiltinCalls transform instead, which runs
  1340. after type analysis.
  1341. """
  1342. # only intercept on call nodes
  1343. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1344. def visit_SimpleCallNode(self, node):
  1345. self.visitchildren(node)
  1346. function = node.function
  1347. if not self._function_is_builtin_name(function):
  1348. return node
  1349. return self._dispatch_to_handler(node, function, node.args)
  1350. def visit_GeneralCallNode(self, node):
  1351. self.visitchildren(node)
  1352. function = node.function
  1353. if not self._function_is_builtin_name(function):
  1354. return node
  1355. arg_tuple = node.positional_args
  1356. if not isinstance(arg_tuple, ExprNodes.TupleNode):
  1357. return node
  1358. args = arg_tuple.args
  1359. return self._dispatch_to_handler(
  1360. node, function, args, node.keyword_args)
  1361. def _function_is_builtin_name(self, function):
  1362. if not function.is_name:
  1363. return False
  1364. env = self.current_env()
  1365. entry = env.lookup(function.name)
  1366. if entry is not env.builtin_scope().lookup_here(function.name):
  1367. return False
  1368. # if entry is None, it's at least an undeclared name, so likely builtin
  1369. return True
  1370. def _dispatch_to_handler(self, node, function, args, kwargs=None):
  1371. if kwargs is None:
  1372. handler_name = '_handle_simple_function_%s' % function.name
  1373. else:
  1374. handler_name = '_handle_general_function_%s' % function.name
  1375. handle_call = getattr(self, handler_name, None)
  1376. if handle_call is not None:
  1377. if kwargs is None:
  1378. return handle_call(node, args)
  1379. else:
  1380. return handle_call(node, args, kwargs)
  1381. return node
  1382. def _inject_capi_function(self, node, cname, func_type, utility_code=None):
  1383. node.function = ExprNodes.PythonCapiFunctionNode(
  1384. node.function.pos, node.function.name, cname, func_type,
  1385. utility_code = utility_code)
  1386. def _error_wrong_arg_count(self, function_name, node, args, expected=None):
  1387. if not expected: # None or 0
  1388. arg_str = ''
  1389. elif isinstance(expected, basestring) or expected > 1:
  1390. arg_str = '...'
  1391. elif expected == 1:
  1392. arg_str = 'x'
  1393. else:
  1394. arg_str = ''
  1395. if expected is not None:
  1396. expected_str = 'expected %s, ' % expected
  1397. else:
  1398. expected_str = ''
  1399. error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
  1400. function_name, arg_str, expected_str, len(args)))
  1401. # specific handlers for simple call nodes
  1402. def _handle_simple_function_float(self, node, pos_args):
  1403. if not pos_args:
  1404. return ExprNodes.FloatNode(node.pos, value='0.0')
  1405. if len(pos_args) > 1:
  1406. self._error_wrong_arg_count('float', node, pos_args, 1)
  1407. arg_type = getattr(pos_args[0], 'type', None)
  1408. if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
  1409. return pos_args[0]
  1410. return node
  1411. def _handle_simple_function_slice(self, node, pos_args):
  1412. arg_count = len(pos_args)
  1413. start = step = None
  1414. if arg_count == 1:
  1415. stop, = pos_args
  1416. elif arg_count == 2:
  1417. start, stop = pos_args
  1418. elif arg_count == 3:
  1419. start, stop, step = pos_args
  1420. else:
  1421. self._error_wrong_arg_count('slice', node, pos_args)
  1422. return node
  1423. return ExprNodes.SliceNode(
  1424. node.pos,
  1425. start=start or ExprNodes.NoneNode(node.pos),
  1426. stop=stop,
  1427. step=step or ExprNodes.NoneNode(node.pos))
  1428. def _handle_simple_function_ord(self, node, pos_args):
  1429. """Unpack ord('X').
  1430. """
  1431. if len(pos_args) != 1:
  1432. return node
  1433. arg = pos_args[0]
  1434. if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
  1435. if len(arg.value) == 1:
  1436. return ExprNodes.IntNode(
  1437. arg.pos, type=PyrexTypes.c_long_type,
  1438. value=str(ord(arg.value)),
  1439. constant_result=ord(arg.value)
  1440. )
  1441. elif isinstance(arg, ExprNodes.StringNode):
  1442. if arg.unicode_value and len(arg.unicode_value) == 1 \
  1443. and ord(arg.unicode_value) <= 255: # Py2/3 portability
  1444. return ExprNodes.IntNode(
  1445. arg.pos, type=PyrexTypes.c_int_type,
  1446. value=str(ord(arg.unicode_value)),
  1447. constant_result=ord(arg.unicode_value)
  1448. )
  1449. return node
  1450. # sequence processing
  1451. def _handle_simple_function_all(self, node, pos_args):
  1452. """Transform
  1453. _result = all(p(x) for L in LL for x in L)
  1454. into
  1455. for L in LL:
  1456. for x in L:
  1457. if not p(x):
  1458. return False
  1459. else:
  1460. return True
  1461. """
  1462. return self._transform_any_all(node, pos_args, False)
  1463. def _handle_simple_function_any(self, node, pos_args):
  1464. """Transform
  1465. _result = any(p(x) for L in LL for x in L)
  1466. into
  1467. for L in LL:
  1468. for x in L:
  1469. if p(x):
  1470. return True
  1471. else:
  1472. return False
  1473. """
  1474. return self._transform_any_all(node, pos_args, True)
  1475. def _transform_any_all(self, node, pos_args, is_any):
  1476. if len(pos_args) != 1:
  1477. return node
  1478. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1479. return node
  1480. gen_expr_node = pos_args[0]
  1481. generator_body = gen_expr_node.def_node.gbody
  1482. loop_node = generator_body.body
  1483. yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
  1484. if yield_expression is None:
  1485. return node
  1486. if is_any:
  1487. condition = yield_expression
  1488. else:
  1489. condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
  1490. test_node = Nodes.IfStatNode(
  1491. yield_expression.pos, else_clause=None, if_clauses=[
  1492. Nodes.IfClauseNode(
  1493. yield_expression.pos,
  1494. condition=condition,
  1495. body=Nodes.ReturnStatNode(
  1496. node.pos,
  1497. value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
  1498. )]
  1499. )
  1500. loop_node.else_clause = Nodes.ReturnStatNode(
  1501. node.pos,
  1502. value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
  1503. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
  1504. return ExprNodes.InlinedGeneratorExpressionNode(
  1505. gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
  1506. PySequence_List_func_type = PyrexTypes.CFuncType(
  1507. Builtin.list_type,
  1508. [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
  1509. def _handle_simple_function_sorted(self, node, pos_args):
  1510. """Transform sorted(genexpr) and sorted([listcomp]) into
  1511. [listcomp].sort(). CPython just reads the iterable into a
  1512. list and calls .sort() on it. Expanding the iterable in a
  1513. listcomp is still faster and the result can be sorted in
  1514. place.
  1515. """
  1516. if len(pos_args) != 1:
  1517. return node
  1518. arg = pos_args[0]
  1519. if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
  1520. list_node = pos_args[0]
  1521. loop_node = list_node.loop
  1522. elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
  1523. gen_expr_node = arg
  1524. loop_node = gen_expr_node.loop
  1525. yield_statements = _find_yield_statements(loop_node)
  1526. if not yield_statements:
  1527. return node
  1528. list_node = ExprNodes.InlinedGeneratorExpressionNode(
  1529. node.pos, gen_expr_node, orig_func='sorted',
  1530. comprehension_type=Builtin.list_type)
  1531. for yield_expression, yield_stat_node in yield_statements:
  1532. append_node = ExprNodes.ComprehensionAppendNode(
  1533. yield_expression.pos,
  1534. expr=yield_expression,
  1535. target=list_node.target)
  1536. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1537. elif arg.is_sequence_constructor:
  1538. # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list,
  1539. # so starting off with a fresh one is more efficient.
  1540. list_node = loop_node = arg.as_list()
  1541. else:
  1542. # Interestingly, PySequence_List works on a lot of non-sequence
  1543. # things as well.
  1544. list_node = loop_node = ExprNodes.PythonCapiCallNode(
  1545. node.pos, "PySequence_List", self.PySequence_List_func_type,
  1546. args=pos_args, is_temp=True)
  1547. result_node = UtilNodes.ResultRefNode(
  1548. pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
  1549. list_assign_node = Nodes.SingleAssignmentNode(
  1550. node.pos, lhs=result_node, rhs=list_node, first=True)
  1551. sort_method = ExprNodes.AttributeNode(
  1552. node.pos, obj=result_node, attribute=EncodedString('sort'),
  1553. # entry ? type ?
  1554. needs_none_check=False)
  1555. sort_node = Nodes.ExprStatNode(
  1556. node.pos, expr=ExprNodes.SimpleCallNode(
  1557. node.pos, function=sort_method, args=[]))
  1558. sort_node.analyse_declarations(self.current_env())
  1559. return UtilNodes.TempResultFromStatNode(
  1560. result_node,
  1561. Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
  1562. def __handle_simple_function_sum(self, node, pos_args):
  1563. """Transform sum(genexpr) into an equivalent inlined aggregation loop.
  1564. """
  1565. if len(pos_args) not in (1,2):
  1566. return node
  1567. if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
  1568. ExprNodes.ComprehensionNode)):
  1569. return node
  1570. gen_expr_node = pos_args[0]
  1571. loop_node = gen_expr_node.loop
  1572. if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
  1573. yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
  1574. # FIXME: currently nonfunctional
  1575. yield_expression = None
  1576. if yield_expression is None:
  1577. return node
  1578. else: # ComprehensionNode
  1579. yield_stat_node = gen_expr_node.append
  1580. yield_expression = yield_stat_node.expr
  1581. try:
  1582. if not yield_expression.is_literal or not yield_expression.type.is_int:
  1583. return node
  1584. except AttributeError:
  1585. return node # in case we don't have a type yet
  1586. # special case: old Py2 backwards compatible "sum([int_const for ...])"
  1587. # can safely be unpacked into a genexpr
  1588. if len(pos_args) == 1:
  1589. start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
  1590. else:
  1591. start = pos_args[1]
  1592. result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
  1593. add_node = Nodes.SingleAssignmentNode(
  1594. yield_expression.pos,
  1595. lhs = result_ref,
  1596. rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
  1597. )
  1598. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
  1599. exec_code = Nodes.StatListNode(
  1600. node.pos,
  1601. stats = [
  1602. Nodes.SingleAssignmentNode(
  1603. start.pos,
  1604. lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
  1605. rhs = start,
  1606. first = True),
  1607. loop_node
  1608. ])
  1609. return ExprNodes.InlinedGeneratorExpressionNode(
  1610. gen_expr_node.pos, loop = exec_code, result_node = result_ref,
  1611. expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
  1612. has_local_scope = gen_expr_node.has_local_scope)
  1613. def _handle_simple_function_min(self, node, pos_args):
  1614. return self._optimise_min_max(node, pos_args, '<')
  1615. def _handle_simple_function_max(self, node, pos_args):
  1616. return self._optimise_min_max(node, pos_args, '>')
  1617. def _optimise_min_max(self, node, args, operator):
  1618. """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
  1619. """
  1620. if len(args) <= 1:
  1621. if len(args) == 1 and args[0].is_sequence_constructor:
  1622. args = args[0].args
  1623. if len(args) <= 1:
  1624. # leave this to Python
  1625. return node
  1626. cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
  1627. last_result = args[0]
  1628. for arg_node in cascaded_nodes:
  1629. result_ref = UtilNodes.ResultRefNode(last_result)
  1630. last_result = ExprNodes.CondExprNode(
  1631. arg_node.pos,
  1632. true_val = arg_node,
  1633. false_val = result_ref,
  1634. test = ExprNodes.PrimaryCmpNode(
  1635. arg_node.pos,
  1636. operand1 = arg_node,
  1637. operator = operator,
  1638. operand2 = result_ref,
  1639. )
  1640. )
  1641. last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
  1642. for ref_node in cascaded_nodes[::-1]:
  1643. last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
  1644. return last_result
  1645. # builtin type creation
  1646. def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
  1647. if not pos_args:
  1648. return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
  1649. # This is a bit special - for iterables (including genexps),
  1650. # Python actually overallocates and resizes a newly created
  1651. # tuple incrementally while reading items, which we can't
  1652. # easily do without explicit node support. Instead, we read
  1653. # the items into a list and then copy them into a tuple of the
  1654. # final size. This takes up to twice as much memory, but will
  1655. # have to do until we have real support for genexps.
  1656. result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
  1657. if result is not node:
  1658. return ExprNodes.AsTupleNode(node.pos, arg=result)
  1659. return node
  1660. def _handle_simple_function_frozenset(self, node, pos_args):
  1661. """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
  1662. """
  1663. if len(pos_args) != 1:
  1664. return node
  1665. if pos_args[0].is_sequence_constructor and not pos_args[0].args:
  1666. del pos_args[0]
  1667. elif isinstance(pos_args[0], ExprNodes.ListNode):
  1668. pos_args[0] = pos_args[0].as_tuple()
  1669. return node
  1670. def _handle_simple_function_list(self, node, pos_args):
  1671. if not pos_args:
  1672. return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
  1673. return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
  1674. def _handle_simple_function_set(self, node, pos_args):
  1675. if not pos_args:
  1676. return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
  1677. return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
  1678. def _transform_list_set_genexpr(self, node, pos_args, target_type):
  1679. """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
  1680. """
  1681. if len(pos_args) > 1:
  1682. return node
  1683. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1684. return node
  1685. gen_expr_node = pos_args[0]
  1686. loop_node = gen_expr_node.loop
  1687. yield_statements = _find_yield_statements(loop_node)
  1688. if not yield_statements:
  1689. return node
  1690. result_node = ExprNodes.InlinedGeneratorExpressionNode(
  1691. node.pos, gen_expr_node,
  1692. orig_func='set' if target_type is Builtin.set_type else 'list',
  1693. comprehension_type=target_type)
  1694. for yield_expression, yield_stat_node in yield_statements:
  1695. append_node = ExprNodes.ComprehensionAppendNode(
  1696. yield_expression.pos,
  1697. expr=yield_expression,
  1698. target=result_node.target)
  1699. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1700. return result_node
  1701. def _handle_simple_function_dict(self, node, pos_args):
  1702. """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
  1703. """
  1704. if len(pos_args) == 0:
  1705. return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
  1706. if len(pos_args) > 1:
  1707. return node
  1708. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1709. return node
  1710. gen_expr_node = pos_args[0]
  1711. loop_node = gen_expr_node.loop
  1712. yield_statements = _find_yield_statements(loop_node)
  1713. if not yield_statements:
  1714. return node
  1715. for yield_expression, _ in yield_statements:
  1716. if not isinstance(yield_expression, ExprNodes.TupleNode):
  1717. return node
  1718. if len(yield_expression.args) != 2:
  1719. return node
  1720. result_node = ExprNodes.InlinedGeneratorExpressionNode(
  1721. node.pos, gen_expr_node, orig_func='dict',
  1722. comprehension_type=Builtin.dict_type)
  1723. for yield_expression, yield_stat_node in yield_statements:
  1724. append_node = ExprNodes.DictComprehensionAppendNode(
  1725. yield_expression.pos,
  1726. key_expr=yield_expression.args[0],
  1727. value_expr=yield_expression.args[1],
  1728. target=result_node.target)
  1729. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1730. return result_node
  1731. # specific handlers for general call nodes
  1732. def _handle_general_function_dict(self, node, pos_args, kwargs):
  1733. """Replace dict(a=b,c=d,...) by the underlying keyword dict
  1734. construction which is done anyway.
  1735. """
  1736. if len(pos_args) > 0:
  1737. return node
  1738. if not isinstance(kwargs, ExprNodes.DictNode):
  1739. return node
  1740. return kwargs
  1741. class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
  1742. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1743. def get_constant_value_node(self, name_node):
  1744. if name_node.cf_state is None:
  1745. return None
  1746. if name_node.cf_state.cf_is_null:
  1747. return None
  1748. entry = self.current_env().lookup(name_node.name)
  1749. if not entry or (not entry.cf_assignments
  1750. or len(entry.cf_assignments) != 1):
  1751. # not just a single assignment in all closures
  1752. return None
  1753. return entry.cf_assignments[0].rhs
  1754. def visit_SimpleCallNode(self, node):
  1755. self.visitchildren(node)
  1756. if not self.current_directives.get('optimize.inline_defnode_calls'):
  1757. return node
  1758. function_name = node.function
  1759. if not function_name.is_name:
  1760. return node
  1761. function = self.get_constant_value_node(function_name)
  1762. if not isinstance(function, ExprNodes.PyCFunctionNode):
  1763. return node
  1764. inlined = ExprNodes.InlinedDefNodeCallNode(
  1765. node.pos, function_name=function_name,
  1766. function=function, args=node.args)
  1767. if inlined.can_be_inlined():
  1768. return self.replace(node, inlined)
  1769. return node
  1770. class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
  1771. Visitor.MethodDispatcherTransform):
  1772. """Optimize some common methods calls and instantiation patterns
  1773. for builtin types *after* the type analysis phase.
  1774. Running after type analysis, this transform can only perform
  1775. function replacements that do not alter the function return type
  1776. in a way that was not anticipated by the type analysis.
  1777. """
  1778. ### cleanup to avoid redundant coercions to/from Python types
  1779. def visit_PyTypeTestNode(self, node):
  1780. """Flatten redundant type checks after tree changes.
  1781. """
  1782. self.visitchildren(node)
  1783. return node.reanalyse()
  1784. def _visit_TypecastNode(self, node):
  1785. # disabled - the user may have had a reason to put a type
  1786. # cast, even if it looks redundant to Cython
  1787. """
  1788. Drop redundant type casts.
  1789. """
  1790. self.visitchildren(node)
  1791. if node.type == node.operand.type:
  1792. return node.operand
  1793. return node
  1794. def visit_ExprStatNode(self, node):
  1795. """
  1796. Drop dead code and useless coercions.
  1797. """
  1798. self.visitchildren(node)
  1799. if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
  1800. node.expr = node.expr.arg
  1801. expr = node.expr
  1802. if expr is None or expr.is_none or expr.is_literal:
  1803. # Expression was removed or is dead code => remove ExprStatNode as well.
  1804. return None
  1805. if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
  1806. # Ignore dead references to local variables etc.
  1807. return None
  1808. return node
  1809. def visit_CoerceToBooleanNode(self, node):
  1810. """Drop redundant conversion nodes after tree changes.
  1811. """
  1812. self.visitchildren(node)
  1813. arg = node.arg
  1814. if isinstance(arg, ExprNodes.PyTypeTestNode):
  1815. arg = arg.arg
  1816. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  1817. if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
  1818. return arg.arg.coerce_to_boolean(self.current_env())
  1819. return node
  1820. PyNumber_Float_func_type = PyrexTypes.CFuncType(
  1821. PyrexTypes.py_object_type, [
  1822. PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
  1823. ])
  1824. def visit_CoerceToPyTypeNode(self, node):
  1825. """Drop redundant conversion nodes after tree changes."""
  1826. self.visitchildren(node)
  1827. arg = node.arg
  1828. if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
  1829. arg = arg.arg
  1830. if isinstance(arg, ExprNodes.PythonCapiCallNode):
  1831. if arg.function.name == 'float' and len(arg.args) == 1:
  1832. # undo redundant Py->C->Py coercion
  1833. func_arg = arg.args[0]
  1834. if func_arg.type is Builtin.float_type:
  1835. return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
  1836. elif func_arg.type.is_pyobject:
  1837. return ExprNodes.PythonCapiCallNode(
  1838. node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
  1839. args=[func_arg],
  1840. py_name='float',
  1841. is_temp=node.is_temp,
  1842. result_is_used=node.result_is_used,
  1843. ).coerce_to(node.type, self.current_env())
  1844. return node
  1845. def visit_CoerceFromPyTypeNode(self, node):
  1846. """Drop redundant conversion nodes after tree changes.
  1847. Also, optimise away calls to Python's builtin int() and
  1848. float() if the result is going to be coerced back into a C
  1849. type anyway.
  1850. """
  1851. self.visitchildren(node)
  1852. arg = node.arg
  1853. if not arg.type.is_pyobject:
  1854. # no Python conversion left at all, just do a C coercion instead
  1855. if node.type != arg.type:
  1856. arg = arg.coerce_to(node.type, self.current_env())
  1857. return arg
  1858. if isinstance(arg, ExprNodes.PyTypeTestNode):
  1859. arg = arg.arg
  1860. if arg.is_literal:
  1861. if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
  1862. node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
  1863. node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
  1864. return arg.coerce_to(node.type, self.current_env())
  1865. elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  1866. if arg.type is PyrexTypes.py_object_type:
  1867. if node.type.assignable_from(arg.arg.type):
  1868. # completely redundant C->Py->C coercion
  1869. return arg.arg.coerce_to(node.type, self.current_env())
  1870. elif arg.type is Builtin.unicode_type:
  1871. if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
  1872. return arg.arg.coerce_to(node.type, self.current_env())
  1873. elif isinstance(arg, ExprNodes.SimpleCallNode):
  1874. if node.type.is_int or node.type.is_float:
  1875. return self._optimise_numeric_cast_call(node, arg)
  1876. elif arg.is_subscript:
  1877. index_node = arg.index
  1878. if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
  1879. index_node = index_node.arg
  1880. if index_node.type.is_int:
  1881. return self._optimise_int_indexing(node, arg, index_node)
  1882. return node
  1883. PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
  1884. PyrexTypes.c_char_type, [
  1885. PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
  1886. PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
  1887. PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
  1888. ],
  1889. exception_value = "((char)-1)",
  1890. exception_check = True)
  1891. def _optimise_int_indexing(self, coerce_node, arg, index_node):
  1892. env = self.current_env()
  1893. bound_check_bool = env.directives['boundscheck'] and 1 or 0
  1894. if arg.base.type is Builtin.bytes_type:
  1895. if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
  1896. # bytes[index] -> char
  1897. bound_check_node = ExprNodes.IntNode(
  1898. coerce_node.pos, value=str(bound_check_bool),
  1899. constant_result=bound_check_bool)
  1900. node = ExprNodes.PythonCapiCallNode(
  1901. coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
  1902. self.PyBytes_GetItemInt_func_type,
  1903. args=[
  1904. arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
  1905. index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
  1906. bound_check_node,
  1907. ],
  1908. is_temp=True,
  1909. utility_code=UtilityCode.load_cached(
  1910. 'bytes_index', 'StringTools.c'))
  1911. if coerce_node.type is not PyrexTypes.c_char_type:
  1912. node = node.coerce_to(coerce_node.type, env)
  1913. return node
  1914. return coerce_node
  1915. float_float_func_types = dict(
  1916. (float_type, PyrexTypes.CFuncType(
  1917. float_type, [
  1918. PyrexTypes.CFuncTypeArg("arg", float_type, None)
  1919. ]))
  1920. for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type))
  1921. def _optimise_numeric_cast_call(self, node, arg):
  1922. function = arg.function
  1923. args = None
  1924. if isinstance(arg, ExprNodes.PythonCapiCallNode):
  1925. args = arg.args
  1926. elif isinstance(function, ExprNodes.NameNode):
  1927. if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
  1928. args = arg.arg_tuple.args
  1929. if args is None or len(args) != 1:
  1930. return node
  1931. func_arg = args[0]
  1932. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  1933. func_arg = func_arg.arg
  1934. elif func_arg.type.is_pyobject:
  1935. # play it safe: Python conversion might work on all sorts of things
  1936. return node
  1937. if function.name == 'int':
  1938. if func_arg.type.is_int or node.type.is_int:
  1939. if func_arg.type == node.type:
  1940. return func_arg
  1941. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
  1942. return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
  1943. elif func_arg.type.is_float and node.type.is_numeric:
  1944. if func_arg.type.math_h_modifier == 'l':
  1945. # Work around missing Cygwin definition.
  1946. truncl = '__Pyx_truncl'
  1947. else:
  1948. truncl = 'trunc' + func_arg.type.math_h_modifier
  1949. return ExprNodes.PythonCapiCallNode(
  1950. node.pos, truncl,
  1951. func_type=self.float_float_func_types[func_arg.type],
  1952. args=[func_arg],
  1953. py_name='int',
  1954. is_temp=node.is_temp,
  1955. result_is_used=node.result_is_used,
  1956. ).coerce_to(node.type, self.current_env())
  1957. elif function.name == 'float':
  1958. if func_arg.type.is_float or node.type.is_float:
  1959. if func_arg.type == node.type:
  1960. return func_arg
  1961. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
  1962. return ExprNodes.TypecastNode(
  1963. node.pos, operand=func_arg, type=node.type)
  1964. return node
  1965. def _error_wrong_arg_count(self, function_name, node, args, expected=None):
  1966. if not expected: # None or 0
  1967. arg_str = ''
  1968. elif isinstance(expected, basestring) or expected > 1:
  1969. arg_str = '...'
  1970. elif expected == 1:
  1971. arg_str = 'x'
  1972. else:
  1973. arg_str = ''
  1974. if expected is not None:
  1975. expected_str = 'expected %s, ' % expected
  1976. else:
  1977. expected_str = ''
  1978. error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
  1979. function_name, arg_str, expected_str, len(args)))
  1980. ### generic fallbacks
  1981. def _handle_function(self, node, function_name, function, arg_list, kwargs):
  1982. return node
  1983. def _handle_method(self, node, type_name, attr_name, function,
  1984. arg_list, is_unbound_method, kwargs):
  1985. """
  1986. Try to inject C-API calls for unbound method calls to builtin types.
  1987. While the method declarations in Builtin.py already handle this, we
  1988. can additionally resolve bound and unbound methods here that were
  1989. assigned to variables ahead of time.
  1990. """
  1991. if kwargs:
  1992. return node
  1993. if not function or not function.is_attribute or not function.obj.is_name:
  1994. # cannot track unbound method calls over more than one indirection as
  1995. # the names might have been reassigned in the meantime
  1996. return node
  1997. type_entry = self.current_env().lookup(type_name)
  1998. if not type_entry:
  1999. return node
  2000. method = ExprNodes.AttributeNode(
  2001. node.function.pos,
  2002. obj=ExprNodes.NameNode(
  2003. function.pos,
  2004. name=type_name,
  2005. entry=type_entry,
  2006. type=type_entry.type),
  2007. attribute=attr_name,
  2008. is_called=True).analyse_as_type_attribute(self.current_env())
  2009. if method is None:
  2010. return self._optimise_generic_builtin_method_call(
  2011. node, attr_name, function, arg_list, is_unbound_method)
  2012. args = node.args
  2013. if args is None and node.arg_tuple:
  2014. args = node.arg_tuple.args
  2015. call_node = ExprNodes.SimpleCallNode(
  2016. node.pos,
  2017. function=method,
  2018. args=args)
  2019. if not is_unbound_method:
  2020. call_node.self = function.obj
  2021. call_node.analyse_c_function_call(self.current_env())
  2022. call_node.analysed = True
  2023. return call_node.coerce_to(node.type, self.current_env())
  2024. ### builtin types
  2025. def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
  2026. """
  2027. Try to inject an unbound method call for a call to a method of a known builtin type.
  2028. This enables caching the underlying C function of the method at runtime.
  2029. """
  2030. arg_count = len(arg_list)
  2031. if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
  2032. return node
  2033. if not function.obj.type.is_builtin_type:
  2034. return node
  2035. if function.obj.type.name in ('basestring', 'type'):
  2036. # these allow different actual types => unsafe
  2037. return node
  2038. return ExprNodes.CachedBuiltinMethodCallNode(
  2039. node, function.obj, attr_name, arg_list)
  2040. PyObject_Unicode_func_type = PyrexTypes.CFuncType(
  2041. Builtin.unicode_type, [
  2042. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
  2043. ])
  2044. def _handle_simple_function_unicode(self, node, function, pos_args):
  2045. """Optimise single argument calls to unicode().
  2046. """
  2047. if len(pos_args) != 1:
  2048. if len(pos_args) == 0:
  2049. return ExprNodes.UnicodeNode(node.pos, value=EncodedString(), constant_result=u'')
  2050. return node
  2051. arg = pos_args[0]
  2052. if arg.type is Builtin.unicode_type:
  2053. if not arg.may_be_none():
  2054. return arg
  2055. cname = "__Pyx_PyUnicode_Unicode"
  2056. utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c')
  2057. else:
  2058. cname = "__Pyx_PyObject_Unicode"
  2059. utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c')
  2060. return ExprNodes.PythonCapiCallNode(
  2061. node.pos, cname, self.PyObject_Unicode_func_type,
  2062. args=pos_args,
  2063. is_temp=node.is_temp,
  2064. utility_code=utility_code,
  2065. py_name="unicode")
  2066. def visit_FormattedValueNode(self, node):
  2067. """Simplify or avoid plain string formatting of a unicode value.
  2068. This seems misplaced here, but plain unicode formatting is essentially
  2069. a call to the unicode() builtin, which is optimised right above.
  2070. """
  2071. self.visitchildren(node)
  2072. if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec:
  2073. if not node.conversion_char or node.conversion_char == 's':
  2074. # value is definitely a unicode string and we don't format it any special
  2075. return self._handle_simple_function_unicode(node, None, [node.value])
  2076. return node
  2077. PyDict_Copy_func_type = PyrexTypes.CFuncType(
  2078. Builtin.dict_type, [
  2079. PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
  2080. ])
  2081. def _handle_simple_function_dict(self, node, function, pos_args):
  2082. """Replace dict(some_dict) by PyDict_Copy(some_dict).
  2083. """
  2084. if len(pos_args) != 1:
  2085. return node
  2086. arg = pos_args[0]
  2087. if arg.type is Builtin.dict_type:
  2088. arg = arg.as_none_safe_node("'NoneType' is not iterable")
  2089. return ExprNodes.PythonCapiCallNode(
  2090. node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
  2091. args = [arg],
  2092. is_temp = node.is_temp
  2093. )
  2094. return node
  2095. PySequence_List_func_type = PyrexTypes.CFuncType(
  2096. Builtin.list_type,
  2097. [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
  2098. def _handle_simple_function_list(self, node, function, pos_args):
  2099. """Turn list(ob) into PySequence_List(ob).
  2100. """
  2101. if len(pos_args) != 1:
  2102. return node
  2103. arg = pos_args[0]
  2104. return ExprNodes.PythonCapiCallNode(
  2105. node.pos, "PySequence_List", self.PySequence_List_func_type,
  2106. args=pos_args, is_temp=node.is_temp)
  2107. PyList_AsTuple_func_type = PyrexTypes.CFuncType(
  2108. Builtin.tuple_type, [
  2109. PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
  2110. ])
  2111. def _handle_simple_function_tuple(self, node, function, pos_args):
  2112. """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
  2113. """
  2114. if len(pos_args) != 1 or not node.is_temp:
  2115. return node
  2116. arg = pos_args[0]
  2117. if arg.type is Builtin.tuple_type and not arg.may_be_none():
  2118. return arg
  2119. if arg.type is Builtin.list_type:
  2120. pos_args[0] = arg.as_none_safe_node(
  2121. "'NoneType' object is not iterable")
  2122. return ExprNodes.PythonCapiCallNode(
  2123. node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
  2124. args=pos_args, is_temp=node.is_temp)
  2125. else:
  2126. return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
  2127. PySet_New_func_type = PyrexTypes.CFuncType(
  2128. Builtin.set_type, [
  2129. PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
  2130. ])
  2131. def _handle_simple_function_set(self, node, function, pos_args):
  2132. if len(pos_args) != 1:
  2133. return node
  2134. if pos_args[0].is_sequence_constructor:
  2135. # We can optimise set([x,y,z]) safely into a set literal,
  2136. # but only if we create all items before adding them -
  2137. # adding an item may raise an exception if it is not
  2138. # hashable, but creating the later items may have
  2139. # side-effects.
  2140. args = []
  2141. temps = []
  2142. for arg in pos_args[0].args:
  2143. if not arg.is_simple():
  2144. arg = UtilNodes.LetRefNode(arg)
  2145. temps.append(arg)
  2146. args.append(arg)
  2147. result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
  2148. self.replace(node, result)
  2149. for temp in temps[::-1]:
  2150. result = UtilNodes.EvalWithTempExprNode(temp, result)
  2151. return result
  2152. else:
  2153. # PySet_New(it) is better than a generic Python call to set(it)
  2154. return self.replace(node, ExprNodes.PythonCapiCallNode(
  2155. node.pos, "PySet_New",
  2156. self.PySet_New_func_type,
  2157. args=pos_args,
  2158. is_temp=node.is_temp,
  2159. py_name="set"))
  2160. PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
  2161. Builtin.frozenset_type, [
  2162. PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
  2163. ])
  2164. def _handle_simple_function_frozenset(self, node, function, pos_args):
  2165. if not pos_args:
  2166. pos_args = [ExprNodes.NullNode(node.pos)]
  2167. elif len(pos_args) > 1:
  2168. return node
  2169. elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
  2170. return pos_args[0]
  2171. # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
  2172. return ExprNodes.PythonCapiCallNode(
  2173. node.pos, "__Pyx_PyFrozenSet_New",
  2174. self.PyFrozenSet_New_func_type,
  2175. args=pos_args,
  2176. is_temp=node.is_temp,
  2177. utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
  2178. py_name="frozenset")
  2179. PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
  2180. PyrexTypes.c_double_type, [
  2181. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
  2182. ],
  2183. exception_value = "((double)-1)",
  2184. exception_check = True)
  2185. def _handle_simple_function_float(self, node, function, pos_args):
  2186. """Transform float() into either a C type cast or a faster C
  2187. function call.
  2188. """
  2189. # Note: this requires the float() function to be typed as
  2190. # returning a C 'double'
  2191. if len(pos_args) == 0:
  2192. return ExprNodes.FloatNode(
  2193. node, value="0.0", constant_result=0.0
  2194. ).coerce_to(Builtin.float_type, self.current_env())
  2195. elif len(pos_args) != 1:
  2196. self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
  2197. return node
  2198. func_arg = pos_args[0]
  2199. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  2200. func_arg = func_arg.arg
  2201. if func_arg.type is PyrexTypes.c_double_type:
  2202. return func_arg
  2203. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
  2204. return ExprNodes.TypecastNode(
  2205. node.pos, operand=func_arg, type=node.type)
  2206. return ExprNodes.PythonCapiCallNode(
  2207. node.pos, "__Pyx_PyObject_AsDouble",
  2208. self.PyObject_AsDouble_func_type,
  2209. args = pos_args,
  2210. is_temp = node.is_temp,
  2211. utility_code = load_c_utility('pyobject_as_double'),
  2212. py_name = "float")
  2213. PyNumber_Int_func_type = PyrexTypes.CFuncType(
  2214. PyrexTypes.py_object_type, [
  2215. PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
  2216. ])
  2217. PyInt_FromDouble_func_type = PyrexTypes.CFuncType(
  2218. PyrexTypes.py_object_type, [
  2219. PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
  2220. ])
  2221. def _handle_simple_function_int(self, node, function, pos_args):
  2222. """Transform int() into a faster C function call.
  2223. """
  2224. if len(pos_args) == 0:
  2225. return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
  2226. type=PyrexTypes.py_object_type)
  2227. elif len(pos_args) != 1:
  2228. return node # int(x, base)
  2229. func_arg = pos_args[0]
  2230. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  2231. if func_arg.arg.type.is_float:
  2232. return ExprNodes.PythonCapiCallNode(
  2233. node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type,
  2234. args=[func_arg.arg], is_temp=True, py_name='int',
  2235. utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c"))
  2236. else:
  2237. return node # handled in visit_CoerceFromPyTypeNode()
  2238. if func_arg.type.is_pyobject and node.type.is_pyobject:
  2239. return ExprNodes.PythonCapiCallNode(
  2240. node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
  2241. args=pos_args, is_temp=True, py_name='int')
  2242. return node
  2243. def _handle_simple_function_bool(self, node, function, pos_args):
  2244. """Transform bool(x) into a type coercion to a boolean.
  2245. """
  2246. if len(pos_args) == 0:
  2247. return ExprNodes.BoolNode(
  2248. node.pos, value=False, constant_result=False
  2249. ).coerce_to(Builtin.bool_type, self.current_env())
  2250. elif len(pos_args) != 1:
  2251. self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
  2252. return node
  2253. else:
  2254. # => !!<bint>(x) to make sure it's exactly 0 or 1
  2255. operand = pos_args[0].coerce_to_boolean(self.current_env())
  2256. operand = ExprNodes.NotNode(node.pos, operand = operand)
  2257. operand = ExprNodes.NotNode(node.pos, operand = operand)
  2258. # coerce back to Python object as that's the result we are expecting
  2259. return operand.coerce_to_pyobject(self.current_env())
  2260. ### builtin functions
  2261. Pyx_strlen_func_type = PyrexTypes.CFuncType(
  2262. PyrexTypes.c_size_t_type, [
  2263. PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
  2264. ])
  2265. Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
  2266. PyrexTypes.c_size_t_type, [
  2267. PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
  2268. ])
  2269. PyObject_Size_func_type = PyrexTypes.CFuncType(
  2270. PyrexTypes.c_py_ssize_t_type, [
  2271. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
  2272. ],
  2273. exception_value="-1")
  2274. _map_to_capi_len_function = {
  2275. Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH",
  2276. Builtin.bytes_type: "PyBytes_GET_SIZE",
  2277. Builtin.bytearray_type: 'PyByteArray_GET_SIZE',
  2278. Builtin.list_type: "PyList_GET_SIZE",
  2279. Builtin.tuple_type: "PyTuple_GET_SIZE",
  2280. Builtin.set_type: "PySet_GET_SIZE",
  2281. Builtin.frozenset_type: "PySet_GET_SIZE",
  2282. Builtin.dict_type: "PyDict_Size",
  2283. }.get
  2284. _ext_types_with_pysize = set(["cpython.array.array"])
  2285. def _handle_simple_function_len(self, node, function, pos_args):
  2286. """Replace len(char*) by the equivalent call to strlen(),
  2287. len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
  2288. len(known_builtin_type) by an equivalent C-API call.
  2289. """
  2290. if len(pos_args) != 1:
  2291. self._error_wrong_arg_count('len', node, pos_args, 1)
  2292. return node
  2293. arg = pos_args[0]
  2294. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  2295. arg = arg.arg
  2296. if arg.type.is_string:
  2297. new_node = ExprNodes.PythonCapiCallNode(
  2298. node.pos, "strlen", self.Pyx_strlen_func_type,
  2299. args = [arg],
  2300. is_temp = node.is_temp,
  2301. utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
  2302. elif arg.type.is_pyunicode_ptr:
  2303. new_node = ExprNodes.PythonCapiCallNode(
  2304. node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
  2305. args = [arg],
  2306. is_temp = node.is_temp)
  2307. elif arg.type.is_memoryviewslice:
  2308. func_type = PyrexTypes.CFuncType(
  2309. PyrexTypes.c_size_t_type, [
  2310. PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
  2311. ], nogil=True)
  2312. new_node = ExprNodes.PythonCapiCallNode(
  2313. node.pos, "__Pyx_MemoryView_Len", func_type,
  2314. args=[arg], is_temp=node.is_temp)
  2315. elif arg.type.is_pyobject:
  2316. cfunc_name = self._map_to_capi_len_function(arg.type)
  2317. if cfunc_name is None:
  2318. arg_type = arg.type
  2319. if ((arg_type.is_extension_type or arg_type.is_builtin_type)
  2320. and arg_type.entry.qualified_name in self._ext_types_with_pysize):
  2321. cfunc_name = 'Py_SIZE'
  2322. else:
  2323. return node
  2324. arg = arg.as_none_safe_node(
  2325. "object of type 'NoneType' has no len()")
  2326. new_node = ExprNodes.PythonCapiCallNode(
  2327. node.pos, cfunc_name, self.PyObject_Size_func_type,
  2328. args=[arg], is_temp=node.is_temp)
  2329. elif arg.type.is_unicode_char:
  2330. return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
  2331. type=node.type)
  2332. else:
  2333. return node
  2334. if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
  2335. new_node = new_node.coerce_to(node.type, self.current_env())
  2336. return new_node
  2337. Pyx_Type_func_type = PyrexTypes.CFuncType(
  2338. Builtin.type_type, [
  2339. PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
  2340. ])
  2341. def _handle_simple_function_type(self, node, function, pos_args):
  2342. """Replace type(o) by a macro call to Py_TYPE(o).
  2343. """
  2344. if len(pos_args) != 1:
  2345. return node
  2346. node = ExprNodes.PythonCapiCallNode(
  2347. node.pos, "Py_TYPE", self.Pyx_Type_func_type,
  2348. args = pos_args,
  2349. is_temp = False)
  2350. return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
  2351. Py_type_check_func_type = PyrexTypes.CFuncType(
  2352. PyrexTypes.c_bint_type, [
  2353. PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
  2354. ])
  2355. def _handle_simple_function_isinstance(self, node, function, pos_args):
  2356. """Replace isinstance() checks against builtin types by the
  2357. corresponding C-API call.
  2358. """
  2359. if len(pos_args) != 2:
  2360. return node
  2361. arg, types = pos_args
  2362. temps = []
  2363. if isinstance(types, ExprNodes.TupleNode):
  2364. types = types.args
  2365. if len(types) == 1 and not types[0].type is Builtin.type_type:
  2366. return node # nothing to improve here
  2367. if arg.is_attribute or not arg.is_simple():
  2368. arg = UtilNodes.ResultRefNode(arg)
  2369. temps.append(arg)
  2370. elif types.type is Builtin.type_type:
  2371. types = [types]
  2372. else:
  2373. return node
  2374. tests = []
  2375. test_nodes = []
  2376. env = self.current_env()
  2377. for test_type_node in types:
  2378. builtin_type = None
  2379. if test_type_node.is_name:
  2380. if test_type_node.entry:
  2381. entry = env.lookup(test_type_node.entry.name)
  2382. if entry and entry.type and entry.type.is_builtin_type:
  2383. builtin_type = entry.type
  2384. if builtin_type is Builtin.type_type:
  2385. # all types have type "type", but there's only one 'type'
  2386. if entry.name != 'type' or not (
  2387. entry.scope and entry.scope.is_builtin_scope):
  2388. builtin_type = None
  2389. if builtin_type is not None:
  2390. type_check_function = entry.type.type_check_function(exact=False)
  2391. if type_check_function in tests:
  2392. continue
  2393. tests.append(type_check_function)
  2394. type_check_args = [arg]
  2395. elif test_type_node.type is Builtin.type_type:
  2396. type_check_function = '__Pyx_TypeCheck'
  2397. type_check_args = [arg, test_type_node]
  2398. else:
  2399. if not test_type_node.is_literal:
  2400. test_type_node = UtilNodes.ResultRefNode(test_type_node)
  2401. temps.append(test_type_node)
  2402. type_check_function = 'PyObject_IsInstance'
  2403. type_check_args = [arg, test_type_node]
  2404. test_nodes.append(
  2405. ExprNodes.PythonCapiCallNode(
  2406. test_type_node.pos, type_check_function, self.Py_type_check_func_type,
  2407. args=type_check_args,
  2408. is_temp=True,
  2409. ))
  2410. def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
  2411. or_node = make_binop_node(node.pos, 'or', a, b)
  2412. or_node.type = PyrexTypes.c_bint_type
  2413. or_node.wrap_operands(env)
  2414. return or_node
  2415. test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
  2416. for temp in temps[::-1]:
  2417. test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
  2418. return test_node
  2419. def _handle_simple_function_ord(self, node, function, pos_args):
  2420. """Unpack ord(Py_UNICODE) and ord('X').
  2421. """
  2422. if len(pos_args) != 1:
  2423. return node
  2424. arg = pos_args[0]
  2425. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  2426. if arg.arg.type.is_unicode_char:
  2427. return ExprNodes.TypecastNode(
  2428. arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
  2429. ).coerce_to(node.type, self.current_env())
  2430. elif isinstance(arg, ExprNodes.UnicodeNode):
  2431. if len(arg.value) == 1:
  2432. return ExprNodes.IntNode(
  2433. arg.pos, type=PyrexTypes.c_int_type,
  2434. value=str(ord(arg.value)),
  2435. constant_result=ord(arg.value)
  2436. ).coerce_to(node.type, self.current_env())
  2437. elif isinstance(arg, ExprNodes.StringNode):
  2438. if arg.unicode_value and len(arg.unicode_value) == 1 \
  2439. and ord(arg.unicode_value) <= 255: # Py2/3 portability
  2440. return ExprNodes.IntNode(
  2441. arg.pos, type=PyrexTypes.c_int_type,
  2442. value=str(ord(arg.unicode_value)),
  2443. constant_result=ord(arg.unicode_value)
  2444. ).coerce_to(node.type, self.current_env())
  2445. return node
  2446. ### special methods
  2447. Pyx_tp_new_func_type = PyrexTypes.CFuncType(
  2448. PyrexTypes.py_object_type, [
  2449. PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
  2450. PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
  2451. ])
  2452. Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
  2453. PyrexTypes.py_object_type, [
  2454. PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
  2455. PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
  2456. PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
  2457. ])
  2458. def _handle_any_slot__new__(self, node, function, args,
  2459. is_unbound_method, kwargs=None):
  2460. """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
  2461. """
  2462. obj = function.obj
  2463. if not is_unbound_method or len(args) < 1:
  2464. return node
  2465. type_arg = args[0]
  2466. if not obj.is_name or not type_arg.is_name:
  2467. # play safe
  2468. return node
  2469. if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
  2470. # not a known type, play safe
  2471. return node
  2472. if not type_arg.type_entry or not obj.type_entry:
  2473. if obj.name != type_arg.name:
  2474. return node
  2475. # otherwise, we know it's a type and we know it's the same
  2476. # type for both - that should do
  2477. elif type_arg.type_entry != obj.type_entry:
  2478. # different types - may or may not lead to an error at runtime
  2479. return node
  2480. args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
  2481. args_tuple = args_tuple.analyse_types(
  2482. self.current_env(), skip_children=True)
  2483. if type_arg.type_entry:
  2484. ext_type = type_arg.type_entry.type
  2485. if (ext_type.is_extension_type and ext_type.typeobj_cname and
  2486. ext_type.scope.global_scope() == self.current_env().global_scope()):
  2487. # known type in current module
  2488. tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
  2489. slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
  2490. if slot_func_cname:
  2491. cython_scope = self.context.cython_scope
  2492. PyTypeObjectPtr = PyrexTypes.CPtrType(
  2493. cython_scope.lookup('PyTypeObject').type)
  2494. pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
  2495. ext_type, [
  2496. PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None),
  2497. PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None),
  2498. PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
  2499. ])
  2500. type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
  2501. if not kwargs:
  2502. kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack?
  2503. return ExprNodes.PythonCapiCallNode(
  2504. node.pos, slot_func_cname,
  2505. pyx_tp_new_kwargs_func_type,
  2506. args=[type_arg, args_tuple, kwargs],
  2507. may_return_none=False,
  2508. is_temp=True)
  2509. else:
  2510. # arbitrary variable, needs a None check for safety
  2511. type_arg = type_arg.as_none_safe_node(
  2512. "object.__new__(X): X is not a type object (NoneType)")
  2513. utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
  2514. if kwargs:
  2515. return ExprNodes.PythonCapiCallNode(
  2516. node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
  2517. args=[type_arg, args_tuple, kwargs],
  2518. utility_code=utility_code,
  2519. is_temp=node.is_temp
  2520. )
  2521. else:
  2522. return ExprNodes.PythonCapiCallNode(
  2523. node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
  2524. args=[type_arg, args_tuple],
  2525. utility_code=utility_code,
  2526. is_temp=node.is_temp
  2527. )
  2528. ### methods of builtin types
  2529. PyObject_Append_func_type = PyrexTypes.CFuncType(
  2530. PyrexTypes.c_returncode_type, [
  2531. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2532. PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
  2533. ],
  2534. exception_value="-1")
  2535. def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
  2536. """Optimistic optimisation as X.append() is almost always
  2537. referring to a list.
  2538. """
  2539. if len(args) != 2 or node.result_is_used or node.function.entry:
  2540. return node
  2541. return ExprNodes.PythonCapiCallNode(
  2542. node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
  2543. args=args,
  2544. may_return_none=False,
  2545. is_temp=node.is_temp,
  2546. result_is_used=False,
  2547. utility_code=load_c_utility('append')
  2548. )
  2549. def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
  2550. """Replace list.extend([...]) for short sequence literals values by sequential appends
  2551. to avoid creating an intermediate sequence argument.
  2552. """
  2553. if len(args) != 2:
  2554. return node
  2555. obj, value = args
  2556. if not value.is_sequence_constructor:
  2557. return node
  2558. items = list(value.args)
  2559. if value.mult_factor is not None or len(items) > 8:
  2560. # Appending wins for short sequences but slows down when multiple resize operations are needed.
  2561. # This seems to be a good enough limit that avoids repeated resizing.
  2562. if False and isinstance(value, ExprNodes.ListNode):
  2563. # One would expect that tuples are more efficient here, but benchmarking with
  2564. # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point.
  2565. # Might be related to the usage of PySequence_FAST() in CPython's list.extend(),
  2566. # which is probably tuned more towards lists than tuples (and rightly so).
  2567. tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True)
  2568. Visitor.recursively_replace_node(node, args[1], tuple_node)
  2569. return node
  2570. wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
  2571. if not items:
  2572. # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
  2573. wrapped_obj.result_is_used = node.result_is_used
  2574. return wrapped_obj
  2575. cloned_obj = obj = wrapped_obj
  2576. if len(items) > 1 and not obj.is_simple():
  2577. cloned_obj = UtilNodes.LetRefNode(obj)
  2578. # Use ListComp_Append() for all but the last item and finish with PyList_Append()
  2579. # to shrink the list storage size at the very end if necessary.
  2580. temps = []
  2581. arg = items[-1]
  2582. if not arg.is_simple():
  2583. arg = UtilNodes.LetRefNode(arg)
  2584. temps.append(arg)
  2585. new_node = ExprNodes.PythonCapiCallNode(
  2586. node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
  2587. args=[cloned_obj, arg],
  2588. is_temp=True,
  2589. utility_code=load_c_utility("ListAppend"))
  2590. for arg in items[-2::-1]:
  2591. if not arg.is_simple():
  2592. arg = UtilNodes.LetRefNode(arg)
  2593. temps.append(arg)
  2594. new_node = ExprNodes.binop_node(
  2595. node.pos, '|',
  2596. ExprNodes.PythonCapiCallNode(
  2597. node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
  2598. args=[cloned_obj, arg], py_name="extend",
  2599. is_temp=True,
  2600. utility_code=load_c_utility("ListCompAppend")),
  2601. new_node,
  2602. type=PyrexTypes.c_returncode_type,
  2603. )
  2604. new_node.result_is_used = node.result_is_used
  2605. if cloned_obj is not obj:
  2606. temps.append(cloned_obj)
  2607. for temp in temps:
  2608. new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
  2609. new_node.result_is_used = node.result_is_used
  2610. return new_node
  2611. PyByteArray_Append_func_type = PyrexTypes.CFuncType(
  2612. PyrexTypes.c_returncode_type, [
  2613. PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
  2614. PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
  2615. ],
  2616. exception_value="-1")
  2617. PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
  2618. PyrexTypes.c_returncode_type, [
  2619. PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
  2620. PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
  2621. ],
  2622. exception_value="-1")
  2623. def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
  2624. if len(args) != 2:
  2625. return node
  2626. func_name = "__Pyx_PyByteArray_Append"
  2627. func_type = self.PyByteArray_Append_func_type
  2628. value = unwrap_coerced_node(args[1])
  2629. if value.type.is_int or isinstance(value, ExprNodes.IntNode):
  2630. value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
  2631. utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
  2632. elif value.is_string_literal:
  2633. if not value.can_coerce_to_char_literal():
  2634. return node
  2635. value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
  2636. utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
  2637. elif value.type.is_pyobject:
  2638. func_name = "__Pyx_PyByteArray_AppendObject"
  2639. func_type = self.PyByteArray_AppendObject_func_type
  2640. utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
  2641. else:
  2642. return node
  2643. new_node = ExprNodes.PythonCapiCallNode(
  2644. node.pos, func_name, func_type,
  2645. args=[args[0], value],
  2646. may_return_none=False,
  2647. is_temp=node.is_temp,
  2648. utility_code=utility_code,
  2649. )
  2650. if node.result_is_used:
  2651. new_node = new_node.coerce_to(node.type, self.current_env())
  2652. return new_node
  2653. PyObject_Pop_func_type = PyrexTypes.CFuncType(
  2654. PyrexTypes.py_object_type, [
  2655. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2656. ])
  2657. PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
  2658. PyrexTypes.py_object_type, [
  2659. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2660. PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
  2661. PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
  2662. PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
  2663. ],
  2664. has_varargs=True) # to fake the additional macro args that lack a proper C type
  2665. def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
  2666. return self._handle_simple_method_object_pop(
  2667. node, function, args, is_unbound_method, is_list=True)
  2668. def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
  2669. """Optimistic optimisation as X.pop([n]) is almost always
  2670. referring to a list.
  2671. """
  2672. if not args:
  2673. return node
  2674. obj = args[0]
  2675. if is_list:
  2676. type_name = 'List'
  2677. obj = obj.as_none_safe_node(
  2678. "'NoneType' object has no attribute '%.30s'",
  2679. error="PyExc_AttributeError",
  2680. format_args=['pop'])
  2681. else:
  2682. type_name = 'Object'
  2683. if len(args) == 1:
  2684. return ExprNodes.PythonCapiCallNode(
  2685. node.pos, "__Pyx_Py%s_Pop" % type_name,
  2686. self.PyObject_Pop_func_type,
  2687. args=[obj],
  2688. may_return_none=True,
  2689. is_temp=node.is_temp,
  2690. utility_code=load_c_utility('pop'),
  2691. )
  2692. elif len(args) == 2:
  2693. index = unwrap_coerced_node(args[1])
  2694. py_index = ExprNodes.NoneNode(index.pos)
  2695. orig_index_type = index.type
  2696. if not index.type.is_int:
  2697. if isinstance(index, ExprNodes.IntNode):
  2698. py_index = index.coerce_to_pyobject(self.current_env())
  2699. index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  2700. elif is_list:
  2701. if index.type.is_pyobject:
  2702. py_index = index.coerce_to_simple(self.current_env())
  2703. index = ExprNodes.CloneNode(py_index)
  2704. index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  2705. else:
  2706. return node
  2707. elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
  2708. return node
  2709. elif isinstance(index, ExprNodes.IntNode):
  2710. py_index = index.coerce_to_pyobject(self.current_env())
  2711. # real type might still be larger at runtime
  2712. if not orig_index_type.is_int:
  2713. orig_index_type = index.type
  2714. if not orig_index_type.create_to_py_utility_code(self.current_env()):
  2715. return node
  2716. convert_func = orig_index_type.to_py_function
  2717. conversion_type = PyrexTypes.CFuncType(
  2718. PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
  2719. return ExprNodes.PythonCapiCallNode(
  2720. node.pos, "__Pyx_Py%s_PopIndex" % type_name,
  2721. self.PyObject_PopIndex_func_type,
  2722. args=[obj, py_index, index,
  2723. ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
  2724. constant_result=orig_index_type.signed and 1 or 0,
  2725. type=PyrexTypes.c_int_type),
  2726. ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
  2727. orig_index_type.empty_declaration_code()),
  2728. ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
  2729. may_return_none=True,
  2730. is_temp=node.is_temp,
  2731. utility_code=load_c_utility("pop_index"),
  2732. )
  2733. return node
  2734. single_param_func_type = PyrexTypes.CFuncType(
  2735. PyrexTypes.c_returncode_type, [
  2736. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
  2737. ],
  2738. exception_value = "-1")
  2739. def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
  2740. """Call PyList_Sort() instead of the 0-argument l.sort().
  2741. """
  2742. if len(args) != 1:
  2743. return node
  2744. return self._substitute_method_call(
  2745. node, function, "PyList_Sort", self.single_param_func_type,
  2746. 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
  2747. Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
  2748. PyrexTypes.py_object_type, [
  2749. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2750. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2751. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2752. ])
  2753. def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
  2754. """Replace dict.get() by a call to PyDict_GetItem().
  2755. """
  2756. if len(args) == 2:
  2757. args.append(ExprNodes.NoneNode(node.pos))
  2758. elif len(args) != 3:
  2759. self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
  2760. return node
  2761. return self._substitute_method_call(
  2762. node, function,
  2763. "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
  2764. 'get', is_unbound_method, args,
  2765. may_return_none = True,
  2766. utility_code = load_c_utility("dict_getitem_default"))
  2767. Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
  2768. PyrexTypes.py_object_type, [
  2769. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2770. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2771. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2772. PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
  2773. ])
  2774. def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
  2775. """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
  2776. """
  2777. if len(args) == 2:
  2778. args.append(ExprNodes.NoneNode(node.pos))
  2779. elif len(args) != 3:
  2780. self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
  2781. return node
  2782. key_type = args[1].type
  2783. if key_type.is_builtin_type:
  2784. is_safe_type = int(key_type.name in
  2785. 'str bytes unicode float int long bool')
  2786. elif key_type is PyrexTypes.py_object_type:
  2787. is_safe_type = -1 # don't know
  2788. else:
  2789. is_safe_type = 0 # definitely not
  2790. args.append(ExprNodes.IntNode(
  2791. node.pos, value=str(is_safe_type), constant_result=is_safe_type))
  2792. return self._substitute_method_call(
  2793. node, function,
  2794. "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
  2795. 'setdefault', is_unbound_method, args,
  2796. may_return_none=True,
  2797. utility_code=load_c_utility('dict_setdefault'))
  2798. PyDict_Pop_func_type = PyrexTypes.CFuncType(
  2799. PyrexTypes.py_object_type, [
  2800. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2801. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2802. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2803. ])
  2804. def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
  2805. """Replace dict.pop() by a call to _PyDict_Pop().
  2806. """
  2807. if len(args) == 2:
  2808. args.append(ExprNodes.NullNode(node.pos))
  2809. elif len(args) != 3:
  2810. self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
  2811. return node
  2812. return self._substitute_method_call(
  2813. node, function,
  2814. "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type,
  2815. 'pop', is_unbound_method, args,
  2816. may_return_none=True,
  2817. utility_code=load_c_utility('py_dict_pop'))
  2818. Pyx_BinopInt_func_types = dict(
  2819. ((ctype, ret_type), PyrexTypes.CFuncType(
  2820. ret_type, [
  2821. PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
  2822. PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
  2823. PyrexTypes.CFuncTypeArg("cval", ctype, None),
  2824. PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
  2825. PyrexTypes.CFuncTypeArg("zerodiv_check", PyrexTypes.c_bint_type, None),
  2826. ], exception_value=None if ret_type.is_pyobject else ret_type.exception_value))
  2827. for ctype in (PyrexTypes.c_long_type, PyrexTypes.c_double_type)
  2828. for ret_type in (PyrexTypes.py_object_type, PyrexTypes.c_bint_type)
  2829. )
  2830. def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
  2831. return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
  2832. def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
  2833. return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
  2834. def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
  2835. return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
  2836. def _handle_simple_method_object___ne__(self, node, function, args, is_unbound_method):
  2837. return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
  2838. def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
  2839. return self._optimise_num_binop('And', node, function, args, is_unbound_method)
  2840. def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
  2841. return self._optimise_num_binop('Or', node, function, args, is_unbound_method)
  2842. def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
  2843. return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)
  2844. def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
  2845. if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
  2846. return node
  2847. if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
  2848. return node
  2849. return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)
  2850. def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
  2851. if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
  2852. return node
  2853. if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
  2854. return node
  2855. return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)
  2856. def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
  2857. return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
  2858. def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
  2859. return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)
  2860. def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
  2861. return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)
  2862. def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
  2863. return self._optimise_num_div('Divide', node, function, args, is_unbound_method)
  2864. def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
  2865. if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
  2866. return node
  2867. if isinstance(args[1], ExprNodes.IntNode):
  2868. if not (-2**30 <= args[1].constant_result <= 2**30):
  2869. return node
  2870. elif isinstance(args[1], ExprNodes.FloatNode):
  2871. if not (-2**53 <= args[1].constant_result <= 2**53):
  2872. return node
  2873. else:
  2874. return node
  2875. return self._optimise_num_binop(operator, node, function, args, is_unbound_method)
  2876. def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
  2877. return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
  2878. def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
  2879. return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
  2880. def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
  2881. return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)
  2882. def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
  2883. return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)
  2884. def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
  2885. return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)
  2886. def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
  2887. return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
  2888. def _handle_simple_method_float___ne__(self, node, function, args, is_unbound_method):
  2889. return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
  2890. def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
  2891. """
  2892. Optimise math operators for (likely) float or small integer operations.
  2893. """
  2894. if len(args) != 2:
  2895. return node
  2896. if node.type.is_pyobject:
  2897. ret_type = PyrexTypes.py_object_type
  2898. elif node.type is PyrexTypes.c_bint_type and operator in ('Eq', 'Ne'):
  2899. ret_type = PyrexTypes.c_bint_type
  2900. else:
  2901. return node
  2902. # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
  2903. # Prefer constants on RHS as they allows better size control for some operators.
  2904. num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
  2905. if isinstance(args[1], num_nodes):
  2906. if args[0].type is not PyrexTypes.py_object_type:
  2907. return node
  2908. numval = args[1]
  2909. arg_order = 'ObjC'
  2910. elif isinstance(args[0], num_nodes):
  2911. if args[1].type is not PyrexTypes.py_object_type:
  2912. return node
  2913. numval = args[0]
  2914. arg_order = 'CObj'
  2915. else:
  2916. return node
  2917. if not numval.has_constant_result():
  2918. return node
  2919. is_float = isinstance(numval, ExprNodes.FloatNode)
  2920. num_type = PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type
  2921. if is_float:
  2922. if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
  2923. return node
  2924. elif operator == 'Divide':
  2925. # mixed old-/new-style division is not currently optimised for integers
  2926. return node
  2927. elif abs(numval.constant_result) > 2**30:
  2928. # Cut off at an integer border that is still safe for all operations.
  2929. return node
  2930. if operator in ('TrueDivide', 'FloorDivide', 'Divide', 'Remainder'):
  2931. if args[1].constant_result == 0:
  2932. # Don't optimise division by 0. :)
  2933. return node
  2934. args = list(args)
  2935. args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
  2936. numval.pos, value=numval.value, constant_result=numval.constant_result,
  2937. type=num_type))
  2938. inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
  2939. args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
  2940. if is_float or operator not in ('Eq', 'Ne'):
  2941. # "PyFloatBinop" and "PyIntBinop" take an additional "check for zero division" argument.
  2942. zerodivision_check = arg_order == 'CObj' and (
  2943. not node.cdivision if isinstance(node, ExprNodes.DivNode) else False)
  2944. args.append(ExprNodes.BoolNode(node.pos, value=zerodivision_check, constant_result=zerodivision_check))
  2945. utility_code = TempitaUtilityCode.load_cached(
  2946. "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
  2947. "Optimize.c",
  2948. context=dict(op=operator, order=arg_order, ret_type=ret_type))
  2949. call_node = self._substitute_method_call(
  2950. node, function,
  2951. "__Pyx_Py%s_%s%s%s" % (
  2952. 'Float' if is_float else 'Int',
  2953. '' if ret_type.is_pyobject else 'Bool',
  2954. operator,
  2955. arg_order),
  2956. self.Pyx_BinopInt_func_types[(num_type, ret_type)],
  2957. '__%s__' % operator[:3].lower(), is_unbound_method, args,
  2958. may_return_none=True,
  2959. with_none_check=False,
  2960. utility_code=utility_code)
  2961. if node.type.is_pyobject and not ret_type.is_pyobject:
  2962. call_node = ExprNodes.CoerceToPyTypeNode(call_node, self.current_env(), node.type)
  2963. return call_node
  2964. ### unicode type methods
  2965. PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
  2966. PyrexTypes.c_bint_type, [
  2967. PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
  2968. ])
  2969. def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
  2970. if is_unbound_method or len(args) != 1:
  2971. return node
  2972. ustring = args[0]
  2973. if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
  2974. not ustring.arg.type.is_unicode_char:
  2975. return node
  2976. uchar = ustring.arg
  2977. method_name = function.attribute
  2978. if method_name == 'istitle':
  2979. # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
  2980. utility_code = UtilityCode.load_cached(
  2981. "py_unicode_istitle", "StringTools.c")
  2982. function_name = '__Pyx_Py_UNICODE_ISTITLE'
  2983. else:
  2984. utility_code = None
  2985. function_name = 'Py_UNICODE_%s' % method_name.upper()
  2986. func_call = self._substitute_method_call(
  2987. node, function,
  2988. function_name, self.PyUnicode_uchar_predicate_func_type,
  2989. method_name, is_unbound_method, [uchar],
  2990. utility_code = utility_code)
  2991. if node.type.is_pyobject:
  2992. func_call = func_call.coerce_to_pyobject(self.current_env)
  2993. return func_call
  2994. _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
  2995. _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
  2996. _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
  2997. _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
  2998. _handle_simple_method_unicode_islower = _inject_unicode_predicate
  2999. _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
  3000. _handle_simple_method_unicode_isspace = _inject_unicode_predicate
  3001. _handle_simple_method_unicode_istitle = _inject_unicode_predicate
  3002. _handle_simple_method_unicode_isupper = _inject_unicode_predicate
  3003. PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
  3004. PyrexTypes.c_py_ucs4_type, [
  3005. PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
  3006. ])
  3007. def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
  3008. if is_unbound_method or len(args) != 1:
  3009. return node
  3010. ustring = args[0]
  3011. if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
  3012. not ustring.arg.type.is_unicode_char:
  3013. return node
  3014. uchar = ustring.arg
  3015. method_name = function.attribute
  3016. function_name = 'Py_UNICODE_TO%s' % method_name.upper()
  3017. func_call = self._substitute_method_call(
  3018. node, function,
  3019. function_name, self.PyUnicode_uchar_conversion_func_type,
  3020. method_name, is_unbound_method, [uchar])
  3021. if node.type.is_pyobject:
  3022. func_call = func_call.coerce_to_pyobject(self.current_env)
  3023. return func_call
  3024. _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
  3025. _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
  3026. _handle_simple_method_unicode_title = _inject_unicode_character_conversion
  3027. PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
  3028. Builtin.list_type, [
  3029. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3030. PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
  3031. ])
  3032. def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
  3033. """Replace unicode.splitlines(...) by a direct call to the
  3034. corresponding C-API function.
  3035. """
  3036. if len(args) not in (1,2):
  3037. self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
  3038. return node
  3039. self._inject_bint_default_argument(node, args, 1, False)
  3040. return self._substitute_method_call(
  3041. node, function,
  3042. "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
  3043. 'splitlines', is_unbound_method, args)
  3044. PyUnicode_Split_func_type = PyrexTypes.CFuncType(
  3045. Builtin.list_type, [
  3046. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3047. PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
  3048. PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
  3049. ]
  3050. )
  3051. def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
  3052. """Replace unicode.split(...) by a direct call to the
  3053. corresponding C-API function.
  3054. """
  3055. if len(args) not in (1,2,3):
  3056. self._error_wrong_arg_count('unicode.split', node, args, "1-3")
  3057. return node
  3058. if len(args) < 2:
  3059. args.append(ExprNodes.NullNode(node.pos))
  3060. self._inject_int_default_argument(
  3061. node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
  3062. return self._substitute_method_call(
  3063. node, function,
  3064. "PyUnicode_Split", self.PyUnicode_Split_func_type,
  3065. 'split', is_unbound_method, args)
  3066. PyUnicode_Join_func_type = PyrexTypes.CFuncType(
  3067. Builtin.unicode_type, [
  3068. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3069. PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
  3070. ])
  3071. def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
  3072. """
  3073. unicode.join() builds a list first => see if we can do this more efficiently
  3074. """
  3075. if len(args) != 2:
  3076. self._error_wrong_arg_count('unicode.join', node, args, "2")
  3077. return node
  3078. if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
  3079. gen_expr_node = args[1]
  3080. loop_node = gen_expr_node.loop
  3081. yield_statements = _find_yield_statements(loop_node)
  3082. if yield_statements:
  3083. inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
  3084. node.pos, gen_expr_node, orig_func='list',
  3085. comprehension_type=Builtin.list_type)
  3086. for yield_expression, yield_stat_node in yield_statements:
  3087. append_node = ExprNodes.ComprehensionAppendNode(
  3088. yield_expression.pos,
  3089. expr=yield_expression,
  3090. target=inlined_genexpr.target)
  3091. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  3092. args[1] = inlined_genexpr
  3093. return self._substitute_method_call(
  3094. node, function,
  3095. "PyUnicode_Join", self.PyUnicode_Join_func_type,
  3096. 'join', is_unbound_method, args)
  3097. PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
  3098. PyrexTypes.c_bint_type, [
  3099. PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
  3100. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3101. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3102. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3103. PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
  3104. ],
  3105. exception_value = '-1')
  3106. def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
  3107. return self._inject_tailmatch(
  3108. node, function, args, is_unbound_method, 'unicode', 'endswith',
  3109. unicode_tailmatch_utility_code, +1)
  3110. def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
  3111. return self._inject_tailmatch(
  3112. node, function, args, is_unbound_method, 'unicode', 'startswith',
  3113. unicode_tailmatch_utility_code, -1)
  3114. def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
  3115. method_name, utility_code, direction):
  3116. """Replace unicode.startswith(...) and unicode.endswith(...)
  3117. by a direct call to the corresponding C-API function.
  3118. """
  3119. if len(args) not in (2,3,4):
  3120. self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
  3121. return node
  3122. self._inject_int_default_argument(
  3123. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3124. self._inject_int_default_argument(
  3125. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3126. args.append(ExprNodes.IntNode(
  3127. node.pos, value=str(direction), type=PyrexTypes.c_int_type))
  3128. method_call = self._substitute_method_call(
  3129. node, function,
  3130. "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
  3131. self.PyString_Tailmatch_func_type,
  3132. method_name, is_unbound_method, args,
  3133. utility_code = utility_code)
  3134. return method_call.coerce_to(Builtin.bool_type, self.current_env())
  3135. PyUnicode_Find_func_type = PyrexTypes.CFuncType(
  3136. PyrexTypes.c_py_ssize_t_type, [
  3137. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3138. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3139. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3140. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3141. PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
  3142. ],
  3143. exception_value = '-2')
  3144. def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
  3145. return self._inject_unicode_find(
  3146. node, function, args, is_unbound_method, 'find', +1)
  3147. def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
  3148. return self._inject_unicode_find(
  3149. node, function, args, is_unbound_method, 'rfind', -1)
  3150. def _inject_unicode_find(self, node, function, args, is_unbound_method,
  3151. method_name, direction):
  3152. """Replace unicode.find(...) and unicode.rfind(...) by a
  3153. direct call to the corresponding C-API function.
  3154. """
  3155. if len(args) not in (2,3,4):
  3156. self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
  3157. return node
  3158. self._inject_int_default_argument(
  3159. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3160. self._inject_int_default_argument(
  3161. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3162. args.append(ExprNodes.IntNode(
  3163. node.pos, value=str(direction), type=PyrexTypes.c_int_type))
  3164. method_call = self._substitute_method_call(
  3165. node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
  3166. method_name, is_unbound_method, args)
  3167. return method_call.coerce_to_pyobject(self.current_env())
  3168. PyUnicode_Count_func_type = PyrexTypes.CFuncType(
  3169. PyrexTypes.c_py_ssize_t_type, [
  3170. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3171. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3172. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3173. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3174. ],
  3175. exception_value = '-1')
  3176. def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
  3177. """Replace unicode.count(...) by a direct call to the
  3178. corresponding C-API function.
  3179. """
  3180. if len(args) not in (2,3,4):
  3181. self._error_wrong_arg_count('unicode.count', node, args, "2-4")
  3182. return node
  3183. self._inject_int_default_argument(
  3184. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3185. self._inject_int_default_argument(
  3186. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3187. method_call = self._substitute_method_call(
  3188. node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
  3189. 'count', is_unbound_method, args)
  3190. return method_call.coerce_to_pyobject(self.current_env())
  3191. PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
  3192. Builtin.unicode_type, [
  3193. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3194. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3195. PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
  3196. PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
  3197. ])
  3198. def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
  3199. """Replace unicode.replace(...) by a direct call to the
  3200. corresponding C-API function.
  3201. """
  3202. if len(args) not in (3,4):
  3203. self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
  3204. return node
  3205. self._inject_int_default_argument(
  3206. node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
  3207. return self._substitute_method_call(
  3208. node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
  3209. 'replace', is_unbound_method, args)
  3210. PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
  3211. Builtin.bytes_type, [
  3212. PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
  3213. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3214. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3215. ])
  3216. PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
  3217. Builtin.bytes_type, [
  3218. PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
  3219. ])
  3220. _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
  3221. 'unicode_escape', 'raw_unicode_escape']
  3222. _special_codecs = [ (name, codecs.getencoder(name))
  3223. for name in _special_encodings ]
  3224. def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
  3225. """Replace unicode.encode(...) by a direct C-API call to the
  3226. corresponding codec.
  3227. """
  3228. if len(args) < 1 or len(args) > 3:
  3229. self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
  3230. return node
  3231. string_node = args[0]
  3232. if len(args) == 1:
  3233. null_node = ExprNodes.NullNode(node.pos)
  3234. return self._substitute_method_call(
  3235. node, function, "PyUnicode_AsEncodedString",
  3236. self.PyUnicode_AsEncodedString_func_type,
  3237. 'encode', is_unbound_method, [string_node, null_node, null_node])
  3238. parameters = self._unpack_encoding_and_error_mode(node.pos, args)
  3239. if parameters is None:
  3240. return node
  3241. encoding, encoding_node, error_handling, error_handling_node = parameters
  3242. if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
  3243. # constant, so try to do the encoding at compile time
  3244. try:
  3245. value = string_node.value.encode(encoding, error_handling)
  3246. except:
  3247. # well, looks like we can't
  3248. pass
  3249. else:
  3250. value = bytes_literal(value, encoding)
  3251. return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
  3252. if encoding and error_handling == 'strict':
  3253. # try to find a specific encoder function
  3254. codec_name = self._find_special_codec_name(encoding)
  3255. if codec_name is not None and '-' not in codec_name:
  3256. encode_function = "PyUnicode_As%sString" % codec_name
  3257. return self._substitute_method_call(
  3258. node, function, encode_function,
  3259. self.PyUnicode_AsXyzString_func_type,
  3260. 'encode', is_unbound_method, [string_node])
  3261. return self._substitute_method_call(
  3262. node, function, "PyUnicode_AsEncodedString",
  3263. self.PyUnicode_AsEncodedString_func_type,
  3264. 'encode', is_unbound_method,
  3265. [string_node, encoding_node, error_handling_node])
  3266. PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
  3267. Builtin.unicode_type, [
  3268. PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
  3269. PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
  3270. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3271. ]))
  3272. _decode_c_string_func_type = PyrexTypes.CFuncType(
  3273. Builtin.unicode_type, [
  3274. PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
  3275. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3276. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3277. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3278. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3279. PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
  3280. ])
  3281. _decode_bytes_func_type = PyrexTypes.CFuncType(
  3282. Builtin.unicode_type, [
  3283. PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
  3284. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3285. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3286. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3287. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3288. PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
  3289. ])
  3290. _decode_cpp_string_func_type = None # lazy init
  3291. def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
  3292. """Replace char*.decode() by a direct C-API call to the
  3293. corresponding codec, possibly resolving a slice on the char*.
  3294. """
  3295. if not (1 <= len(args) <= 3):
  3296. self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
  3297. return node
  3298. # normalise input nodes
  3299. string_node = args[0]
  3300. start = stop = None
  3301. if isinstance(string_node, ExprNodes.SliceIndexNode):
  3302. index_node = string_node
  3303. string_node = index_node.base
  3304. start, stop = index_node.start, index_node.stop
  3305. if not start or start.constant_result == 0:
  3306. start = None
  3307. if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
  3308. string_node = string_node.arg
  3309. string_type = string_node.type
  3310. if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
  3311. if is_unbound_method:
  3312. string_node = string_node.as_none_safe_node(
  3313. "descriptor '%s' requires a '%s' object but received a 'NoneType'",
  3314. format_args=['decode', string_type.name])
  3315. else:
  3316. string_node = string_node.as_none_safe_node(
  3317. "'NoneType' object has no attribute '%.30s'",
  3318. error="PyExc_AttributeError",
  3319. format_args=['decode'])
  3320. elif not string_type.is_string and not string_type.is_cpp_string:
  3321. # nothing to optimise here
  3322. return node
  3323. parameters = self._unpack_encoding_and_error_mode(node.pos, args)
  3324. if parameters is None:
  3325. return node
  3326. encoding, encoding_node, error_handling, error_handling_node = parameters
  3327. if not start:
  3328. start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
  3329. elif not start.type.is_int:
  3330. start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3331. if stop and not stop.type.is_int:
  3332. stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3333. # try to find a specific encoder function
  3334. codec_name = None
  3335. if encoding is not None:
  3336. codec_name = self._find_special_codec_name(encoding)
  3337. if codec_name is not None:
  3338. if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
  3339. codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
  3340. else:
  3341. codec_cname = "PyUnicode_Decode%s" % codec_name
  3342. decode_function = ExprNodes.RawCNameExprNode(
  3343. node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
  3344. encoding_node = ExprNodes.NullNode(node.pos)
  3345. else:
  3346. decode_function = ExprNodes.NullNode(node.pos)
  3347. # build the helper function call
  3348. temps = []
  3349. if string_type.is_string:
  3350. # C string
  3351. if not stop:
  3352. # use strlen() to find the string length, just as CPython would
  3353. if not string_node.is_name:
  3354. string_node = UtilNodes.LetRefNode(string_node) # used twice
  3355. temps.append(string_node)
  3356. stop = ExprNodes.PythonCapiCallNode(
  3357. string_node.pos, "strlen", self.Pyx_strlen_func_type,
  3358. args=[string_node],
  3359. is_temp=False,
  3360. utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
  3361. ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3362. helper_func_type = self._decode_c_string_func_type
  3363. utility_code_name = 'decode_c_string'
  3364. elif string_type.is_cpp_string:
  3365. # C++ std::string
  3366. if not stop:
  3367. stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
  3368. constant_result=ExprNodes.not_a_constant)
  3369. if self._decode_cpp_string_func_type is None:
  3370. # lazy init to reuse the C++ string type
  3371. self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
  3372. Builtin.unicode_type, [
  3373. PyrexTypes.CFuncTypeArg("string", string_type, None),
  3374. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3375. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3376. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3377. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3378. PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
  3379. ])
  3380. helper_func_type = self._decode_cpp_string_func_type
  3381. utility_code_name = 'decode_cpp_string'
  3382. else:
  3383. # Python bytes/bytearray object
  3384. if not stop:
  3385. stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
  3386. constant_result=ExprNodes.not_a_constant)
  3387. helper_func_type = self._decode_bytes_func_type
  3388. if string_type is Builtin.bytes_type:
  3389. utility_code_name = 'decode_bytes'
  3390. else:
  3391. utility_code_name = 'decode_bytearray'
  3392. node = ExprNodes.PythonCapiCallNode(
  3393. node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
  3394. args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
  3395. is_temp=node.is_temp,
  3396. utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
  3397. )
  3398. for temp in temps[::-1]:
  3399. node = UtilNodes.EvalWithTempExprNode(temp, node)
  3400. return node
  3401. _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
  3402. def _find_special_codec_name(self, encoding):
  3403. try:
  3404. requested_codec = codecs.getencoder(encoding)
  3405. except LookupError:
  3406. return None
  3407. for name, codec in self._special_codecs:
  3408. if codec == requested_codec:
  3409. if '_' in name:
  3410. name = ''.join([s.capitalize()
  3411. for s in name.split('_')])
  3412. return name
  3413. return None
  3414. def _unpack_encoding_and_error_mode(self, pos, args):
  3415. null_node = ExprNodes.NullNode(pos)
  3416. if len(args) >= 2:
  3417. encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
  3418. if encoding_node is None:
  3419. return None
  3420. else:
  3421. encoding = None
  3422. encoding_node = null_node
  3423. if len(args) == 3:
  3424. error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
  3425. if error_handling_node is None:
  3426. return None
  3427. if error_handling == 'strict':
  3428. error_handling_node = null_node
  3429. else:
  3430. error_handling = 'strict'
  3431. error_handling_node = null_node
  3432. return (encoding, encoding_node, error_handling, error_handling_node)
  3433. def _unpack_string_and_cstring_node(self, node):
  3434. if isinstance(node, ExprNodes.CoerceToPyTypeNode):
  3435. node = node.arg
  3436. if isinstance(node, ExprNodes.UnicodeNode):
  3437. encoding = node.value
  3438. node = ExprNodes.BytesNode(
  3439. node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
  3440. elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
  3441. encoding = node.value.decode('ISO-8859-1')
  3442. node = ExprNodes.BytesNode(
  3443. node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
  3444. elif node.type is Builtin.bytes_type:
  3445. encoding = None
  3446. node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
  3447. elif node.type.is_string:
  3448. encoding = None
  3449. else:
  3450. encoding = node = None
  3451. return encoding, node
  3452. def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
  3453. return self._inject_tailmatch(
  3454. node, function, args, is_unbound_method, 'str', 'endswith',
  3455. str_tailmatch_utility_code, +1)
  3456. def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
  3457. return self._inject_tailmatch(
  3458. node, function, args, is_unbound_method, 'str', 'startswith',
  3459. str_tailmatch_utility_code, -1)
  3460. def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
  3461. return self._inject_tailmatch(
  3462. node, function, args, is_unbound_method, 'bytes', 'endswith',
  3463. bytes_tailmatch_utility_code, +1)
  3464. def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
  3465. return self._inject_tailmatch(
  3466. node, function, args, is_unbound_method, 'bytes', 'startswith',
  3467. bytes_tailmatch_utility_code, -1)
  3468. ''' # disabled for now, enable when we consider it worth it (see StringTools.c)
  3469. def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
  3470. return self._inject_tailmatch(
  3471. node, function, args, is_unbound_method, 'bytearray', 'endswith',
  3472. bytes_tailmatch_utility_code, +1)
  3473. def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
  3474. return self._inject_tailmatch(
  3475. node, function, args, is_unbound_method, 'bytearray', 'startswith',
  3476. bytes_tailmatch_utility_code, -1)
  3477. '''
  3478. ### helpers
  3479. def _substitute_method_call(self, node, function, name, func_type,
  3480. attr_name, is_unbound_method, args=(),
  3481. utility_code=None, is_temp=None,
  3482. may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
  3483. with_none_check=True):
  3484. args = list(args)
  3485. if with_none_check and args:
  3486. args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
  3487. if is_temp is None:
  3488. is_temp = node.is_temp
  3489. return ExprNodes.PythonCapiCallNode(
  3490. node.pos, name, func_type,
  3491. args = args,
  3492. is_temp = is_temp,
  3493. utility_code = utility_code,
  3494. may_return_none = may_return_none,
  3495. result_is_used = node.result_is_used,
  3496. )
  3497. def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
  3498. if self_arg.is_literal:
  3499. return self_arg
  3500. if is_unbound_method:
  3501. self_arg = self_arg.as_none_safe_node(
  3502. "descriptor '%s' requires a '%s' object but received a 'NoneType'",
  3503. format_args=[attr_name, self_arg.type.name])
  3504. else:
  3505. self_arg = self_arg.as_none_safe_node(
  3506. "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''),
  3507. error="PyExc_AttributeError",
  3508. format_args=[attr_name])
  3509. return self_arg
  3510. def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
  3511. assert len(args) >= arg_index
  3512. if len(args) == arg_index:
  3513. args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
  3514. type=type, constant_result=default_value))
  3515. else:
  3516. args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
  3517. def _inject_bint_default_argument(self, node, args, arg_index, default_value):
  3518. assert len(args) >= arg_index
  3519. if len(args) == arg_index:
  3520. default_value = bool(default_value)
  3521. args.append(ExprNodes.BoolNode(node.pos, value=default_value,
  3522. constant_result=default_value))
  3523. else:
  3524. args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
  3525. unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
  3526. bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
  3527. str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
  3528. class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
  3529. """Calculate the result of constant expressions to store it in
  3530. ``expr_node.constant_result``, and replace trivial cases by their
  3531. constant result.
  3532. General rules:
  3533. - We calculate float constants to make them available to the
  3534. compiler, but we do not aggregate them into a single literal
  3535. node to prevent any loss of precision.
  3536. - We recursively calculate constants from non-literal nodes to
  3537. make them available to the compiler, but we only aggregate
  3538. literal nodes at each step. Non-literal nodes are never merged
  3539. into a single node.
  3540. """
  3541. def __init__(self, reevaluate=False):
  3542. """
  3543. The reevaluate argument specifies whether constant values that were
  3544. previously computed should be recomputed.
  3545. """
  3546. super(ConstantFolding, self).__init__()
  3547. self.reevaluate = reevaluate
  3548. def _calculate_const(self, node):
  3549. if (not self.reevaluate and
  3550. node.constant_result is not ExprNodes.constant_value_not_set):
  3551. return
  3552. # make sure we always set the value
  3553. not_a_constant = ExprNodes.not_a_constant
  3554. node.constant_result = not_a_constant
  3555. # check if all children are constant
  3556. children = self.visitchildren(node)
  3557. for child_result in children.values():
  3558. if type(child_result) is list:
  3559. for child in child_result:
  3560. if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
  3561. return
  3562. elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
  3563. return
  3564. # now try to calculate the real constant value
  3565. try:
  3566. node.calculate_constant_result()
  3567. # if node.constant_result is not ExprNodes.not_a_constant:
  3568. # print node.__class__.__name__, node.constant_result
  3569. except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
  3570. # ignore all 'normal' errors here => no constant result
  3571. pass
  3572. except Exception:
  3573. # this looks like a real error
  3574. import traceback, sys
  3575. traceback.print_exc(file=sys.stdout)
  3576. NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
  3577. ExprNodes.IntNode, ExprNodes.FloatNode]
  3578. def _widest_node_class(self, *nodes):
  3579. try:
  3580. return self.NODE_TYPE_ORDER[
  3581. max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
  3582. except ValueError:
  3583. return None
  3584. def _bool_node(self, node, value):
  3585. value = bool(value)
  3586. return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
  3587. def visit_ExprNode(self, node):
  3588. self._calculate_const(node)
  3589. return node
  3590. def visit_UnopNode(self, node):
  3591. self._calculate_const(node)
  3592. if not node.has_constant_result():
  3593. if node.operator == '!':
  3594. return self._handle_NotNode(node)
  3595. return node
  3596. if not node.operand.is_literal:
  3597. return node
  3598. if node.operator == '!':
  3599. return self._bool_node(node, node.constant_result)
  3600. elif isinstance(node.operand, ExprNodes.BoolNode):
  3601. return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
  3602. type=PyrexTypes.c_int_type,
  3603. constant_result=int(node.constant_result))
  3604. elif node.operator == '+':
  3605. return self._handle_UnaryPlusNode(node)
  3606. elif node.operator == '-':
  3607. return self._handle_UnaryMinusNode(node)
  3608. return node
  3609. _negate_operator = {
  3610. 'in': 'not_in',
  3611. 'not_in': 'in',
  3612. 'is': 'is_not',
  3613. 'is_not': 'is'
  3614. }.get
  3615. def _handle_NotNode(self, node):
  3616. operand = node.operand
  3617. if isinstance(operand, ExprNodes.PrimaryCmpNode):
  3618. operator = self._negate_operator(operand.operator)
  3619. if operator:
  3620. node = copy.copy(operand)
  3621. node.operator = operator
  3622. node = self.visit_PrimaryCmpNode(node)
  3623. return node
  3624. def _handle_UnaryMinusNode(self, node):
  3625. def _negate(value):
  3626. if value.startswith('-'):
  3627. value = value[1:]
  3628. else:
  3629. value = '-' + value
  3630. return value
  3631. node_type = node.operand.type
  3632. if isinstance(node.operand, ExprNodes.FloatNode):
  3633. # this is a safe operation
  3634. return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
  3635. type=node_type,
  3636. constant_result=node.constant_result)
  3637. if node_type.is_int and node_type.signed or \
  3638. isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
  3639. return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
  3640. type=node_type,
  3641. longness=node.operand.longness,
  3642. constant_result=node.constant_result)
  3643. return node
  3644. def _handle_UnaryPlusNode(self, node):
  3645. if (node.operand.has_constant_result() and
  3646. node.constant_result == node.operand.constant_result):
  3647. return node.operand
  3648. return node
  3649. def visit_BoolBinopNode(self, node):
  3650. self._calculate_const(node)
  3651. if not node.operand1.has_constant_result():
  3652. return node
  3653. if node.operand1.constant_result:
  3654. if node.operator == 'and':
  3655. return node.operand2
  3656. else:
  3657. return node.operand1
  3658. else:
  3659. if node.operator == 'and':
  3660. return node.operand1
  3661. else:
  3662. return node.operand2
  3663. def visit_BinopNode(self, node):
  3664. self._calculate_const(node)
  3665. if node.constant_result is ExprNodes.not_a_constant:
  3666. return node
  3667. if isinstance(node.constant_result, float):
  3668. return node
  3669. operand1, operand2 = node.operand1, node.operand2
  3670. if not operand1.is_literal or not operand2.is_literal:
  3671. return node
  3672. # now inject a new constant node with the calculated value
  3673. try:
  3674. type1, type2 = operand1.type, operand2.type
  3675. if type1 is None or type2 is None:
  3676. return node
  3677. except AttributeError:
  3678. return node
  3679. if type1.is_numeric and type2.is_numeric:
  3680. widest_type = PyrexTypes.widest_numeric_type(type1, type2)
  3681. else:
  3682. widest_type = PyrexTypes.py_object_type
  3683. target_class = self._widest_node_class(operand1, operand2)
  3684. if target_class is None:
  3685. return node
  3686. elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
  3687. # C arithmetic results in at least an int type
  3688. target_class = ExprNodes.IntNode
  3689. elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
  3690. # C arithmetic results in at least an int type
  3691. target_class = ExprNodes.IntNode
  3692. if target_class is ExprNodes.IntNode:
  3693. unsigned = getattr(operand1, 'unsigned', '') and \
  3694. getattr(operand2, 'unsigned', '')
  3695. longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
  3696. len(getattr(operand2, 'longness', '')))]
  3697. new_node = ExprNodes.IntNode(pos=node.pos,
  3698. unsigned=unsigned, longness=longness,
  3699. value=str(int(node.constant_result)),
  3700. constant_result=int(node.constant_result))
  3701. # IntNode is smart about the type it chooses, so we just
  3702. # make sure we were not smarter this time
  3703. if widest_type.is_pyobject or new_node.type.is_pyobject:
  3704. new_node.type = PyrexTypes.py_object_type
  3705. else:
  3706. new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
  3707. else:
  3708. if target_class is ExprNodes.BoolNode:
  3709. node_value = node.constant_result
  3710. else:
  3711. node_value = str(node.constant_result)
  3712. new_node = target_class(pos=node.pos, type = widest_type,
  3713. value = node_value,
  3714. constant_result = node.constant_result)
  3715. return new_node
  3716. def visit_AddNode(self, node):
  3717. self._calculate_const(node)
  3718. if node.constant_result is ExprNodes.not_a_constant:
  3719. return node
  3720. if node.operand1.is_string_literal and node.operand2.is_string_literal:
  3721. # some people combine string literals with a '+'
  3722. str1, str2 = node.operand1, node.operand2
  3723. if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
  3724. bytes_value = None
  3725. if str1.bytes_value is not None and str2.bytes_value is not None:
  3726. if str1.bytes_value.encoding == str2.bytes_value.encoding:
  3727. bytes_value = bytes_literal(
  3728. str1.bytes_value + str2.bytes_value,
  3729. str1.bytes_value.encoding)
  3730. string_value = EncodedString(node.constant_result)
  3731. return ExprNodes.UnicodeNode(
  3732. str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
  3733. elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
  3734. if str1.value.encoding == str2.value.encoding:
  3735. bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
  3736. return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
  3737. # all other combinations are rather complicated
  3738. # to get right in Py2/3: encodings, unicode escapes, ...
  3739. return self.visit_BinopNode(node)
  3740. def visit_MulNode(self, node):
  3741. self._calculate_const(node)
  3742. if node.operand1.is_sequence_constructor:
  3743. return self._calculate_constant_seq(node, node.operand1, node.operand2)
  3744. if isinstance(node.operand1, ExprNodes.IntNode) and \
  3745. node.operand2.is_sequence_constructor:
  3746. return self._calculate_constant_seq(node, node.operand2, node.operand1)
  3747. if node.operand1.is_string_literal:
  3748. return self._multiply_string(node, node.operand1, node.operand2)
  3749. elif node.operand2.is_string_literal:
  3750. return self._multiply_string(node, node.operand2, node.operand1)
  3751. return self.visit_BinopNode(node)
  3752. def _multiply_string(self, node, string_node, multiplier_node):
  3753. multiplier = multiplier_node.constant_result
  3754. if not isinstance(multiplier, _py_int_types):
  3755. return node
  3756. if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)):
  3757. return node
  3758. if len(node.constant_result) > 256:
  3759. # Too long for static creation, leave it to runtime. (-> arbitrary limit)
  3760. return node
  3761. build_string = encoded_string
  3762. if isinstance(string_node, ExprNodes.BytesNode):
  3763. build_string = bytes_literal
  3764. elif isinstance(string_node, ExprNodes.StringNode):
  3765. if string_node.unicode_value is not None:
  3766. string_node.unicode_value = encoded_string(
  3767. string_node.unicode_value * multiplier,
  3768. string_node.unicode_value.encoding)
  3769. build_string = encoded_string if string_node.value.is_unicode else bytes_literal
  3770. elif isinstance(string_node, ExprNodes.UnicodeNode):
  3771. if string_node.bytes_value is not None:
  3772. string_node.bytes_value = bytes_literal(
  3773. string_node.bytes_value * multiplier,
  3774. string_node.bytes_value.encoding)
  3775. else:
  3776. assert False, "unknown string node type: %s" % type(string_node)
  3777. string_node.value = build_string(
  3778. string_node.value * multiplier,
  3779. string_node.value.encoding)
  3780. # follow constant-folding and use unicode_value in preference
  3781. if isinstance(string_node, ExprNodes.StringNode) and string_node.unicode_value is not None:
  3782. string_node.constant_result = string_node.unicode_value
  3783. else:
  3784. string_node.constant_result = string_node.value
  3785. return string_node
  3786. def _calculate_constant_seq(self, node, sequence_node, factor):
  3787. if factor.constant_result != 1 and sequence_node.args:
  3788. if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0:
  3789. del sequence_node.args[:]
  3790. sequence_node.mult_factor = None
  3791. elif sequence_node.mult_factor is not None:
  3792. if (isinstance(factor.constant_result, _py_int_types) and
  3793. isinstance(sequence_node.mult_factor.constant_result, _py_int_types)):
  3794. value = sequence_node.mult_factor.constant_result * factor.constant_result
  3795. sequence_node.mult_factor = ExprNodes.IntNode(
  3796. sequence_node.mult_factor.pos,
  3797. value=str(value), constant_result=value)
  3798. else:
  3799. # don't know if we can combine the factors, so don't
  3800. return self.visit_BinopNode(node)
  3801. else:
  3802. sequence_node.mult_factor = factor
  3803. return sequence_node
  3804. def visit_ModNode(self, node):
  3805. self.visitchildren(node)
  3806. if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode):
  3807. if not node.operand2.mult_factor:
  3808. fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args)
  3809. if fstring is not None:
  3810. return fstring
  3811. return self.visit_BinopNode(node)
  3812. _parse_string_format_regex = (
  3813. u'(%(?:' # %...
  3814. u'(?:[-0-9]+|[ ])?' # width (optional) or space prefix fill character (optional)
  3815. u'(?:[.][0-9]+)?' # precision (optional)
  3816. u')?.)' # format type (or something different for unsupported formats)
  3817. )
  3818. def _build_fstring(self, pos, ustring, format_args):
  3819. # Issues formatting warnings instead of errors since we really only catch a few errors by accident.
  3820. args = iter(format_args)
  3821. substrings = []
  3822. can_be_optimised = True
  3823. for s in re.split(self._parse_string_format_regex, ustring):
  3824. if not s:
  3825. continue
  3826. if s == u'%%':
  3827. substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u'%'), constant_result=u'%'))
  3828. continue
  3829. if s[0] != u'%':
  3830. if s[-1] == u'%':
  3831. warning(pos, "Incomplete format: '...%s'" % s[-3:], level=1)
  3832. can_be_optimised = False
  3833. substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s), constant_result=s))
  3834. continue
  3835. format_type = s[-1]
  3836. try:
  3837. arg = next(args)
  3838. except StopIteration:
  3839. warning(pos, "Too few arguments for format placeholders", level=1)
  3840. can_be_optimised = False
  3841. break
  3842. if arg.is_starred:
  3843. can_be_optimised = False
  3844. break
  3845. if format_type in u'asrfdoxX':
  3846. format_spec = s[1:]
  3847. conversion_char = None
  3848. if format_type in u'doxX' and u'.' in format_spec:
  3849. # Precision is not allowed for integers in format(), but ok in %-formatting.
  3850. can_be_optimised = False
  3851. elif format_type in u'ars':
  3852. format_spec = format_spec[:-1]
  3853. conversion_char = format_type
  3854. if format_spec.startswith('0'):
  3855. format_spec = '>' + format_spec[1:] # right-alignment '%05s' spells '{:>5}'
  3856. elif format_type == u'd':
  3857. # '%d' formatting supports float, but '{obj:d}' does not => convert to int first.
  3858. conversion_char = 'd'
  3859. if format_spec.startswith('-'):
  3860. format_spec = '<' + format_spec[1:] # left-alignment '%-5s' spells '{:<5}'
  3861. substrings.append(ExprNodes.FormattedValueNode(
  3862. arg.pos, value=arg,
  3863. conversion_char=conversion_char,
  3864. format_spec=ExprNodes.UnicodeNode(
  3865. pos, value=EncodedString(format_spec), constant_result=format_spec)
  3866. if format_spec else None,
  3867. ))
  3868. else:
  3869. # keep it simple for now ...
  3870. can_be_optimised = False
  3871. break
  3872. if not can_be_optimised:
  3873. # Print all warnings we can find before finally giving up here.
  3874. return None
  3875. try:
  3876. next(args)
  3877. except StopIteration: pass
  3878. else:
  3879. warning(pos, "Too many arguments for format placeholders", level=1)
  3880. return None
  3881. node = ExprNodes.JoinedStrNode(pos, values=substrings)
  3882. return self.visit_JoinedStrNode(node)
  3883. def visit_FormattedValueNode(self, node):
  3884. self.visitchildren(node)
  3885. conversion_char = node.conversion_char or 's'
  3886. if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value:
  3887. node.format_spec = None
  3888. if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
  3889. value = EncodedString(node.value.value)
  3890. if value.isdigit():
  3891. return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
  3892. if node.format_spec is None and conversion_char == 's':
  3893. value = None
  3894. if isinstance(node.value, ExprNodes.UnicodeNode):
  3895. value = node.value.value
  3896. elif isinstance(node.value, ExprNodes.StringNode):
  3897. value = node.value.unicode_value
  3898. if value is not None:
  3899. return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
  3900. return node
  3901. def visit_JoinedStrNode(self, node):
  3902. """
  3903. Clean up after the parser by discarding empty Unicode strings and merging
  3904. substring sequences. Empty or single-value join lists are not uncommon
  3905. because f-string format specs are always parsed into JoinedStrNodes.
  3906. """
  3907. self.visitchildren(node)
  3908. unicode_node = ExprNodes.UnicodeNode
  3909. values = []
  3910. for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)):
  3911. if is_unode_group:
  3912. substrings = list(substrings)
  3913. unode = substrings[0]
  3914. if len(substrings) > 1:
  3915. value = EncodedString(u''.join(value.value for value in substrings))
  3916. unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value)
  3917. # ignore empty Unicode strings
  3918. if unode.value:
  3919. values.append(unode)
  3920. else:
  3921. values.extend(substrings)
  3922. if not values:
  3923. value = EncodedString('')
  3924. node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value)
  3925. elif len(values) == 1:
  3926. node = values[0]
  3927. elif len(values) == 2:
  3928. # reduce to string concatenation
  3929. node = ExprNodes.binop_node(node.pos, '+', *values)
  3930. else:
  3931. node.values = values
  3932. return node
  3933. def visit_MergedDictNode(self, node):
  3934. """Unpack **args in place if we can."""
  3935. self.visitchildren(node)
  3936. args = []
  3937. items = []
  3938. def add(arg):
  3939. if arg.is_dict_literal:
  3940. if items:
  3941. items[0].key_value_pairs.extend(arg.key_value_pairs)
  3942. else:
  3943. items.append(arg)
  3944. elif isinstance(arg, ExprNodes.MergedDictNode):
  3945. for child_arg in arg.keyword_args:
  3946. add(child_arg)
  3947. else:
  3948. if items:
  3949. args.append(items[0])
  3950. del items[:]
  3951. args.append(arg)
  3952. for arg in node.keyword_args:
  3953. add(arg)
  3954. if items:
  3955. args.append(items[0])
  3956. if len(args) == 1:
  3957. arg = args[0]
  3958. if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
  3959. return arg
  3960. node.keyword_args[:] = args
  3961. self._calculate_const(node)
  3962. return node
  3963. def visit_MergedSequenceNode(self, node):
  3964. """Unpack *args in place if we can."""
  3965. self.visitchildren(node)
  3966. is_set = node.type is Builtin.set_type
  3967. args = []
  3968. values = []
  3969. def add(arg):
  3970. if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
  3971. if values:
  3972. values[0].args.extend(arg.args)
  3973. else:
  3974. values.append(arg)
  3975. elif isinstance(arg, ExprNodes.MergedSequenceNode):
  3976. for child_arg in arg.args:
  3977. add(child_arg)
  3978. else:
  3979. if values:
  3980. args.append(values[0])
  3981. del values[:]
  3982. args.append(arg)
  3983. for arg in node.args:
  3984. add(arg)
  3985. if values:
  3986. args.append(values[0])
  3987. if len(args) == 1:
  3988. arg = args[0]
  3989. if ((is_set and arg.is_set_literal) or
  3990. (arg.is_sequence_constructor and arg.type is node.type) or
  3991. isinstance(arg, ExprNodes.MergedSequenceNode)):
  3992. return arg
  3993. node.args[:] = args
  3994. self._calculate_const(node)
  3995. return node
  3996. def visit_SequenceNode(self, node):
  3997. """Unpack *args in place if we can."""
  3998. self.visitchildren(node)
  3999. args = []
  4000. for arg in node.args:
  4001. if not arg.is_starred:
  4002. args.append(arg)
  4003. elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
  4004. args.extend(arg.target.args)
  4005. else:
  4006. args.append(arg)
  4007. node.args[:] = args
  4008. self._calculate_const(node)
  4009. return node
  4010. def visit_PrimaryCmpNode(self, node):
  4011. # calculate constant partial results in the comparison cascade
  4012. self.visitchildren(node, ['operand1'])
  4013. left_node = node.operand1
  4014. cmp_node = node
  4015. while cmp_node is not None:
  4016. self.visitchildren(cmp_node, ['operand2'])
  4017. right_node = cmp_node.operand2
  4018. cmp_node.constant_result = not_a_constant
  4019. if left_node.has_constant_result() and right_node.has_constant_result():
  4020. try:
  4021. cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
  4022. except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
  4023. pass # ignore all 'normal' errors here => no constant result
  4024. left_node = right_node
  4025. cmp_node = cmp_node.cascade
  4026. if not node.cascade:
  4027. if node.has_constant_result():
  4028. return self._bool_node(node, node.constant_result)
  4029. return node
  4030. # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
  4031. cascades = [[node.operand1]]
  4032. final_false_result = []
  4033. def split_cascades(cmp_node):
  4034. if cmp_node.has_constant_result():
  4035. if not cmp_node.constant_result:
  4036. # False => short-circuit
  4037. final_false_result.append(self._bool_node(cmp_node, False))
  4038. return
  4039. else:
  4040. # True => discard and start new cascade
  4041. cascades.append([cmp_node.operand2])
  4042. else:
  4043. # not constant => append to current cascade
  4044. cascades[-1].append(cmp_node)
  4045. if cmp_node.cascade:
  4046. split_cascades(cmp_node.cascade)
  4047. split_cascades(node)
  4048. cmp_nodes = []
  4049. for cascade in cascades:
  4050. if len(cascade) < 2:
  4051. continue
  4052. cmp_node = cascade[1]
  4053. pcmp_node = ExprNodes.PrimaryCmpNode(
  4054. cmp_node.pos,
  4055. operand1=cascade[0],
  4056. operator=cmp_node.operator,
  4057. operand2=cmp_node.operand2,
  4058. constant_result=not_a_constant)
  4059. cmp_nodes.append(pcmp_node)
  4060. last_cmp_node = pcmp_node
  4061. for cmp_node in cascade[2:]:
  4062. last_cmp_node.cascade = cmp_node
  4063. last_cmp_node = cmp_node
  4064. last_cmp_node.cascade = None
  4065. if final_false_result:
  4066. # last cascade was constant False
  4067. cmp_nodes.append(final_false_result[0])
  4068. elif not cmp_nodes:
  4069. # only constants, but no False result
  4070. return self._bool_node(node, True)
  4071. node = cmp_nodes[0]
  4072. if len(cmp_nodes) == 1:
  4073. if node.has_constant_result():
  4074. return self._bool_node(node, node.constant_result)
  4075. else:
  4076. for cmp_node in cmp_nodes[1:]:
  4077. node = ExprNodes.BoolBinopNode(
  4078. node.pos,
  4079. operand1=node,
  4080. operator='and',
  4081. operand2=cmp_node,
  4082. constant_result=not_a_constant)
  4083. return node
  4084. def visit_CondExprNode(self, node):
  4085. self._calculate_const(node)
  4086. if not node.test.has_constant_result():
  4087. return node
  4088. if node.test.constant_result:
  4089. return node.true_val
  4090. else:
  4091. return node.false_val
  4092. def visit_IfStatNode(self, node):
  4093. self.visitchildren(node)
  4094. # eliminate dead code based on constant condition results
  4095. if_clauses = []
  4096. for if_clause in node.if_clauses:
  4097. condition = if_clause.condition
  4098. if condition.has_constant_result():
  4099. if condition.constant_result:
  4100. # always true => subsequent clauses can safely be dropped
  4101. node.else_clause = if_clause.body
  4102. break
  4103. # else: false => drop clause
  4104. else:
  4105. # unknown result => normal runtime evaluation
  4106. if_clauses.append(if_clause)
  4107. if if_clauses:
  4108. node.if_clauses = if_clauses
  4109. return node
  4110. elif node.else_clause:
  4111. return node.else_clause
  4112. else:
  4113. return Nodes.StatListNode(node.pos, stats=[])
  4114. def visit_SliceIndexNode(self, node):
  4115. self._calculate_const(node)
  4116. # normalise start/stop values
  4117. if node.start is None or node.start.constant_result is None:
  4118. start = node.start = None
  4119. else:
  4120. start = node.start.constant_result
  4121. if node.stop is None or node.stop.constant_result is None:
  4122. stop = node.stop = None
  4123. else:
  4124. stop = node.stop.constant_result
  4125. # cut down sliced constant sequences
  4126. if node.constant_result is not not_a_constant:
  4127. base = node.base
  4128. if base.is_sequence_constructor and base.mult_factor is None:
  4129. base.args = base.args[start:stop]
  4130. return base
  4131. elif base.is_string_literal:
  4132. base = base.as_sliced_node(start, stop)
  4133. if base is not None:
  4134. return base
  4135. return node
  4136. def visit_ComprehensionNode(self, node):
  4137. self.visitchildren(node)
  4138. if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
  4139. # loop was pruned already => transform into literal
  4140. if node.type is Builtin.list_type:
  4141. return ExprNodes.ListNode(
  4142. node.pos, args=[], constant_result=[])
  4143. elif node.type is Builtin.set_type:
  4144. return ExprNodes.SetNode(
  4145. node.pos, args=[], constant_result=set())
  4146. elif node.type is Builtin.dict_type:
  4147. return ExprNodes.DictNode(
  4148. node.pos, key_value_pairs=[], constant_result={})
  4149. return node
  4150. def visit_ForInStatNode(self, node):
  4151. self.visitchildren(node)
  4152. sequence = node.iterator.sequence
  4153. if isinstance(sequence, ExprNodes.SequenceNode):
  4154. if not sequence.args:
  4155. if node.else_clause:
  4156. return node.else_clause
  4157. else:
  4158. # don't break list comprehensions
  4159. return Nodes.StatListNode(node.pos, stats=[])
  4160. # iterating over a list literal? => tuples are more efficient
  4161. if isinstance(sequence, ExprNodes.ListNode):
  4162. node.iterator.sequence = sequence.as_tuple()
  4163. return node
  4164. def visit_WhileStatNode(self, node):
  4165. self.visitchildren(node)
  4166. if node.condition and node.condition.has_constant_result():
  4167. if node.condition.constant_result:
  4168. node.condition = None
  4169. node.else_clause = None
  4170. else:
  4171. return node.else_clause
  4172. return node
  4173. def visit_ExprStatNode(self, node):
  4174. self.visitchildren(node)
  4175. if not isinstance(node.expr, ExprNodes.ExprNode):
  4176. # ParallelRangeTransform does this ...
  4177. return node
  4178. # drop unused constant expressions
  4179. if node.expr.has_constant_result():
  4180. return None
  4181. return node
  4182. # in the future, other nodes can have their own handler method here
  4183. # that can replace them with a constant result node
  4184. visit_Node = Visitor.VisitorTransform.recurse_to_children
  4185. class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
  4186. """
  4187. This visitor handles several commuting optimizations, and is run
  4188. just before the C code generation phase.
  4189. The optimizations currently implemented in this class are:
  4190. - eliminate None assignment and refcounting for first assignment.
  4191. - isinstance -> typecheck for cdef types
  4192. - eliminate checks for None and/or types that became redundant after tree changes
  4193. - eliminate useless string formatting steps
  4194. - replace Python function calls that look like method calls by a faster PyMethodCallNode
  4195. """
  4196. in_loop = False
  4197. def visit_SingleAssignmentNode(self, node):
  4198. """Avoid redundant initialisation of local variables before their
  4199. first assignment.
  4200. """
  4201. self.visitchildren(node)
  4202. if node.first:
  4203. lhs = node.lhs
  4204. lhs.lhs_of_first_assignment = True
  4205. return node
  4206. def visit_SimpleCallNode(self, node):
  4207. """
  4208. Replace generic calls to isinstance(x, type) by a more efficient type check.
  4209. Replace likely Python method calls by a specialised PyMethodCallNode.
  4210. """
  4211. self.visitchildren(node)
  4212. function = node.function
  4213. if function.type.is_cfunction and function.is_name:
  4214. if function.name == 'isinstance' and len(node.args) == 2:
  4215. type_arg = node.args[1]
  4216. if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
  4217. cython_scope = self.context.cython_scope
  4218. function.entry = cython_scope.lookup('PyObject_TypeCheck')
  4219. function.type = function.entry.type
  4220. PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
  4221. node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
  4222. elif (node.is_temp and function.type.is_pyobject and self.current_directives.get(
  4223. "optimize.unpack_method_calls_in_pyinit"
  4224. if not self.in_loop and self.current_env().is_module_scope
  4225. else "optimize.unpack_method_calls")):
  4226. # optimise simple Python methods calls
  4227. if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not (
  4228. node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and len(node.arg_tuple.args) > 1)):
  4229. # simple call, now exclude calls to objects that are definitely not methods
  4230. may_be_a_method = True
  4231. if function.type is Builtin.type_type:
  4232. may_be_a_method = False
  4233. elif function.is_attribute:
  4234. if function.entry and function.entry.type.is_cfunction:
  4235. # optimised builtin method
  4236. may_be_a_method = False
  4237. elif function.is_name:
  4238. entry = function.entry
  4239. if entry.is_builtin or entry.type.is_cfunction:
  4240. may_be_a_method = False
  4241. elif entry.cf_assignments:
  4242. # local functions/classes are definitely not methods
  4243. non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode)
  4244. may_be_a_method = any(
  4245. assignment.rhs and not isinstance(assignment.rhs, non_method_nodes)
  4246. for assignment in entry.cf_assignments)
  4247. if may_be_a_method:
  4248. if (node.self and function.is_attribute and
  4249. isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
  4250. # function self object was moved into a CloneNode => undo
  4251. function.obj = function.obj.arg
  4252. node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
  4253. node, function=function, arg_tuple=node.arg_tuple, type=node.type))
  4254. return node
  4255. def visit_NumPyMethodCallNode(self, node):
  4256. # Exclude from replacement above.
  4257. self.visitchildren(node)
  4258. return node
  4259. def visit_PyTypeTestNode(self, node):
  4260. """Remove tests for alternatively allowed None values from
  4261. type tests when we know that the argument cannot be None
  4262. anyway.
  4263. """
  4264. self.visitchildren(node)
  4265. if not node.notnone:
  4266. if not node.arg.may_be_none():
  4267. node.notnone = True
  4268. return node
  4269. def visit_NoneCheckNode(self, node):
  4270. """Remove None checks from expressions that definitely do not
  4271. carry a None value.
  4272. """
  4273. self.visitchildren(node)
  4274. if not node.arg.may_be_none():
  4275. return node.arg
  4276. return node
  4277. def visit_LoopNode(self, node):
  4278. """Remember when we enter a loop as some expensive optimisations might still be worth it there.
  4279. """
  4280. old_val = self.in_loop
  4281. self.in_loop = True
  4282. self.visitchildren(node)
  4283. self.in_loop = old_val
  4284. return node
  4285. class ConsolidateOverflowCheck(Visitor.CythonTransform):
  4286. """
  4287. This class facilitates the sharing of overflow checking among all nodes
  4288. of a nested arithmetic expression. For example, given the expression
  4289. a*b + c, where a, b, and x are all possibly overflowing ints, the entire
  4290. sequence will be evaluated and the overflow bit checked only at the end.
  4291. """
  4292. overflow_bit_node = None
  4293. def visit_Node(self, node):
  4294. if self.overflow_bit_node is not None:
  4295. saved = self.overflow_bit_node
  4296. self.overflow_bit_node = None
  4297. self.visitchildren(node)
  4298. self.overflow_bit_node = saved
  4299. else:
  4300. self.visitchildren(node)
  4301. return node
  4302. def visit_NumBinopNode(self, node):
  4303. if node.overflow_check and node.overflow_fold:
  4304. top_level_overflow = self.overflow_bit_node is None
  4305. if top_level_overflow:
  4306. self.overflow_bit_node = node
  4307. else:
  4308. node.overflow_bit_node = self.overflow_bit_node
  4309. node.overflow_check = False
  4310. self.visitchildren(node)
  4311. if top_level_overflow:
  4312. self.overflow_bit_node = None
  4313. else:
  4314. self.visitchildren(node)
  4315. return node