await_callback.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #include <coroutine>
  2. #include <exception>
  3. #include <concepts>
  4. namespace NActors {
  5. namespace NDetail {
  6. template<class TAwaitable>
  7. decltype(auto) GetAwaiter(TAwaitable&& awaitable) {
  8. if constexpr (requires { ((TAwaitable&&) awaitable).operator co_await(); }) {
  9. return ((TAwaitable&&) awaitable).operator co_await();
  10. } else if constexpr (requires { operator co_await((TAwaitable&&) awaitable); }) {
  11. return operator co_await((TAwaitable&&) awaitable);
  12. } else {
  13. return ((TAwaitable&&) awaitable);
  14. }
  15. }
  16. template<class TAwaitable>
  17. using TAwaitResult = decltype(GetAwaiter(std::declval<TAwaitable>()).await_resume());
  18. template<class TCallback, class TResult>
  19. class TCallbackResult {
  20. public:
  21. TCallbackResult(TCallback& callback)
  22. : Callback(callback)
  23. {}
  24. template<class TRealResult>
  25. void return_value(TRealResult&& result) noexcept {
  26. Callback(std::forward<TRealResult>(result));
  27. }
  28. private:
  29. TCallback& Callback;
  30. };
  31. template<class TCallback>
  32. class TCallbackResult<TCallback, void> {
  33. public:
  34. TCallbackResult(TCallback& callback)
  35. : Callback(callback)
  36. {}
  37. void return_void() noexcept {
  38. Callback();
  39. }
  40. private:
  41. TCallback& Callback;
  42. };
  43. template<class TAwaitable, class TCallback>
  44. class TAwaitThenCallbackPromise
  45. : public TCallbackResult<TCallback, TAwaitResult<TAwaitable>>
  46. {
  47. public:
  48. using THandle = std::coroutine_handle<TAwaitThenCallbackPromise<TAwaitable, TCallback>>;
  49. TAwaitThenCallbackPromise(TAwaitable&, TCallback& callback)
  50. : TCallbackResult<TCallback, TAwaitResult<TAwaitable>>(callback)
  51. {}
  52. THandle get_return_object() noexcept {
  53. return THandle::from_promise(*this);
  54. }
  55. static auto initial_suspend() noexcept { return std::suspend_never{}; }
  56. static auto final_suspend() noexcept { return std::suspend_never{}; }
  57. void unhandled_exception() noexcept {
  58. std::terminate();
  59. }
  60. };
  61. template<class TAwaitable, class TCallback>
  62. class TAwaitThenCallback {
  63. public:
  64. using promise_type = TAwaitThenCallbackPromise<TAwaitable, TCallback>;
  65. using THandle = typename promise_type::THandle;
  66. TAwaitThenCallback(THandle) noexcept {}
  67. };
  68. } // namespace NDetail
  69. /**
  70. * Awaits the awaitable and calls callback with the result.
  71. *
  72. * Note: program terminates if awaitable or callback throw an exception.
  73. */
  74. template<class TAwaitable, class TCallback>
  75. NDetail::TAwaitThenCallback<TAwaitable, TCallback> AwaitThenCallback(TAwaitable awaitable, TCallback) {
  76. // Note: underlying promise takes callback argument address and calls it when we return
  77. co_return co_await std::move(awaitable);
  78. }
  79. } // namespace NActors