zengine-old/test/co_context/include/co_context/detail/task_promise.h
2024-02-18 21:33:25 +08:00

221 lines
6.8 KiB
C++

#include <cassert>
#include <concepts>
#include <coroutine>
#include <exception>
#include <memory>
namespace co_context {
template<typename T>
class task;
namespace detail {
template<typename T>
class task_promise_base;
/**
* @brief When current task<> finishes, resume its parent.
*/
template<typename T>
struct task_final_awaiter {
static constexpr bool await_ready() noexcept { return false; }
template<std::derived_from<task_promise_base<T>> Promise>
std::coroutine_handle<>
await_suspend(std::coroutine_handle<Promise> current) noexcept {
return current.promise().parent_coro;
}
// Won't be resumed anyway
constexpr void await_resume() const noexcept {}
};
/**
* @brief When current task<> finishes, resume its parent.
*/
template<>
struct task_final_awaiter<void> {
static constexpr bool await_ready() noexcept { return false; }
template<std::derived_from<task_promise_base<void>> Promise>
std::coroutine_handle<>
await_suspend(std::coroutine_handle<Promise> current) noexcept {
auto& promise = current.promise();
std::coroutine_handle<> continuation = promise.parent_coro;
if (promise.is_detached_flag == Promise::is_detached) {
current.destroy();
}
return continuation;
}
// Won't be resumed anyway
constexpr void await_resume() const noexcept {}
};
/**
* @brief Define the behavior of all tasks.
*
* final_suspend: yes, and return to parent
*/
template<typename T>
class task_promise_base {
friend struct task_final_awaiter<T>;
public:
task_promise_base() noexcept = default;
inline constexpr std::suspend_always initial_suspend() noexcept {
return {};
}
inline constexpr task_final_awaiter<T> final_suspend() noexcept {
return {};
}
inline void set_parent(std::coroutine_handle<> continuation) noexcept {
parent_coro = continuation;
}
task_promise_base(const task_promise_base&) = delete;
task_promise_base(task_promise_base&&) = delete;
task_promise_base& operator=(const task_promise_base&) = delete;
task_promise_base& operator=(task_promise_base&&) = delete;
private:
std::coroutine_handle<> parent_coro{ std::noop_coroutine() };
};
/**
* @brief task<> with a return value
*
* @tparam T the type of the final result
*/
template<typename T>
class task_promise final : public task_promise_base<T> {
public:
task_promise() noexcept : state(value_state::mono) {};
~task_promise() {
switch (state) {
[[likely]] case value_state::value:
value.~T();
break;
case value_state::exception:
exception_ptr.~exception_ptr();
break;
default: break;
}
};
task<T> get_return_object() noexcept;
void unhandled_exception() noexcept {
exception_ptr = std::current_exception();
state = value_state::exception;
}
template<typename Value>
requires std::convertible_to<Value&&, T>
void return_value(Value&& result
) noexcept(std::is_nothrow_constructible_v<T, Value&&>) {
std::construct_at(
std::addressof(value), std::forward<Value>(result)
);
state = value_state::value;
}
// get the lvalue ref
T& result()& {
if (state == value_state::exception) [[unlikely]] {
std::rethrow_exception(exception_ptr);
}
assert(state == value_state::value);
return value;
}
// get the prvalue
T&& result()&& {
if (state == value_state::exception) [[unlikely]] {
std::rethrow_exception(exception_ptr);
}
assert(state == value_state::value);
return std::move(value);
}
private:
union {
T value;
std::exception_ptr exception_ptr;
};
enum class value_state : uint8_t { mono, value, exception } state;
};
template<>
class task_promise<void> final : public task_promise_base<void> {
friend struct task_final_awaiter<void>;
friend class task<void>;
public:
task_promise() noexcept : is_detached_flag(0) {};
~task_promise() noexcept {
if (is_detached_flag != is_detached) {
exception_ptr.~exception_ptr();
}
}
task<void> get_return_object() noexcept;
constexpr void return_void() noexcept {}
void unhandled_exception() {
if (is_detached_flag == is_detached) {
std::rethrow_exception(std::current_exception());
}
else {
exception_ptr = std::current_exception();
}
}
void result() const {
if (this->exception_ptr) [[unlikely]] {
std::rethrow_exception(this->exception_ptr);
}
}
private:
inline static constexpr uintptr_t is_detached = -1ULL;
union {
uintptr_t is_detached_flag; // set to `is_detached` if is detached.
std::exception_ptr exception_ptr;
};
}; // namespace co_context
template<typename T>
class task_promise<T&> final : public task_promise_base<T&> {
public:
task_promise() noexcept = default;
task<T&> get_return_object() noexcept;
void unhandled_exception() noexcept {
this->exception_ptr = std::current_exception();
}
void return_value(T& result) noexcept {
value = std::addressof(result);
}
T& result() {
if (exception_ptr) [[unlikely]] {
std::rethrow_exception(exception_ptr);
}
return *value;
}
private:
T* value = nullptr;
std::exception_ptr exception_ptr;
};
}
}