task_ut.cpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. #include "task.h"
  2. #include "task_group.h"
  3. #include "await_callback.h"
  4. #include <library/cpp/testing/unittest/registar.h>
  5. using namespace NActors;
  6. Y_UNIT_TEST_SUITE(Task) {
  7. TTask<void> SimpleReturnVoid() {
  8. co_return;
  9. }
  10. TTask<int> SimpleReturn42() {
  11. co_return 42;
  12. }
  13. Y_UNIT_TEST(SimpleVoidCoroutine) {
  14. bool finished = false;
  15. AwaitThenCallback(SimpleReturnVoid(), [&]() {
  16. finished = true;
  17. });
  18. UNIT_ASSERT(finished);
  19. }
  20. Y_UNIT_TEST(SimpleIntCoroutine) {
  21. std::optional<int> result;
  22. AwaitThenCallback(SimpleReturn42(), [&](int value) {
  23. result = value;
  24. });
  25. UNIT_ASSERT(result);
  26. UNIT_ASSERT_VALUES_EQUAL(*result, 42);
  27. }
  28. Y_UNIT_TEST(SimpleVoidWhenDone) {
  29. std::optional<TTaskResult<void>> result;
  30. AwaitThenCallback(SimpleReturnVoid().WhenDone(), [&](auto value) {
  31. result = std::move(value);
  32. });
  33. UNIT_ASSERT(result);
  34. result->Value();
  35. }
  36. Y_UNIT_TEST(SimpleIntWhenDone) {
  37. std::optional<TTaskResult<int>> result;
  38. AwaitThenCallback(SimpleReturn42().WhenDone(), [&](auto value) {
  39. result = std::move(value);
  40. });
  41. UNIT_ASSERT(result);
  42. UNIT_ASSERT_VALUES_EQUAL(result->Value(), 42);
  43. }
  44. template<class TCallback>
  45. TTask<int> CallTwice(TCallback callback) {
  46. int a = co_await callback();
  47. int b = co_await callback();
  48. co_return a + b;
  49. }
  50. Y_UNIT_TEST(NestedAwait) {
  51. auto task = CallTwice([]{
  52. return SimpleReturn42();
  53. });
  54. UNIT_ASSERT(task);
  55. std::optional<int> result;
  56. AwaitThenCallback(std::move(task), [&](int value) {
  57. result = value;
  58. });
  59. UNIT_ASSERT(result);
  60. UNIT_ASSERT_VALUES_EQUAL(*result, 84);
  61. }
  62. template<class T>
  63. struct TPauseState {
  64. std::coroutine_handle<> Next;
  65. std::optional<T> NextResult;
  66. ~TPauseState() {
  67. while (Next) {
  68. NextResult.reset();
  69. std::exchange(Next, {}).resume();
  70. }
  71. }
  72. struct TAwaiter {
  73. TPauseState* State;
  74. bool await_ready() const noexcept { return false; }
  75. void await_suspend(std::coroutine_handle<> c) const noexcept {
  76. State->Next = c;
  77. }
  78. T await_resume() const {
  79. if (!State->NextResult) {
  80. throw TTaskCancelled();
  81. } else {
  82. T result = std::move(*State->NextResult);
  83. State->NextResult.reset();
  84. return result;
  85. }
  86. }
  87. };
  88. auto Wait() {
  89. return TAwaiter{ this };
  90. }
  91. explicit operator bool() const {
  92. return bool(Next);
  93. }
  94. void Resume(T result) {
  95. Y_ABORT_UNLESS(Next && !Next.done());
  96. NextResult = result;
  97. std::exchange(Next, {}).resume();
  98. }
  99. void Cancel() {
  100. Y_ABORT_UNLESS(Next && !Next.done());
  101. NextResult.reset();
  102. std::exchange(Next, {}).resume();
  103. }
  104. };
  105. Y_UNIT_TEST(PauseResume) {
  106. TPauseState<int> state;
  107. auto task = CallTwice([&]{
  108. return state.Wait();
  109. });
  110. std::optional<int> result;
  111. AwaitThenCallback(std::move(task), [&](int value) {
  112. result = value;
  113. });
  114. UNIT_ASSERT(!result);
  115. UNIT_ASSERT(state);
  116. state.Resume(11);
  117. UNIT_ASSERT(!result);
  118. UNIT_ASSERT(state);
  119. state.Resume(22);
  120. UNIT_ASSERT(result);
  121. UNIT_ASSERT_VALUES_EQUAL(*result, 33);
  122. }
  123. Y_UNIT_TEST(PauseCancel) {
  124. TPauseState<int> state;
  125. auto task = CallTwice([&]{
  126. return state.Wait();
  127. });
  128. std::optional<int> result;
  129. AwaitThenCallback(std::move(task).WhenDone(), [&](TTaskResult<int>&& value) {
  130. try {
  131. result = value.Value();
  132. } catch (TTaskCancelled&) {
  133. // nothing
  134. }
  135. });
  136. UNIT_ASSERT(!result);
  137. UNIT_ASSERT(state);
  138. state.Resume(11);
  139. UNIT_ASSERT(!result);
  140. UNIT_ASSERT(state);
  141. state.Cancel();
  142. UNIT_ASSERT(!result);
  143. }
  144. Y_UNIT_TEST(GroupWithTwoSubTasks) {
  145. TPauseState<int> state1;
  146. TPauseState<int> state2;
  147. std::vector<int> results;
  148. auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
  149. TTaskGroup<int> group;
  150. group.AddTask(state1.Wait());
  151. group.AddTask(state2.Wait());
  152. int a = co_await group;
  153. results.push_back(a);
  154. int b = co_await group;
  155. results.push_back(b);
  156. co_return a + b;
  157. }(state1, state2, results);
  158. std::optional<int> result;
  159. AwaitThenCallback(std::move(task), [&](int value) {
  160. result = value;
  161. });
  162. // We must be waiting for both states
  163. UNIT_ASSERT(state1);
  164. UNIT_ASSERT(state2);
  165. state2.Resume(22);
  166. UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
  167. UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22);
  168. UNIT_ASSERT(!result);
  169. state1.Resume(11);
  170. UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u);
  171. UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11);
  172. UNIT_ASSERT(result);
  173. UNIT_ASSERT_VALUES_EQUAL(*result, 33);
  174. }
  175. Y_UNIT_TEST(GroupWithTwoSubTasksDetached) {
  176. TPauseState<int> state1;
  177. TPauseState<int> state2;
  178. std::vector<int> results;
  179. auto task = [](auto& state1, auto& state2, auto& results) -> TTask<int> {
  180. TTaskGroup<int> group;
  181. group.AddTask(state1.Wait());
  182. group.AddTask(state2.Wait());
  183. int a = co_await group;
  184. results.push_back(a);
  185. co_return a;
  186. }(state1, state2, results);
  187. std::optional<int> result;
  188. AwaitThenCallback(std::move(task), [&](int value) {
  189. result = value;
  190. });
  191. // We must be waiting for both states
  192. UNIT_ASSERT(state1);
  193. UNIT_ASSERT(state2);
  194. state2.Resume(22);
  195. UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
  196. UNIT_ASSERT_VALUES_EQUAL(results.at(0), 22);
  197. UNIT_ASSERT(result);
  198. UNIT_ASSERT_VALUES_EQUAL(*result, 22);
  199. }
  200. Y_UNIT_TEST(GroupWithTwoSubTasksOneCancelled) {
  201. TPauseState<int> state1;
  202. TPauseState<int> state2;
  203. std::vector<int> results;
  204. auto task = [](auto& state1, auto& state2, auto& results) -> TTask<void> {
  205. TTaskGroup<int> group;
  206. group.AddTask(state1.Wait());
  207. group.AddTask(state2.Wait());
  208. for (int i = 0; i < 2; ++i) {
  209. try {
  210. results.push_back(co_await group);
  211. } catch (TTaskCancelled&) {
  212. results.push_back(-1);
  213. }
  214. }
  215. }(state1, state2, results);
  216. bool finished = false;
  217. AwaitThenCallback(std::move(task), [&]() {
  218. finished = true;
  219. });
  220. UNIT_ASSERT(state1);
  221. UNIT_ASSERT(state2);
  222. state2.Cancel();
  223. UNIT_ASSERT_VALUES_EQUAL(results.size(), 1u);
  224. UNIT_ASSERT_VALUES_EQUAL(results.at(0), -1);
  225. UNIT_ASSERT(!finished);
  226. state1.Resume(11);
  227. UNIT_ASSERT_VALUES_EQUAL(results.size(), 2u);
  228. UNIT_ASSERT_VALUES_EQUAL(results.at(1), 11);
  229. UNIT_ASSERT(finished);
  230. }
  231. } // Y_UNIT_TEST_SUITE(Task)