最近在使用 Folly 的协程做 RPC 框架,学习一下它的协程 Cancellation 实现。先举个例子,假设 RPC 框架中使用 co_await
监听端口上的新连接,要如何实现优雅退出?
folly::coro::ServerSocket ss(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
while (true) {
auto cs = co_await ss.accept();
// ...
}
Folly 中提供了 CancellationToken
来实现 co_await
动作的取消,上方代码可以改写为:
folly::CancellationSource cs;
try {
while (true) {
auto cs = co_await co_withCancellation(cs.getToken(), ss.accept());
// ...
}
} catch (folly::OperationCancelled &) {
// be cancelled.
}
// Later...
cs.requestCancellation(); // 而后 co_await co_withCancellation 抛出 OperationCancelled 异常
folly::coro::ServerSocket::accept
支持 Cancel 是因为其内部有对应的埋点:
class TaskPromiseBase {
public:
auto await_transform(co_current_cancellation_token_t) noexcept {
// co_await co_current_cancellation_token 时拿到 cancelToken_
return ready_awaitable<const folly::CancellationToken&>{cancelToken_};
}
};
template <
typename Callable,
std::enable_if_t<
std::is_constructible<CancellationCallback::VoidFunction, Callable>::
value,
int>>
inline CancellationCallback::CancellationCallback(
CancellationToken&& ct, Callable&& callable)
: next_(nullptr),
prevNext_(nullptr),
state_(nullptr),
callback_(static_cast<Callable&&>(callable)),
destructorHasRunInsideCallback_(nullptr),
callbackCompleted_(false) {
if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) {
state_ = ct.state_.release();
}
}
inline CancellationCallback::~CancellationCallback() {
if (state_ != nullptr) {
state_->removeCallback(this);
}
}
Task<std::unique_ptr<Transport>> ServerSocket::accept() {
VLOG(5) << "accept() called";
co_await folly::coro::co_safe_point;
Baton baton;
AcceptCallback cb(baton, socket_); // 有新连接时会唤醒 baton
socket_->addAcceptCallback(&cb, nullptr);
socket_->startAccepting();
auto cancelToken = co_await folly::coro::co_current_cancellation_token;
// 构造 CancellationCallback,当 cancel 发生时,执行 callback
CancellationCallback cancellationCallback{cancelToken, [&baton, this] {
this->socket_->stopAccepting();
// 被取消时也会唤醒 baton
baton.post();
}};
co_await baton;
co_await folly::coro::co_safe_point; // 返回 OperationCancelled
if (cb.error) {
co_yield co_error(std::move(cb.error));
}
co_return std::make_unique<Transport>(
socket_->getEventBase(),
AsyncSocket::newSocket(
socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}
当 Cancel 发生时,cancellationCallback
对象注册的回调函数会被执行,唤醒 baton,继而调用 co_await folly::coro::co_safe_point
。该操作会检查当前协程是否已经被取消,若是则抛出 OperationCancelled
异常。
class co_cancelled_t final {
public:
/* implicit */ operator co_error() const {
return co_error(OperationCancelled{});
}
};
FOLLY_INLINE_VARIABLE constexpr co_cancelled_t co_cancelled{};
class TaskPromiseBase {
protected:
template <typename Promise>
variant_awaitable<FinalAwaiter, ready_awaitable<>> do_safe_point(
Promise& promise) noexcept {
if (cancelToken_.isCancellationRequested()) {
// 如果已经被取消,则 yield 一个 OperationCancelled 异常
return promise.yield_value(co_cancelled);
}
return ready_awaitable<>{};
}
public:
auto await_transform(co_safe_point_t) noexcept {
return do_safe_point(*this);
}
};
TEST_F(TaskTest, SafePoint) {
folly::coro::blockingWait([]() -> folly::coro::Task<void> {
enum class step_type {
init,
before_continue_sp,
after_continue_sp,
before_cancel_sp,
after_cancel_sp,
};
step_type step = step_type::init;
folly::CancellationSource cancelSrc;
auto makeTask = [&]() -> folly::coro::Task<void> {
step = step_type::before_continue_sp;
co_await folly::coro::co_safe_point; // 未取消,直接通过
step = step_type::after_continue_sp;
cancelSrc.requestCancellation(); // 触发取消操作
step = step_type::before_cancel_sp;
co_await folly::coro::co_safe_point; // 检查到取消,抛出异常
step = step_type::after_cancel_sp;
};
auto result = co_await folly::coro::co_awaitTry( //
folly::coro::co_withCancellation(cancelSrc.getToken(), makeTask()));
EXPECT_THROW(result.value(), folly::OperationCancelled);
EXPECT_EQ(step_type::before_cancel_sp, step);
}());
}
Folly 的协程对象默认会在 co_await
时透传 cancelToken_
对象,因此在 Cancel 时可以对深层协程调用进行取消,自底向上传递 OperationCancelled
。
class TaskPromiseBase {
public:
void setCancelToken(folly::CancellationToken&& cancelToken) noexcept {
if (!hasCancelTokenOverride_) {
cancelToken_ = std::move(cancelToken);
hasCancelTokenOverride_ = true;
}
}
template <typename Awaitable>
auto await_transform(Awaitable&& awaitable) {
bypassExceptionThrowing_ =
bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
? BypassExceptionThrowing::ACTIVE
: BypassExceptionThrowing::INACTIVE;
// cancellable 的协程内 co_await 会继续透传 cancelToken_
return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync(
executor_.get_alias(),
folly::coro::co_withCancellation(
cancelToken_, static_cast<Awaitable&&>(awaitable))));
}
};
template <typename T>
class FOLLY_NODISCARD Task {
public:
friend Task co_withCancellation(
folly::CancellationToken cancelToken, Task&& task) noexcept {
DCHECK(task.coro_);
task.coro_.promise().setCancelToken(std::move(cancelToken));
return std::move(task);
}
};
CancellationToken
是一个可以传递给函数或操作的对象,允许调用者稍后请求取消操作。该对象可以通过 CancellationSource.getToken()
来获取,支持复制。从同一个原始的 CancellationSource
对象获取的 CancellationToken
对象使用引用计数指向相同的底层状态 CancellationState
,在 CancellationSource.requestCancellation()
发生时会被一起取消。
class CancellationState;
struct CancellationStateTokenDeleter {
void operator()(CancellationState*) noexcept {
state->removeTokenReference();
}
};
using CancellationStateTokenPtr =
std::unique_ptr<CancellationState, CancellationStateTokenDeleter>;
class CancellationToken {
public:
bool canBeCancelled() const noexcept {
return state_ != nullptr && state_->canBeCancelled();
}
private:
friend class CancellationCallback;
friend class CancellationSource;
detail::CancellationStateTokenPtr state_;
};
CancellationSource
对象可以构造 CancellationToken
对象,并且可以通过调用 requestCancellation
取消关联了 CancellationToken
对象的操作。
struct CancellationStateSourceDeleter {
void operator()(CancellationState*) noexcept {
state->removeSourceReference();
}
};
using CancellationStateSourcePtr =
std::unique_ptr<CancellationState, CancellationStateSourceDeleter>;
class CancellationSource {
public:
// Construct to a new, independent cancellation source.
CancellationSource(): state_(detail::CancellationState::create()) {}
CancellationToken getToken() const noexcept {
if (state_ != nullptr) {
return CancellationToken{state_->addTokenReference()};
}
return CancellationToken{};
}
bool requestCancellation() const noexcept{
if (state_ != nullptr) {
return state_->requestCancellation();
}
return false;
}
private:
detail::CancellationStateSourcePtr state_;
};
CancellationState
的实现原理并不复杂,核心原理是引用计数和 CAS。
class CancellationState {
public:
FOLLY_NODISCARD static CancellationStateSourcePtr create();
protected:
// Constructed initially with a CancellationSource reference count of 1.
CancellationState() noexcept;
// Constructed initially with a CancellationToken reference count of 1.
explicit CancellationState(FixedMergingCancellationStateTag) noexcept;
virtual ~CancellationState();
friend struct CancellationStateTokenDeleter;
friend struct CancellationStateSourceDeleter;
void removeTokenReference() noexcept;
void removeSourceReference() noexcept;
public:
FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept;
FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept;
bool tryAddCallback(
CancellationCallback* callback,
bool incrementRefCountIfSuccessful) noexcept;
void removeCallback(CancellationCallback* callback) noexcept;
bool isCancellationRequested() const noexcept;
bool canBeCancelled() const noexcept;
// Request cancellation.
// Return 'true' if cancellation had already been requested.
// Return 'false' if this was the first thread to request
// cancellation.
bool requestCancellation() noexcept;
private:
void lock() noexcept;
void unlock() noexcept;
void unlockAndIncrementTokenCount() noexcept;
void unlockAndDecrementTokenCount() noexcept;
bool tryLockAndCancelUnlessCancelled() noexcept;
template <typename Predicate>
bool tryLock(Predicate predicate) noexcept;
static bool canBeCancelled(std::uint64_t state) noexcept;
static bool isCancellationRequested(std::uint64_t state) noexcept;
static bool isLocked(std::uint64_t state) noexcept;
static constexpr std::uint64_t kCancellationRequestedFlag = 1;
static constexpr std::uint64_t kLockedFlag = 2;
static constexpr std::uint64_t kMergingFlag = 4;
static constexpr std::uint64_t kTokenReferenceCountIncrement = 8;
static constexpr std::uint64_t kSourceReferenceCountIncrement =
std::uint64_t(1) << 34u;
static constexpr std::uint64_t kTokenReferenceCountMask =
(kSourceReferenceCountIncrement - 1u) -
(kTokenReferenceCountIncrement - 1u);
static constexpr std::uint64_t kSourceReferenceCountMask =
std::numeric_limits<std::uint64_t>::max() -
(kSourceReferenceCountIncrement - 1u);
// Bit 0 - Cancellation Requested
// Bit 1 - Locked Flag
// Bit 2 - MergingCancellationState Flag
// Bits 3-33 - Token reference count (max ~2 billion)
// Bits 34-63 - Source reference count (max ~1 billion)
std::atomic<std::uint64_t> state_;
CancellationCallback* head_{nullptr};
std::thread::id signallingThreadId_;
};
inline CancellationState::CancellationState() noexcept
: state_(kSourceReferenceCountIncrement) {}
inline CancellationStateTokenPtr
CancellationState::addTokenReference() noexcept {
state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed);
return CancellationStateTokenPtr{this};
}
inline void CancellationState::removeTokenReference() noexcept {
const auto oldState = state_.fetch_sub(
kTokenReferenceCountIncrement, std::memory_order_acq_rel);
DCHECK(
(oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement);
if (oldState < (2 * kTokenReferenceCountIncrement)) {
delete this;
}
}
inline CancellationStateSourcePtr
CancellationState::addSourceReference() noexcept {
state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed);
return CancellationStateSourcePtr{this};
}
inline void CancellationState::removeSourceReference() noexcept {
const auto oldState = state_.fetch_sub(
kSourceReferenceCountIncrement, std::memory_order_acq_rel);
DCHECK(
(oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement);
if (oldState <
(kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) {
delete this;
}
}
CancellationCallback
对象会调用 tryAddCallback
接口增加 Cancel 时的回调。
bool CancellationState::tryAddCallback(
CancellationCallback* callback,
bool incrementRefCountIfSuccessful) noexcept {
// Try to acquire the lock, but abandon trying to acquire the lock if
// cancellation has already been requested (we can just immediately invoke
// the callback) or if cancellation can never be requested (we can just
// skip registration).
if (!tryLock([callback](std::uint64_t oldState) noexcept {
if (isCancellationRequested(oldState)) {
callback->invokeCallback();
return false;
}
return canBeCancelled(oldState);
})) {
return false;
}
// We've acquired the lock and cancellation has not yet been requested.
// Push this callback onto the head of the list.
if (head_ != nullptr) {
head_->prevNext_ = &callback->next_;
}
callback->next_ = head_;
callback->prevNext_ = &head_;
head_ = callback;
if (incrementRefCountIfSuccessful) {
// Combine multiple atomic operations into a single atomic operation.
unlockAndIncrementTokenCount();
} else {
unlock();
}
// Successfully added the callback.
return true;
}
template <typename Predicate>
bool CancellationState::tryLock(Predicate predicate) noexcept {
folly::detail::Sleeper sleeper;
std::uint64_t oldState = state_.load(std::memory_order_acquire);
while (true) {
if (!predicate(oldState)) {
return false;
} else if (isLocked(oldState)) {
sleeper.wait();
oldState = state_.load(std::memory_order_acquire);
} else if (state_.compare_exchange_weak(
oldState,
oldState | kLockedFlag,
std::memory_order_acquire)) {
return true;
}
}
}
调用 requestCancellation
接口时,从链表中依次取出 CancellationCallback
对象,调用回调函数。这里需要处理可能存在的 removeCallback
时的竞争关系。
inline void CancellationCallback::invokeCallback() noexcept {
// Invoke within a noexcept context so that we std::terminate() if it throws.
callback_();
}
bool CancellationState::requestCancellation() noexcept {
if (!tryLockAndCancelUnlessCancelled()) {
// Was already marked as cancelled
return true;
}
// This thread marked as cancelled and acquired the lock
signallingThreadId_ = std::this_thread::get_id();
while (head_ != nullptr) {
// Dequeue the first item on the queue.
CancellationCallback* callback = head_;
head_ = callback->next_;
const bool anyMore = head_ != nullptr;
if (anyMore) {
head_->prevNext_ = &head_;
}
// Mark this item as removed from the list.
callback->prevNext_ = nullptr;
// Don't hold the lock while executing the callback
// as we don't want to block other threads from
// deregistering callbacks.
unlock();
// TRICKY: Need to store a flag on the stack here that the callback
// can use to signal that the destructor was executed inline
// during the call.
// If the destructor was executed inline then it's not safe to
// dereference 'callback' after 'invokeCallback()' returns.
// If the destructor runs on some other thread then the other
// thread will block waiting for this thread to signal that the
// callback has finished executing.
bool destructorHasRunInsideCallback = false;
callback->destructorHasRunInsideCallback_ = &destructorHasRunInsideCallback;
callback->invokeCallback();
if (!destructorHasRunInsideCallback) {
callback->destructorHasRunInsideCallback_ = nullptr;
callback->callbackCompleted_.store(true, std::memory_order_release);
}
if (!anyMore) {
// This was the last item in the queue when we dequeued it.
// No more items should be added to the queue after we have
// marked the state as cancelled, only removed from the queue.
// Avoid acquiring/releasing the lock in this case.
return false;
}
lock();
}
unlock();
return false;
}
void CancellationState::removeCallback(
CancellationCallback* callback) noexcept {
DCHECK(callback != nullptr);
lock();
if (callback->prevNext_ != nullptr) {
// Still registered in the list => not yet executed.
// Just remove it from the list.
*callback->prevNext_ = callback->next_;
if (callback->next_ != nullptr) {
callback->next_->prevNext_ = callback->prevNext_;
}
unlockAndDecrementTokenCount();
return;
}
unlock();
// Callback has either already executed or is executing concurrently on
// another thread.
if (signallingThreadId_ == std::this_thread::get_id()) {
// Callback executed on this thread or is still currently executing
// and is deregistering itself from within the callback.
if (callback->destructorHasRunInsideCallback_ != nullptr) {
// Currently inside the callback, let the requestCancellation() method
// know the object is about to be destructed and that it should
// not try to access the object when the callback returns.
*callback->destructorHasRunInsideCallback_ = true;
}
} else {
// Callback is currently executing on another thread, block until it
// finishes executing.
folly::detail::Sleeper sleeper;
while (!callback->callbackCompleted_.load(std::memory_order_acquire)) {
sleeper.wait();
}
}
removeTokenReference();
}