221 lines
6.8 KiB
C++
221 lines
6.8 KiB
C++
#include <cassert>
|
|
#include <concepts>
|
|
#include <coroutine>
|
|
#include <exception>
|
|
#include <memory>
|
|
namespace co_context {
|
|
template<typename T>
|
|
class task;
|
|
namespace detail {
|
|
template<typename T>
|
|
class task_promise_base;
|
|
/**
|
|
* @brief When current task<> finishes, resume its parent.
|
|
*/
|
|
template<typename T>
|
|
struct task_final_awaiter {
|
|
static constexpr bool await_ready() noexcept { return false; }
|
|
|
|
template<std::derived_from<task_promise_base<T>> Promise>
|
|
std::coroutine_handle<>
|
|
await_suspend(std::coroutine_handle<Promise> current) noexcept {
|
|
return current.promise().parent_coro;
|
|
}
|
|
|
|
// Won't be resumed anyway
|
|
constexpr void await_resume() const noexcept {}
|
|
};
|
|
|
|
/**
|
|
* @brief When current task<> finishes, resume its parent.
|
|
*/
|
|
template<>
|
|
struct task_final_awaiter<void> {
|
|
static constexpr bool await_ready() noexcept { return false; }
|
|
|
|
template<std::derived_from<task_promise_base<void>> Promise>
|
|
std::coroutine_handle<>
|
|
await_suspend(std::coroutine_handle<Promise> current) noexcept {
|
|
auto& promise = current.promise();
|
|
|
|
std::coroutine_handle<> continuation = promise.parent_coro;
|
|
|
|
if (promise.is_detached_flag == Promise::is_detached) {
|
|
current.destroy();
|
|
}
|
|
|
|
return continuation;
|
|
}
|
|
|
|
// Won't be resumed anyway
|
|
constexpr void await_resume() const noexcept {}
|
|
};
|
|
|
|
/**
|
|
* @brief Define the behavior of all tasks.
|
|
*
|
|
* final_suspend: yes, and return to parent
|
|
*/
|
|
template<typename T>
|
|
class task_promise_base {
|
|
friend struct task_final_awaiter<T>;
|
|
|
|
public:
|
|
task_promise_base() noexcept = default;
|
|
|
|
inline constexpr std::suspend_always initial_suspend() noexcept {
|
|
return {};
|
|
}
|
|
|
|
inline constexpr task_final_awaiter<T> final_suspend() noexcept {
|
|
return {};
|
|
}
|
|
|
|
inline void set_parent(std::coroutine_handle<> continuation) noexcept {
|
|
parent_coro = continuation;
|
|
}
|
|
|
|
task_promise_base(const task_promise_base&) = delete;
|
|
task_promise_base(task_promise_base&&) = delete;
|
|
task_promise_base& operator=(const task_promise_base&) = delete;
|
|
task_promise_base& operator=(task_promise_base&&) = delete;
|
|
|
|
private:
|
|
std::coroutine_handle<> parent_coro{ std::noop_coroutine() };
|
|
};
|
|
|
|
/**
|
|
* @brief task<> with a return value
|
|
*
|
|
* @tparam T the type of the final result
|
|
*/
|
|
template<typename T>
|
|
class task_promise final : public task_promise_base<T> {
|
|
public:
|
|
task_promise() noexcept : state(value_state::mono) {};
|
|
|
|
~task_promise() {
|
|
switch (state) {
|
|
[[likely]] case value_state::value:
|
|
value.~T();
|
|
break;
|
|
case value_state::exception:
|
|
exception_ptr.~exception_ptr();
|
|
break;
|
|
default: break;
|
|
}
|
|
};
|
|
|
|
task<T> get_return_object() noexcept;
|
|
|
|
void unhandled_exception() noexcept {
|
|
exception_ptr = std::current_exception();
|
|
state = value_state::exception;
|
|
}
|
|
|
|
template<typename Value>
|
|
requires std::convertible_to<Value&&, T>
|
|
void return_value(Value&& result
|
|
) noexcept(std::is_nothrow_constructible_v<T, Value&&>) {
|
|
std::construct_at(
|
|
std::addressof(value), std::forward<Value>(result)
|
|
);
|
|
state = value_state::value;
|
|
}
|
|
|
|
// get the lvalue ref
|
|
T& result()& {
|
|
if (state == value_state::exception) [[unlikely]] {
|
|
std::rethrow_exception(exception_ptr);
|
|
}
|
|
assert(state == value_state::value);
|
|
return value;
|
|
}
|
|
|
|
// get the prvalue
|
|
T&& result()&& {
|
|
if (state == value_state::exception) [[unlikely]] {
|
|
std::rethrow_exception(exception_ptr);
|
|
}
|
|
assert(state == value_state::value);
|
|
return std::move(value);
|
|
}
|
|
|
|
private:
|
|
union {
|
|
T value;
|
|
std::exception_ptr exception_ptr;
|
|
};
|
|
enum class value_state : uint8_t { mono, value, exception } state;
|
|
};
|
|
|
|
template<>
|
|
class task_promise<void> final : public task_promise_base<void> {
|
|
friend struct task_final_awaiter<void>;
|
|
friend class task<void>;
|
|
|
|
public:
|
|
task_promise() noexcept : is_detached_flag(0) {};
|
|
|
|
~task_promise() noexcept {
|
|
if (is_detached_flag != is_detached) {
|
|
exception_ptr.~exception_ptr();
|
|
}
|
|
}
|
|
|
|
task<void> get_return_object() noexcept;
|
|
|
|
constexpr void return_void() noexcept {}
|
|
|
|
void unhandled_exception() {
|
|
if (is_detached_flag == is_detached) {
|
|
std::rethrow_exception(std::current_exception());
|
|
}
|
|
else {
|
|
exception_ptr = std::current_exception();
|
|
}
|
|
}
|
|
|
|
void result() const {
|
|
if (this->exception_ptr) [[unlikely]] {
|
|
std::rethrow_exception(this->exception_ptr);
|
|
}
|
|
}
|
|
|
|
private:
|
|
inline static constexpr uintptr_t is_detached = -1ULL;
|
|
|
|
union {
|
|
uintptr_t is_detached_flag; // set to `is_detached` if is detached.
|
|
std::exception_ptr exception_ptr;
|
|
};
|
|
}; // namespace co_context
|
|
|
|
template<typename T>
|
|
class task_promise<T&> final : public task_promise_base<T&> {
|
|
public:
|
|
task_promise() noexcept = default;
|
|
|
|
task<T&> get_return_object() noexcept;
|
|
|
|
void unhandled_exception() noexcept {
|
|
this->exception_ptr = std::current_exception();
|
|
}
|
|
|
|
void return_value(T& result) noexcept {
|
|
value = std::addressof(result);
|
|
}
|
|
|
|
T& result() {
|
|
if (exception_ptr) [[unlikely]] {
|
|
std::rethrow_exception(exception_ptr);
|
|
}
|
|
return *value;
|
|
}
|
|
|
|
private:
|
|
T* value = nullptr;
|
|
std::exception_ptr exception_ptr;
|
|
};
|
|
}
|
|
} |