thread_helper.h 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #pragma once
  2. #include <util/thread/pool.h>
  3. #include <util/generic/utility.h>
  4. #include <util/generic/yexception.h>
  5. #include <util/system/info.h>
  6. #include <library/cpp/deprecated/atomic/atomic.h>
  7. #include <util/system/condvar.h>
  8. #include <util/system/mutex.h>
  9. #include <util/stream/output.h>
  10. #include <functional>
  11. #include <cstdlib>
  12. class TMtpQueueHelper {
  13. public:
  14. TMtpQueueHelper() {
  15. SetThreadCount(NSystemInfo::CachedNumberOfCpus());
  16. }
  17. IThreadPool* Get() {
  18. return q.Get();
  19. }
  20. size_t GetThreadCount() {
  21. return ThreadCount;
  22. }
  23. void SetThreadCount(size_t threads) {
  24. ThreadCount = threads;
  25. q = CreateThreadPool(ThreadCount);
  26. }
  27. static TMtpQueueHelper& Instance();
  28. private:
  29. size_t ThreadCount;
  30. TAutoPtr<IThreadPool> q;
  31. };
  32. namespace NYmp {
  33. inline void SetThreadCount(size_t threads) {
  34. TMtpQueueHelper::Instance().SetThreadCount(threads);
  35. }
  36. inline size_t GetThreadCount() {
  37. return TMtpQueueHelper::Instance().GetThreadCount();
  38. }
  39. template <typename T>
  40. inline void ParallelForStaticChunk(T begin, T end, size_t chunkSize, std::function<void(T)> func) {
  41. chunkSize = Max<size_t>(chunkSize, 1);
  42. size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount();
  43. IThreadPool* queue = TMtpQueueHelper::Instance().Get();
  44. TCondVar cv;
  45. TMutex mutex;
  46. TAtomic counter = threadCount;
  47. std::exception_ptr err;
  48. for (size_t i = 0; i < threadCount; ++i) {
  49. queue->SafeAddFunc([&cv, &counter, &mutex, &func, i, begin, end, chunkSize, threadCount, &err]() {
  50. try {
  51. T currentChunkStart = begin + static_cast<decltype(T() - T())>(i * chunkSize);
  52. while (currentChunkStart < end) {
  53. T currentChunkEnd = Min<T>(end, currentChunkStart + chunkSize);
  54. for (T val = currentChunkStart; val < currentChunkEnd; ++val) {
  55. func(val);
  56. }
  57. currentChunkStart += chunkSize * threadCount;
  58. }
  59. } catch (...) {
  60. with_lock (mutex) {
  61. err = std::current_exception();
  62. }
  63. }
  64. with_lock (mutex) {
  65. if (AtomicDecrement(counter) == 0) {
  66. //last one
  67. cv.Signal();
  68. }
  69. }
  70. });
  71. }
  72. with_lock (mutex) {
  73. while (AtomicGet(counter) > 0) {
  74. cv.WaitI(mutex);
  75. }
  76. }
  77. if (err) {
  78. std::rethrow_exception(err);
  79. }
  80. }
  81. template <typename T>
  82. inline void ParallelForStaticAutoChunk(T begin, T end, std::function<void(T)> func) {
  83. const size_t taskSize = end - begin;
  84. const size_t threadCount = TMtpQueueHelper::Instance().GetThreadCount();
  85. ParallelForStaticChunk(begin, end, (taskSize + threadCount - 1) / threadCount, func);
  86. }
  87. }