OpenMPOpt.cpp 207 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545
  1. //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
  2. //
  3. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  4. // See https://llvm.org/LICENSE.txt for license information.
  5. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  6. //
  7. //===----------------------------------------------------------------------===//
  8. //
  9. // OpenMP specific optimizations:
  10. //
  11. // - Deduplication of runtime calls, e.g., omp_get_thread_num.
  12. // - Replacing globalized device memory with stack memory.
  13. // - Replacing globalized device memory with shared memory.
  14. // - Parallel region merging.
  15. // - Transforming generic-mode device kernels to SPMD mode.
  16. // - Specializing the state machine for generic-mode device kernels.
  17. //
  18. //===----------------------------------------------------------------------===//
  19. #include "llvm/Transforms/IPO/OpenMPOpt.h"
  20. #include "llvm/ADT/EnumeratedArray.h"
  21. #include "llvm/ADT/PostOrderIterator.h"
  22. #include "llvm/ADT/SetVector.h"
  23. #include "llvm/ADT/SmallVector.h"
  24. #include "llvm/ADT/Statistic.h"
  25. #include "llvm/ADT/StringRef.h"
  26. #include "llvm/Analysis/CallGraph.h"
  27. #include "llvm/Analysis/CallGraphSCCPass.h"
  28. #include "llvm/Analysis/MemoryLocation.h"
  29. #include "llvm/Analysis/OptimizationRemarkEmitter.h"
  30. #include "llvm/Analysis/ValueTracking.h"
  31. #include "llvm/Frontend/OpenMP/OMPConstants.h"
  32. #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
  33. #include "llvm/IR/Assumptions.h"
  34. #include "llvm/IR/BasicBlock.h"
  35. #include "llvm/IR/Constants.h"
  36. #include "llvm/IR/DiagnosticInfo.h"
  37. #include "llvm/IR/GlobalValue.h"
  38. #include "llvm/IR/GlobalVariable.h"
  39. #include "llvm/IR/Instruction.h"
  40. #include "llvm/IR/Instructions.h"
  41. #include "llvm/IR/IntrinsicInst.h"
  42. #include "llvm/IR/IntrinsicsAMDGPU.h"
  43. #include "llvm/IR/IntrinsicsNVPTX.h"
  44. #include "llvm/IR/LLVMContext.h"
  45. #include "llvm/InitializePasses.h"
  46. #include "llvm/Support/CommandLine.h"
  47. #include "llvm/Support/Debug.h"
  48. #include "llvm/Transforms/IPO/Attributor.h"
  49. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  50. #include "llvm/Transforms/Utils/CallGraphUpdater.h"
  51. #include <algorithm>
  52. #include <optional>
  53. #include <string>
  54. using namespace llvm;
  55. using namespace omp;
  56. #define DEBUG_TYPE "openmp-opt"
  57. static cl::opt<bool> DisableOpenMPOptimizations(
  58. "openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),
  59. cl::Hidden, cl::init(false));
  60. static cl::opt<bool> EnableParallelRegionMerging(
  61. "openmp-opt-enable-merging",
  62. cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
  63. cl::init(false));
  64. static cl::opt<bool>
  65. DisableInternalization("openmp-opt-disable-internalization",
  66. cl::desc("Disable function internalization."),
  67. cl::Hidden, cl::init(false));
  68. static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",
  69. cl::init(false), cl::Hidden);
  70. static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
  71. cl::Hidden);
  72. static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
  73. cl::init(false), cl::Hidden);
  74. static cl::opt<bool> HideMemoryTransferLatency(
  75. "openmp-hide-memory-transfer-latency",
  76. cl::desc("[WIP] Tries to hide the latency of host to device memory"
  77. " transfers"),
  78. cl::Hidden, cl::init(false));
  79. static cl::opt<bool> DisableOpenMPOptDeglobalization(
  80. "openmp-opt-disable-deglobalization",
  81. cl::desc("Disable OpenMP optimizations involving deglobalization."),
  82. cl::Hidden, cl::init(false));
  83. static cl::opt<bool> DisableOpenMPOptSPMDization(
  84. "openmp-opt-disable-spmdization",
  85. cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
  86. cl::Hidden, cl::init(false));
  87. static cl::opt<bool> DisableOpenMPOptFolding(
  88. "openmp-opt-disable-folding",
  89. cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
  90. cl::init(false));
  91. static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
  92. "openmp-opt-disable-state-machine-rewrite",
  93. cl::desc("Disable OpenMP optimizations that replace the state machine."),
  94. cl::Hidden, cl::init(false));
  95. static cl::opt<bool> DisableOpenMPOptBarrierElimination(
  96. "openmp-opt-disable-barrier-elimination",
  97. cl::desc("Disable OpenMP optimizations that eliminate barriers."),
  98. cl::Hidden, cl::init(false));
  99. static cl::opt<bool> PrintModuleAfterOptimizations(
  100. "openmp-opt-print-module-after",
  101. cl::desc("Print the current module after OpenMP optimizations."),
  102. cl::Hidden, cl::init(false));
  103. static cl::opt<bool> PrintModuleBeforeOptimizations(
  104. "openmp-opt-print-module-before",
  105. cl::desc("Print the current module before OpenMP optimizations."),
  106. cl::Hidden, cl::init(false));
  107. static cl::opt<bool> AlwaysInlineDeviceFunctions(
  108. "openmp-opt-inline-device",
  109. cl::desc("Inline all applicible functions on the device."), cl::Hidden,
  110. cl::init(false));
  111. static cl::opt<bool>
  112. EnableVerboseRemarks("openmp-opt-verbose-remarks",
  113. cl::desc("Enables more verbose remarks."), cl::Hidden,
  114. cl::init(false));
  115. static cl::opt<unsigned>
  116. SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,
  117. cl::desc("Maximal number of attributor iterations."),
  118. cl::init(256));
  119. static cl::opt<unsigned>
  120. SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,
  121. cl::desc("Maximum amount of shared memory to use."),
  122. cl::init(std::numeric_limits<unsigned>::max()));
  123. STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
  124. "Number of OpenMP runtime calls deduplicated");
  125. STATISTIC(NumOpenMPParallelRegionsDeleted,
  126. "Number of OpenMP parallel regions deleted");
  127. STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
  128. "Number of OpenMP runtime functions identified");
  129. STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
  130. "Number of OpenMP runtime function uses identified");
  131. STATISTIC(NumOpenMPTargetRegionKernels,
  132. "Number of OpenMP target region entry points (=kernels) identified");
  133. STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
  134. "Number of OpenMP target region entry points (=kernels) executed in "
  135. "SPMD-mode instead of generic-mode");
  136. STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
  137. "Number of OpenMP target region entry points (=kernels) executed in "
  138. "generic-mode without a state machines");
  139. STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
  140. "Number of OpenMP target region entry points (=kernels) executed in "
  141. "generic-mode with customized state machines with fallback");
  142. STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
  143. "Number of OpenMP target region entry points (=kernels) executed in "
  144. "generic-mode with customized state machines without fallback");
  145. STATISTIC(
  146. NumOpenMPParallelRegionsReplacedInGPUStateMachine,
  147. "Number of OpenMP parallel regions replaced with ID in GPU state machines");
  148. STATISTIC(NumOpenMPParallelRegionsMerged,
  149. "Number of OpenMP parallel regions merged");
  150. STATISTIC(NumBytesMovedToSharedMemory,
  151. "Amount of memory pushed to shared memory");
  152. STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");
  153. #if !defined(NDEBUG)
  154. static constexpr auto TAG = "[" DEBUG_TYPE "]";
  155. #endif
  156. namespace {
  157. struct AAHeapToShared;
  158. struct AAICVTracker;
  159. /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
  160. /// Attributor runs.
  161. struct OMPInformationCache : public InformationCache {
  162. OMPInformationCache(Module &M, AnalysisGetter &AG,
  163. BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,
  164. KernelSet &Kernels, bool OpenMPPostLink)
  165. : InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),
  166. Kernels(Kernels), OpenMPPostLink(OpenMPPostLink) {
  167. OMPBuilder.initialize();
  168. initializeRuntimeFunctions(M);
  169. initializeInternalControlVars();
  170. }
  171. /// Generic information that describes an internal control variable.
  172. struct InternalControlVarInfo {
  173. /// The kind, as described by InternalControlVar enum.
  174. InternalControlVar Kind;
  175. /// The name of the ICV.
  176. StringRef Name;
  177. /// Environment variable associated with this ICV.
  178. StringRef EnvVarName;
  179. /// Initial value kind.
  180. ICVInitValue InitKind;
  181. /// Initial value.
  182. ConstantInt *InitValue;
  183. /// Setter RTL function associated with this ICV.
  184. RuntimeFunction Setter;
  185. /// Getter RTL function associated with this ICV.
  186. RuntimeFunction Getter;
  187. /// RTL Function corresponding to the override clause of this ICV
  188. RuntimeFunction Clause;
  189. };
  190. /// Generic information that describes a runtime function
  191. struct RuntimeFunctionInfo {
  192. /// The kind, as described by the RuntimeFunction enum.
  193. RuntimeFunction Kind;
  194. /// The name of the function.
  195. StringRef Name;
  196. /// Flag to indicate a variadic function.
  197. bool IsVarArg;
  198. /// The return type of the function.
  199. Type *ReturnType;
  200. /// The argument types of the function.
  201. SmallVector<Type *, 8> ArgumentTypes;
  202. /// The declaration if available.
  203. Function *Declaration = nullptr;
  204. /// Uses of this runtime function per function containing the use.
  205. using UseVector = SmallVector<Use *, 16>;
  206. /// Clear UsesMap for runtime function.
  207. void clearUsesMap() { UsesMap.clear(); }
  208. /// Boolean conversion that is true if the runtime function was found.
  209. operator bool() const { return Declaration; }
  210. /// Return the vector of uses in function \p F.
  211. UseVector &getOrCreateUseVector(Function *F) {
  212. std::shared_ptr<UseVector> &UV = UsesMap[F];
  213. if (!UV)
  214. UV = std::make_shared<UseVector>();
  215. return *UV;
  216. }
  217. /// Return the vector of uses in function \p F or `nullptr` if there are
  218. /// none.
  219. const UseVector *getUseVector(Function &F) const {
  220. auto I = UsesMap.find(&F);
  221. if (I != UsesMap.end())
  222. return I->second.get();
  223. return nullptr;
  224. }
  225. /// Return how many functions contain uses of this runtime function.
  226. size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
  227. /// Return the number of arguments (or the minimal number for variadic
  228. /// functions).
  229. size_t getNumArgs() const { return ArgumentTypes.size(); }
  230. /// Run the callback \p CB on each use and forget the use if the result is
  231. /// true. The callback will be fed the function in which the use was
  232. /// encountered as second argument.
  233. void foreachUse(SmallVectorImpl<Function *> &SCC,
  234. function_ref<bool(Use &, Function &)> CB) {
  235. for (Function *F : SCC)
  236. foreachUse(CB, F);
  237. }
  238. /// Run the callback \p CB on each use within the function \p F and forget
  239. /// the use if the result is true.
  240. void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
  241. SmallVector<unsigned, 8> ToBeDeleted;
  242. ToBeDeleted.clear();
  243. unsigned Idx = 0;
  244. UseVector &UV = getOrCreateUseVector(F);
  245. for (Use *U : UV) {
  246. if (CB(*U, *F))
  247. ToBeDeleted.push_back(Idx);
  248. ++Idx;
  249. }
  250. // Remove the to-be-deleted indices in reverse order as prior
  251. // modifications will not modify the smaller indices.
  252. while (!ToBeDeleted.empty()) {
  253. unsigned Idx = ToBeDeleted.pop_back_val();
  254. UV[Idx] = UV.back();
  255. UV.pop_back();
  256. }
  257. }
  258. private:
  259. /// Map from functions to all uses of this runtime function contained in
  260. /// them.
  261. DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
  262. public:
  263. /// Iterators for the uses of this runtime function.
  264. decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
  265. decltype(UsesMap)::iterator end() { return UsesMap.end(); }
  266. };
  267. /// An OpenMP-IR-Builder instance
  268. OpenMPIRBuilder OMPBuilder;
  269. /// Map from runtime function kind to the runtime function description.
  270. EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
  271. RuntimeFunction::OMPRTL___last>
  272. RFIs;
  273. /// Map from function declarations/definitions to their runtime enum type.
  274. DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
  275. /// Map from ICV kind to the ICV description.
  276. EnumeratedArray<InternalControlVarInfo, InternalControlVar,
  277. InternalControlVar::ICV___last>
  278. ICVs;
  279. /// Helper to initialize all internal control variable information for those
  280. /// defined in OMPKinds.def.
  281. void initializeInternalControlVars() {
  282. #define ICV_RT_SET(_Name, RTL) \
  283. { \
  284. auto &ICV = ICVs[_Name]; \
  285. ICV.Setter = RTL; \
  286. }
  287. #define ICV_RT_GET(Name, RTL) \
  288. { \
  289. auto &ICV = ICVs[Name]; \
  290. ICV.Getter = RTL; \
  291. }
  292. #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \
  293. { \
  294. auto &ICV = ICVs[Enum]; \
  295. ICV.Name = _Name; \
  296. ICV.Kind = Enum; \
  297. ICV.InitKind = Init; \
  298. ICV.EnvVarName = _EnvVarName; \
  299. switch (ICV.InitKind) { \
  300. case ICV_IMPLEMENTATION_DEFINED: \
  301. ICV.InitValue = nullptr; \
  302. break; \
  303. case ICV_ZERO: \
  304. ICV.InitValue = ConstantInt::get( \
  305. Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \
  306. break; \
  307. case ICV_FALSE: \
  308. ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \
  309. break; \
  310. case ICV_LAST: \
  311. break; \
  312. } \
  313. }
  314. #include "llvm/Frontend/OpenMP/OMPKinds.def"
  315. }
  316. /// Returns true if the function declaration \p F matches the runtime
  317. /// function types, that is, return type \p RTFRetType, and argument types
  318. /// \p RTFArgTypes.
  319. static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
  320. SmallVector<Type *, 8> &RTFArgTypes) {
  321. // TODO: We should output information to the user (under debug output
  322. // and via remarks).
  323. if (!F)
  324. return false;
  325. if (F->getReturnType() != RTFRetType)
  326. return false;
  327. if (F->arg_size() != RTFArgTypes.size())
  328. return false;
  329. auto *RTFTyIt = RTFArgTypes.begin();
  330. for (Argument &Arg : F->args()) {
  331. if (Arg.getType() != *RTFTyIt)
  332. return false;
  333. ++RTFTyIt;
  334. }
  335. return true;
  336. }
  337. // Helper to collect all uses of the declaration in the UsesMap.
  338. unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
  339. unsigned NumUses = 0;
  340. if (!RFI.Declaration)
  341. return NumUses;
  342. OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
  343. if (CollectStats) {
  344. NumOpenMPRuntimeFunctionsIdentified += 1;
  345. NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
  346. }
  347. // TODO: We directly convert uses into proper calls and unknown uses.
  348. for (Use &U : RFI.Declaration->uses()) {
  349. if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
  350. if (ModuleSlice.empty() || ModuleSlice.count(UserI->getFunction())) {
  351. RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
  352. ++NumUses;
  353. }
  354. } else {
  355. RFI.getOrCreateUseVector(nullptr).push_back(&U);
  356. ++NumUses;
  357. }
  358. }
  359. return NumUses;
  360. }
  361. // Helper function to recollect uses of a runtime function.
  362. void recollectUsesForFunction(RuntimeFunction RTF) {
  363. auto &RFI = RFIs[RTF];
  364. RFI.clearUsesMap();
  365. collectUses(RFI, /*CollectStats*/ false);
  366. }
  367. // Helper function to recollect uses of all runtime functions.
  368. void recollectUses() {
  369. for (int Idx = 0; Idx < RFIs.size(); ++Idx)
  370. recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
  371. }
  372. // Helper function to inherit the calling convention of the function callee.
  373. void setCallingConvention(FunctionCallee Callee, CallInst *CI) {
  374. if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))
  375. CI->setCallingConv(Fn->getCallingConv());
  376. }
  377. // Helper function to determine if it's legal to create a call to the runtime
  378. // functions.
  379. bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {
  380. // We can always emit calls if we haven't yet linked in the runtime.
  381. if (!OpenMPPostLink)
  382. return true;
  383. // Once the runtime has been already been linked in we cannot emit calls to
  384. // any undefined functions.
  385. for (RuntimeFunction Fn : Fns) {
  386. RuntimeFunctionInfo &RFI = RFIs[Fn];
  387. if (RFI.Declaration && RFI.Declaration->isDeclaration())
  388. return false;
  389. }
  390. return true;
  391. }
  392. /// Helper to initialize all runtime function information for those defined
  393. /// in OpenMPKinds.def.
  394. void initializeRuntimeFunctions(Module &M) {
  395. // Helper macros for handling __VA_ARGS__ in OMP_RTL
  396. #define OMP_TYPE(VarName, ...) \
  397. Type *VarName = OMPBuilder.VarName; \
  398. (void)VarName;
  399. #define OMP_ARRAY_TYPE(VarName, ...) \
  400. ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \
  401. (void)VarName##Ty; \
  402. PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \
  403. (void)VarName##PtrTy;
  404. #define OMP_FUNCTION_TYPE(VarName, ...) \
  405. FunctionType *VarName = OMPBuilder.VarName; \
  406. (void)VarName; \
  407. PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
  408. (void)VarName##Ptr;
  409. #define OMP_STRUCT_TYPE(VarName, ...) \
  410. StructType *VarName = OMPBuilder.VarName; \
  411. (void)VarName; \
  412. PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \
  413. (void)VarName##Ptr;
  414. #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \
  415. { \
  416. SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \
  417. Function *F = M.getFunction(_Name); \
  418. RTLFunctions.insert(F); \
  419. if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \
  420. RuntimeFunctionIDMap[F] = _Enum; \
  421. auto &RFI = RFIs[_Enum]; \
  422. RFI.Kind = _Enum; \
  423. RFI.Name = _Name; \
  424. RFI.IsVarArg = _IsVarArg; \
  425. RFI.ReturnType = OMPBuilder._ReturnType; \
  426. RFI.ArgumentTypes = std::move(ArgsTypes); \
  427. RFI.Declaration = F; \
  428. unsigned NumUses = collectUses(RFI); \
  429. (void)NumUses; \
  430. LLVM_DEBUG({ \
  431. dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \
  432. << " found\n"; \
  433. if (RFI.Declaration) \
  434. dbgs() << TAG << "-> got " << NumUses << " uses in " \
  435. << RFI.getNumFunctionsWithUses() \
  436. << " different functions.\n"; \
  437. }); \
  438. } \
  439. }
  440. #include "llvm/Frontend/OpenMP/OMPKinds.def"
  441. // Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`
  442. // functions, except if `optnone` is present.
  443. if (isOpenMPDevice(M)) {
  444. for (Function &F : M) {
  445. for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})
  446. if (F.hasFnAttribute(Attribute::NoInline) &&
  447. F.getName().startswith(Prefix) &&
  448. !F.hasFnAttribute(Attribute::OptimizeNone))
  449. F.removeFnAttr(Attribute::NoInline);
  450. }
  451. }
  452. // TODO: We should attach the attributes defined in OMPKinds.def.
  453. }
  454. /// Collection of known kernels (\see Kernel) in the module.
  455. KernelSet &Kernels;
  456. /// Collection of known OpenMP runtime functions..
  457. DenseSet<const Function *> RTLFunctions;
  458. /// Indicates if we have already linked in the OpenMP device library.
  459. bool OpenMPPostLink = false;
  460. };
  461. template <typename Ty, bool InsertInvalidates = true>
  462. struct BooleanStateWithSetVector : public BooleanState {
  463. bool contains(const Ty &Elem) const { return Set.contains(Elem); }
  464. bool insert(const Ty &Elem) {
  465. if (InsertInvalidates)
  466. BooleanState::indicatePessimisticFixpoint();
  467. return Set.insert(Elem);
  468. }
  469. const Ty &operator[](int Idx) const { return Set[Idx]; }
  470. bool operator==(const BooleanStateWithSetVector &RHS) const {
  471. return BooleanState::operator==(RHS) && Set == RHS.Set;
  472. }
  473. bool operator!=(const BooleanStateWithSetVector &RHS) const {
  474. return !(*this == RHS);
  475. }
  476. bool empty() const { return Set.empty(); }
  477. size_t size() const { return Set.size(); }
  478. /// "Clamp" this state with \p RHS.
  479. BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
  480. BooleanState::operator^=(RHS);
  481. Set.insert(RHS.Set.begin(), RHS.Set.end());
  482. return *this;
  483. }
  484. private:
  485. /// A set to keep track of elements.
  486. SetVector<Ty> Set;
  487. public:
  488. typename decltype(Set)::iterator begin() { return Set.begin(); }
  489. typename decltype(Set)::iterator end() { return Set.end(); }
  490. typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
  491. typename decltype(Set)::const_iterator end() const { return Set.end(); }
  492. };
  493. template <typename Ty, bool InsertInvalidates = true>
  494. using BooleanStateWithPtrSetVector =
  495. BooleanStateWithSetVector<Ty *, InsertInvalidates>;
  496. struct KernelInfoState : AbstractState {
  497. /// Flag to track if we reached a fixpoint.
  498. bool IsAtFixpoint = false;
  499. /// The parallel regions (identified by the outlined parallel functions) that
  500. /// can be reached from the associated function.
  501. BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
  502. ReachedKnownParallelRegions;
  503. /// State to track what parallel region we might reach.
  504. BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
  505. /// State to track if we are in SPMD-mode, assumed or know, and why we decided
  506. /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
  507. /// false.
  508. BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
  509. /// The __kmpc_target_init call in this kernel, if any. If we find more than
  510. /// one we abort as the kernel is malformed.
  511. CallBase *KernelInitCB = nullptr;
  512. /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
  513. /// one we abort as the kernel is malformed.
  514. CallBase *KernelDeinitCB = nullptr;
  515. /// Flag to indicate if the associated function is a kernel entry.
  516. bool IsKernelEntry = false;
  517. /// State to track what kernel entries can reach the associated function.
  518. BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
  519. /// State to indicate if we can track parallel level of the associated
  520. /// function. We will give up tracking if we encounter unknown caller or the
  521. /// caller is __kmpc_parallel_51.
  522. BooleanStateWithSetVector<uint8_t> ParallelLevels;
  523. /// Flag that indicates if the kernel has nested Parallelism
  524. bool NestedParallelism = false;
  525. /// Abstract State interface
  526. ///{
  527. KernelInfoState() = default;
  528. KernelInfoState(bool BestState) {
  529. if (!BestState)
  530. indicatePessimisticFixpoint();
  531. }
  532. /// See AbstractState::isValidState(...)
  533. bool isValidState() const override { return true; }
  534. /// See AbstractState::isAtFixpoint(...)
  535. bool isAtFixpoint() const override { return IsAtFixpoint; }
  536. /// See AbstractState::indicatePessimisticFixpoint(...)
  537. ChangeStatus indicatePessimisticFixpoint() override {
  538. IsAtFixpoint = true;
  539. ParallelLevels.indicatePessimisticFixpoint();
  540. ReachingKernelEntries.indicatePessimisticFixpoint();
  541. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  542. ReachedKnownParallelRegions.indicatePessimisticFixpoint();
  543. ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
  544. return ChangeStatus::CHANGED;
  545. }
  546. /// See AbstractState::indicateOptimisticFixpoint(...)
  547. ChangeStatus indicateOptimisticFixpoint() override {
  548. IsAtFixpoint = true;
  549. ParallelLevels.indicateOptimisticFixpoint();
  550. ReachingKernelEntries.indicateOptimisticFixpoint();
  551. SPMDCompatibilityTracker.indicateOptimisticFixpoint();
  552. ReachedKnownParallelRegions.indicateOptimisticFixpoint();
  553. ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
  554. return ChangeStatus::UNCHANGED;
  555. }
  556. /// Return the assumed state
  557. KernelInfoState &getAssumed() { return *this; }
  558. const KernelInfoState &getAssumed() const { return *this; }
  559. bool operator==(const KernelInfoState &RHS) const {
  560. if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
  561. return false;
  562. if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
  563. return false;
  564. if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
  565. return false;
  566. if (ReachingKernelEntries != RHS.ReachingKernelEntries)
  567. return false;
  568. if (ParallelLevels != RHS.ParallelLevels)
  569. return false;
  570. return true;
  571. }
  572. /// Returns true if this kernel contains any OpenMP parallel regions.
  573. bool mayContainParallelRegion() {
  574. return !ReachedKnownParallelRegions.empty() ||
  575. !ReachedUnknownParallelRegions.empty();
  576. }
  577. /// Return empty set as the best state of potential values.
  578. static KernelInfoState getBestState() { return KernelInfoState(true); }
  579. static KernelInfoState getBestState(KernelInfoState &KIS) {
  580. return getBestState();
  581. }
  582. /// Return full set as the worst state of potential values.
  583. static KernelInfoState getWorstState() { return KernelInfoState(false); }
  584. /// "Clamp" this state with \p KIS.
  585. KernelInfoState operator^=(const KernelInfoState &KIS) {
  586. // Do not merge two different _init and _deinit call sites.
  587. if (KIS.KernelInitCB) {
  588. if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
  589. llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
  590. "assumptions.");
  591. KernelInitCB = KIS.KernelInitCB;
  592. }
  593. if (KIS.KernelDeinitCB) {
  594. if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
  595. llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "
  596. "assumptions.");
  597. KernelDeinitCB = KIS.KernelDeinitCB;
  598. }
  599. SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
  600. ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
  601. ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
  602. NestedParallelism |= KIS.NestedParallelism;
  603. return *this;
  604. }
  605. KernelInfoState operator&=(const KernelInfoState &KIS) {
  606. return (*this ^= KIS);
  607. }
  608. ///}
  609. };
  610. /// Used to map the values physically (in the IR) stored in an offload
  611. /// array, to a vector in memory.
  612. struct OffloadArray {
  613. /// Physical array (in the IR).
  614. AllocaInst *Array = nullptr;
  615. /// Mapped values.
  616. SmallVector<Value *, 8> StoredValues;
  617. /// Last stores made in the offload array.
  618. SmallVector<StoreInst *, 8> LastAccesses;
  619. OffloadArray() = default;
  620. /// Initializes the OffloadArray with the values stored in \p Array before
  621. /// instruction \p Before is reached. Returns false if the initialization
  622. /// fails.
  623. /// This MUST be used immediately after the construction of the object.
  624. bool initialize(AllocaInst &Array, Instruction &Before) {
  625. if (!Array.getAllocatedType()->isArrayTy())
  626. return false;
  627. if (!getValues(Array, Before))
  628. return false;
  629. this->Array = &Array;
  630. return true;
  631. }
  632. static const unsigned DeviceIDArgNum = 1;
  633. static const unsigned BasePtrsArgNum = 3;
  634. static const unsigned PtrsArgNum = 4;
  635. static const unsigned SizesArgNum = 5;
  636. private:
  637. /// Traverses the BasicBlock where \p Array is, collecting the stores made to
  638. /// \p Array, leaving StoredValues with the values stored before the
  639. /// instruction \p Before is reached.
  640. bool getValues(AllocaInst &Array, Instruction &Before) {
  641. // Initialize container.
  642. const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
  643. StoredValues.assign(NumValues, nullptr);
  644. LastAccesses.assign(NumValues, nullptr);
  645. // TODO: This assumes the instruction \p Before is in the same
  646. // BasicBlock as Array. Make it general, for any control flow graph.
  647. BasicBlock *BB = Array.getParent();
  648. if (BB != Before.getParent())
  649. return false;
  650. const DataLayout &DL = Array.getModule()->getDataLayout();
  651. const unsigned int PointerSize = DL.getPointerSize();
  652. for (Instruction &I : *BB) {
  653. if (&I == &Before)
  654. break;
  655. if (!isa<StoreInst>(&I))
  656. continue;
  657. auto *S = cast<StoreInst>(&I);
  658. int64_t Offset = -1;
  659. auto *Dst =
  660. GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
  661. if (Dst == &Array) {
  662. int64_t Idx = Offset / PointerSize;
  663. StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
  664. LastAccesses[Idx] = S;
  665. }
  666. }
  667. return isFilled();
  668. }
  669. /// Returns true if all values in StoredValues and
  670. /// LastAccesses are not nullptrs.
  671. bool isFilled() {
  672. const unsigned NumValues = StoredValues.size();
  673. for (unsigned I = 0; I < NumValues; ++I) {
  674. if (!StoredValues[I] || !LastAccesses[I])
  675. return false;
  676. }
  677. return true;
  678. }
  679. };
  680. struct OpenMPOpt {
  681. using OptimizationRemarkGetter =
  682. function_ref<OptimizationRemarkEmitter &(Function *)>;
  683. OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
  684. OptimizationRemarkGetter OREGetter,
  685. OMPInformationCache &OMPInfoCache, Attributor &A)
  686. : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
  687. OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
  688. /// Check if any remarks are enabled for openmp-opt
  689. bool remarksEnabled() {
  690. auto &Ctx = M.getContext();
  691. return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
  692. }
  693. /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
  694. bool run(bool IsModulePass) {
  695. if (SCC.empty())
  696. return false;
  697. bool Changed = false;
  698. LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
  699. << " functions in a slice with "
  700. << OMPInfoCache.ModuleSlice.size() << " functions\n");
  701. if (IsModulePass) {
  702. Changed |= runAttributor(IsModulePass);
  703. // Recollect uses, in case Attributor deleted any.
  704. OMPInfoCache.recollectUses();
  705. // TODO: This should be folded into buildCustomStateMachine.
  706. Changed |= rewriteDeviceCodeStateMachine();
  707. if (remarksEnabled())
  708. analysisGlobalization();
  709. } else {
  710. if (PrintICVValues)
  711. printICVs();
  712. if (PrintOpenMPKernels)
  713. printKernels();
  714. Changed |= runAttributor(IsModulePass);
  715. // Recollect uses, in case Attributor deleted any.
  716. OMPInfoCache.recollectUses();
  717. Changed |= deleteParallelRegions();
  718. if (HideMemoryTransferLatency)
  719. Changed |= hideMemTransfersLatency();
  720. Changed |= deduplicateRuntimeCalls();
  721. if (EnableParallelRegionMerging) {
  722. if (mergeParallelRegions()) {
  723. deduplicateRuntimeCalls();
  724. Changed = true;
  725. }
  726. }
  727. }
  728. return Changed;
  729. }
  730. /// Print initial ICV values for testing.
  731. /// FIXME: This should be done from the Attributor once it is added.
  732. void printICVs() const {
  733. InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
  734. ICV_proc_bind};
  735. for (Function *F : SCC) {
  736. for (auto ICV : ICVs) {
  737. auto ICVInfo = OMPInfoCache.ICVs[ICV];
  738. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  739. return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
  740. << " Value: "
  741. << (ICVInfo.InitValue
  742. ? toString(ICVInfo.InitValue->getValue(), 10, true)
  743. : "IMPLEMENTATION_DEFINED");
  744. };
  745. emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
  746. }
  747. }
  748. }
  749. /// Print OpenMP GPU kernels for testing.
  750. void printKernels() const {
  751. for (Function *F : SCC) {
  752. if (!OMPInfoCache.Kernels.count(F))
  753. continue;
  754. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  755. return ORA << "OpenMP GPU kernel "
  756. << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
  757. };
  758. emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
  759. }
  760. }
  761. /// Return the call if \p U is a callee use in a regular call. If \p RFI is
  762. /// given it has to be the callee or a nullptr is returned.
  763. static CallInst *getCallIfRegularCall(
  764. Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
  765. CallInst *CI = dyn_cast<CallInst>(U.getUser());
  766. if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
  767. (!RFI ||
  768. (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
  769. return CI;
  770. return nullptr;
  771. }
  772. /// Return the call if \p V is a regular call. If \p RFI is given it has to be
  773. /// the callee or a nullptr is returned.
  774. static CallInst *getCallIfRegularCall(
  775. Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
  776. CallInst *CI = dyn_cast<CallInst>(&V);
  777. if (CI && !CI->hasOperandBundles() &&
  778. (!RFI ||
  779. (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
  780. return CI;
  781. return nullptr;
  782. }
  783. private:
  784. /// Merge parallel regions when it is safe.
  785. bool mergeParallelRegions() {
  786. const unsigned CallbackCalleeOperand = 2;
  787. const unsigned CallbackFirstArgOperand = 3;
  788. using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
  789. // Check if there are any __kmpc_fork_call calls to merge.
  790. OMPInformationCache::RuntimeFunctionInfo &RFI =
  791. OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
  792. if (!RFI.Declaration)
  793. return false;
  794. // Unmergable calls that prevent merging a parallel region.
  795. OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
  796. OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
  797. OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
  798. };
  799. bool Changed = false;
  800. LoopInfo *LI = nullptr;
  801. DominatorTree *DT = nullptr;
  802. SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
  803. BasicBlock *StartBB = nullptr, *EndBB = nullptr;
  804. auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
  805. BasicBlock *CGStartBB = CodeGenIP.getBlock();
  806. BasicBlock *CGEndBB =
  807. SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
  808. assert(StartBB != nullptr && "StartBB should not be null");
  809. CGStartBB->getTerminator()->setSuccessor(0, StartBB);
  810. assert(EndBB != nullptr && "EndBB should not be null");
  811. EndBB->getTerminator()->setSuccessor(0, CGEndBB);
  812. };
  813. auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
  814. Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
  815. ReplacementValue = &Inner;
  816. return CodeGenIP;
  817. };
  818. auto FiniCB = [&](InsertPointTy CodeGenIP) {};
  819. /// Create a sequential execution region within a merged parallel region,
  820. /// encapsulated in a master construct with a barrier for synchronization.
  821. auto CreateSequentialRegion = [&](Function *OuterFn,
  822. BasicBlock *OuterPredBB,
  823. Instruction *SeqStartI,
  824. Instruction *SeqEndI) {
  825. // Isolate the instructions of the sequential region to a separate
  826. // block.
  827. BasicBlock *ParentBB = SeqStartI->getParent();
  828. BasicBlock *SeqEndBB =
  829. SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
  830. BasicBlock *SeqAfterBB =
  831. SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
  832. BasicBlock *SeqStartBB =
  833. SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
  834. assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
  835. "Expected a different CFG");
  836. const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
  837. ParentBB->getTerminator()->eraseFromParent();
  838. auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
  839. BasicBlock *CGStartBB = CodeGenIP.getBlock();
  840. BasicBlock *CGEndBB =
  841. SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
  842. assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
  843. CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
  844. assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
  845. SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
  846. };
  847. auto FiniCB = [&](InsertPointTy CodeGenIP) {};
  848. // Find outputs from the sequential region to outside users and
  849. // broadcast their values to them.
  850. for (Instruction &I : *SeqStartBB) {
  851. SmallPtrSet<Instruction *, 4> OutsideUsers;
  852. for (User *Usr : I.users()) {
  853. Instruction &UsrI = *cast<Instruction>(Usr);
  854. // Ignore outputs to LT intrinsics, code extraction for the merged
  855. // parallel region will fix them.
  856. if (UsrI.isLifetimeStartOrEnd())
  857. continue;
  858. if (UsrI.getParent() != SeqStartBB)
  859. OutsideUsers.insert(&UsrI);
  860. }
  861. if (OutsideUsers.empty())
  862. continue;
  863. // Emit an alloca in the outer region to store the broadcasted
  864. // value.
  865. const DataLayout &DL = M.getDataLayout();
  866. AllocaInst *AllocaI = new AllocaInst(
  867. I.getType(), DL.getAllocaAddrSpace(), nullptr,
  868. I.getName() + ".seq.output.alloc", &OuterFn->front().front());
  869. // Emit a store instruction in the sequential BB to update the
  870. // value.
  871. new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
  872. // Emit a load instruction and replace the use of the output value
  873. // with it.
  874. for (Instruction *UsrI : OutsideUsers) {
  875. LoadInst *LoadI = new LoadInst(
  876. I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
  877. UsrI->replaceUsesOfWith(&I, LoadI);
  878. }
  879. }
  880. OpenMPIRBuilder::LocationDescription Loc(
  881. InsertPointTy(ParentBB, ParentBB->end()), DL);
  882. InsertPointTy SeqAfterIP =
  883. OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
  884. OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
  885. BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
  886. LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
  887. << "\n");
  888. };
  889. // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
  890. // contained in BB and only separated by instructions that can be
  891. // redundantly executed in parallel. The block BB is split before the first
  892. // call (in MergableCIs) and after the last so the entire region we merge
  893. // into a single parallel region is contained in a single basic block
  894. // without any other instructions. We use the OpenMPIRBuilder to outline
  895. // that block and call the resulting function via __kmpc_fork_call.
  896. auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,
  897. BasicBlock *BB) {
  898. // TODO: Change the interface to allow single CIs expanded, e.g, to
  899. // include an outer loop.
  900. assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
  901. auto Remark = [&](OptimizationRemark OR) {
  902. OR << "Parallel region merged with parallel region"
  903. << (MergableCIs.size() > 2 ? "s" : "") << " at ";
  904. for (auto *CI : llvm::drop_begin(MergableCIs)) {
  905. OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
  906. if (CI != MergableCIs.back())
  907. OR << ", ";
  908. }
  909. return OR << ".";
  910. };
  911. emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
  912. Function *OriginalFn = BB->getParent();
  913. LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
  914. << " parallel regions in " << OriginalFn->getName()
  915. << "\n");
  916. // Isolate the calls to merge in a separate block.
  917. EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
  918. BasicBlock *AfterBB =
  919. SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
  920. StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
  921. "omp.par.merged");
  922. assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
  923. const DebugLoc DL = BB->getTerminator()->getDebugLoc();
  924. BB->getTerminator()->eraseFromParent();
  925. // Create sequential regions for sequential instructions that are
  926. // in-between mergable parallel regions.
  927. for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
  928. It != End; ++It) {
  929. Instruction *ForkCI = *It;
  930. Instruction *NextForkCI = *(It + 1);
  931. // Continue if there are not in-between instructions.
  932. if (ForkCI->getNextNode() == NextForkCI)
  933. continue;
  934. CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
  935. NextForkCI->getPrevNode());
  936. }
  937. OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
  938. DL);
  939. IRBuilder<>::InsertPoint AllocaIP(
  940. &OriginalFn->getEntryBlock(),
  941. OriginalFn->getEntryBlock().getFirstInsertionPt());
  942. // Create the merged parallel region with default proc binding, to
  943. // avoid overriding binding settings, and without explicit cancellation.
  944. InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
  945. Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
  946. OMP_PROC_BIND_default, /* IsCancellable */ false);
  947. BranchInst::Create(AfterBB, AfterIP.getBlock());
  948. // Perform the actual outlining.
  949. OMPInfoCache.OMPBuilder.finalize(OriginalFn);
  950. Function *OutlinedFn = MergableCIs.front()->getCaller();
  951. // Replace the __kmpc_fork_call calls with direct calls to the outlined
  952. // callbacks.
  953. SmallVector<Value *, 8> Args;
  954. for (auto *CI : MergableCIs) {
  955. Value *Callee = CI->getArgOperand(CallbackCalleeOperand);
  956. FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;
  957. Args.clear();
  958. Args.push_back(OutlinedFn->getArg(0));
  959. Args.push_back(OutlinedFn->getArg(1));
  960. for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
  961. ++U)
  962. Args.push_back(CI->getArgOperand(U));
  963. CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
  964. if (CI->getDebugLoc())
  965. NewCI->setDebugLoc(CI->getDebugLoc());
  966. // Forward parameter attributes from the callback to the callee.
  967. for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;
  968. ++U)
  969. for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
  970. NewCI->addParamAttr(
  971. U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
  972. // Emit an explicit barrier to replace the implicit fork-join barrier.
  973. if (CI != MergableCIs.back()) {
  974. // TODO: Remove barrier if the merged parallel region includes the
  975. // 'nowait' clause.
  976. OMPInfoCache.OMPBuilder.createBarrier(
  977. InsertPointTy(NewCI->getParent(),
  978. NewCI->getNextNode()->getIterator()),
  979. OMPD_parallel);
  980. }
  981. CI->eraseFromParent();
  982. }
  983. assert(OutlinedFn != OriginalFn && "Outlining failed");
  984. CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
  985. CGUpdater.reanalyzeFunction(*OriginalFn);
  986. NumOpenMPParallelRegionsMerged += MergableCIs.size();
  987. return true;
  988. };
  989. // Helper function that identifes sequences of
  990. // __kmpc_fork_call uses in a basic block.
  991. auto DetectPRsCB = [&](Use &U, Function &F) {
  992. CallInst *CI = getCallIfRegularCall(U, &RFI);
  993. BB2PRMap[CI->getParent()].insert(CI);
  994. return false;
  995. };
  996. BB2PRMap.clear();
  997. RFI.foreachUse(SCC, DetectPRsCB);
  998. SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
  999. // Find mergable parallel regions within a basic block that are
  1000. // safe to merge, that is any in-between instructions can safely
  1001. // execute in parallel after merging.
  1002. // TODO: support merging across basic-blocks.
  1003. for (auto &It : BB2PRMap) {
  1004. auto &CIs = It.getSecond();
  1005. if (CIs.size() < 2)
  1006. continue;
  1007. BasicBlock *BB = It.getFirst();
  1008. SmallVector<CallInst *, 4> MergableCIs;
  1009. /// Returns true if the instruction is mergable, false otherwise.
  1010. /// A terminator instruction is unmergable by definition since merging
  1011. /// works within a BB. Instructions before the mergable region are
  1012. /// mergable if they are not calls to OpenMP runtime functions that may
  1013. /// set different execution parameters for subsequent parallel regions.
  1014. /// Instructions in-between parallel regions are mergable if they are not
  1015. /// calls to any non-intrinsic function since that may call a non-mergable
  1016. /// OpenMP runtime function.
  1017. auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
  1018. // We do not merge across BBs, hence return false (unmergable) if the
  1019. // instruction is a terminator.
  1020. if (I.isTerminator())
  1021. return false;
  1022. if (!isa<CallInst>(&I))
  1023. return true;
  1024. CallInst *CI = cast<CallInst>(&I);
  1025. if (IsBeforeMergableRegion) {
  1026. Function *CalledFunction = CI->getCalledFunction();
  1027. if (!CalledFunction)
  1028. return false;
  1029. // Return false (unmergable) if the call before the parallel
  1030. // region calls an explicit affinity (proc_bind) or number of
  1031. // threads (num_threads) compiler-generated function. Those settings
  1032. // may be incompatible with following parallel regions.
  1033. // TODO: ICV tracking to detect compatibility.
  1034. for (const auto &RFI : UnmergableCallsInfo) {
  1035. if (CalledFunction == RFI.Declaration)
  1036. return false;
  1037. }
  1038. } else {
  1039. // Return false (unmergable) if there is a call instruction
  1040. // in-between parallel regions when it is not an intrinsic. It
  1041. // may call an unmergable OpenMP runtime function in its callpath.
  1042. // TODO: Keep track of possible OpenMP calls in the callpath.
  1043. if (!isa<IntrinsicInst>(CI))
  1044. return false;
  1045. }
  1046. return true;
  1047. };
  1048. // Find maximal number of parallel region CIs that are safe to merge.
  1049. for (auto It = BB->begin(), End = BB->end(); It != End;) {
  1050. Instruction &I = *It;
  1051. ++It;
  1052. if (CIs.count(&I)) {
  1053. MergableCIs.push_back(cast<CallInst>(&I));
  1054. continue;
  1055. }
  1056. // Continue expanding if the instruction is mergable.
  1057. if (IsMergable(I, MergableCIs.empty()))
  1058. continue;
  1059. // Forward the instruction iterator to skip the next parallel region
  1060. // since there is an unmergable instruction which can affect it.
  1061. for (; It != End; ++It) {
  1062. Instruction &SkipI = *It;
  1063. if (CIs.count(&SkipI)) {
  1064. LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
  1065. << " due to " << I << "\n");
  1066. ++It;
  1067. break;
  1068. }
  1069. }
  1070. // Store mergable regions found.
  1071. if (MergableCIs.size() > 1) {
  1072. MergableCIsVector.push_back(MergableCIs);
  1073. LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
  1074. << " parallel regions in block " << BB->getName()
  1075. << " of function " << BB->getParent()->getName()
  1076. << "\n";);
  1077. }
  1078. MergableCIs.clear();
  1079. }
  1080. if (!MergableCIsVector.empty()) {
  1081. Changed = true;
  1082. for (auto &MergableCIs : MergableCIsVector)
  1083. Merge(MergableCIs, BB);
  1084. MergableCIsVector.clear();
  1085. }
  1086. }
  1087. if (Changed) {
  1088. /// Re-collect use for fork calls, emitted barrier calls, and
  1089. /// any emitted master/end_master calls.
  1090. OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
  1091. OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
  1092. OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
  1093. OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
  1094. }
  1095. return Changed;
  1096. }
  1097. /// Try to delete parallel regions if possible.
  1098. bool deleteParallelRegions() {
  1099. const unsigned CallbackCalleeOperand = 2;
  1100. OMPInformationCache::RuntimeFunctionInfo &RFI =
  1101. OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
  1102. if (!RFI.Declaration)
  1103. return false;
  1104. bool Changed = false;
  1105. auto DeleteCallCB = [&](Use &U, Function &) {
  1106. CallInst *CI = getCallIfRegularCall(U);
  1107. if (!CI)
  1108. return false;
  1109. auto *Fn = dyn_cast<Function>(
  1110. CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
  1111. if (!Fn)
  1112. return false;
  1113. if (!Fn->onlyReadsMemory())
  1114. return false;
  1115. if (!Fn->hasFnAttribute(Attribute::WillReturn))
  1116. return false;
  1117. LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
  1118. << CI->getCaller()->getName() << "\n");
  1119. auto Remark = [&](OptimizationRemark OR) {
  1120. return OR << "Removing parallel region with no side-effects.";
  1121. };
  1122. emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
  1123. CGUpdater.removeCallSite(*CI);
  1124. CI->eraseFromParent();
  1125. Changed = true;
  1126. ++NumOpenMPParallelRegionsDeleted;
  1127. return true;
  1128. };
  1129. RFI.foreachUse(SCC, DeleteCallCB);
  1130. return Changed;
  1131. }
  1132. /// Try to eliminate runtime calls by reusing existing ones.
  1133. bool deduplicateRuntimeCalls() {
  1134. bool Changed = false;
  1135. RuntimeFunction DeduplicableRuntimeCallIDs[] = {
  1136. OMPRTL_omp_get_num_threads,
  1137. OMPRTL_omp_in_parallel,
  1138. OMPRTL_omp_get_cancellation,
  1139. OMPRTL_omp_get_thread_limit,
  1140. OMPRTL_omp_get_supported_active_levels,
  1141. OMPRTL_omp_get_level,
  1142. OMPRTL_omp_get_ancestor_thread_num,
  1143. OMPRTL_omp_get_team_size,
  1144. OMPRTL_omp_get_active_level,
  1145. OMPRTL_omp_in_final,
  1146. OMPRTL_omp_get_proc_bind,
  1147. OMPRTL_omp_get_num_places,
  1148. OMPRTL_omp_get_num_procs,
  1149. OMPRTL_omp_get_place_num,
  1150. OMPRTL_omp_get_partition_num_places,
  1151. OMPRTL_omp_get_partition_place_nums};
  1152. // Global-tid is handled separately.
  1153. SmallSetVector<Value *, 16> GTIdArgs;
  1154. collectGlobalThreadIdArguments(GTIdArgs);
  1155. LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
  1156. << " global thread ID arguments\n");
  1157. for (Function *F : SCC) {
  1158. for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
  1159. Changed |= deduplicateRuntimeCalls(
  1160. *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
  1161. // __kmpc_global_thread_num is special as we can replace it with an
  1162. // argument in enough cases to make it worth trying.
  1163. Value *GTIdArg = nullptr;
  1164. for (Argument &Arg : F->args())
  1165. if (GTIdArgs.count(&Arg)) {
  1166. GTIdArg = &Arg;
  1167. break;
  1168. }
  1169. Changed |= deduplicateRuntimeCalls(
  1170. *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
  1171. }
  1172. return Changed;
  1173. }
  1174. /// Tries to hide the latency of runtime calls that involve host to
  1175. /// device memory transfers by splitting them into their "issue" and "wait"
  1176. /// versions. The "issue" is moved upwards as much as possible. The "wait" is
  1177. /// moved downards as much as possible. The "issue" issues the memory transfer
  1178. /// asynchronously, returning a handle. The "wait" waits in the returned
  1179. /// handle for the memory transfer to finish.
  1180. bool hideMemTransfersLatency() {
  1181. auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
  1182. bool Changed = false;
  1183. auto SplitMemTransfers = [&](Use &U, Function &Decl) {
  1184. auto *RTCall = getCallIfRegularCall(U, &RFI);
  1185. if (!RTCall)
  1186. return false;
  1187. OffloadArray OffloadArrays[3];
  1188. if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
  1189. return false;
  1190. LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
  1191. // TODO: Check if can be moved upwards.
  1192. bool WasSplit = false;
  1193. Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
  1194. if (WaitMovementPoint)
  1195. WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
  1196. Changed |= WasSplit;
  1197. return WasSplit;
  1198. };
  1199. if (OMPInfoCache.runtimeFnsAvailable(
  1200. {OMPRTL___tgt_target_data_begin_mapper_issue,
  1201. OMPRTL___tgt_target_data_begin_mapper_wait}))
  1202. RFI.foreachUse(SCC, SplitMemTransfers);
  1203. return Changed;
  1204. }
  1205. void analysisGlobalization() {
  1206. auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
  1207. auto CheckGlobalization = [&](Use &U, Function &Decl) {
  1208. if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
  1209. auto Remark = [&](OptimizationRemarkMissed ORM) {
  1210. return ORM
  1211. << "Found thread data sharing on the GPU. "
  1212. << "Expect degraded performance due to data globalization.";
  1213. };
  1214. emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
  1215. }
  1216. return false;
  1217. };
  1218. RFI.foreachUse(SCC, CheckGlobalization);
  1219. }
  1220. /// Maps the values stored in the offload arrays passed as arguments to
  1221. /// \p RuntimeCall into the offload arrays in \p OAs.
  1222. bool getValuesInOffloadArrays(CallInst &RuntimeCall,
  1223. MutableArrayRef<OffloadArray> OAs) {
  1224. assert(OAs.size() == 3 && "Need space for three offload arrays!");
  1225. // A runtime call that involves memory offloading looks something like:
  1226. // call void @__tgt_target_data_begin_mapper(arg0, arg1,
  1227. // i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
  1228. // ...)
  1229. // So, the idea is to access the allocas that allocate space for these
  1230. // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
  1231. // Therefore:
  1232. // i8** %offload_baseptrs.
  1233. Value *BasePtrsArg =
  1234. RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
  1235. // i8** %offload_ptrs.
  1236. Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
  1237. // i8** %offload_sizes.
  1238. Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
  1239. // Get values stored in **offload_baseptrs.
  1240. auto *V = getUnderlyingObject(BasePtrsArg);
  1241. if (!isa<AllocaInst>(V))
  1242. return false;
  1243. auto *BasePtrsArray = cast<AllocaInst>(V);
  1244. if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
  1245. return false;
  1246. // Get values stored in **offload_baseptrs.
  1247. V = getUnderlyingObject(PtrsArg);
  1248. if (!isa<AllocaInst>(V))
  1249. return false;
  1250. auto *PtrsArray = cast<AllocaInst>(V);
  1251. if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
  1252. return false;
  1253. // Get values stored in **offload_sizes.
  1254. V = getUnderlyingObject(SizesArg);
  1255. // If it's a [constant] global array don't analyze it.
  1256. if (isa<GlobalValue>(V))
  1257. return isa<Constant>(V);
  1258. if (!isa<AllocaInst>(V))
  1259. return false;
  1260. auto *SizesArray = cast<AllocaInst>(V);
  1261. if (!OAs[2].initialize(*SizesArray, RuntimeCall))
  1262. return false;
  1263. return true;
  1264. }
  1265. /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
  1266. /// For now this is a way to test that the function getValuesInOffloadArrays
  1267. /// is working properly.
  1268. /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
  1269. void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
  1270. assert(OAs.size() == 3 && "There are three offload arrays to debug!");
  1271. LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
  1272. std::string ValuesStr;
  1273. raw_string_ostream Printer(ValuesStr);
  1274. std::string Separator = " --- ";
  1275. for (auto *BP : OAs[0].StoredValues) {
  1276. BP->print(Printer);
  1277. Printer << Separator;
  1278. }
  1279. LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
  1280. ValuesStr.clear();
  1281. for (auto *P : OAs[1].StoredValues) {
  1282. P->print(Printer);
  1283. Printer << Separator;
  1284. }
  1285. LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
  1286. ValuesStr.clear();
  1287. for (auto *S : OAs[2].StoredValues) {
  1288. S->print(Printer);
  1289. Printer << Separator;
  1290. }
  1291. LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
  1292. }
  1293. /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
  1294. /// moved. Returns nullptr if the movement is not possible, or not worth it.
  1295. Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
  1296. // FIXME: This traverses only the BasicBlock where RuntimeCall is.
  1297. // Make it traverse the CFG.
  1298. Instruction *CurrentI = &RuntimeCall;
  1299. bool IsWorthIt = false;
  1300. while ((CurrentI = CurrentI->getNextNode())) {
  1301. // TODO: Once we detect the regions to be offloaded we should use the
  1302. // alias analysis manager to check if CurrentI may modify one of
  1303. // the offloaded regions.
  1304. if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
  1305. if (IsWorthIt)
  1306. return CurrentI;
  1307. return nullptr;
  1308. }
  1309. // FIXME: For now if we move it over anything without side effect
  1310. // is worth it.
  1311. IsWorthIt = true;
  1312. }
  1313. // Return end of BasicBlock.
  1314. return RuntimeCall.getParent()->getTerminator();
  1315. }
  1316. /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
  1317. bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
  1318. Instruction &WaitMovementPoint) {
  1319. // Create stack allocated handle (__tgt_async_info) at the beginning of the
  1320. // function. Used for storing information of the async transfer, allowing to
  1321. // wait on it later.
  1322. auto &IRBuilder = OMPInfoCache.OMPBuilder;
  1323. Function *F = RuntimeCall.getCaller();
  1324. BasicBlock &Entry = F->getEntryBlock();
  1325. IRBuilder.Builder.SetInsertPoint(&Entry,
  1326. Entry.getFirstNonPHIOrDbgOrAlloca());
  1327. Value *Handle = IRBuilder.Builder.CreateAlloca(
  1328. IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");
  1329. Handle =
  1330. IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);
  1331. // Add "issue" runtime call declaration:
  1332. // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
  1333. // i8**, i8**, i64*, i64*)
  1334. FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
  1335. M, OMPRTL___tgt_target_data_begin_mapper_issue);
  1336. // Change RuntimeCall call site for its asynchronous version.
  1337. SmallVector<Value *, 16> Args;
  1338. for (auto &Arg : RuntimeCall.args())
  1339. Args.push_back(Arg.get());
  1340. Args.push_back(Handle);
  1341. CallInst *IssueCallsite =
  1342. CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
  1343. OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);
  1344. RuntimeCall.eraseFromParent();
  1345. // Add "wait" runtime call declaration:
  1346. // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
  1347. FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
  1348. M, OMPRTL___tgt_target_data_begin_mapper_wait);
  1349. Value *WaitParams[2] = {
  1350. IssueCallsite->getArgOperand(
  1351. OffloadArray::DeviceIDArgNum), // device_id.
  1352. Handle // handle to wait on.
  1353. };
  1354. CallInst *WaitCallsite = CallInst::Create(
  1355. WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
  1356. OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);
  1357. return true;
  1358. }
  1359. static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
  1360. bool GlobalOnly, bool &SingleChoice) {
  1361. if (CurrentIdent == NextIdent)
  1362. return CurrentIdent;
  1363. // TODO: Figure out how to actually combine multiple debug locations. For
  1364. // now we just keep an existing one if there is a single choice.
  1365. if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
  1366. SingleChoice = !CurrentIdent;
  1367. return NextIdent;
  1368. }
  1369. return nullptr;
  1370. }
  1371. /// Return an `struct ident_t*` value that represents the ones used in the
  1372. /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
  1373. /// return a local `struct ident_t*`. For now, if we cannot find a suitable
  1374. /// return value we create one from scratch. We also do not yet combine
  1375. /// information, e.g., the source locations, see combinedIdentStruct.
  1376. Value *
  1377. getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
  1378. Function &F, bool GlobalOnly) {
  1379. bool SingleChoice = true;
  1380. Value *Ident = nullptr;
  1381. auto CombineIdentStruct = [&](Use &U, Function &Caller) {
  1382. CallInst *CI = getCallIfRegularCall(U, &RFI);
  1383. if (!CI || &F != &Caller)
  1384. return false;
  1385. Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
  1386. /* GlobalOnly */ true, SingleChoice);
  1387. return false;
  1388. };
  1389. RFI.foreachUse(SCC, CombineIdentStruct);
  1390. if (!Ident || !SingleChoice) {
  1391. // The IRBuilder uses the insertion block to get to the module, this is
  1392. // unfortunate but we work around it for now.
  1393. if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
  1394. OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
  1395. &F.getEntryBlock(), F.getEntryBlock().begin()));
  1396. // Create a fallback location if non was found.
  1397. // TODO: Use the debug locations of the calls instead.
  1398. uint32_t SrcLocStrSize;
  1399. Constant *Loc =
  1400. OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
  1401. Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);
  1402. }
  1403. return Ident;
  1404. }
  1405. /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
  1406. /// \p ReplVal if given.
  1407. bool deduplicateRuntimeCalls(Function &F,
  1408. OMPInformationCache::RuntimeFunctionInfo &RFI,
  1409. Value *ReplVal = nullptr) {
  1410. auto *UV = RFI.getUseVector(F);
  1411. if (!UV || UV->size() + (ReplVal != nullptr) < 2)
  1412. return false;
  1413. LLVM_DEBUG(
  1414. dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
  1415. << (ReplVal ? " with an existing value\n" : "\n") << "\n");
  1416. assert((!ReplVal || (isa<Argument>(ReplVal) &&
  1417. cast<Argument>(ReplVal)->getParent() == &F)) &&
  1418. "Unexpected replacement value!");
  1419. // TODO: Use dominance to find a good position instead.
  1420. auto CanBeMoved = [this](CallBase &CB) {
  1421. unsigned NumArgs = CB.arg_size();
  1422. if (NumArgs == 0)
  1423. return true;
  1424. if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
  1425. return false;
  1426. for (unsigned U = 1; U < NumArgs; ++U)
  1427. if (isa<Instruction>(CB.getArgOperand(U)))
  1428. return false;
  1429. return true;
  1430. };
  1431. if (!ReplVal) {
  1432. for (Use *U : *UV)
  1433. if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
  1434. if (!CanBeMoved(*CI))
  1435. continue;
  1436. // If the function is a kernel, dedup will move
  1437. // the runtime call right after the kernel init callsite. Otherwise,
  1438. // it will move it to the beginning of the caller function.
  1439. if (isKernel(F)) {
  1440. auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
  1441. auto *KernelInitUV = KernelInitRFI.getUseVector(F);
  1442. if (KernelInitUV->empty())
  1443. continue;
  1444. assert(KernelInitUV->size() == 1 &&
  1445. "Expected a single __kmpc_target_init in kernel\n");
  1446. CallInst *KernelInitCI =
  1447. getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
  1448. assert(KernelInitCI &&
  1449. "Expected a call to __kmpc_target_init in kernel\n");
  1450. CI->moveAfter(KernelInitCI);
  1451. } else
  1452. CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
  1453. ReplVal = CI;
  1454. break;
  1455. }
  1456. if (!ReplVal)
  1457. return false;
  1458. }
  1459. // If we use a call as a replacement value we need to make sure the ident is
  1460. // valid at the new location. For now we just pick a global one, either
  1461. // existing and used by one of the calls, or created from scratch.
  1462. if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
  1463. if (!CI->arg_empty() &&
  1464. CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
  1465. Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
  1466. /* GlobalOnly */ true);
  1467. CI->setArgOperand(0, Ident);
  1468. }
  1469. }
  1470. bool Changed = false;
  1471. auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
  1472. CallInst *CI = getCallIfRegularCall(U, &RFI);
  1473. if (!CI || CI == ReplVal || &F != &Caller)
  1474. return false;
  1475. assert(CI->getCaller() == &F && "Unexpected call!");
  1476. auto Remark = [&](OptimizationRemark OR) {
  1477. return OR << "OpenMP runtime call "
  1478. << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
  1479. };
  1480. if (CI->getDebugLoc())
  1481. emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
  1482. else
  1483. emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
  1484. CGUpdater.removeCallSite(*CI);
  1485. CI->replaceAllUsesWith(ReplVal);
  1486. CI->eraseFromParent();
  1487. ++NumOpenMPRuntimeCallsDeduplicated;
  1488. Changed = true;
  1489. return true;
  1490. };
  1491. RFI.foreachUse(SCC, ReplaceAndDeleteCB);
  1492. return Changed;
  1493. }
  1494. /// Collect arguments that represent the global thread id in \p GTIdArgs.
  1495. void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
  1496. // TODO: Below we basically perform a fixpoint iteration with a pessimistic
  1497. // initialization. We could define an AbstractAttribute instead and
  1498. // run the Attributor here once it can be run as an SCC pass.
  1499. // Helper to check the argument \p ArgNo at all call sites of \p F for
  1500. // a GTId.
  1501. auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
  1502. if (!F.hasLocalLinkage())
  1503. return false;
  1504. for (Use &U : F.uses()) {
  1505. if (CallInst *CI = getCallIfRegularCall(U)) {
  1506. Value *ArgOp = CI->getArgOperand(ArgNo);
  1507. if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
  1508. getCallIfRegularCall(
  1509. *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
  1510. continue;
  1511. }
  1512. return false;
  1513. }
  1514. return true;
  1515. };
  1516. // Helper to identify uses of a GTId as GTId arguments.
  1517. auto AddUserArgs = [&](Value &GTId) {
  1518. for (Use &U : GTId.uses())
  1519. if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
  1520. if (CI->isArgOperand(&U))
  1521. if (Function *Callee = CI->getCalledFunction())
  1522. if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
  1523. GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
  1524. };
  1525. // The argument users of __kmpc_global_thread_num calls are GTIds.
  1526. OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
  1527. OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
  1528. GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
  1529. if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
  1530. AddUserArgs(*CI);
  1531. return false;
  1532. });
  1533. // Transitively search for more arguments by looking at the users of the
  1534. // ones we know already. During the search the GTIdArgs vector is extended
  1535. // so we cannot cache the size nor can we use a range based for.
  1536. for (unsigned U = 0; U < GTIdArgs.size(); ++U)
  1537. AddUserArgs(*GTIdArgs[U]);
  1538. }
  1539. /// Kernel (=GPU) optimizations and utility functions
  1540. ///
  1541. ///{{
  1542. /// Check if \p F is a kernel, hence entry point for target offloading.
  1543. bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
  1544. /// Cache to remember the unique kernel for a function.
  1545. DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;
  1546. /// Find the unique kernel that will execute \p F, if any.
  1547. Kernel getUniqueKernelFor(Function &F);
  1548. /// Find the unique kernel that will execute \p I, if any.
  1549. Kernel getUniqueKernelFor(Instruction &I) {
  1550. return getUniqueKernelFor(*I.getFunction());
  1551. }
  1552. /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
  1553. /// the cases we can avoid taking the address of a function.
  1554. bool rewriteDeviceCodeStateMachine();
  1555. ///
  1556. ///}}
  1557. /// Emit a remark generically
  1558. ///
  1559. /// This template function can be used to generically emit a remark. The
  1560. /// RemarkKind should be one of the following:
  1561. /// - OptimizationRemark to indicate a successful optimization attempt
  1562. /// - OptimizationRemarkMissed to report a failed optimization attempt
  1563. /// - OptimizationRemarkAnalysis to provide additional information about an
  1564. /// optimization attempt
  1565. ///
  1566. /// The remark is built using a callback function provided by the caller that
  1567. /// takes a RemarkKind as input and returns a RemarkKind.
  1568. template <typename RemarkKind, typename RemarkCallBack>
  1569. void emitRemark(Instruction *I, StringRef RemarkName,
  1570. RemarkCallBack &&RemarkCB) const {
  1571. Function *F = I->getParent()->getParent();
  1572. auto &ORE = OREGetter(F);
  1573. if (RemarkName.startswith("OMP"))
  1574. ORE.emit([&]() {
  1575. return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
  1576. << " [" << RemarkName << "]";
  1577. });
  1578. else
  1579. ORE.emit(
  1580. [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
  1581. }
  1582. /// Emit a remark on a function.
  1583. template <typename RemarkKind, typename RemarkCallBack>
  1584. void emitRemark(Function *F, StringRef RemarkName,
  1585. RemarkCallBack &&RemarkCB) const {
  1586. auto &ORE = OREGetter(F);
  1587. if (RemarkName.startswith("OMP"))
  1588. ORE.emit([&]() {
  1589. return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
  1590. << " [" << RemarkName << "]";
  1591. });
  1592. else
  1593. ORE.emit(
  1594. [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
  1595. }
  1596. /// The underlying module.
  1597. Module &M;
  1598. /// The SCC we are operating on.
  1599. SmallVectorImpl<Function *> &SCC;
  1600. /// Callback to update the call graph, the first argument is a removed call,
  1601. /// the second an optional replacement call.
  1602. CallGraphUpdater &CGUpdater;
  1603. /// Callback to get an OptimizationRemarkEmitter from a Function *
  1604. OptimizationRemarkGetter OREGetter;
  1605. /// OpenMP-specific information cache. Also Used for Attributor runs.
  1606. OMPInformationCache &OMPInfoCache;
  1607. /// Attributor instance.
  1608. Attributor &A;
  1609. /// Helper function to run Attributor on SCC.
  1610. bool runAttributor(bool IsModulePass) {
  1611. if (SCC.empty())
  1612. return false;
  1613. registerAAs(IsModulePass);
  1614. ChangeStatus Changed = A.run();
  1615. LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
  1616. << " functions, result: " << Changed << ".\n");
  1617. return Changed == ChangeStatus::CHANGED;
  1618. }
  1619. void registerFoldRuntimeCall(RuntimeFunction RF);
  1620. /// Populate the Attributor with abstract attribute opportunities in the
  1621. /// functions.
  1622. void registerAAs(bool IsModulePass);
  1623. public:
  1624. /// Callback to register AAs for live functions, including internal functions
  1625. /// marked live during the traversal.
  1626. static void registerAAsForFunction(Attributor &A, const Function &F);
  1627. };
  1628. Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
  1629. if (!OMPInfoCache.ModuleSlice.empty() && !OMPInfoCache.ModuleSlice.count(&F))
  1630. return nullptr;
  1631. // Use a scope to keep the lifetime of the CachedKernel short.
  1632. {
  1633. std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
  1634. if (CachedKernel)
  1635. return *CachedKernel;
  1636. // TODO: We should use an AA to create an (optimistic and callback
  1637. // call-aware) call graph. For now we stick to simple patterns that
  1638. // are less powerful, basically the worst fixpoint.
  1639. if (isKernel(F)) {
  1640. CachedKernel = Kernel(&F);
  1641. return *CachedKernel;
  1642. }
  1643. CachedKernel = nullptr;
  1644. if (!F.hasLocalLinkage()) {
  1645. // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
  1646. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  1647. return ORA << "Potentially unknown OpenMP target region caller.";
  1648. };
  1649. emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
  1650. return nullptr;
  1651. }
  1652. }
  1653. auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
  1654. if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
  1655. // Allow use in equality comparisons.
  1656. if (Cmp->isEquality())
  1657. return getUniqueKernelFor(*Cmp);
  1658. return nullptr;
  1659. }
  1660. if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
  1661. // Allow direct calls.
  1662. if (CB->isCallee(&U))
  1663. return getUniqueKernelFor(*CB);
  1664. OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
  1665. OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
  1666. // Allow the use in __kmpc_parallel_51 calls.
  1667. if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
  1668. return getUniqueKernelFor(*CB);
  1669. return nullptr;
  1670. }
  1671. // Disallow every other use.
  1672. return nullptr;
  1673. };
  1674. // TODO: In the future we want to track more than just a unique kernel.
  1675. SmallPtrSet<Kernel, 2> PotentialKernels;
  1676. OMPInformationCache::foreachUse(F, [&](const Use &U) {
  1677. PotentialKernels.insert(GetUniqueKernelForUse(U));
  1678. });
  1679. Kernel K = nullptr;
  1680. if (PotentialKernels.size() == 1)
  1681. K = *PotentialKernels.begin();
  1682. // Cache the result.
  1683. UniqueKernelMap[&F] = K;
  1684. return K;
  1685. }
  1686. bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
  1687. OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
  1688. OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
  1689. bool Changed = false;
  1690. if (!KernelParallelRFI)
  1691. return Changed;
  1692. // If we have disabled state machine changes, exit
  1693. if (DisableOpenMPOptStateMachineRewrite)
  1694. return Changed;
  1695. for (Function *F : SCC) {
  1696. // Check if the function is a use in a __kmpc_parallel_51 call at
  1697. // all.
  1698. bool UnknownUse = false;
  1699. bool KernelParallelUse = false;
  1700. unsigned NumDirectCalls = 0;
  1701. SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
  1702. OMPInformationCache::foreachUse(*F, [&](Use &U) {
  1703. if (auto *CB = dyn_cast<CallBase>(U.getUser()))
  1704. if (CB->isCallee(&U)) {
  1705. ++NumDirectCalls;
  1706. return;
  1707. }
  1708. if (isa<ICmpInst>(U.getUser())) {
  1709. ToBeReplacedStateMachineUses.push_back(&U);
  1710. return;
  1711. }
  1712. // Find wrapper functions that represent parallel kernels.
  1713. CallInst *CI =
  1714. OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
  1715. const unsigned int WrapperFunctionArgNo = 6;
  1716. if (!KernelParallelUse && CI &&
  1717. CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
  1718. KernelParallelUse = true;
  1719. ToBeReplacedStateMachineUses.push_back(&U);
  1720. return;
  1721. }
  1722. UnknownUse = true;
  1723. });
  1724. // Do not emit a remark if we haven't seen a __kmpc_parallel_51
  1725. // use.
  1726. if (!KernelParallelUse)
  1727. continue;
  1728. // If this ever hits, we should investigate.
  1729. // TODO: Checking the number of uses is not a necessary restriction and
  1730. // should be lifted.
  1731. if (UnknownUse || NumDirectCalls != 1 ||
  1732. ToBeReplacedStateMachineUses.size() > 2) {
  1733. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  1734. return ORA << "Parallel region is used in "
  1735. << (UnknownUse ? "unknown" : "unexpected")
  1736. << " ways. Will not attempt to rewrite the state machine.";
  1737. };
  1738. emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
  1739. continue;
  1740. }
  1741. // Even if we have __kmpc_parallel_51 calls, we (for now) give
  1742. // up if the function is not called from a unique kernel.
  1743. Kernel K = getUniqueKernelFor(*F);
  1744. if (!K) {
  1745. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  1746. return ORA << "Parallel region is not called from a unique kernel. "
  1747. "Will not attempt to rewrite the state machine.";
  1748. };
  1749. emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
  1750. continue;
  1751. }
  1752. // We now know F is a parallel body function called only from the kernel K.
  1753. // We also identified the state machine uses in which we replace the
  1754. // function pointer by a new global symbol for identification purposes. This
  1755. // ensures only direct calls to the function are left.
  1756. Module &M = *F->getParent();
  1757. Type *Int8Ty = Type::getInt8Ty(M.getContext());
  1758. auto *ID = new GlobalVariable(
  1759. M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
  1760. UndefValue::get(Int8Ty), F->getName() + ".ID");
  1761. for (Use *U : ToBeReplacedStateMachineUses)
  1762. U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
  1763. ID, U->get()->getType()));
  1764. ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
  1765. Changed = true;
  1766. }
  1767. return Changed;
  1768. }
  1769. /// Abstract Attribute for tracking ICV values.
  1770. struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
  1771. using Base = StateWrapper<BooleanState, AbstractAttribute>;
  1772. AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
  1773. void initialize(Attributor &A) override {
  1774. Function *F = getAnchorScope();
  1775. if (!F || !A.isFunctionIPOAmendable(*F))
  1776. indicatePessimisticFixpoint();
  1777. }
  1778. /// Returns true if value is assumed to be tracked.
  1779. bool isAssumedTracked() const { return getAssumed(); }
  1780. /// Returns true if value is known to be tracked.
  1781. bool isKnownTracked() const { return getAssumed(); }
  1782. /// Create an abstract attribute biew for the position \p IRP.
  1783. static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
  1784. /// Return the value with which \p I can be replaced for specific \p ICV.
  1785. virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,
  1786. const Instruction *I,
  1787. Attributor &A) const {
  1788. return std::nullopt;
  1789. }
  1790. /// Return an assumed unique ICV value if a single candidate is found. If
  1791. /// there cannot be one, return a nullptr. If it is not clear yet, return
  1792. /// std::nullopt.
  1793. virtual std::optional<Value *>
  1794. getUniqueReplacementValue(InternalControlVar ICV) const = 0;
  1795. // Currently only nthreads is being tracked.
  1796. // this array will only grow with time.
  1797. InternalControlVar TrackableICVs[1] = {ICV_nthreads};
  1798. /// See AbstractAttribute::getName()
  1799. const std::string getName() const override { return "AAICVTracker"; }
  1800. /// See AbstractAttribute::getIdAddr()
  1801. const char *getIdAddr() const override { return &ID; }
  1802. /// This function should return true if the type of the \p AA is AAICVTracker
  1803. static bool classof(const AbstractAttribute *AA) {
  1804. return (AA->getIdAddr() == &ID);
  1805. }
  1806. static const char ID;
  1807. };
  1808. struct AAICVTrackerFunction : public AAICVTracker {
  1809. AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
  1810. : AAICVTracker(IRP, A) {}
  1811. // FIXME: come up with better string.
  1812. const std::string getAsStr() const override { return "ICVTrackerFunction"; }
  1813. // FIXME: come up with some stats.
  1814. void trackStatistics() const override {}
  1815. /// We don't manifest anything for this AA.
  1816. ChangeStatus manifest(Attributor &A) override {
  1817. return ChangeStatus::UNCHANGED;
  1818. }
  1819. // Map of ICV to their values at specific program point.
  1820. EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
  1821. InternalControlVar::ICV___last>
  1822. ICVReplacementValuesMap;
  1823. ChangeStatus updateImpl(Attributor &A) override {
  1824. ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
  1825. Function *F = getAnchorScope();
  1826. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  1827. for (InternalControlVar ICV : TrackableICVs) {
  1828. auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
  1829. auto &ValuesMap = ICVReplacementValuesMap[ICV];
  1830. auto TrackValues = [&](Use &U, Function &) {
  1831. CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
  1832. if (!CI)
  1833. return false;
  1834. // FIXME: handle setters with more that 1 arguments.
  1835. /// Track new value.
  1836. if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
  1837. HasChanged = ChangeStatus::CHANGED;
  1838. return false;
  1839. };
  1840. auto CallCheck = [&](Instruction &I) {
  1841. std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);
  1842. if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
  1843. HasChanged = ChangeStatus::CHANGED;
  1844. return true;
  1845. };
  1846. // Track all changes of an ICV.
  1847. SetterRFI.foreachUse(TrackValues, F);
  1848. bool UsedAssumedInformation = false;
  1849. A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
  1850. UsedAssumedInformation,
  1851. /* CheckBBLivenessOnly */ true);
  1852. /// TODO: Figure out a way to avoid adding entry in
  1853. /// ICVReplacementValuesMap
  1854. Instruction *Entry = &F->getEntryBlock().front();
  1855. if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
  1856. ValuesMap.insert(std::make_pair(Entry, nullptr));
  1857. }
  1858. return HasChanged;
  1859. }
  1860. /// Helper to check if \p I is a call and get the value for it if it is
  1861. /// unique.
  1862. std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,
  1863. InternalControlVar &ICV) const {
  1864. const auto *CB = dyn_cast<CallBase>(&I);
  1865. if (!CB || CB->hasFnAttr("no_openmp") ||
  1866. CB->hasFnAttr("no_openmp_routines"))
  1867. return std::nullopt;
  1868. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  1869. auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
  1870. auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
  1871. Function *CalledFunction = CB->getCalledFunction();
  1872. // Indirect call, assume ICV changes.
  1873. if (CalledFunction == nullptr)
  1874. return nullptr;
  1875. if (CalledFunction == GetterRFI.Declaration)
  1876. return std::nullopt;
  1877. if (CalledFunction == SetterRFI.Declaration) {
  1878. if (ICVReplacementValuesMap[ICV].count(&I))
  1879. return ICVReplacementValuesMap[ICV].lookup(&I);
  1880. return nullptr;
  1881. }
  1882. // Since we don't know, assume it changes the ICV.
  1883. if (CalledFunction->isDeclaration())
  1884. return nullptr;
  1885. const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
  1886. *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
  1887. if (ICVTrackingAA.isAssumedTracked()) {
  1888. std::optional<Value *> URV = ICVTrackingAA.getUniqueReplacementValue(ICV);
  1889. if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),
  1890. OMPInfoCache)))
  1891. return URV;
  1892. }
  1893. // If we don't know, assume it changes.
  1894. return nullptr;
  1895. }
  1896. // We don't check unique value for a function, so return std::nullopt.
  1897. std::optional<Value *>
  1898. getUniqueReplacementValue(InternalControlVar ICV) const override {
  1899. return std::nullopt;
  1900. }
  1901. /// Return the value with which \p I can be replaced for specific \p ICV.
  1902. std::optional<Value *> getReplacementValue(InternalControlVar ICV,
  1903. const Instruction *I,
  1904. Attributor &A) const override {
  1905. const auto &ValuesMap = ICVReplacementValuesMap[ICV];
  1906. if (ValuesMap.count(I))
  1907. return ValuesMap.lookup(I);
  1908. SmallVector<const Instruction *, 16> Worklist;
  1909. SmallPtrSet<const Instruction *, 16> Visited;
  1910. Worklist.push_back(I);
  1911. std::optional<Value *> ReplVal;
  1912. while (!Worklist.empty()) {
  1913. const Instruction *CurrInst = Worklist.pop_back_val();
  1914. if (!Visited.insert(CurrInst).second)
  1915. continue;
  1916. const BasicBlock *CurrBB = CurrInst->getParent();
  1917. // Go up and look for all potential setters/calls that might change the
  1918. // ICV.
  1919. while ((CurrInst = CurrInst->getPrevNode())) {
  1920. if (ValuesMap.count(CurrInst)) {
  1921. std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
  1922. // Unknown value, track new.
  1923. if (!ReplVal) {
  1924. ReplVal = NewReplVal;
  1925. break;
  1926. }
  1927. // If we found a new value, we can't know the icv value anymore.
  1928. if (NewReplVal)
  1929. if (ReplVal != NewReplVal)
  1930. return nullptr;
  1931. break;
  1932. }
  1933. std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);
  1934. if (!NewReplVal)
  1935. continue;
  1936. // Unknown value, track new.
  1937. if (!ReplVal) {
  1938. ReplVal = NewReplVal;
  1939. break;
  1940. }
  1941. // if (NewReplVal.hasValue())
  1942. // We found a new value, we can't know the icv value anymore.
  1943. if (ReplVal != NewReplVal)
  1944. return nullptr;
  1945. }
  1946. // If we are in the same BB and we have a value, we are done.
  1947. if (CurrBB == I->getParent() && ReplVal)
  1948. return ReplVal;
  1949. // Go through all predecessors and add terminators for analysis.
  1950. for (const BasicBlock *Pred : predecessors(CurrBB))
  1951. if (const Instruction *Terminator = Pred->getTerminator())
  1952. Worklist.push_back(Terminator);
  1953. }
  1954. return ReplVal;
  1955. }
  1956. };
  1957. struct AAICVTrackerFunctionReturned : AAICVTracker {
  1958. AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
  1959. : AAICVTracker(IRP, A) {}
  1960. // FIXME: come up with better string.
  1961. const std::string getAsStr() const override {
  1962. return "ICVTrackerFunctionReturned";
  1963. }
  1964. // FIXME: come up with some stats.
  1965. void trackStatistics() const override {}
  1966. /// We don't manifest anything for this AA.
  1967. ChangeStatus manifest(Attributor &A) override {
  1968. return ChangeStatus::UNCHANGED;
  1969. }
  1970. // Map of ICV to their values at specific program point.
  1971. EnumeratedArray<std::optional<Value *>, InternalControlVar,
  1972. InternalControlVar::ICV___last>
  1973. ICVReplacementValuesMap;
  1974. /// Return the value with which \p I can be replaced for specific \p ICV.
  1975. std::optional<Value *>
  1976. getUniqueReplacementValue(InternalControlVar ICV) const override {
  1977. return ICVReplacementValuesMap[ICV];
  1978. }
  1979. ChangeStatus updateImpl(Attributor &A) override {
  1980. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  1981. const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
  1982. *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
  1983. if (!ICVTrackingAA.isAssumedTracked())
  1984. return indicatePessimisticFixpoint();
  1985. for (InternalControlVar ICV : TrackableICVs) {
  1986. std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
  1987. std::optional<Value *> UniqueICVValue;
  1988. auto CheckReturnInst = [&](Instruction &I) {
  1989. std::optional<Value *> NewReplVal =
  1990. ICVTrackingAA.getReplacementValue(ICV, &I, A);
  1991. // If we found a second ICV value there is no unique returned value.
  1992. if (UniqueICVValue && UniqueICVValue != NewReplVal)
  1993. return false;
  1994. UniqueICVValue = NewReplVal;
  1995. return true;
  1996. };
  1997. bool UsedAssumedInformation = false;
  1998. if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
  1999. UsedAssumedInformation,
  2000. /* CheckBBLivenessOnly */ true))
  2001. UniqueICVValue = nullptr;
  2002. if (UniqueICVValue == ReplVal)
  2003. continue;
  2004. ReplVal = UniqueICVValue;
  2005. Changed = ChangeStatus::CHANGED;
  2006. }
  2007. return Changed;
  2008. }
  2009. };
  2010. struct AAICVTrackerCallSite : AAICVTracker {
  2011. AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
  2012. : AAICVTracker(IRP, A) {}
  2013. void initialize(Attributor &A) override {
  2014. Function *F = getAnchorScope();
  2015. if (!F || !A.isFunctionIPOAmendable(*F))
  2016. indicatePessimisticFixpoint();
  2017. // We only initialize this AA for getters, so we need to know which ICV it
  2018. // gets.
  2019. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2020. for (InternalControlVar ICV : TrackableICVs) {
  2021. auto ICVInfo = OMPInfoCache.ICVs[ICV];
  2022. auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
  2023. if (Getter.Declaration == getAssociatedFunction()) {
  2024. AssociatedICV = ICVInfo.Kind;
  2025. return;
  2026. }
  2027. }
  2028. /// Unknown ICV.
  2029. indicatePessimisticFixpoint();
  2030. }
  2031. ChangeStatus manifest(Attributor &A) override {
  2032. if (!ReplVal || !*ReplVal)
  2033. return ChangeStatus::UNCHANGED;
  2034. A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);
  2035. A.deleteAfterManifest(*getCtxI());
  2036. return ChangeStatus::CHANGED;
  2037. }
  2038. // FIXME: come up with better string.
  2039. const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
  2040. // FIXME: come up with some stats.
  2041. void trackStatistics() const override {}
  2042. InternalControlVar AssociatedICV;
  2043. std::optional<Value *> ReplVal;
  2044. ChangeStatus updateImpl(Attributor &A) override {
  2045. const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
  2046. *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
  2047. // We don't have any information, so we assume it changes the ICV.
  2048. if (!ICVTrackingAA.isAssumedTracked())
  2049. return indicatePessimisticFixpoint();
  2050. std::optional<Value *> NewReplVal =
  2051. ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
  2052. if (ReplVal == NewReplVal)
  2053. return ChangeStatus::UNCHANGED;
  2054. ReplVal = NewReplVal;
  2055. return ChangeStatus::CHANGED;
  2056. }
  2057. // Return the value with which associated value can be replaced for specific
  2058. // \p ICV.
  2059. std::optional<Value *>
  2060. getUniqueReplacementValue(InternalControlVar ICV) const override {
  2061. return ReplVal;
  2062. }
  2063. };
  2064. struct AAICVTrackerCallSiteReturned : AAICVTracker {
  2065. AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
  2066. : AAICVTracker(IRP, A) {}
  2067. // FIXME: come up with better string.
  2068. const std::string getAsStr() const override {
  2069. return "ICVTrackerCallSiteReturned";
  2070. }
  2071. // FIXME: come up with some stats.
  2072. void trackStatistics() const override {}
  2073. /// We don't manifest anything for this AA.
  2074. ChangeStatus manifest(Attributor &A) override {
  2075. return ChangeStatus::UNCHANGED;
  2076. }
  2077. // Map of ICV to their values at specific program point.
  2078. EnumeratedArray<std::optional<Value *>, InternalControlVar,
  2079. InternalControlVar::ICV___last>
  2080. ICVReplacementValuesMap;
  2081. /// Return the value with which associated value can be replaced for specific
  2082. /// \p ICV.
  2083. std::optional<Value *>
  2084. getUniqueReplacementValue(InternalControlVar ICV) const override {
  2085. return ICVReplacementValuesMap[ICV];
  2086. }
  2087. ChangeStatus updateImpl(Attributor &A) override {
  2088. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  2089. const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
  2090. *this, IRPosition::returned(*getAssociatedFunction()),
  2091. DepClassTy::REQUIRED);
  2092. // We don't have any information, so we assume it changes the ICV.
  2093. if (!ICVTrackingAA.isAssumedTracked())
  2094. return indicatePessimisticFixpoint();
  2095. for (InternalControlVar ICV : TrackableICVs) {
  2096. std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
  2097. std::optional<Value *> NewReplVal =
  2098. ICVTrackingAA.getUniqueReplacementValue(ICV);
  2099. if (ReplVal == NewReplVal)
  2100. continue;
  2101. ReplVal = NewReplVal;
  2102. Changed = ChangeStatus::CHANGED;
  2103. }
  2104. return Changed;
  2105. }
  2106. };
  2107. struct AAExecutionDomainFunction : public AAExecutionDomain {
  2108. AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
  2109. : AAExecutionDomain(IRP, A) {}
  2110. ~AAExecutionDomainFunction() {
  2111. delete RPOT;
  2112. }
  2113. void initialize(Attributor &A) override {
  2114. if (getAnchorScope()->isDeclaration()) {
  2115. indicatePessimisticFixpoint();
  2116. return;
  2117. }
  2118. RPOT = new ReversePostOrderTraversal<Function *>(getAnchorScope());
  2119. }
  2120. const std::string getAsStr() const override {
  2121. unsigned TotalBlocks = 0, InitialThreadBlocks = 0;
  2122. for (auto &It : BEDMap) {
  2123. TotalBlocks++;
  2124. InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;
  2125. }
  2126. return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +
  2127. std::to_string(TotalBlocks) + " executed by initial thread only";
  2128. }
  2129. /// See AbstractAttribute::trackStatistics().
  2130. void trackStatistics() const override {}
  2131. ChangeStatus manifest(Attributor &A) override {
  2132. LLVM_DEBUG({
  2133. for (const BasicBlock &BB : *getAnchorScope()) {
  2134. if (!isExecutedByInitialThreadOnly(BB))
  2135. continue;
  2136. dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
  2137. << BB.getName() << " is executed by a single thread.\n";
  2138. }
  2139. });
  2140. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  2141. if (DisableOpenMPOptBarrierElimination)
  2142. return Changed;
  2143. SmallPtrSet<CallBase *, 16> DeletedBarriers;
  2144. auto HandleAlignedBarrier = [&](CallBase *CB) {
  2145. const ExecutionDomainTy &ED = CEDMap[CB];
  2146. if (!ED.IsReachedFromAlignedBarrierOnly ||
  2147. ED.EncounteredNonLocalSideEffect)
  2148. return;
  2149. // We can remove this barrier, if it is one, or all aligned barriers
  2150. // reaching the kernel end. In the latter case we can transitively work
  2151. // our way back until we find a barrier that guards a side-effect if we
  2152. // are dealing with the kernel end here.
  2153. if (CB) {
  2154. DeletedBarriers.insert(CB);
  2155. A.deleteAfterManifest(*CB);
  2156. ++NumBarriersEliminated;
  2157. Changed = ChangeStatus::CHANGED;
  2158. } else if (!ED.AlignedBarriers.empty()) {
  2159. NumBarriersEliminated += ED.AlignedBarriers.size();
  2160. Changed = ChangeStatus::CHANGED;
  2161. SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),
  2162. ED.AlignedBarriers.end());
  2163. SmallSetVector<CallBase *, 16> Visited;
  2164. while (!Worklist.empty()) {
  2165. CallBase *LastCB = Worklist.pop_back_val();
  2166. if (!Visited.insert(LastCB))
  2167. continue;
  2168. if (!DeletedBarriers.count(LastCB)) {
  2169. A.deleteAfterManifest(*LastCB);
  2170. continue;
  2171. }
  2172. // The final aligned barrier (LastCB) reaching the kernel end was
  2173. // removed already. This means we can go one step further and remove
  2174. // the barriers encoutered last before (LastCB).
  2175. const ExecutionDomainTy &LastED = CEDMap[LastCB];
  2176. Worklist.append(LastED.AlignedBarriers.begin(),
  2177. LastED.AlignedBarriers.end());
  2178. }
  2179. }
  2180. // If we actually eliminated a barrier we need to eliminate the associated
  2181. // llvm.assumes as well to avoid creating UB.
  2182. if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))
  2183. for (auto *AssumeCB : ED.EncounteredAssumes)
  2184. A.deleteAfterManifest(*AssumeCB);
  2185. };
  2186. for (auto *CB : AlignedBarriers)
  2187. HandleAlignedBarrier(CB);
  2188. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2189. // Handle the "kernel end barrier" for kernels too.
  2190. if (OMPInfoCache.Kernels.count(getAnchorScope()))
  2191. HandleAlignedBarrier(nullptr);
  2192. return Changed;
  2193. }
  2194. /// Merge barrier and assumption information from \p PredED into the successor
  2195. /// \p ED.
  2196. void
  2197. mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,
  2198. const ExecutionDomainTy &PredED);
  2199. /// Merge all information from \p PredED into the successor \p ED. If
  2200. /// \p InitialEdgeOnly is set, only the initial edge will enter the block
  2201. /// represented by \p ED from this predecessor.
  2202. void mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,
  2203. const ExecutionDomainTy &PredED,
  2204. bool InitialEdgeOnly = false);
  2205. /// Accumulate information for the entry block in \p EntryBBED.
  2206. void handleEntryBB(Attributor &A, ExecutionDomainTy &EntryBBED);
  2207. /// See AbstractAttribute::updateImpl.
  2208. ChangeStatus updateImpl(Attributor &A) override;
  2209. /// Query interface, see AAExecutionDomain
  2210. ///{
  2211. bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
  2212. if (!isValidState())
  2213. return false;
  2214. return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;
  2215. }
  2216. bool isExecutedInAlignedRegion(Attributor &A,
  2217. const Instruction &I) const override {
  2218. assert(I.getFunction() == getAnchorScope() &&
  2219. "Instruction is out of scope!");
  2220. if (!isValidState())
  2221. return false;
  2222. const Instruction *CurI;
  2223. // Check forward until a call or the block end is reached.
  2224. CurI = &I;
  2225. do {
  2226. auto *CB = dyn_cast<CallBase>(CurI);
  2227. if (!CB)
  2228. continue;
  2229. if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
  2230. break;
  2231. }
  2232. const auto &It = CEDMap.find(CB);
  2233. if (It == CEDMap.end())
  2234. continue;
  2235. if (!It->getSecond().IsReachingAlignedBarrierOnly)
  2236. return false;
  2237. break;
  2238. } while ((CurI = CurI->getNextNonDebugInstruction()));
  2239. if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)
  2240. return false;
  2241. // Check backward until a call or the block beginning is reached.
  2242. CurI = &I;
  2243. do {
  2244. auto *CB = dyn_cast<CallBase>(CurI);
  2245. if (!CB)
  2246. continue;
  2247. if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB))) {
  2248. break;
  2249. }
  2250. const auto &It = CEDMap.find(CB);
  2251. if (It == CEDMap.end())
  2252. continue;
  2253. if (!AA::isNoSyncInst(A, *CB, *this)) {
  2254. if (It->getSecond().IsReachedFromAlignedBarrierOnly) {
  2255. break;
  2256. }
  2257. return false;
  2258. }
  2259. Function *Callee = CB->getCalledFunction();
  2260. if (!Callee || Callee->isDeclaration())
  2261. return false;
  2262. const auto &EDAA = A.getAAFor<AAExecutionDomain>(
  2263. *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
  2264. if (!EDAA.getState().isValidState())
  2265. return false;
  2266. if (!EDAA.getFunctionExecutionDomain().IsReachedFromAlignedBarrierOnly)
  2267. return false;
  2268. break;
  2269. } while ((CurI = CurI->getPrevNonDebugInstruction()));
  2270. if (!CurI &&
  2271. !llvm::all_of(
  2272. predecessors(I.getParent()), [&](const BasicBlock *PredBB) {
  2273. return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;
  2274. })) {
  2275. return false;
  2276. }
  2277. // On neither traversal we found a anything but aligned barriers.
  2278. return true;
  2279. }
  2280. ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {
  2281. assert(isValidState() &&
  2282. "No request should be made against an invalid state!");
  2283. return BEDMap.lookup(&BB);
  2284. }
  2285. ExecutionDomainTy getExecutionDomain(const CallBase &CB) const override {
  2286. assert(isValidState() &&
  2287. "No request should be made against an invalid state!");
  2288. return CEDMap.lookup(&CB);
  2289. }
  2290. ExecutionDomainTy getFunctionExecutionDomain() const override {
  2291. assert(isValidState() &&
  2292. "No request should be made against an invalid state!");
  2293. return BEDMap.lookup(nullptr);
  2294. }
  2295. ///}
  2296. // Check if the edge into the successor block contains a condition that only
  2297. // lets the main thread execute it.
  2298. static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,
  2299. BasicBlock &SuccessorBB) {
  2300. if (!Edge || !Edge->isConditional())
  2301. return false;
  2302. if (Edge->getSuccessor(0) != &SuccessorBB)
  2303. return false;
  2304. auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
  2305. if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
  2306. return false;
  2307. ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
  2308. if (!C)
  2309. return false;
  2310. // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
  2311. if (C->isAllOnesValue()) {
  2312. auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
  2313. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2314. auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
  2315. CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
  2316. if (!CB)
  2317. return false;
  2318. const int InitModeArgNo = 1;
  2319. auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
  2320. return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
  2321. }
  2322. if (C->isZero()) {
  2323. // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
  2324. if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
  2325. if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
  2326. return true;
  2327. // Match: 0 == llvm.amdgcn.workitem.id.x()
  2328. if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
  2329. if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
  2330. return true;
  2331. }
  2332. return false;
  2333. };
  2334. /// Mapping containing information per block.
  2335. DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;
  2336. DenseMap<const CallBase *, ExecutionDomainTy> CEDMap;
  2337. SmallSetVector<CallBase *, 16> AlignedBarriers;
  2338. ReversePostOrderTraversal<Function *> *RPOT = nullptr;
  2339. };
  2340. void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(
  2341. Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {
  2342. for (auto *EA : PredED.EncounteredAssumes)
  2343. ED.addAssumeInst(A, *EA);
  2344. for (auto *AB : PredED.AlignedBarriers)
  2345. ED.addAlignedBarrier(A, *AB);
  2346. }
  2347. void AAExecutionDomainFunction::mergeInPredecessor(
  2348. Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,
  2349. bool InitialEdgeOnly) {
  2350. ED.IsExecutedByInitialThreadOnly =
  2351. InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&
  2352. ED.IsExecutedByInitialThreadOnly);
  2353. ED.IsReachedFromAlignedBarrierOnly = ED.IsReachedFromAlignedBarrierOnly &&
  2354. PredED.IsReachedFromAlignedBarrierOnly;
  2355. ED.EncounteredNonLocalSideEffect =
  2356. ED.EncounteredNonLocalSideEffect | PredED.EncounteredNonLocalSideEffect;
  2357. if (ED.IsReachedFromAlignedBarrierOnly)
  2358. mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);
  2359. else
  2360. ED.clearAssumeInstAndAlignedBarriers();
  2361. }
  2362. void AAExecutionDomainFunction::handleEntryBB(Attributor &A,
  2363. ExecutionDomainTy &EntryBBED) {
  2364. SmallVector<ExecutionDomainTy> PredExecDomains;
  2365. auto PredForCallSite = [&](AbstractCallSite ACS) {
  2366. const auto &EDAA = A.getAAFor<AAExecutionDomain>(
  2367. *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
  2368. DepClassTy::OPTIONAL);
  2369. if (!EDAA.getState().isValidState())
  2370. return false;
  2371. PredExecDomains.emplace_back(
  2372. EDAA.getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));
  2373. return true;
  2374. };
  2375. bool AllCallSitesKnown;
  2376. if (A.checkForAllCallSites(PredForCallSite, *this,
  2377. /* RequiresAllCallSites */ true,
  2378. AllCallSitesKnown)) {
  2379. for (const auto &PredED : PredExecDomains)
  2380. mergeInPredecessor(A, EntryBBED, PredED);
  2381. } else {
  2382. // We could not find all predecessors, so this is either a kernel or a
  2383. // function with external linkage (or with some other weird uses).
  2384. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2385. if (OMPInfoCache.Kernels.count(getAnchorScope())) {
  2386. EntryBBED.IsExecutedByInitialThreadOnly = false;
  2387. EntryBBED.IsReachedFromAlignedBarrierOnly = true;
  2388. EntryBBED.EncounteredNonLocalSideEffect = false;
  2389. } else {
  2390. EntryBBED.IsExecutedByInitialThreadOnly = false;
  2391. EntryBBED.IsReachedFromAlignedBarrierOnly = false;
  2392. EntryBBED.EncounteredNonLocalSideEffect = true;
  2393. }
  2394. }
  2395. auto &FnED = BEDMap[nullptr];
  2396. FnED.IsReachingAlignedBarrierOnly &=
  2397. EntryBBED.IsReachedFromAlignedBarrierOnly;
  2398. }
  2399. ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
  2400. bool Changed = false;
  2401. // Helper to deal with an aligned barrier encountered during the forward
  2402. // traversal. \p CB is the aligned barrier, \p ED is the execution domain when
  2403. // it was encountered.
  2404. auto HandleAlignedBarrier = [&](CallBase *CB, ExecutionDomainTy &ED) {
  2405. if (CB)
  2406. Changed |= AlignedBarriers.insert(CB);
  2407. // First, update the barrier ED kept in the separate CEDMap.
  2408. auto &CallED = CEDMap[CB];
  2409. mergeInPredecessor(A, CallED, ED);
  2410. // Next adjust the ED we use for the traversal.
  2411. ED.EncounteredNonLocalSideEffect = false;
  2412. ED.IsReachedFromAlignedBarrierOnly = true;
  2413. // Aligned barrier collection has to come last.
  2414. ED.clearAssumeInstAndAlignedBarriers();
  2415. if (CB)
  2416. ED.addAlignedBarrier(A, *CB);
  2417. };
  2418. auto &LivenessAA =
  2419. A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);
  2420. // Set \p R to \V and report true if that changed \p R.
  2421. auto SetAndRecord = [&](bool &R, bool V) {
  2422. bool Eq = (R == V);
  2423. R = V;
  2424. return !Eq;
  2425. };
  2426. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2427. Function *F = getAnchorScope();
  2428. BasicBlock &EntryBB = F->getEntryBlock();
  2429. bool IsKernel = OMPInfoCache.Kernels.count(F);
  2430. SmallVector<Instruction *> SyncInstWorklist;
  2431. for (auto &RIt : *RPOT) {
  2432. BasicBlock &BB = *RIt;
  2433. bool IsEntryBB = &BB == &EntryBB;
  2434. // TODO: We use local reasoning since we don't have a divergence analysis
  2435. // running as well. We could basically allow uniform branches here.
  2436. bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;
  2437. ExecutionDomainTy ED;
  2438. // Propagate "incoming edges" into information about this block.
  2439. if (IsEntryBB) {
  2440. handleEntryBB(A, ED);
  2441. } else {
  2442. // For live non-entry blocks we only propagate
  2443. // information via live edges.
  2444. if (LivenessAA.isAssumedDead(&BB))
  2445. continue;
  2446. for (auto *PredBB : predecessors(&BB)) {
  2447. if (LivenessAA.isEdgeDead(PredBB, &BB))
  2448. continue;
  2449. bool InitialEdgeOnly = isInitialThreadOnlyEdge(
  2450. A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);
  2451. mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);
  2452. }
  2453. }
  2454. // Now we traverse the block, accumulate effects in ED and attach
  2455. // information to calls.
  2456. for (Instruction &I : BB) {
  2457. bool UsedAssumedInformation;
  2458. if (A.isAssumedDead(I, *this, &LivenessAA, UsedAssumedInformation,
  2459. /* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,
  2460. /* CheckForDeadStore */ true))
  2461. continue;
  2462. // Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the
  2463. // former is collected the latter is ignored.
  2464. if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
  2465. if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {
  2466. ED.addAssumeInst(A, *AI);
  2467. continue;
  2468. }
  2469. // TODO: Should we also collect and delete lifetime markers?
  2470. if (II->isAssumeLikeIntrinsic())
  2471. continue;
  2472. }
  2473. auto *CB = dyn_cast<CallBase>(&I);
  2474. bool IsNoSync = AA::isNoSyncInst(A, I, *this);
  2475. bool IsAlignedBarrier =
  2476. !IsNoSync && CB &&
  2477. AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);
  2478. AlignedBarrierLastInBlock &= IsNoSync;
  2479. // Next we check for calls. Aligned barriers are handled
  2480. // explicitly, everything else is kept for the backward traversal and will
  2481. // also affect our state.
  2482. if (CB) {
  2483. if (IsAlignedBarrier) {
  2484. HandleAlignedBarrier(CB, ED);
  2485. AlignedBarrierLastInBlock = true;
  2486. continue;
  2487. }
  2488. // Check the pointer(s) of a memory intrinsic explicitly.
  2489. if (isa<MemIntrinsic>(&I)) {
  2490. if (!ED.EncounteredNonLocalSideEffect &&
  2491. AA::isPotentiallyAffectedByBarrier(A, I, *this))
  2492. ED.EncounteredNonLocalSideEffect = true;
  2493. if (!IsNoSync) {
  2494. ED.IsReachedFromAlignedBarrierOnly = false;
  2495. SyncInstWorklist.push_back(&I);
  2496. }
  2497. continue;
  2498. }
  2499. // Record how we entered the call, then accumulate the effect of the
  2500. // call in ED for potential use by the callee.
  2501. auto &CallED = CEDMap[CB];
  2502. mergeInPredecessor(A, CallED, ED);
  2503. // If we have a sync-definition we can check if it starts/ends in an
  2504. // aligned barrier. If we are unsure we assume any sync breaks
  2505. // alignment.
  2506. Function *Callee = CB->getCalledFunction();
  2507. if (!IsNoSync && Callee && !Callee->isDeclaration()) {
  2508. const auto &EDAA = A.getAAFor<AAExecutionDomain>(
  2509. *this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);
  2510. if (EDAA.getState().isValidState()) {
  2511. const auto &CalleeED = EDAA.getFunctionExecutionDomain();
  2512. ED.IsReachedFromAlignedBarrierOnly =
  2513. CallED.IsReachedFromAlignedBarrierOnly =
  2514. CalleeED.IsReachedFromAlignedBarrierOnly;
  2515. AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;
  2516. if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)
  2517. ED.EncounteredNonLocalSideEffect |=
  2518. CalleeED.EncounteredNonLocalSideEffect;
  2519. else
  2520. ED.EncounteredNonLocalSideEffect =
  2521. CalleeED.EncounteredNonLocalSideEffect;
  2522. if (!CalleeED.IsReachingAlignedBarrierOnly)
  2523. SyncInstWorklist.push_back(&I);
  2524. if (CalleeED.IsReachedFromAlignedBarrierOnly)
  2525. mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);
  2526. continue;
  2527. }
  2528. }
  2529. if (!IsNoSync)
  2530. ED.IsReachedFromAlignedBarrierOnly =
  2531. CallED.IsReachedFromAlignedBarrierOnly = false;
  2532. AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;
  2533. ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();
  2534. if (!IsNoSync)
  2535. SyncInstWorklist.push_back(&I);
  2536. }
  2537. if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())
  2538. continue;
  2539. // If we have a callee we try to use fine-grained information to
  2540. // determine local side-effects.
  2541. if (CB) {
  2542. const auto &MemAA = A.getAAFor<AAMemoryLocation>(
  2543. *this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);
  2544. auto AccessPred = [&](const Instruction *I, const Value *Ptr,
  2545. AAMemoryLocation::AccessKind,
  2546. AAMemoryLocation::MemoryLocationsKind) {
  2547. return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);
  2548. };
  2549. if (MemAA.getState().isValidState() &&
  2550. MemAA.checkForAllAccessesToMemoryKind(
  2551. AccessPred, AAMemoryLocation::ALL_LOCATIONS))
  2552. continue;
  2553. }
  2554. if (!I.mayHaveSideEffects() && OMPInfoCache.isOnlyUsedByAssume(I))
  2555. continue;
  2556. if (auto *LI = dyn_cast<LoadInst>(&I))
  2557. if (LI->hasMetadata(LLVMContext::MD_invariant_load))
  2558. continue;
  2559. if (!ED.EncounteredNonLocalSideEffect &&
  2560. AA::isPotentiallyAffectedByBarrier(A, I, *this))
  2561. ED.EncounteredNonLocalSideEffect = true;
  2562. }
  2563. if (!isa<UnreachableInst>(BB.getTerminator()) &&
  2564. !BB.getTerminator()->getNumSuccessors()) {
  2565. auto &FnED = BEDMap[nullptr];
  2566. mergeInPredecessor(A, FnED, ED);
  2567. if (IsKernel)
  2568. HandleAlignedBarrier(nullptr, ED);
  2569. }
  2570. ExecutionDomainTy &StoredED = BEDMap[&BB];
  2571. ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly;
  2572. // Check if we computed anything different as part of the forward
  2573. // traversal. We do not take assumptions and aligned barriers into account
  2574. // as they do not influence the state we iterate. Backward traversal values
  2575. // are handled later on.
  2576. if (ED.IsExecutedByInitialThreadOnly !=
  2577. StoredED.IsExecutedByInitialThreadOnly ||
  2578. ED.IsReachedFromAlignedBarrierOnly !=
  2579. StoredED.IsReachedFromAlignedBarrierOnly ||
  2580. ED.EncounteredNonLocalSideEffect !=
  2581. StoredED.EncounteredNonLocalSideEffect)
  2582. Changed = true;
  2583. // Update the state with the new value.
  2584. StoredED = std::move(ED);
  2585. }
  2586. // Propagate (non-aligned) sync instruction effects backwards until the
  2587. // entry is hit or an aligned barrier.
  2588. SmallSetVector<BasicBlock *, 16> Visited;
  2589. while (!SyncInstWorklist.empty()) {
  2590. Instruction *SyncInst = SyncInstWorklist.pop_back_val();
  2591. Instruction *CurInst = SyncInst;
  2592. bool HitAlignedBarrier = false;
  2593. while ((CurInst = CurInst->getPrevNode())) {
  2594. auto *CB = dyn_cast<CallBase>(CurInst);
  2595. if (!CB)
  2596. continue;
  2597. auto &CallED = CEDMap[CB];
  2598. if (SetAndRecord(CallED.IsReachingAlignedBarrierOnly, false))
  2599. Changed = true;
  2600. HitAlignedBarrier = AlignedBarriers.count(CB);
  2601. if (HitAlignedBarrier)
  2602. break;
  2603. }
  2604. if (HitAlignedBarrier)
  2605. continue;
  2606. BasicBlock *SyncBB = SyncInst->getParent();
  2607. for (auto *PredBB : predecessors(SyncBB)) {
  2608. if (LivenessAA.isEdgeDead(PredBB, SyncBB))
  2609. continue;
  2610. if (!Visited.insert(PredBB))
  2611. continue;
  2612. SyncInstWorklist.push_back(PredBB->getTerminator());
  2613. auto &PredED = BEDMap[PredBB];
  2614. if (SetAndRecord(PredED.IsReachingAlignedBarrierOnly, false))
  2615. Changed = true;
  2616. }
  2617. if (SyncBB != &EntryBB)
  2618. continue;
  2619. auto &FnED = BEDMap[nullptr];
  2620. if (SetAndRecord(FnED.IsReachingAlignedBarrierOnly, false))
  2621. Changed = true;
  2622. }
  2623. return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;
  2624. }
  2625. /// Try to replace memory allocation calls called by a single thread with a
  2626. /// static buffer of shared memory.
  2627. struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
  2628. using Base = StateWrapper<BooleanState, AbstractAttribute>;
  2629. AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
  2630. /// Create an abstract attribute view for the position \p IRP.
  2631. static AAHeapToShared &createForPosition(const IRPosition &IRP,
  2632. Attributor &A);
  2633. /// Returns true if HeapToShared conversion is assumed to be possible.
  2634. virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
  2635. /// Returns true if HeapToShared conversion is assumed and the CB is a
  2636. /// callsite to a free operation to be removed.
  2637. virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
  2638. /// See AbstractAttribute::getName().
  2639. const std::string getName() const override { return "AAHeapToShared"; }
  2640. /// See AbstractAttribute::getIdAddr().
  2641. const char *getIdAddr() const override { return &ID; }
  2642. /// This function should return true if the type of the \p AA is
  2643. /// AAHeapToShared.
  2644. static bool classof(const AbstractAttribute *AA) {
  2645. return (AA->getIdAddr() == &ID);
  2646. }
  2647. /// Unique ID (due to the unique address)
  2648. static const char ID;
  2649. };
  2650. struct AAHeapToSharedFunction : public AAHeapToShared {
  2651. AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
  2652. : AAHeapToShared(IRP, A) {}
  2653. const std::string getAsStr() const override {
  2654. return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
  2655. " malloc calls eligible.";
  2656. }
  2657. /// See AbstractAttribute::trackStatistics().
  2658. void trackStatistics() const override {}
  2659. /// This functions finds free calls that will be removed by the
  2660. /// HeapToShared transformation.
  2661. void findPotentialRemovedFreeCalls(Attributor &A) {
  2662. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2663. auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
  2664. PotentialRemovedFreeCalls.clear();
  2665. // Update free call users of found malloc calls.
  2666. for (CallBase *CB : MallocCalls) {
  2667. SmallVector<CallBase *, 4> FreeCalls;
  2668. for (auto *U : CB->users()) {
  2669. CallBase *C = dyn_cast<CallBase>(U);
  2670. if (C && C->getCalledFunction() == FreeRFI.Declaration)
  2671. FreeCalls.push_back(C);
  2672. }
  2673. if (FreeCalls.size() != 1)
  2674. continue;
  2675. PotentialRemovedFreeCalls.insert(FreeCalls.front());
  2676. }
  2677. }
  2678. void initialize(Attributor &A) override {
  2679. if (DisableOpenMPOptDeglobalization) {
  2680. indicatePessimisticFixpoint();
  2681. return;
  2682. }
  2683. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2684. auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
  2685. if (!RFI.Declaration)
  2686. return;
  2687. Attributor::SimplifictionCallbackTy SCB =
  2688. [](const IRPosition &, const AbstractAttribute *,
  2689. bool &) -> std::optional<Value *> { return nullptr; };
  2690. Function *F = getAnchorScope();
  2691. for (User *U : RFI.Declaration->users())
  2692. if (CallBase *CB = dyn_cast<CallBase>(U)) {
  2693. if (CB->getFunction() != F)
  2694. continue;
  2695. MallocCalls.insert(CB);
  2696. A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),
  2697. SCB);
  2698. }
  2699. findPotentialRemovedFreeCalls(A);
  2700. }
  2701. bool isAssumedHeapToShared(CallBase &CB) const override {
  2702. return isValidState() && MallocCalls.count(&CB);
  2703. }
  2704. bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
  2705. return isValidState() && PotentialRemovedFreeCalls.count(&CB);
  2706. }
  2707. ChangeStatus manifest(Attributor &A) override {
  2708. if (MallocCalls.empty())
  2709. return ChangeStatus::UNCHANGED;
  2710. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2711. auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
  2712. Function *F = getAnchorScope();
  2713. auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
  2714. DepClassTy::OPTIONAL);
  2715. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  2716. for (CallBase *CB : MallocCalls) {
  2717. // Skip replacing this if HeapToStack has already claimed it.
  2718. if (HS && HS->isAssumedHeapToStack(*CB))
  2719. continue;
  2720. // Find the unique free call to remove it.
  2721. SmallVector<CallBase *, 4> FreeCalls;
  2722. for (auto *U : CB->users()) {
  2723. CallBase *C = dyn_cast<CallBase>(U);
  2724. if (C && C->getCalledFunction() == FreeCall.Declaration)
  2725. FreeCalls.push_back(C);
  2726. }
  2727. if (FreeCalls.size() != 1)
  2728. continue;
  2729. auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));
  2730. if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {
  2731. LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB
  2732. << " with shared memory."
  2733. << " Shared memory usage is limited to "
  2734. << SharedMemoryLimit << " bytes\n");
  2735. continue;
  2736. }
  2737. LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
  2738. << " with " << AllocSize->getZExtValue()
  2739. << " bytes of shared memory\n");
  2740. // Create a new shared memory buffer of the same size as the allocation
  2741. // and replace all the uses of the original allocation with it.
  2742. Module *M = CB->getModule();
  2743. Type *Int8Ty = Type::getInt8Ty(M->getContext());
  2744. Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
  2745. auto *SharedMem = new GlobalVariable(
  2746. *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
  2747. UndefValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,
  2748. GlobalValue::NotThreadLocal,
  2749. static_cast<unsigned>(AddressSpace::Shared));
  2750. auto *NewBuffer =
  2751. ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
  2752. auto Remark = [&](OptimizationRemark OR) {
  2753. return OR << "Replaced globalized variable with "
  2754. << ore::NV("SharedMemory", AllocSize->getZExtValue())
  2755. << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
  2756. << "of shared memory.";
  2757. };
  2758. A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
  2759. MaybeAlign Alignment = CB->getRetAlign();
  2760. assert(Alignment &&
  2761. "HeapToShared on allocation without alignment attribute");
  2762. SharedMem->setAlignment(MaybeAlign(Alignment));
  2763. A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);
  2764. A.deleteAfterManifest(*CB);
  2765. A.deleteAfterManifest(*FreeCalls.front());
  2766. SharedMemoryUsed += AllocSize->getZExtValue();
  2767. NumBytesMovedToSharedMemory = SharedMemoryUsed;
  2768. Changed = ChangeStatus::CHANGED;
  2769. }
  2770. return Changed;
  2771. }
  2772. ChangeStatus updateImpl(Attributor &A) override {
  2773. if (MallocCalls.empty())
  2774. return indicatePessimisticFixpoint();
  2775. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2776. auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
  2777. if (!RFI.Declaration)
  2778. return ChangeStatus::UNCHANGED;
  2779. Function *F = getAnchorScope();
  2780. auto NumMallocCalls = MallocCalls.size();
  2781. // Only consider malloc calls executed by a single thread with a constant.
  2782. for (User *U : RFI.Declaration->users()) {
  2783. if (CallBase *CB = dyn_cast<CallBase>(U)) {
  2784. if (CB->getCaller() != F)
  2785. continue;
  2786. if (!MallocCalls.count(CB))
  2787. continue;
  2788. if (!isa<ConstantInt>(CB->getArgOperand(0))) {
  2789. MallocCalls.remove(CB);
  2790. continue;
  2791. }
  2792. const auto &ED = A.getAAFor<AAExecutionDomain>(
  2793. *this, IRPosition::function(*F), DepClassTy::REQUIRED);
  2794. if (!ED.isExecutedByInitialThreadOnly(*CB))
  2795. MallocCalls.remove(CB);
  2796. }
  2797. }
  2798. findPotentialRemovedFreeCalls(A);
  2799. if (NumMallocCalls != MallocCalls.size())
  2800. return ChangeStatus::CHANGED;
  2801. return ChangeStatus::UNCHANGED;
  2802. }
  2803. /// Collection of all malloc calls in a function.
  2804. SmallSetVector<CallBase *, 4> MallocCalls;
  2805. /// Collection of potentially removed free calls in a function.
  2806. SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
  2807. /// The total amount of shared memory that has been used for HeapToShared.
  2808. unsigned SharedMemoryUsed = 0;
  2809. };
  2810. struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
  2811. using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
  2812. AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
  2813. /// Statistics are tracked as part of manifest for now.
  2814. void trackStatistics() const override {}
  2815. /// See AbstractAttribute::getAsStr()
  2816. const std::string getAsStr() const override {
  2817. if (!isValidState())
  2818. return "<invalid>";
  2819. return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
  2820. : "generic") +
  2821. std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
  2822. : "") +
  2823. std::string(" #PRs: ") +
  2824. (ReachedKnownParallelRegions.isValidState()
  2825. ? std::to_string(ReachedKnownParallelRegions.size())
  2826. : "<invalid>") +
  2827. ", #Unknown PRs: " +
  2828. (ReachedUnknownParallelRegions.isValidState()
  2829. ? std::to_string(ReachedUnknownParallelRegions.size())
  2830. : "<invalid>") +
  2831. ", #Reaching Kernels: " +
  2832. (ReachingKernelEntries.isValidState()
  2833. ? std::to_string(ReachingKernelEntries.size())
  2834. : "<invalid>") +
  2835. ", #ParLevels: " +
  2836. (ParallelLevels.isValidState()
  2837. ? std::to_string(ParallelLevels.size())
  2838. : "<invalid>");
  2839. }
  2840. /// Create an abstract attribute biew for the position \p IRP.
  2841. static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
  2842. /// See AbstractAttribute::getName()
  2843. const std::string getName() const override { return "AAKernelInfo"; }
  2844. /// See AbstractAttribute::getIdAddr()
  2845. const char *getIdAddr() const override { return &ID; }
  2846. /// This function should return true if the type of the \p AA is AAKernelInfo
  2847. static bool classof(const AbstractAttribute *AA) {
  2848. return (AA->getIdAddr() == &ID);
  2849. }
  2850. static const char ID;
  2851. };
  2852. /// The function kernel info abstract attribute, basically, what can we say
  2853. /// about a function with regards to the KernelInfoState.
  2854. struct AAKernelInfoFunction : AAKernelInfo {
  2855. AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
  2856. : AAKernelInfo(IRP, A) {}
  2857. SmallPtrSet<Instruction *, 4> GuardedInstructions;
  2858. SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
  2859. return GuardedInstructions;
  2860. }
  2861. /// See AbstractAttribute::initialize(...).
  2862. void initialize(Attributor &A) override {
  2863. // This is a high-level transform that might change the constant arguments
  2864. // of the init and dinit calls. We need to tell the Attributor about this
  2865. // to avoid other parts using the current constant value for simpliication.
  2866. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  2867. Function *Fn = getAnchorScope();
  2868. OMPInformationCache::RuntimeFunctionInfo &InitRFI =
  2869. OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
  2870. OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
  2871. OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
  2872. // For kernels we perform more initialization work, first we find the init
  2873. // and deinit calls.
  2874. auto StoreCallBase = [](Use &U,
  2875. OMPInformationCache::RuntimeFunctionInfo &RFI,
  2876. CallBase *&Storage) {
  2877. CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
  2878. assert(CB &&
  2879. "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
  2880. assert(!Storage &&
  2881. "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
  2882. Storage = CB;
  2883. return false;
  2884. };
  2885. InitRFI.foreachUse(
  2886. [&](Use &U, Function &) {
  2887. StoreCallBase(U, InitRFI, KernelInitCB);
  2888. return false;
  2889. },
  2890. Fn);
  2891. DeinitRFI.foreachUse(
  2892. [&](Use &U, Function &) {
  2893. StoreCallBase(U, DeinitRFI, KernelDeinitCB);
  2894. return false;
  2895. },
  2896. Fn);
  2897. // Ignore kernels without initializers such as global constructors.
  2898. if (!KernelInitCB || !KernelDeinitCB)
  2899. return;
  2900. // Add itself to the reaching kernel and set IsKernelEntry.
  2901. ReachingKernelEntries.insert(Fn);
  2902. IsKernelEntry = true;
  2903. // For kernels we might need to initialize/finalize the IsSPMD state and
  2904. // we need to register a simplification callback so that the Attributor
  2905. // knows the constant arguments to __kmpc_target_init and
  2906. // __kmpc_target_deinit might actually change.
  2907. Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
  2908. [&](const IRPosition &IRP, const AbstractAttribute *AA,
  2909. bool &UsedAssumedInformation) -> std::optional<Value *> {
  2910. // IRP represents the "use generic state machine" argument of an
  2911. // __kmpc_target_init call. We will answer this one with the internal
  2912. // state. As long as we are not in an invalid state, we will create a
  2913. // custom state machine so the value should be a `i1 false`. If we are
  2914. // in an invalid state, we won't change the value that is in the IR.
  2915. if (!ReachedKnownParallelRegions.isValidState())
  2916. return nullptr;
  2917. // If we have disabled state machine rewrites, don't make a custom one.
  2918. if (DisableOpenMPOptStateMachineRewrite)
  2919. return nullptr;
  2920. if (AA)
  2921. A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
  2922. UsedAssumedInformation = !isAtFixpoint();
  2923. auto *FalseVal =
  2924. ConstantInt::getBool(IRP.getAnchorValue().getContext(), false);
  2925. return FalseVal;
  2926. };
  2927. Attributor::SimplifictionCallbackTy ModeSimplifyCB =
  2928. [&](const IRPosition &IRP, const AbstractAttribute *AA,
  2929. bool &UsedAssumedInformation) -> std::optional<Value *> {
  2930. // IRP represents the "SPMDCompatibilityTracker" argument of an
  2931. // __kmpc_target_init or
  2932. // __kmpc_target_deinit call. We will answer this one with the internal
  2933. // state.
  2934. if (!SPMDCompatibilityTracker.isValidState())
  2935. return nullptr;
  2936. if (!SPMDCompatibilityTracker.isAtFixpoint()) {
  2937. if (AA)
  2938. A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
  2939. UsedAssumedInformation = true;
  2940. } else {
  2941. UsedAssumedInformation = false;
  2942. }
  2943. auto *Val = ConstantInt::getSigned(
  2944. IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
  2945. SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
  2946. : OMP_TGT_EXEC_MODE_GENERIC);
  2947. return Val;
  2948. };
  2949. constexpr const int InitModeArgNo = 1;
  2950. constexpr const int DeinitModeArgNo = 1;
  2951. constexpr const int InitUseStateMachineArgNo = 2;
  2952. A.registerSimplificationCallback(
  2953. IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
  2954. StateMachineSimplifyCB);
  2955. A.registerSimplificationCallback(
  2956. IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
  2957. ModeSimplifyCB);
  2958. A.registerSimplificationCallback(
  2959. IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
  2960. ModeSimplifyCB);
  2961. // Check if we know we are in SPMD-mode already.
  2962. ConstantInt *ModeArg =
  2963. dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
  2964. if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
  2965. SPMDCompatibilityTracker.indicateOptimisticFixpoint();
  2966. // This is a generic region but SPMDization is disabled so stop tracking.
  2967. else if (DisableOpenMPOptSPMDization)
  2968. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  2969. // Register virtual uses of functions we might need to preserve.
  2970. auto RegisterVirtualUse = [&](RuntimeFunction RFKind,
  2971. Attributor::VirtualUseCallbackTy &CB) {
  2972. if (!OMPInfoCache.RFIs[RFKind].Declaration)
  2973. return;
  2974. A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);
  2975. };
  2976. // Add a dependence to ensure updates if the state changes.
  2977. auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,
  2978. const AbstractAttribute *QueryingAA) {
  2979. if (QueryingAA) {
  2980. A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);
  2981. }
  2982. return true;
  2983. };
  2984. Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =
  2985. [&](Attributor &A, const AbstractAttribute *QueryingAA) {
  2986. // Whenever we create a custom state machine we will insert calls to
  2987. // __kmpc_get_hardware_num_threads_in_block,
  2988. // __kmpc_get_warp_size,
  2989. // __kmpc_barrier_simple_generic,
  2990. // __kmpc_kernel_parallel, and
  2991. // __kmpc_kernel_end_parallel.
  2992. // Not needed if we are on track for SPMDzation.
  2993. if (SPMDCompatibilityTracker.isValidState())
  2994. return AddDependence(A, this, QueryingAA);
  2995. // Not needed if we can't rewrite due to an invalid state.
  2996. if (!ReachedKnownParallelRegions.isValidState())
  2997. return AddDependence(A, this, QueryingAA);
  2998. return false;
  2999. };
  3000. // Not needed if we are pre-runtime merge.
  3001. if (!KernelInitCB->getCalledFunction()->isDeclaration()) {
  3002. RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,
  3003. CustomStateMachineUseCB);
  3004. RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);
  3005. RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,
  3006. CustomStateMachineUseCB);
  3007. RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,
  3008. CustomStateMachineUseCB);
  3009. RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,
  3010. CustomStateMachineUseCB);
  3011. }
  3012. // If we do not perform SPMDzation we do not need the virtual uses below.
  3013. if (SPMDCompatibilityTracker.isAtFixpoint())
  3014. return;
  3015. Attributor::VirtualUseCallbackTy HWThreadIdUseCB =
  3016. [&](Attributor &A, const AbstractAttribute *QueryingAA) {
  3017. // Whenever we perform SPMDzation we will insert
  3018. // __kmpc_get_hardware_thread_id_in_block calls.
  3019. if (!SPMDCompatibilityTracker.isValidState())
  3020. return AddDependence(A, this, QueryingAA);
  3021. return false;
  3022. };
  3023. RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,
  3024. HWThreadIdUseCB);
  3025. Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =
  3026. [&](Attributor &A, const AbstractAttribute *QueryingAA) {
  3027. // Whenever we perform SPMDzation with guarding we will insert
  3028. // __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is
  3029. // nothing to guard, or there are no parallel regions, we don't need
  3030. // the calls.
  3031. if (!SPMDCompatibilityTracker.isValidState())
  3032. return AddDependence(A, this, QueryingAA);
  3033. if (SPMDCompatibilityTracker.empty())
  3034. return AddDependence(A, this, QueryingAA);
  3035. if (!mayContainParallelRegion())
  3036. return AddDependence(A, this, QueryingAA);
  3037. return false;
  3038. };
  3039. RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);
  3040. }
  3041. /// Sanitize the string \p S such that it is a suitable global symbol name.
  3042. static std::string sanitizeForGlobalName(std::string S) {
  3043. std::replace_if(
  3044. S.begin(), S.end(),
  3045. [](const char C) {
  3046. return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||
  3047. (C >= '0' && C <= '9') || C == '_');
  3048. },
  3049. '.');
  3050. return S;
  3051. }
  3052. /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
  3053. /// finished now.
  3054. ChangeStatus manifest(Attributor &A) override {
  3055. // If we are not looking at a kernel with __kmpc_target_init and
  3056. // __kmpc_target_deinit call we cannot actually manifest the information.
  3057. if (!KernelInitCB || !KernelDeinitCB)
  3058. return ChangeStatus::UNCHANGED;
  3059. /// Insert nested Parallelism global variable
  3060. Function *Kernel = getAnchorScope();
  3061. Module &M = *Kernel->getParent();
  3062. Type *Int8Ty = Type::getInt8Ty(M.getContext());
  3063. new GlobalVariable(M, Int8Ty, /* isConstant */ true,
  3064. GlobalValue::WeakAnyLinkage,
  3065. ConstantInt::get(Int8Ty, NestedParallelism ? 1 : 0),
  3066. Kernel->getName() + "_nested_parallelism");
  3067. // If we can we change the execution mode to SPMD-mode otherwise we build a
  3068. // custom state machine.
  3069. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  3070. if (!changeToSPMDMode(A, Changed)) {
  3071. if (!KernelInitCB->getCalledFunction()->isDeclaration())
  3072. return buildCustomStateMachine(A);
  3073. }
  3074. return Changed;
  3075. }
  3076. void insertInstructionGuardsHelper(Attributor &A) {
  3077. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3078. auto CreateGuardedRegion = [&](Instruction *RegionStartI,
  3079. Instruction *RegionEndI) {
  3080. LoopInfo *LI = nullptr;
  3081. DominatorTree *DT = nullptr;
  3082. MemorySSAUpdater *MSU = nullptr;
  3083. using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
  3084. BasicBlock *ParentBB = RegionStartI->getParent();
  3085. Function *Fn = ParentBB->getParent();
  3086. Module &M = *Fn->getParent();
  3087. // Create all the blocks and logic.
  3088. // ParentBB:
  3089. // goto RegionCheckTidBB
  3090. // RegionCheckTidBB:
  3091. // Tid = __kmpc_hardware_thread_id()
  3092. // if (Tid != 0)
  3093. // goto RegionBarrierBB
  3094. // RegionStartBB:
  3095. // <execute instructions guarded>
  3096. // goto RegionEndBB
  3097. // RegionEndBB:
  3098. // <store escaping values to shared mem>
  3099. // goto RegionBarrierBB
  3100. // RegionBarrierBB:
  3101. // __kmpc_simple_barrier_spmd()
  3102. // // second barrier is omitted if lacking escaping values.
  3103. // <load escaping values from shared mem>
  3104. // __kmpc_simple_barrier_spmd()
  3105. // goto RegionExitBB
  3106. // RegionExitBB:
  3107. // <execute rest of instructions>
  3108. BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
  3109. DT, LI, MSU, "region.guarded.end");
  3110. BasicBlock *RegionBarrierBB =
  3111. SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
  3112. MSU, "region.barrier");
  3113. BasicBlock *RegionExitBB =
  3114. SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
  3115. DT, LI, MSU, "region.exit");
  3116. BasicBlock *RegionStartBB =
  3117. SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
  3118. assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
  3119. "Expected a different CFG");
  3120. BasicBlock *RegionCheckTidBB = SplitBlock(
  3121. ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
  3122. // Register basic blocks with the Attributor.
  3123. A.registerManifestAddedBasicBlock(*RegionEndBB);
  3124. A.registerManifestAddedBasicBlock(*RegionBarrierBB);
  3125. A.registerManifestAddedBasicBlock(*RegionExitBB);
  3126. A.registerManifestAddedBasicBlock(*RegionStartBB);
  3127. A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
  3128. bool HasBroadcastValues = false;
  3129. // Find escaping outputs from the guarded region to outside users and
  3130. // broadcast their values to them.
  3131. for (Instruction &I : *RegionStartBB) {
  3132. SmallPtrSet<Instruction *, 4> OutsideUsers;
  3133. for (User *Usr : I.users()) {
  3134. Instruction &UsrI = *cast<Instruction>(Usr);
  3135. if (UsrI.getParent() != RegionStartBB)
  3136. OutsideUsers.insert(&UsrI);
  3137. }
  3138. if (OutsideUsers.empty())
  3139. continue;
  3140. HasBroadcastValues = true;
  3141. // Emit a global variable in shared memory to store the broadcasted
  3142. // value.
  3143. auto *SharedMem = new GlobalVariable(
  3144. M, I.getType(), /* IsConstant */ false,
  3145. GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
  3146. sanitizeForGlobalName(
  3147. (I.getName() + ".guarded.output.alloc").str()),
  3148. nullptr, GlobalValue::NotThreadLocal,
  3149. static_cast<unsigned>(AddressSpace::Shared));
  3150. // Emit a store instruction to update the value.
  3151. new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
  3152. LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
  3153. I.getName() + ".guarded.output.load",
  3154. RegionBarrierBB->getTerminator());
  3155. // Emit a load instruction and replace uses of the output value.
  3156. for (Instruction *UsrI : OutsideUsers)
  3157. UsrI->replaceUsesOfWith(&I, LoadI);
  3158. }
  3159. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3160. // Go to tid check BB in ParentBB.
  3161. const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
  3162. ParentBB->getTerminator()->eraseFromParent();
  3163. OpenMPIRBuilder::LocationDescription Loc(
  3164. InsertPointTy(ParentBB, ParentBB->end()), DL);
  3165. OMPInfoCache.OMPBuilder.updateToLocation(Loc);
  3166. uint32_t SrcLocStrSize;
  3167. auto *SrcLocStr =
  3168. OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);
  3169. Value *Ident =
  3170. OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);
  3171. BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
  3172. // Add check for Tid in RegionCheckTidBB
  3173. RegionCheckTidBB->getTerminator()->eraseFromParent();
  3174. OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
  3175. InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
  3176. OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
  3177. FunctionCallee HardwareTidFn =
  3178. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3179. M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
  3180. CallInst *Tid =
  3181. OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
  3182. Tid->setDebugLoc(DL);
  3183. OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);
  3184. Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
  3185. OMPInfoCache.OMPBuilder.Builder
  3186. .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
  3187. ->setDebugLoc(DL);
  3188. // First barrier for synchronization, ensures main thread has updated
  3189. // values.
  3190. FunctionCallee BarrierFn =
  3191. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3192. M, OMPRTL___kmpc_barrier_simple_spmd);
  3193. OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
  3194. RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
  3195. CallInst *Barrier =
  3196. OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});
  3197. Barrier->setDebugLoc(DL);
  3198. OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
  3199. // Second barrier ensures workers have read broadcast values.
  3200. if (HasBroadcastValues) {
  3201. CallInst *Barrier = CallInst::Create(BarrierFn, {Ident, Tid}, "",
  3202. RegionBarrierBB->getTerminator());
  3203. Barrier->setDebugLoc(DL);
  3204. OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
  3205. }
  3206. };
  3207. auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
  3208. SmallPtrSet<BasicBlock *, 8> Visited;
  3209. for (Instruction *GuardedI : SPMDCompatibilityTracker) {
  3210. BasicBlock *BB = GuardedI->getParent();
  3211. if (!Visited.insert(BB).second)
  3212. continue;
  3213. SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
  3214. Instruction *LastEffect = nullptr;
  3215. BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
  3216. while (++IP != IPEnd) {
  3217. if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
  3218. continue;
  3219. Instruction *I = &*IP;
  3220. if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
  3221. continue;
  3222. if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
  3223. LastEffect = nullptr;
  3224. continue;
  3225. }
  3226. if (LastEffect)
  3227. Reorders.push_back({I, LastEffect});
  3228. LastEffect = &*IP;
  3229. }
  3230. for (auto &Reorder : Reorders)
  3231. Reorder.first->moveBefore(Reorder.second);
  3232. }
  3233. SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
  3234. for (Instruction *GuardedI : SPMDCompatibilityTracker) {
  3235. BasicBlock *BB = GuardedI->getParent();
  3236. auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
  3237. IRPosition::function(*GuardedI->getFunction()), nullptr,
  3238. DepClassTy::NONE);
  3239. assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
  3240. auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
  3241. // Continue if instruction is already guarded.
  3242. if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
  3243. continue;
  3244. Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
  3245. for (Instruction &I : *BB) {
  3246. // If instruction I needs to be guarded update the guarded region
  3247. // bounds.
  3248. if (SPMDCompatibilityTracker.contains(&I)) {
  3249. CalleeAAFunction.getGuardedInstructions().insert(&I);
  3250. if (GuardedRegionStart)
  3251. GuardedRegionEnd = &I;
  3252. else
  3253. GuardedRegionStart = GuardedRegionEnd = &I;
  3254. continue;
  3255. }
  3256. // Instruction I does not need guarding, store
  3257. // any region found and reset bounds.
  3258. if (GuardedRegionStart) {
  3259. GuardedRegions.push_back(
  3260. std::make_pair(GuardedRegionStart, GuardedRegionEnd));
  3261. GuardedRegionStart = nullptr;
  3262. GuardedRegionEnd = nullptr;
  3263. }
  3264. }
  3265. }
  3266. for (auto &GR : GuardedRegions)
  3267. CreateGuardedRegion(GR.first, GR.second);
  3268. }
  3269. void forceSingleThreadPerWorkgroupHelper(Attributor &A) {
  3270. // Only allow 1 thread per workgroup to continue executing the user code.
  3271. //
  3272. // InitCB = __kmpc_target_init(...)
  3273. // ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();
  3274. // if (ThreadIdInBlock != 0) return;
  3275. // UserCode:
  3276. // // user code
  3277. //
  3278. auto &Ctx = getAnchorValue().getContext();
  3279. Function *Kernel = getAssociatedFunction();
  3280. assert(Kernel && "Expected an associated function!");
  3281. // Create block for user code to branch to from initial block.
  3282. BasicBlock *InitBB = KernelInitCB->getParent();
  3283. BasicBlock *UserCodeBB = InitBB->splitBasicBlock(
  3284. KernelInitCB->getNextNode(), "main.thread.user_code");
  3285. BasicBlock *ReturnBB =
  3286. BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);
  3287. // Register blocks with attributor:
  3288. A.registerManifestAddedBasicBlock(*InitBB);
  3289. A.registerManifestAddedBasicBlock(*UserCodeBB);
  3290. A.registerManifestAddedBasicBlock(*ReturnBB);
  3291. // Debug location:
  3292. const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
  3293. ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);
  3294. InitBB->getTerminator()->eraseFromParent();
  3295. // Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.
  3296. Module &M = *Kernel->getParent();
  3297. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3298. FunctionCallee ThreadIdInBlockFn =
  3299. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3300. M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
  3301. // Get thread ID in block.
  3302. CallInst *ThreadIdInBlock =
  3303. CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);
  3304. OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);
  3305. ThreadIdInBlock->setDebugLoc(DLoc);
  3306. // Eliminate all threads in the block with ID not equal to 0:
  3307. Instruction *IsMainThread =
  3308. ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,
  3309. ConstantInt::get(ThreadIdInBlock->getType(), 0),
  3310. "thread.is_main", InitBB);
  3311. IsMainThread->setDebugLoc(DLoc);
  3312. BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);
  3313. }
  3314. bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {
  3315. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3316. // We cannot change to SPMD mode if the runtime functions aren't availible.
  3317. if (!OMPInfoCache.runtimeFnsAvailable(
  3318. {OMPRTL___kmpc_get_hardware_thread_id_in_block,
  3319. OMPRTL___kmpc_barrier_simple_spmd}))
  3320. return false;
  3321. if (!SPMDCompatibilityTracker.isAssumed()) {
  3322. for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
  3323. if (!NonCompatibleI)
  3324. continue;
  3325. // Skip diagnostics on calls to known OpenMP runtime functions for now.
  3326. if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
  3327. if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
  3328. continue;
  3329. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  3330. ORA << "Value has potential side effects preventing SPMD-mode "
  3331. "execution";
  3332. if (isa<CallBase>(NonCompatibleI)) {
  3333. ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
  3334. "the called function to override";
  3335. }
  3336. return ORA << ".";
  3337. };
  3338. A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
  3339. Remark);
  3340. LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
  3341. << *NonCompatibleI << "\n");
  3342. }
  3343. return false;
  3344. }
  3345. // Get the actual kernel, could be the caller of the anchor scope if we have
  3346. // a debug wrapper.
  3347. Function *Kernel = getAnchorScope();
  3348. if (Kernel->hasLocalLinkage()) {
  3349. assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");
  3350. auto *CB = cast<CallBase>(Kernel->user_back());
  3351. Kernel = CB->getCaller();
  3352. }
  3353. assert(OMPInfoCache.Kernels.count(Kernel) && "Expected kernel function!");
  3354. // Check if the kernel is already in SPMD mode, if so, return success.
  3355. GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
  3356. (Kernel->getName() + "_exec_mode").str());
  3357. assert(ExecMode && "Kernel without exec mode?");
  3358. assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
  3359. // Set the global exec mode flag to indicate SPMD-Generic mode.
  3360. assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
  3361. "ExecMode is not an integer!");
  3362. const int8_t ExecModeVal =
  3363. cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
  3364. if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)
  3365. return true;
  3366. // We will now unconditionally modify the IR, indicate a change.
  3367. Changed = ChangeStatus::CHANGED;
  3368. // Do not use instruction guards when no parallel is present inside
  3369. // the target region.
  3370. if (mayContainParallelRegion())
  3371. insertInstructionGuardsHelper(A);
  3372. else
  3373. forceSingleThreadPerWorkgroupHelper(A);
  3374. // Adjust the global exec mode flag that tells the runtime what mode this
  3375. // kernel is executed in.
  3376. assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
  3377. "Initially non-SPMD kernel has SPMD exec mode!");
  3378. ExecMode->setInitializer(
  3379. ConstantInt::get(ExecMode->getInitializer()->getType(),
  3380. ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
  3381. // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
  3382. const int InitModeArgNo = 1;
  3383. const int DeinitModeArgNo = 1;
  3384. const int InitUseStateMachineArgNo = 2;
  3385. auto &Ctx = getAnchorValue().getContext();
  3386. A.changeUseAfterManifest(
  3387. KernelInitCB->getArgOperandUse(InitModeArgNo),
  3388. *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
  3389. OMP_TGT_EXEC_MODE_SPMD));
  3390. A.changeUseAfterManifest(
  3391. KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
  3392. *ConstantInt::getBool(Ctx, false));
  3393. A.changeUseAfterManifest(
  3394. KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
  3395. *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
  3396. OMP_TGT_EXEC_MODE_SPMD));
  3397. ++NumOpenMPTargetRegionKernelsSPMD;
  3398. auto Remark = [&](OptimizationRemark OR) {
  3399. return OR << "Transformed generic-mode kernel to SPMD-mode.";
  3400. };
  3401. A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
  3402. return true;
  3403. };
  3404. ChangeStatus buildCustomStateMachine(Attributor &A) {
  3405. // If we have disabled state machine rewrites, don't make a custom one
  3406. if (DisableOpenMPOptStateMachineRewrite)
  3407. return ChangeStatus::UNCHANGED;
  3408. // Don't rewrite the state machine if we are not in a valid state.
  3409. if (!ReachedKnownParallelRegions.isValidState())
  3410. return ChangeStatus::UNCHANGED;
  3411. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3412. if (!OMPInfoCache.runtimeFnsAvailable(
  3413. {OMPRTL___kmpc_get_hardware_num_threads_in_block,
  3414. OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,
  3415. OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))
  3416. return ChangeStatus::UNCHANGED;
  3417. const int InitModeArgNo = 1;
  3418. const int InitUseStateMachineArgNo = 2;
  3419. // Check if the current configuration is non-SPMD and generic state machine.
  3420. // If we already have SPMD mode or a custom state machine we do not need to
  3421. // go any further. If it is anything but a constant something is weird and
  3422. // we give up.
  3423. ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
  3424. KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
  3425. ConstantInt *Mode =
  3426. dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
  3427. // If we are stuck with generic mode, try to create a custom device (=GPU)
  3428. // state machine which is specialized for the parallel regions that are
  3429. // reachable by the kernel.
  3430. if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
  3431. (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
  3432. return ChangeStatus::UNCHANGED;
  3433. // If not SPMD mode, indicate we use a custom state machine now.
  3434. auto &Ctx = getAnchorValue().getContext();
  3435. auto *FalseVal = ConstantInt::getBool(Ctx, false);
  3436. A.changeUseAfterManifest(
  3437. KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
  3438. // If we don't actually need a state machine we are done here. This can
  3439. // happen if there simply are no parallel regions. In the resulting kernel
  3440. // all worker threads will simply exit right away, leaving the main thread
  3441. // to do the work alone.
  3442. if (!mayContainParallelRegion()) {
  3443. ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
  3444. auto Remark = [&](OptimizationRemark OR) {
  3445. return OR << "Removing unused state machine from generic-mode kernel.";
  3446. };
  3447. A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
  3448. return ChangeStatus::CHANGED;
  3449. }
  3450. // Keep track in the statistics of our new shiny custom state machine.
  3451. if (ReachedUnknownParallelRegions.empty()) {
  3452. ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
  3453. auto Remark = [&](OptimizationRemark OR) {
  3454. return OR << "Rewriting generic-mode kernel with a customized state "
  3455. "machine.";
  3456. };
  3457. A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
  3458. } else {
  3459. ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
  3460. auto Remark = [&](OptimizationRemarkAnalysis OR) {
  3461. return OR << "Generic-mode kernel is executed with a customized state "
  3462. "machine that requires a fallback.";
  3463. };
  3464. A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
  3465. // Tell the user why we ended up with a fallback.
  3466. for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
  3467. if (!UnknownParallelRegionCB)
  3468. continue;
  3469. auto Remark = [&](OptimizationRemarkAnalysis ORA) {
  3470. return ORA << "Call may contain unknown parallel regions. Use "
  3471. << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
  3472. "override.";
  3473. };
  3474. A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
  3475. "OMP133", Remark);
  3476. }
  3477. }
  3478. // Create all the blocks:
  3479. //
  3480. // InitCB = __kmpc_target_init(...)
  3481. // BlockHwSize =
  3482. // __kmpc_get_hardware_num_threads_in_block();
  3483. // WarpSize = __kmpc_get_warp_size();
  3484. // BlockSize = BlockHwSize - WarpSize;
  3485. // IsWorkerCheckBB: bool IsWorker = InitCB != -1;
  3486. // if (IsWorker) {
  3487. // if (InitCB >= BlockSize) return;
  3488. // SMBeginBB: __kmpc_barrier_simple_generic(...);
  3489. // void *WorkFn;
  3490. // bool Active = __kmpc_kernel_parallel(&WorkFn);
  3491. // if (!WorkFn) return;
  3492. // SMIsActiveCheckBB: if (Active) {
  3493. // SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)
  3494. // ParFn0(...);
  3495. // SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)
  3496. // ParFn1(...);
  3497. // ...
  3498. // SMIfCascadeCurrentBB: else
  3499. // ((WorkFnTy*)WorkFn)(...);
  3500. // SMEndParallelBB: __kmpc_kernel_end_parallel(...);
  3501. // }
  3502. // SMDoneBB: __kmpc_barrier_simple_generic(...);
  3503. // goto SMBeginBB;
  3504. // }
  3505. // UserCodeEntryBB: // user code
  3506. // __kmpc_target_deinit(...)
  3507. //
  3508. Function *Kernel = getAssociatedFunction();
  3509. assert(Kernel && "Expected an associated function!");
  3510. BasicBlock *InitBB = KernelInitCB->getParent();
  3511. BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
  3512. KernelInitCB->getNextNode(), "thread.user_code.check");
  3513. BasicBlock *IsWorkerCheckBB =
  3514. BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);
  3515. BasicBlock *StateMachineBeginBB = BasicBlock::Create(
  3516. Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
  3517. BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
  3518. Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
  3519. BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
  3520. Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
  3521. BasicBlock *StateMachineIfCascadeCurrentBB =
  3522. BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
  3523. Kernel, UserCodeEntryBB);
  3524. BasicBlock *StateMachineEndParallelBB =
  3525. BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
  3526. Kernel, UserCodeEntryBB);
  3527. BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
  3528. Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
  3529. A.registerManifestAddedBasicBlock(*InitBB);
  3530. A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
  3531. A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);
  3532. A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
  3533. A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
  3534. A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
  3535. A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
  3536. A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
  3537. A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
  3538. const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
  3539. ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
  3540. InitBB->getTerminator()->eraseFromParent();
  3541. Instruction *IsWorker =
  3542. ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
  3543. ConstantInt::get(KernelInitCB->getType(), -1),
  3544. "thread.is_worker", InitBB);
  3545. IsWorker->setDebugLoc(DLoc);
  3546. BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);
  3547. Module &M = *Kernel->getParent();
  3548. FunctionCallee BlockHwSizeFn =
  3549. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3550. M, OMPRTL___kmpc_get_hardware_num_threads_in_block);
  3551. FunctionCallee WarpSizeFn =
  3552. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3553. M, OMPRTL___kmpc_get_warp_size);
  3554. CallInst *BlockHwSize =
  3555. CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);
  3556. OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);
  3557. BlockHwSize->setDebugLoc(DLoc);
  3558. CallInst *WarpSize =
  3559. CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);
  3560. OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);
  3561. WarpSize->setDebugLoc(DLoc);
  3562. Instruction *BlockSize = BinaryOperator::CreateSub(
  3563. BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);
  3564. BlockSize->setDebugLoc(DLoc);
  3565. Instruction *IsMainOrWorker = ICmpInst::Create(
  3566. ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,
  3567. "thread.is_main_or_worker", IsWorkerCheckBB);
  3568. IsMainOrWorker->setDebugLoc(DLoc);
  3569. BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,
  3570. IsMainOrWorker, IsWorkerCheckBB);
  3571. // Create local storage for the work function pointer.
  3572. const DataLayout &DL = M.getDataLayout();
  3573. Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
  3574. Instruction *WorkFnAI =
  3575. new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
  3576. "worker.work_fn.addr", &Kernel->getEntryBlock().front());
  3577. WorkFnAI->setDebugLoc(DLoc);
  3578. OMPInfoCache.OMPBuilder.updateToLocation(
  3579. OpenMPIRBuilder::LocationDescription(
  3580. IRBuilder<>::InsertPoint(StateMachineBeginBB,
  3581. StateMachineBeginBB->end()),
  3582. DLoc));
  3583. Value *Ident = KernelInitCB->getArgOperand(0);
  3584. Value *GTid = KernelInitCB;
  3585. FunctionCallee BarrierFn =
  3586. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3587. M, OMPRTL___kmpc_barrier_simple_generic);
  3588. CallInst *Barrier =
  3589. CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);
  3590. OMPInfoCache.setCallingConvention(BarrierFn, Barrier);
  3591. Barrier->setDebugLoc(DLoc);
  3592. if (WorkFnAI->getType()->getPointerAddressSpace() !=
  3593. (unsigned int)AddressSpace::Generic) {
  3594. WorkFnAI = new AddrSpaceCastInst(
  3595. WorkFnAI,
  3596. PointerType::getWithSamePointeeType(
  3597. cast<PointerType>(WorkFnAI->getType()),
  3598. (unsigned int)AddressSpace::Generic),
  3599. WorkFnAI->getName() + ".generic", StateMachineBeginBB);
  3600. WorkFnAI->setDebugLoc(DLoc);
  3601. }
  3602. FunctionCallee KernelParallelFn =
  3603. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3604. M, OMPRTL___kmpc_kernel_parallel);
  3605. CallInst *IsActiveWorker = CallInst::Create(
  3606. KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
  3607. OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);
  3608. IsActiveWorker->setDebugLoc(DLoc);
  3609. Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
  3610. StateMachineBeginBB);
  3611. WorkFn->setDebugLoc(DLoc);
  3612. FunctionType *ParallelRegionFnTy = FunctionType::get(
  3613. Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
  3614. false);
  3615. Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
  3616. WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
  3617. StateMachineBeginBB);
  3618. Instruction *IsDone =
  3619. ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
  3620. Constant::getNullValue(VoidPtrTy), "worker.is_done",
  3621. StateMachineBeginBB);
  3622. IsDone->setDebugLoc(DLoc);
  3623. BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
  3624. IsDone, StateMachineBeginBB)
  3625. ->setDebugLoc(DLoc);
  3626. BranchInst::Create(StateMachineIfCascadeCurrentBB,
  3627. StateMachineDoneBarrierBB, IsActiveWorker,
  3628. StateMachineIsActiveCheckBB)
  3629. ->setDebugLoc(DLoc);
  3630. Value *ZeroArg =
  3631. Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
  3632. // Now that we have most of the CFG skeleton it is time for the if-cascade
  3633. // that checks the function pointer we got from the runtime against the
  3634. // parallel regions we expect, if there are any.
  3635. for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {
  3636. auto *ParallelRegion = ReachedKnownParallelRegions[I];
  3637. BasicBlock *PRExecuteBB = BasicBlock::Create(
  3638. Ctx, "worker_state_machine.parallel_region.execute", Kernel,
  3639. StateMachineEndParallelBB);
  3640. CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
  3641. ->setDebugLoc(DLoc);
  3642. BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
  3643. ->setDebugLoc(DLoc);
  3644. BasicBlock *PRNextBB =
  3645. BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
  3646. Kernel, StateMachineEndParallelBB);
  3647. // Check if we need to compare the pointer at all or if we can just
  3648. // call the parallel region function.
  3649. Value *IsPR;
  3650. if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {
  3651. Instruction *CmpI = ICmpInst::Create(
  3652. ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
  3653. "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
  3654. CmpI->setDebugLoc(DLoc);
  3655. IsPR = CmpI;
  3656. } else {
  3657. IsPR = ConstantInt::getTrue(Ctx);
  3658. }
  3659. BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
  3660. StateMachineIfCascadeCurrentBB)
  3661. ->setDebugLoc(DLoc);
  3662. StateMachineIfCascadeCurrentBB = PRNextBB;
  3663. }
  3664. // At the end of the if-cascade we place the indirect function pointer call
  3665. // in case we might need it, that is if there can be parallel regions we
  3666. // have not handled in the if-cascade above.
  3667. if (!ReachedUnknownParallelRegions.empty()) {
  3668. StateMachineIfCascadeCurrentBB->setName(
  3669. "worker_state_machine.parallel_region.fallback.execute");
  3670. CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
  3671. StateMachineIfCascadeCurrentBB)
  3672. ->setDebugLoc(DLoc);
  3673. }
  3674. BranchInst::Create(StateMachineEndParallelBB,
  3675. StateMachineIfCascadeCurrentBB)
  3676. ->setDebugLoc(DLoc);
  3677. FunctionCallee EndParallelFn =
  3678. OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
  3679. M, OMPRTL___kmpc_kernel_end_parallel);
  3680. CallInst *EndParallel =
  3681. CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);
  3682. OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);
  3683. EndParallel->setDebugLoc(DLoc);
  3684. BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
  3685. ->setDebugLoc(DLoc);
  3686. CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
  3687. ->setDebugLoc(DLoc);
  3688. BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
  3689. ->setDebugLoc(DLoc);
  3690. return ChangeStatus::CHANGED;
  3691. }
  3692. /// Fixpoint iteration update function. Will be called every time a dependence
  3693. /// changed its state (and in the beginning).
  3694. ChangeStatus updateImpl(Attributor &A) override {
  3695. KernelInfoState StateBefore = getState();
  3696. // Callback to check a read/write instruction.
  3697. auto CheckRWInst = [&](Instruction &I) {
  3698. // We handle calls later.
  3699. if (isa<CallBase>(I))
  3700. return true;
  3701. // We only care about write effects.
  3702. if (!I.mayWriteToMemory())
  3703. return true;
  3704. if (auto *SI = dyn_cast<StoreInst>(&I)) {
  3705. const auto &UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(
  3706. *this, IRPosition::value(*SI->getPointerOperand()),
  3707. DepClassTy::OPTIONAL);
  3708. auto &HS = A.getAAFor<AAHeapToStack>(
  3709. *this, IRPosition::function(*I.getFunction()),
  3710. DepClassTy::OPTIONAL);
  3711. if (UnderlyingObjsAA.forallUnderlyingObjects([&](Value &Obj) {
  3712. if (AA::isAssumedThreadLocalObject(A, Obj, *this))
  3713. return true;
  3714. // Check for AAHeapToStack moved objects which must not be
  3715. // guarded.
  3716. auto *CB = dyn_cast<CallBase>(&Obj);
  3717. return CB && HS.isAssumedHeapToStack(*CB);
  3718. }))
  3719. return true;
  3720. }
  3721. // Insert instruction that needs guarding.
  3722. SPMDCompatibilityTracker.insert(&I);
  3723. return true;
  3724. };
  3725. bool UsedAssumedInformationInCheckRWInst = false;
  3726. if (!SPMDCompatibilityTracker.isAtFixpoint())
  3727. if (!A.checkForAllReadWriteInstructions(
  3728. CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
  3729. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3730. bool UsedAssumedInformationFromReachingKernels = false;
  3731. if (!IsKernelEntry) {
  3732. updateParallelLevels(A);
  3733. bool AllReachingKernelsKnown = true;
  3734. updateReachingKernelEntries(A, AllReachingKernelsKnown);
  3735. UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;
  3736. if (!SPMDCompatibilityTracker.empty()) {
  3737. if (!ParallelLevels.isValidState())
  3738. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3739. else if (!ReachingKernelEntries.isValidState())
  3740. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3741. else {
  3742. // Check if all reaching kernels agree on the mode as we can otherwise
  3743. // not guard instructions. We might not be sure about the mode so we
  3744. // we cannot fix the internal spmd-zation state either.
  3745. int SPMD = 0, Generic = 0;
  3746. for (auto *Kernel : ReachingKernelEntries) {
  3747. auto &CBAA = A.getAAFor<AAKernelInfo>(
  3748. *this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);
  3749. if (CBAA.SPMDCompatibilityTracker.isValidState() &&
  3750. CBAA.SPMDCompatibilityTracker.isAssumed())
  3751. ++SPMD;
  3752. else
  3753. ++Generic;
  3754. if (!CBAA.SPMDCompatibilityTracker.isAtFixpoint())
  3755. UsedAssumedInformationFromReachingKernels = true;
  3756. }
  3757. if (SPMD != 0 && Generic != 0)
  3758. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3759. }
  3760. }
  3761. }
  3762. // Callback to check a call instruction.
  3763. bool AllParallelRegionStatesWereFixed = true;
  3764. bool AllSPMDStatesWereFixed = true;
  3765. auto CheckCallInst = [&](Instruction &I) {
  3766. auto &CB = cast<CallBase>(I);
  3767. auto &CBAA = A.getAAFor<AAKernelInfo>(
  3768. *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
  3769. getState() ^= CBAA.getState();
  3770. AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
  3771. AllParallelRegionStatesWereFixed &=
  3772. CBAA.ReachedKnownParallelRegions.isAtFixpoint();
  3773. AllParallelRegionStatesWereFixed &=
  3774. CBAA.ReachedUnknownParallelRegions.isAtFixpoint();
  3775. return true;
  3776. };
  3777. bool UsedAssumedInformationInCheckCallInst = false;
  3778. if (!A.checkForAllCallLikeInstructions(
  3779. CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {
  3780. LLVM_DEBUG(dbgs() << TAG
  3781. << "Failed to visit all call-like instructions!\n";);
  3782. return indicatePessimisticFixpoint();
  3783. }
  3784. // If we haven't used any assumed information for the reached parallel
  3785. // region states we can fix it.
  3786. if (!UsedAssumedInformationInCheckCallInst &&
  3787. AllParallelRegionStatesWereFixed) {
  3788. ReachedKnownParallelRegions.indicateOptimisticFixpoint();
  3789. ReachedUnknownParallelRegions.indicateOptimisticFixpoint();
  3790. }
  3791. // If we haven't used any assumed information for the SPMD state we can fix
  3792. // it.
  3793. if (!UsedAssumedInformationInCheckRWInst &&
  3794. !UsedAssumedInformationInCheckCallInst &&
  3795. !UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)
  3796. SPMDCompatibilityTracker.indicateOptimisticFixpoint();
  3797. return StateBefore == getState() ? ChangeStatus::UNCHANGED
  3798. : ChangeStatus::CHANGED;
  3799. }
  3800. private:
  3801. /// Update info regarding reaching kernels.
  3802. void updateReachingKernelEntries(Attributor &A,
  3803. bool &AllReachingKernelsKnown) {
  3804. auto PredCallSite = [&](AbstractCallSite ACS) {
  3805. Function *Caller = ACS.getInstruction()->getFunction();
  3806. assert(Caller && "Caller is nullptr");
  3807. auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
  3808. IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
  3809. if (CAA.ReachingKernelEntries.isValidState()) {
  3810. ReachingKernelEntries ^= CAA.ReachingKernelEntries;
  3811. return true;
  3812. }
  3813. // We lost track of the caller of the associated function, any kernel
  3814. // could reach now.
  3815. ReachingKernelEntries.indicatePessimisticFixpoint();
  3816. return true;
  3817. };
  3818. if (!A.checkForAllCallSites(PredCallSite, *this,
  3819. true /* RequireAllCallSites */,
  3820. AllReachingKernelsKnown))
  3821. ReachingKernelEntries.indicatePessimisticFixpoint();
  3822. }
  3823. /// Update info regarding parallel levels.
  3824. void updateParallelLevels(Attributor &A) {
  3825. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3826. OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
  3827. OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
  3828. auto PredCallSite = [&](AbstractCallSite ACS) {
  3829. Function *Caller = ACS.getInstruction()->getFunction();
  3830. assert(Caller && "Caller is nullptr");
  3831. auto &CAA =
  3832. A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
  3833. if (CAA.ParallelLevels.isValidState()) {
  3834. // Any function that is called by `__kmpc_parallel_51` will not be
  3835. // folded as the parallel level in the function is updated. In order to
  3836. // get it right, all the analysis would depend on the implentation. That
  3837. // said, if in the future any change to the implementation, the analysis
  3838. // could be wrong. As a consequence, we are just conservative here.
  3839. if (Caller == Parallel51RFI.Declaration) {
  3840. ParallelLevels.indicatePessimisticFixpoint();
  3841. return true;
  3842. }
  3843. ParallelLevels ^= CAA.ParallelLevels;
  3844. return true;
  3845. }
  3846. // We lost track of the caller of the associated function, any kernel
  3847. // could reach now.
  3848. ParallelLevels.indicatePessimisticFixpoint();
  3849. return true;
  3850. };
  3851. bool AllCallSitesKnown = true;
  3852. if (!A.checkForAllCallSites(PredCallSite, *this,
  3853. true /* RequireAllCallSites */,
  3854. AllCallSitesKnown))
  3855. ParallelLevels.indicatePessimisticFixpoint();
  3856. }
  3857. };
  3858. /// The call site kernel info abstract attribute, basically, what can we say
  3859. /// about a call site with regards to the KernelInfoState. For now this simply
  3860. /// forwards the information from the callee.
  3861. struct AAKernelInfoCallSite : AAKernelInfo {
  3862. AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
  3863. : AAKernelInfo(IRP, A) {}
  3864. /// See AbstractAttribute::initialize(...).
  3865. void initialize(Attributor &A) override {
  3866. AAKernelInfo::initialize(A);
  3867. CallBase &CB = cast<CallBase>(getAssociatedValue());
  3868. Function *Callee = getAssociatedFunction();
  3869. auto &AssumptionAA = A.getAAFor<AAAssumptionInfo>(
  3870. *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
  3871. // Check for SPMD-mode assumptions.
  3872. if (AssumptionAA.hasAssumption("ompx_spmd_amenable")) {
  3873. SPMDCompatibilityTracker.indicateOptimisticFixpoint();
  3874. indicateOptimisticFixpoint();
  3875. }
  3876. // First weed out calls we do not care about, that is readonly/readnone
  3877. // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
  3878. // parallel region or anything else we are looking for.
  3879. if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
  3880. indicateOptimisticFixpoint();
  3881. return;
  3882. }
  3883. // Next we check if we know the callee. If it is a known OpenMP function
  3884. // we will handle them explicitly in the switch below. If it is not, we
  3885. // will use an AAKernelInfo object on the callee to gather information and
  3886. // merge that into the current state. The latter happens in the updateImpl.
  3887. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  3888. const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
  3889. if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
  3890. // Unknown caller or declarations are not analyzable, we give up.
  3891. if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
  3892. // Unknown callees might contain parallel regions, except if they have
  3893. // an appropriate assumption attached.
  3894. if (!(AssumptionAA.hasAssumption("omp_no_openmp") ||
  3895. AssumptionAA.hasAssumption("omp_no_parallelism")))
  3896. ReachedUnknownParallelRegions.insert(&CB);
  3897. // If SPMDCompatibilityTracker is not fixed, we need to give up on the
  3898. // idea we can run something unknown in SPMD-mode.
  3899. if (!SPMDCompatibilityTracker.isAtFixpoint()) {
  3900. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3901. SPMDCompatibilityTracker.insert(&CB);
  3902. }
  3903. // We have updated the state for this unknown call properly, there won't
  3904. // be any change so we indicate a fixpoint.
  3905. indicateOptimisticFixpoint();
  3906. }
  3907. // If the callee is known and can be used in IPO, we will update the state
  3908. // based on the callee state in updateImpl.
  3909. return;
  3910. }
  3911. const unsigned int WrapperFunctionArgNo = 6;
  3912. RuntimeFunction RF = It->getSecond();
  3913. switch (RF) {
  3914. // All the functions we know are compatible with SPMD mode.
  3915. case OMPRTL___kmpc_is_spmd_exec_mode:
  3916. case OMPRTL___kmpc_distribute_static_fini:
  3917. case OMPRTL___kmpc_for_static_fini:
  3918. case OMPRTL___kmpc_global_thread_num:
  3919. case OMPRTL___kmpc_get_hardware_num_threads_in_block:
  3920. case OMPRTL___kmpc_get_hardware_num_blocks:
  3921. case OMPRTL___kmpc_single:
  3922. case OMPRTL___kmpc_end_single:
  3923. case OMPRTL___kmpc_master:
  3924. case OMPRTL___kmpc_end_master:
  3925. case OMPRTL___kmpc_barrier:
  3926. case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:
  3927. case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:
  3928. case OMPRTL___kmpc_nvptx_end_reduce_nowait:
  3929. break;
  3930. case OMPRTL___kmpc_distribute_static_init_4:
  3931. case OMPRTL___kmpc_distribute_static_init_4u:
  3932. case OMPRTL___kmpc_distribute_static_init_8:
  3933. case OMPRTL___kmpc_distribute_static_init_8u:
  3934. case OMPRTL___kmpc_for_static_init_4:
  3935. case OMPRTL___kmpc_for_static_init_4u:
  3936. case OMPRTL___kmpc_for_static_init_8:
  3937. case OMPRTL___kmpc_for_static_init_8u: {
  3938. // Check the schedule and allow static schedule in SPMD mode.
  3939. unsigned ScheduleArgOpNo = 2;
  3940. auto *ScheduleTypeCI =
  3941. dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
  3942. unsigned ScheduleTypeVal =
  3943. ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
  3944. switch (OMPScheduleType(ScheduleTypeVal)) {
  3945. case OMPScheduleType::UnorderedStatic:
  3946. case OMPScheduleType::UnorderedStaticChunked:
  3947. case OMPScheduleType::OrderedDistribute:
  3948. case OMPScheduleType::OrderedDistributeChunked:
  3949. break;
  3950. default:
  3951. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3952. SPMDCompatibilityTracker.insert(&CB);
  3953. break;
  3954. };
  3955. } break;
  3956. case OMPRTL___kmpc_target_init:
  3957. KernelInitCB = &CB;
  3958. break;
  3959. case OMPRTL___kmpc_target_deinit:
  3960. KernelDeinitCB = &CB;
  3961. break;
  3962. case OMPRTL___kmpc_parallel_51:
  3963. if (auto *ParallelRegion = dyn_cast<Function>(
  3964. CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
  3965. ReachedKnownParallelRegions.insert(ParallelRegion);
  3966. /// Check nested parallelism
  3967. auto &FnAA = A.getAAFor<AAKernelInfo>(
  3968. *this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);
  3969. NestedParallelism |= !FnAA.getState().isValidState() ||
  3970. !FnAA.ReachedKnownParallelRegions.empty() ||
  3971. !FnAA.ReachedUnknownParallelRegions.empty();
  3972. break;
  3973. }
  3974. // The condition above should usually get the parallel region function
  3975. // pointer and record it. In the off chance it doesn't we assume the
  3976. // worst.
  3977. ReachedUnknownParallelRegions.insert(&CB);
  3978. break;
  3979. case OMPRTL___kmpc_omp_task:
  3980. // We do not look into tasks right now, just give up.
  3981. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3982. SPMDCompatibilityTracker.insert(&CB);
  3983. ReachedUnknownParallelRegions.insert(&CB);
  3984. break;
  3985. case OMPRTL___kmpc_alloc_shared:
  3986. case OMPRTL___kmpc_free_shared:
  3987. // Return without setting a fixpoint, to be resolved in updateImpl.
  3988. return;
  3989. default:
  3990. // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
  3991. // generally. However, they do not hide parallel regions.
  3992. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  3993. SPMDCompatibilityTracker.insert(&CB);
  3994. break;
  3995. }
  3996. // All other OpenMP runtime calls will not reach parallel regions so they
  3997. // can be safely ignored for now. Since it is a known OpenMP runtime call we
  3998. // have now modeled all effects and there is no need for any update.
  3999. indicateOptimisticFixpoint();
  4000. }
  4001. ChangeStatus updateImpl(Attributor &A) override {
  4002. // TODO: Once we have call site specific value information we can provide
  4003. // call site specific liveness information and then it makes
  4004. // sense to specialize attributes for call sites arguments instead of
  4005. // redirecting requests to the callee argument.
  4006. Function *F = getAssociatedFunction();
  4007. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  4008. const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
  4009. // If F is not a runtime function, propagate the AAKernelInfo of the callee.
  4010. if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
  4011. const IRPosition &FnPos = IRPosition::function(*F);
  4012. auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
  4013. if (getState() == FnAA.getState())
  4014. return ChangeStatus::UNCHANGED;
  4015. getState() = FnAA.getState();
  4016. return ChangeStatus::CHANGED;
  4017. }
  4018. // F is a runtime function that allocates or frees memory, check
  4019. // AAHeapToStack and AAHeapToShared.
  4020. KernelInfoState StateBefore = getState();
  4021. assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
  4022. It->getSecond() == OMPRTL___kmpc_free_shared) &&
  4023. "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
  4024. CallBase &CB = cast<CallBase>(getAssociatedValue());
  4025. auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
  4026. *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
  4027. auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
  4028. *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
  4029. RuntimeFunction RF = It->getSecond();
  4030. switch (RF) {
  4031. // If neither HeapToStack nor HeapToShared assume the call is removed,
  4032. // assume SPMD incompatibility.
  4033. case OMPRTL___kmpc_alloc_shared:
  4034. if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
  4035. !HeapToSharedAA.isAssumedHeapToShared(CB))
  4036. SPMDCompatibilityTracker.insert(&CB);
  4037. break;
  4038. case OMPRTL___kmpc_free_shared:
  4039. if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
  4040. !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
  4041. SPMDCompatibilityTracker.insert(&CB);
  4042. break;
  4043. default:
  4044. SPMDCompatibilityTracker.indicatePessimisticFixpoint();
  4045. SPMDCompatibilityTracker.insert(&CB);
  4046. }
  4047. return StateBefore == getState() ? ChangeStatus::UNCHANGED
  4048. : ChangeStatus::CHANGED;
  4049. }
  4050. };
  4051. struct AAFoldRuntimeCall
  4052. : public StateWrapper<BooleanState, AbstractAttribute> {
  4053. using Base = StateWrapper<BooleanState, AbstractAttribute>;
  4054. AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
  4055. /// Statistics are tracked as part of manifest for now.
  4056. void trackStatistics() const override {}
  4057. /// Create an abstract attribute biew for the position \p IRP.
  4058. static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
  4059. Attributor &A);
  4060. /// See AbstractAttribute::getName()
  4061. const std::string getName() const override { return "AAFoldRuntimeCall"; }
  4062. /// See AbstractAttribute::getIdAddr()
  4063. const char *getIdAddr() const override { return &ID; }
  4064. /// This function should return true if the type of the \p AA is
  4065. /// AAFoldRuntimeCall
  4066. static bool classof(const AbstractAttribute *AA) {
  4067. return (AA->getIdAddr() == &ID);
  4068. }
  4069. static const char ID;
  4070. };
  4071. struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
  4072. AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
  4073. : AAFoldRuntimeCall(IRP, A) {}
  4074. /// See AbstractAttribute::getAsStr()
  4075. const std::string getAsStr() const override {
  4076. if (!isValidState())
  4077. return "<invalid>";
  4078. std::string Str("simplified value: ");
  4079. if (!SimplifiedValue)
  4080. return Str + std::string("none");
  4081. if (!*SimplifiedValue)
  4082. return Str + std::string("nullptr");
  4083. if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))
  4084. return Str + std::to_string(CI->getSExtValue());
  4085. return Str + std::string("unknown");
  4086. }
  4087. void initialize(Attributor &A) override {
  4088. if (DisableOpenMPOptFolding)
  4089. indicatePessimisticFixpoint();
  4090. Function *Callee = getAssociatedFunction();
  4091. auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
  4092. const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
  4093. assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
  4094. "Expected a known OpenMP runtime function");
  4095. RFKind = It->getSecond();
  4096. CallBase &CB = cast<CallBase>(getAssociatedValue());
  4097. A.registerSimplificationCallback(
  4098. IRPosition::callsite_returned(CB),
  4099. [&](const IRPosition &IRP, const AbstractAttribute *AA,
  4100. bool &UsedAssumedInformation) -> std::optional<Value *> {
  4101. assert((isValidState() ||
  4102. (SimplifiedValue && *SimplifiedValue == nullptr)) &&
  4103. "Unexpected invalid state!");
  4104. if (!isAtFixpoint()) {
  4105. UsedAssumedInformation = true;
  4106. if (AA)
  4107. A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
  4108. }
  4109. return SimplifiedValue;
  4110. });
  4111. }
  4112. ChangeStatus updateImpl(Attributor &A) override {
  4113. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  4114. switch (RFKind) {
  4115. case OMPRTL___kmpc_is_spmd_exec_mode:
  4116. Changed |= foldIsSPMDExecMode(A);
  4117. break;
  4118. case OMPRTL___kmpc_parallel_level:
  4119. Changed |= foldParallelLevel(A);
  4120. break;
  4121. case OMPRTL___kmpc_get_hardware_num_threads_in_block:
  4122. Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
  4123. break;
  4124. case OMPRTL___kmpc_get_hardware_num_blocks:
  4125. Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
  4126. break;
  4127. default:
  4128. llvm_unreachable("Unhandled OpenMP runtime function!");
  4129. }
  4130. return Changed;
  4131. }
  4132. ChangeStatus manifest(Attributor &A) override {
  4133. ChangeStatus Changed = ChangeStatus::UNCHANGED;
  4134. if (SimplifiedValue && *SimplifiedValue) {
  4135. Instruction &I = *getCtxI();
  4136. A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);
  4137. A.deleteAfterManifest(I);
  4138. CallBase *CB = dyn_cast<CallBase>(&I);
  4139. auto Remark = [&](OptimizationRemark OR) {
  4140. if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
  4141. return OR << "Replacing OpenMP runtime call "
  4142. << CB->getCalledFunction()->getName() << " with "
  4143. << ore::NV("FoldedValue", C->getZExtValue()) << ".";
  4144. return OR << "Replacing OpenMP runtime call "
  4145. << CB->getCalledFunction()->getName() << ".";
  4146. };
  4147. if (CB && EnableVerboseRemarks)
  4148. A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
  4149. LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
  4150. << **SimplifiedValue << "\n");
  4151. Changed = ChangeStatus::CHANGED;
  4152. }
  4153. return Changed;
  4154. }
  4155. ChangeStatus indicatePessimisticFixpoint() override {
  4156. SimplifiedValue = nullptr;
  4157. return AAFoldRuntimeCall::indicatePessimisticFixpoint();
  4158. }
  4159. private:
  4160. /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
  4161. ChangeStatus foldIsSPMDExecMode(Attributor &A) {
  4162. std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
  4163. unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
  4164. unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
  4165. auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
  4166. *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
  4167. if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
  4168. return indicatePessimisticFixpoint();
  4169. for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
  4170. auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
  4171. DepClassTy::REQUIRED);
  4172. if (!AA.isValidState()) {
  4173. SimplifiedValue = nullptr;
  4174. return indicatePessimisticFixpoint();
  4175. }
  4176. if (AA.SPMDCompatibilityTracker.isAssumed()) {
  4177. if (AA.SPMDCompatibilityTracker.isAtFixpoint())
  4178. ++KnownSPMDCount;
  4179. else
  4180. ++AssumedSPMDCount;
  4181. } else {
  4182. if (AA.SPMDCompatibilityTracker.isAtFixpoint())
  4183. ++KnownNonSPMDCount;
  4184. else
  4185. ++AssumedNonSPMDCount;
  4186. }
  4187. }
  4188. if ((AssumedSPMDCount + KnownSPMDCount) &&
  4189. (AssumedNonSPMDCount + KnownNonSPMDCount))
  4190. return indicatePessimisticFixpoint();
  4191. auto &Ctx = getAnchorValue().getContext();
  4192. if (KnownSPMDCount || AssumedSPMDCount) {
  4193. assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
  4194. "Expected only SPMD kernels!");
  4195. // All reaching kernels are in SPMD mode. Update all function calls to
  4196. // __kmpc_is_spmd_exec_mode to 1.
  4197. SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
  4198. } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
  4199. assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
  4200. "Expected only non-SPMD kernels!");
  4201. // All reaching kernels are in non-SPMD mode. Update all function
  4202. // calls to __kmpc_is_spmd_exec_mode to 0.
  4203. SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
  4204. } else {
  4205. // We have empty reaching kernels, therefore we cannot tell if the
  4206. // associated call site can be folded. At this moment, SimplifiedValue
  4207. // must be none.
  4208. assert(!SimplifiedValue && "SimplifiedValue should be none");
  4209. }
  4210. return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
  4211. : ChangeStatus::CHANGED;
  4212. }
  4213. /// Fold __kmpc_parallel_level into a constant if possible.
  4214. ChangeStatus foldParallelLevel(Attributor &A) {
  4215. std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
  4216. auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
  4217. *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
  4218. if (!CallerKernelInfoAA.ParallelLevels.isValidState())
  4219. return indicatePessimisticFixpoint();
  4220. if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
  4221. return indicatePessimisticFixpoint();
  4222. if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
  4223. assert(!SimplifiedValue &&
  4224. "SimplifiedValue should keep none at this point");
  4225. return ChangeStatus::UNCHANGED;
  4226. }
  4227. unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
  4228. unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
  4229. for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
  4230. auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
  4231. DepClassTy::REQUIRED);
  4232. if (!AA.SPMDCompatibilityTracker.isValidState())
  4233. return indicatePessimisticFixpoint();
  4234. if (AA.SPMDCompatibilityTracker.isAssumed()) {
  4235. if (AA.SPMDCompatibilityTracker.isAtFixpoint())
  4236. ++KnownSPMDCount;
  4237. else
  4238. ++AssumedSPMDCount;
  4239. } else {
  4240. if (AA.SPMDCompatibilityTracker.isAtFixpoint())
  4241. ++KnownNonSPMDCount;
  4242. else
  4243. ++AssumedNonSPMDCount;
  4244. }
  4245. }
  4246. if ((AssumedSPMDCount + KnownSPMDCount) &&
  4247. (AssumedNonSPMDCount + KnownNonSPMDCount))
  4248. return indicatePessimisticFixpoint();
  4249. auto &Ctx = getAnchorValue().getContext();
  4250. // If the caller can only be reached by SPMD kernel entries, the parallel
  4251. // level is 1. Similarly, if the caller can only be reached by non-SPMD
  4252. // kernel entries, it is 0.
  4253. if (AssumedSPMDCount || KnownSPMDCount) {
  4254. assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
  4255. "Expected only SPMD kernels!");
  4256. SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
  4257. } else {
  4258. assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
  4259. "Expected only non-SPMD kernels!");
  4260. SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
  4261. }
  4262. return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
  4263. : ChangeStatus::CHANGED;
  4264. }
  4265. ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
  4266. // Specialize only if all the calls agree with the attribute constant value
  4267. int32_t CurrentAttrValue = -1;
  4268. std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;
  4269. auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
  4270. *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
  4271. if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
  4272. return indicatePessimisticFixpoint();
  4273. // Iterate over the kernels that reach this function
  4274. for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
  4275. int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);
  4276. if (NextAttrVal == -1 ||
  4277. (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
  4278. return indicatePessimisticFixpoint();
  4279. CurrentAttrValue = NextAttrVal;
  4280. }
  4281. if (CurrentAttrValue != -1) {
  4282. auto &Ctx = getAnchorValue().getContext();
  4283. SimplifiedValue =
  4284. ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
  4285. }
  4286. return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
  4287. : ChangeStatus::CHANGED;
  4288. }
  4289. /// An optional value the associated value is assumed to fold to. That is, we
  4290. /// assume the associated value (which is a call) can be replaced by this
  4291. /// simplified value.
  4292. std::optional<Value *> SimplifiedValue;
  4293. /// The runtime function kind of the callee of the associated call site.
  4294. RuntimeFunction RFKind;
  4295. };
  4296. } // namespace
  4297. /// Register folding callsite
  4298. void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
  4299. auto &RFI = OMPInfoCache.RFIs[RF];
  4300. RFI.foreachUse(SCC, [&](Use &U, Function &F) {
  4301. CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
  4302. if (!CI)
  4303. return false;
  4304. A.getOrCreateAAFor<AAFoldRuntimeCall>(
  4305. IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
  4306. DepClassTy::NONE, /* ForceUpdate */ false,
  4307. /* UpdateAfterInit */ false);
  4308. return false;
  4309. });
  4310. }
  4311. void OpenMPOpt::registerAAs(bool IsModulePass) {
  4312. if (SCC.empty())
  4313. return;
  4314. if (IsModulePass) {
  4315. // Ensure we create the AAKernelInfo AAs first and without triggering an
  4316. // update. This will make sure we register all value simplification
  4317. // callbacks before any other AA has the chance to create an AAValueSimplify
  4318. // or similar.
  4319. auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {
  4320. A.getOrCreateAAFor<AAKernelInfo>(
  4321. IRPosition::function(Kernel), /* QueryingAA */ nullptr,
  4322. DepClassTy::NONE, /* ForceUpdate */ false,
  4323. /* UpdateAfterInit */ false);
  4324. return false;
  4325. };
  4326. OMPInformationCache::RuntimeFunctionInfo &InitRFI =
  4327. OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
  4328. InitRFI.foreachUse(SCC, CreateKernelInfoCB);
  4329. registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
  4330. registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
  4331. registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
  4332. registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
  4333. }
  4334. // Create CallSite AA for all Getters.
  4335. if (DeduceICVValues) {
  4336. for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
  4337. auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
  4338. auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
  4339. auto CreateAA = [&](Use &U, Function &Caller) {
  4340. CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
  4341. if (!CI)
  4342. return false;
  4343. auto &CB = cast<CallBase>(*CI);
  4344. IRPosition CBPos = IRPosition::callsite_function(CB);
  4345. A.getOrCreateAAFor<AAICVTracker>(CBPos);
  4346. return false;
  4347. };
  4348. GetterRFI.foreachUse(SCC, CreateAA);
  4349. }
  4350. }
  4351. // Create an ExecutionDomain AA for every function and a HeapToStack AA for
  4352. // every function if there is a device kernel.
  4353. if (!isOpenMPDevice(M))
  4354. return;
  4355. for (auto *F : SCC) {
  4356. if (F->isDeclaration())
  4357. continue;
  4358. // We look at internal functions only on-demand but if any use is not a
  4359. // direct call or outside the current set of analyzed functions, we have
  4360. // to do it eagerly.
  4361. if (F->hasLocalLinkage()) {
  4362. if (llvm::all_of(F->uses(), [this](const Use &U) {
  4363. const auto *CB = dyn_cast<CallBase>(U.getUser());
  4364. return CB && CB->isCallee(&U) &&
  4365. A.isRunOn(const_cast<Function *>(CB->getCaller()));
  4366. }))
  4367. continue;
  4368. }
  4369. registerAAsForFunction(A, *F);
  4370. }
  4371. }
  4372. void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {
  4373. if (!DisableOpenMPOptDeglobalization)
  4374. A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
  4375. A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));
  4376. if (!DisableOpenMPOptDeglobalization)
  4377. A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));
  4378. for (auto &I : instructions(F)) {
  4379. if (auto *LI = dyn_cast<LoadInst>(&I)) {
  4380. bool UsedAssumedInformation = false;
  4381. A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
  4382. UsedAssumedInformation, AA::Interprocedural);
  4383. continue;
  4384. }
  4385. if (auto *SI = dyn_cast<StoreInst>(&I)) {
  4386. A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));
  4387. continue;
  4388. }
  4389. if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
  4390. if (II->getIntrinsicID() == Intrinsic::assume) {
  4391. A.getOrCreateAAFor<AAPotentialValues>(
  4392. IRPosition::value(*II->getArgOperand(0)));
  4393. continue;
  4394. }
  4395. }
  4396. }
  4397. }
  4398. const char AAICVTracker::ID = 0;
  4399. const char AAKernelInfo::ID = 0;
  4400. const char AAExecutionDomain::ID = 0;
  4401. const char AAHeapToShared::ID = 0;
  4402. const char AAFoldRuntimeCall::ID = 0;
  4403. AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
  4404. Attributor &A) {
  4405. AAICVTracker *AA = nullptr;
  4406. switch (IRP.getPositionKind()) {
  4407. case IRPosition::IRP_INVALID:
  4408. case IRPosition::IRP_FLOAT:
  4409. case IRPosition::IRP_ARGUMENT:
  4410. case IRPosition::IRP_CALL_SITE_ARGUMENT:
  4411. llvm_unreachable("ICVTracker can only be created for function position!");
  4412. case IRPosition::IRP_RETURNED:
  4413. AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
  4414. break;
  4415. case IRPosition::IRP_CALL_SITE_RETURNED:
  4416. AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
  4417. break;
  4418. case IRPosition::IRP_CALL_SITE:
  4419. AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
  4420. break;
  4421. case IRPosition::IRP_FUNCTION:
  4422. AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
  4423. break;
  4424. }
  4425. return *AA;
  4426. }
  4427. AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
  4428. Attributor &A) {
  4429. AAExecutionDomainFunction *AA = nullptr;
  4430. switch (IRP.getPositionKind()) {
  4431. case IRPosition::IRP_INVALID:
  4432. case IRPosition::IRP_FLOAT:
  4433. case IRPosition::IRP_ARGUMENT:
  4434. case IRPosition::IRP_CALL_SITE_ARGUMENT:
  4435. case IRPosition::IRP_RETURNED:
  4436. case IRPosition::IRP_CALL_SITE_RETURNED:
  4437. case IRPosition::IRP_CALL_SITE:
  4438. llvm_unreachable(
  4439. "AAExecutionDomain can only be created for function position!");
  4440. case IRPosition::IRP_FUNCTION:
  4441. AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
  4442. break;
  4443. }
  4444. return *AA;
  4445. }
  4446. AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
  4447. Attributor &A) {
  4448. AAHeapToSharedFunction *AA = nullptr;
  4449. switch (IRP.getPositionKind()) {
  4450. case IRPosition::IRP_INVALID:
  4451. case IRPosition::IRP_FLOAT:
  4452. case IRPosition::IRP_ARGUMENT:
  4453. case IRPosition::IRP_CALL_SITE_ARGUMENT:
  4454. case IRPosition::IRP_RETURNED:
  4455. case IRPosition::IRP_CALL_SITE_RETURNED:
  4456. case IRPosition::IRP_CALL_SITE:
  4457. llvm_unreachable(
  4458. "AAHeapToShared can only be created for function position!");
  4459. case IRPosition::IRP_FUNCTION:
  4460. AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
  4461. break;
  4462. }
  4463. return *AA;
  4464. }
  4465. AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
  4466. Attributor &A) {
  4467. AAKernelInfo *AA = nullptr;
  4468. switch (IRP.getPositionKind()) {
  4469. case IRPosition::IRP_INVALID:
  4470. case IRPosition::IRP_FLOAT:
  4471. case IRPosition::IRP_ARGUMENT:
  4472. case IRPosition::IRP_RETURNED:
  4473. case IRPosition::IRP_CALL_SITE_RETURNED:
  4474. case IRPosition::IRP_CALL_SITE_ARGUMENT:
  4475. llvm_unreachable("KernelInfo can only be created for function position!");
  4476. case IRPosition::IRP_CALL_SITE:
  4477. AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
  4478. break;
  4479. case IRPosition::IRP_FUNCTION:
  4480. AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
  4481. break;
  4482. }
  4483. return *AA;
  4484. }
  4485. AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
  4486. Attributor &A) {
  4487. AAFoldRuntimeCall *AA = nullptr;
  4488. switch (IRP.getPositionKind()) {
  4489. case IRPosition::IRP_INVALID:
  4490. case IRPosition::IRP_FLOAT:
  4491. case IRPosition::IRP_ARGUMENT:
  4492. case IRPosition::IRP_RETURNED:
  4493. case IRPosition::IRP_FUNCTION:
  4494. case IRPosition::IRP_CALL_SITE:
  4495. case IRPosition::IRP_CALL_SITE_ARGUMENT:
  4496. llvm_unreachable("KernelInfo can only be created for call site position!");
  4497. case IRPosition::IRP_CALL_SITE_RETURNED:
  4498. AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
  4499. break;
  4500. }
  4501. return *AA;
  4502. }
  4503. PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
  4504. if (!containsOpenMP(M))
  4505. return PreservedAnalyses::all();
  4506. if (DisableOpenMPOptimizations)
  4507. return PreservedAnalyses::all();
  4508. FunctionAnalysisManager &FAM =
  4509. AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
  4510. KernelSet Kernels = getDeviceKernels(M);
  4511. if (PrintModuleBeforeOptimizations)
  4512. LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);
  4513. auto IsCalled = [&](Function &F) {
  4514. if (Kernels.contains(&F))
  4515. return true;
  4516. for (const User *U : F.users())
  4517. if (!isa<BlockAddress>(U))
  4518. return true;
  4519. return false;
  4520. };
  4521. auto EmitRemark = [&](Function &F) {
  4522. auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
  4523. ORE.emit([&]() {
  4524. OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
  4525. return ORA << "Could not internalize function. "
  4526. << "Some optimizations may not be possible. [OMP140]";
  4527. });
  4528. };
  4529. // Create internal copies of each function if this is a kernel Module. This
  4530. // allows iterprocedural passes to see every call edge.
  4531. DenseMap<Function *, Function *> InternalizedMap;
  4532. if (isOpenMPDevice(M)) {
  4533. SmallPtrSet<Function *, 16> InternalizeFns;
  4534. for (Function &F : M)
  4535. if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
  4536. !DisableInternalization) {
  4537. if (Attributor::isInternalizable(F)) {
  4538. InternalizeFns.insert(&F);
  4539. } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
  4540. EmitRemark(F);
  4541. }
  4542. }
  4543. Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
  4544. }
  4545. // Look at every function in the Module unless it was internalized.
  4546. SetVector<Function *> Functions;
  4547. SmallVector<Function *, 16> SCC;
  4548. for (Function &F : M)
  4549. if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {
  4550. SCC.push_back(&F);
  4551. Functions.insert(&F);
  4552. }
  4553. if (SCC.empty())
  4554. return PreservedAnalyses::all();
  4555. AnalysisGetter AG(FAM);
  4556. auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
  4557. return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
  4558. };
  4559. BumpPtrAllocator Allocator;
  4560. CallGraphUpdater CGUpdater;
  4561. bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
  4562. LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
  4563. OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, Kernels,
  4564. PostLink);
  4565. unsigned MaxFixpointIterations =
  4566. (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
  4567. AttributorConfig AC(CGUpdater);
  4568. AC.DefaultInitializeLiveInternals = false;
  4569. AC.IsModulePass = true;
  4570. AC.RewriteSignatures = false;
  4571. AC.MaxFixpointIterations = MaxFixpointIterations;
  4572. AC.OREGetter = OREGetter;
  4573. AC.PassName = DEBUG_TYPE;
  4574. AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
  4575. Attributor A(Functions, InfoCache, AC);
  4576. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
  4577. bool Changed = OMPOpt.run(true);
  4578. // Optionally inline device functions for potentially better performance.
  4579. if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
  4580. for (Function &F : M)
  4581. if (!F.isDeclaration() && !Kernels.contains(&F) &&
  4582. !F.hasFnAttribute(Attribute::NoInline))
  4583. F.addFnAttr(Attribute::AlwaysInline);
  4584. if (PrintModuleAfterOptimizations)
  4585. LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
  4586. if (Changed)
  4587. return PreservedAnalyses::none();
  4588. return PreservedAnalyses::all();
  4589. }
  4590. PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
  4591. CGSCCAnalysisManager &AM,
  4592. LazyCallGraph &CG,
  4593. CGSCCUpdateResult &UR) {
  4594. if (!containsOpenMP(*C.begin()->getFunction().getParent()))
  4595. return PreservedAnalyses::all();
  4596. if (DisableOpenMPOptimizations)
  4597. return PreservedAnalyses::all();
  4598. SmallVector<Function *, 16> SCC;
  4599. // If there are kernels in the module, we have to run on all SCC's.
  4600. for (LazyCallGraph::Node &N : C) {
  4601. Function *Fn = &N.getFunction();
  4602. SCC.push_back(Fn);
  4603. }
  4604. if (SCC.empty())
  4605. return PreservedAnalyses::all();
  4606. Module &M = *C.begin()->getFunction().getParent();
  4607. if (PrintModuleBeforeOptimizations)
  4608. LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);
  4609. KernelSet Kernels = getDeviceKernels(M);
  4610. FunctionAnalysisManager &FAM =
  4611. AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
  4612. AnalysisGetter AG(FAM);
  4613. auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
  4614. return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
  4615. };
  4616. BumpPtrAllocator Allocator;
  4617. CallGraphUpdater CGUpdater;
  4618. CGUpdater.initialize(CG, C, AM, UR);
  4619. bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
  4620. LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
  4621. SetVector<Function *> Functions(SCC.begin(), SCC.end());
  4622. OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
  4623. /*CGSCC*/ &Functions, Kernels, PostLink);
  4624. unsigned MaxFixpointIterations =
  4625. (isOpenMPDevice(M)) ? SetFixpointIterations : 32;
  4626. AttributorConfig AC(CGUpdater);
  4627. AC.DefaultInitializeLiveInternals = false;
  4628. AC.IsModulePass = false;
  4629. AC.RewriteSignatures = false;
  4630. AC.MaxFixpointIterations = MaxFixpointIterations;
  4631. AC.OREGetter = OREGetter;
  4632. AC.PassName = DEBUG_TYPE;
  4633. AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;
  4634. Attributor A(Functions, InfoCache, AC);
  4635. OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
  4636. bool Changed = OMPOpt.run(false);
  4637. if (PrintModuleAfterOptimizations)
  4638. LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
  4639. if (Changed)
  4640. return PreservedAnalyses::none();
  4641. return PreservedAnalyses::all();
  4642. }
  4643. KernelSet llvm::omp::getDeviceKernels(Module &M) {
  4644. // TODO: Create a more cross-platform way of determining device kernels.
  4645. NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");
  4646. KernelSet Kernels;
  4647. if (!MD)
  4648. return Kernels;
  4649. for (auto *Op : MD->operands()) {
  4650. if (Op->getNumOperands() < 2)
  4651. continue;
  4652. MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
  4653. if (!KindID || KindID->getString() != "kernel")
  4654. continue;
  4655. Function *KernelFn =
  4656. mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
  4657. if (!KernelFn)
  4658. continue;
  4659. ++NumOpenMPTargetRegionKernels;
  4660. Kernels.insert(KernelFn);
  4661. }
  4662. return Kernels;
  4663. }
  4664. bool llvm::omp::containsOpenMP(Module &M) {
  4665. Metadata *MD = M.getModuleFlag("openmp");
  4666. if (!MD)
  4667. return false;
  4668. return true;
  4669. }
  4670. bool llvm::omp::isOpenMPDevice(Module &M) {
  4671. Metadata *MD = M.getModuleFlag("openmp-device");
  4672. if (!MD)
  4673. return false;
  4674. return true;
  4675. }