Parallel.h 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. #pragma once
  2. #ifdef __GNUC__
  3. #pragma GCC diagnostic push
  4. #pragma GCC diagnostic ignored "-Wunused-parameter"
  5. #endif
  6. //===- llvm/Support/Parallel.h - Parallel algorithms ----------------------===//
  7. //
  8. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  9. // See https://llvm.org/LICENSE.txt for license information.
  10. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  11. //
  12. //===----------------------------------------------------------------------===//
  13. #ifndef LLVM_SUPPORT_PARALLEL_H
  14. #define LLVM_SUPPORT_PARALLEL_H
  15. #include "llvm/ADT/STLExtras.h"
  16. #include "llvm/Config/llvm-config.h"
  17. #include "llvm/Support/Error.h"
  18. #include "llvm/Support/MathExtras.h"
  19. #include "llvm/Support/Threading.h"
  20. #include <algorithm>
  21. #include <condition_variable>
  22. #include <functional>
  23. #include <mutex>
  24. namespace llvm {
  25. namespace parallel {
  26. // Strategy for the default executor used by the parallel routines provided by
  27. // this file. It defaults to using all hardware threads and should be
  28. // initialized before the first use of parallel routines.
  29. extern ThreadPoolStrategy strategy;
  30. #if LLVM_ENABLE_THREADS
  31. #ifdef _WIN32
  32. // Direct access to thread_local variables from a different DLL isn't
  33. // possible with Windows Native TLS.
  34. unsigned getThreadIndex();
  35. #else
  36. // Don't access this directly, use the getThreadIndex wrapper.
  37. extern thread_local unsigned threadIndex;
  38. inline unsigned getThreadIndex() { return threadIndex; }
  39. #endif
  40. #else
  41. inline unsigned getThreadIndex() { return 0; }
  42. #endif
  43. namespace detail {
  44. class Latch {
  45. uint32_t Count;
  46. mutable std::mutex Mutex;
  47. mutable std::condition_variable Cond;
  48. public:
  49. explicit Latch(uint32_t Count = 0) : Count(Count) {}
  50. ~Latch() {
  51. // Ensure at least that sync() was called.
  52. assert(Count == 0);
  53. }
  54. void inc() {
  55. std::lock_guard<std::mutex> lock(Mutex);
  56. ++Count;
  57. }
  58. void dec() {
  59. std::lock_guard<std::mutex> lock(Mutex);
  60. if (--Count == 0)
  61. Cond.notify_all();
  62. }
  63. void sync() const {
  64. std::unique_lock<std::mutex> lock(Mutex);
  65. Cond.wait(lock, [&] { return Count == 0; });
  66. }
  67. };
  68. } // namespace detail
  69. class TaskGroup {
  70. detail::Latch L;
  71. bool Parallel;
  72. public:
  73. TaskGroup();
  74. ~TaskGroup();
  75. // Spawn a task, but does not wait for it to finish.
  76. void spawn(std::function<void()> f);
  77. // Similar to spawn, but execute the task immediately when ThreadsRequested ==
  78. // 1. The difference is to give the following pattern a more intuitive order
  79. // when single threading is requested.
  80. //
  81. // for (size_t begin = 0, i = 0, taskSize = 0;;) {
  82. // taskSize += ...
  83. // bool done = ++i == end;
  84. // if (done || taskSize >= taskSizeLimit) {
  85. // tg.execute([=] { fn(begin, i); });
  86. // if (done)
  87. // break;
  88. // begin = i;
  89. // taskSize = 0;
  90. // }
  91. // }
  92. void execute(std::function<void()> f);
  93. void sync() const { L.sync(); }
  94. };
  95. namespace detail {
  96. #if LLVM_ENABLE_THREADS
  97. const ptrdiff_t MinParallelSize = 1024;
  98. /// Inclusive median.
  99. template <class RandomAccessIterator, class Comparator>
  100. RandomAccessIterator medianOf3(RandomAccessIterator Start,
  101. RandomAccessIterator End,
  102. const Comparator &Comp) {
  103. RandomAccessIterator Mid = Start + (std::distance(Start, End) / 2);
  104. return Comp(*Start, *(End - 1))
  105. ? (Comp(*Mid, *(End - 1)) ? (Comp(*Start, *Mid) ? Mid : Start)
  106. : End - 1)
  107. : (Comp(*Mid, *Start) ? (Comp(*(End - 1), *Mid) ? Mid : End - 1)
  108. : Start);
  109. }
  110. template <class RandomAccessIterator, class Comparator>
  111. void parallel_quick_sort(RandomAccessIterator Start, RandomAccessIterator End,
  112. const Comparator &Comp, TaskGroup &TG, size_t Depth) {
  113. // Do a sequential sort for small inputs.
  114. if (std::distance(Start, End) < detail::MinParallelSize || Depth == 0) {
  115. llvm::sort(Start, End, Comp);
  116. return;
  117. }
  118. // Partition.
  119. auto Pivot = medianOf3(Start, End, Comp);
  120. // Move Pivot to End.
  121. std::swap(*(End - 1), *Pivot);
  122. Pivot = std::partition(Start, End - 1, [&Comp, End](decltype(*Start) V) {
  123. return Comp(V, *(End - 1));
  124. });
  125. // Move Pivot to middle of partition.
  126. std::swap(*Pivot, *(End - 1));
  127. // Recurse.
  128. TG.spawn([=, &Comp, &TG] {
  129. parallel_quick_sort(Start, Pivot, Comp, TG, Depth - 1);
  130. });
  131. parallel_quick_sort(Pivot + 1, End, Comp, TG, Depth - 1);
  132. }
  133. template <class RandomAccessIterator, class Comparator>
  134. void parallel_sort(RandomAccessIterator Start, RandomAccessIterator End,
  135. const Comparator &Comp) {
  136. TaskGroup TG;
  137. parallel_quick_sort(Start, End, Comp, TG,
  138. llvm::Log2_64(std::distance(Start, End)) + 1);
  139. }
  140. // TaskGroup has a relatively high overhead, so we want to reduce
  141. // the number of spawn() calls. We'll create up to 1024 tasks here.
  142. // (Note that 1024 is an arbitrary number. This code probably needs
  143. // improving to take the number of available cores into account.)
  144. enum { MaxTasksPerGroup = 1024 };
  145. template <class IterTy, class ResultTy, class ReduceFuncTy,
  146. class TransformFuncTy>
  147. ResultTy parallel_transform_reduce(IterTy Begin, IterTy End, ResultTy Init,
  148. ReduceFuncTy Reduce,
  149. TransformFuncTy Transform) {
  150. // Limit the number of tasks to MaxTasksPerGroup to limit job scheduling
  151. // overhead on large inputs.
  152. size_t NumInputs = std::distance(Begin, End);
  153. if (NumInputs == 0)
  154. return std::move(Init);
  155. size_t NumTasks = std::min(static_cast<size_t>(MaxTasksPerGroup), NumInputs);
  156. std::vector<ResultTy> Results(NumTasks, Init);
  157. {
  158. // Each task processes either TaskSize or TaskSize+1 inputs. Any inputs
  159. // remaining after dividing them equally amongst tasks are distributed as
  160. // one extra input over the first tasks.
  161. TaskGroup TG;
  162. size_t TaskSize = NumInputs / NumTasks;
  163. size_t RemainingInputs = NumInputs % NumTasks;
  164. IterTy TBegin = Begin;
  165. for (size_t TaskId = 0; TaskId < NumTasks; ++TaskId) {
  166. IterTy TEnd = TBegin + TaskSize + (TaskId < RemainingInputs ? 1 : 0);
  167. TG.spawn([=, &Transform, &Reduce, &Results] {
  168. // Reduce the result of transformation eagerly within each task.
  169. ResultTy R = Init;
  170. for (IterTy It = TBegin; It != TEnd; ++It)
  171. R = Reduce(R, Transform(*It));
  172. Results[TaskId] = R;
  173. });
  174. TBegin = TEnd;
  175. }
  176. assert(TBegin == End);
  177. }
  178. // Do a final reduction. There are at most 1024 tasks, so this only adds
  179. // constant single-threaded overhead for large inputs. Hopefully most
  180. // reductions are cheaper than the transformation.
  181. ResultTy FinalResult = std::move(Results.front());
  182. for (ResultTy &PartialResult :
  183. MutableArrayRef(Results.data() + 1, Results.size() - 1))
  184. FinalResult = Reduce(FinalResult, std::move(PartialResult));
  185. return std::move(FinalResult);
  186. }
  187. #endif
  188. } // namespace detail
  189. } // namespace parallel
  190. template <class RandomAccessIterator,
  191. class Comparator = std::less<
  192. typename std::iterator_traits<RandomAccessIterator>::value_type>>
  193. void parallelSort(RandomAccessIterator Start, RandomAccessIterator End,
  194. const Comparator &Comp = Comparator()) {
  195. #if LLVM_ENABLE_THREADS
  196. if (parallel::strategy.ThreadsRequested != 1) {
  197. parallel::detail::parallel_sort(Start, End, Comp);
  198. return;
  199. }
  200. #endif
  201. llvm::sort(Start, End, Comp);
  202. }
  203. void parallelFor(size_t Begin, size_t End, function_ref<void(size_t)> Fn);
  204. template <class IterTy, class FuncTy>
  205. void parallelForEach(IterTy Begin, IterTy End, FuncTy Fn) {
  206. parallelFor(0, End - Begin, [&](size_t I) { Fn(Begin[I]); });
  207. }
  208. template <class IterTy, class ResultTy, class ReduceFuncTy,
  209. class TransformFuncTy>
  210. ResultTy parallelTransformReduce(IterTy Begin, IterTy End, ResultTy Init,
  211. ReduceFuncTy Reduce,
  212. TransformFuncTy Transform) {
  213. #if LLVM_ENABLE_THREADS
  214. if (parallel::strategy.ThreadsRequested != 1) {
  215. return parallel::detail::parallel_transform_reduce(Begin, End, Init, Reduce,
  216. Transform);
  217. }
  218. #endif
  219. for (IterTy I = Begin; I != End; ++I)
  220. Init = Reduce(std::move(Init), Transform(*I));
  221. return std::move(Init);
  222. }
  223. // Range wrappers.
  224. template <class RangeTy,
  225. class Comparator = std::less<decltype(*std::begin(RangeTy()))>>
  226. void parallelSort(RangeTy &&R, const Comparator &Comp = Comparator()) {
  227. parallelSort(std::begin(R), std::end(R), Comp);
  228. }
  229. template <class RangeTy, class FuncTy>
  230. void parallelForEach(RangeTy &&R, FuncTy Fn) {
  231. parallelForEach(std::begin(R), std::end(R), Fn);
  232. }
  233. template <class RangeTy, class ResultTy, class ReduceFuncTy,
  234. class TransformFuncTy>
  235. ResultTy parallelTransformReduce(RangeTy &&R, ResultTy Init,
  236. ReduceFuncTy Reduce,
  237. TransformFuncTy Transform) {
  238. return parallelTransformReduce(std::begin(R), std::end(R), Init, Reduce,
  239. Transform);
  240. }
  241. // Parallel for-each, but with error handling.
  242. template <class RangeTy, class FuncTy>
  243. Error parallelForEachError(RangeTy &&R, FuncTy Fn) {
  244. // The transform_reduce algorithm requires that the initial value be copyable.
  245. // Error objects are uncopyable. We only need to copy initial success values,
  246. // so work around this mismatch via the C API. The C API represents success
  247. // values with a null pointer. The joinErrors discards null values and joins
  248. // multiple errors into an ErrorList.
  249. return unwrap(parallelTransformReduce(
  250. std::begin(R), std::end(R), wrap(Error::success()),
  251. [](LLVMErrorRef Lhs, LLVMErrorRef Rhs) {
  252. return wrap(joinErrors(unwrap(Lhs), unwrap(Rhs)));
  253. },
  254. [&Fn](auto &&V) { return wrap(Fn(V)); }));
  255. }
  256. } // namespace llvm
  257. #endif // LLVM_SUPPORT_PARALLEL_H
  258. #ifdef __GNUC__
  259. #pragma GCC diagnostic pop
  260. #endif