Folly Coroutines Cancellation 的实现

2022.07.04
SF-Zhou

最近在使用 Folly 的协程做 RPC 框架,学习一下它的协程 Cancellation 实现。先举个例子,假设 RPC 框架中使用 co_await 监听端口上的新连接,要如何实现优雅退出?

folly::coro::ServerSocket ss(AsyncServerSocket::newSocket(&evb), std::nullopt, 16);
while (true) {
  auto cs = co_await ss.accept();
  // ...
}

Folly 中提供了 CancellationToken 来实现 co_await 动作的取消,上方代码可以改写为:

folly::CancellationSource cs;

try {
  while (true) {
    auto cs = co_await co_withCancellation(cs.getToken(), ss.accept());
    // ...
  }
} catch (folly::OperationCancelled &) {
  // be cancelled.
}

// Later...
cs.requestCancellation();  // 而后 co_await co_withCancellation 抛出 OperationCancelled 异常

1. CancellationCallback

folly::coro::ServerSocket::accept 支持 Cancel 是因为其内部有对应的埋点:

class TaskPromiseBase {
 public:
  auto await_transform(co_current_cancellation_token_t) noexcept {
    // co_await co_current_cancellation_token 时拿到 cancelToken_
    return ready_awaitable<const folly::CancellationToken&>{cancelToken_};
  }
};

template <
    typename Callable,
    std::enable_if_t<
        std::is_constructible<CancellationCallback::VoidFunction, Callable>::
            value,
        int>>
inline CancellationCallback::CancellationCallback(
    CancellationToken&& ct, Callable&& callable)
    : next_(nullptr),
      prevNext_(nullptr),
      state_(nullptr),
      callback_(static_cast<Callable&&>(callable)),
      destructorHasRunInsideCallback_(nullptr),
      callbackCompleted_(false) {
  if (ct.state_ != nullptr && ct.state_->tryAddCallback(this, false)) {
    state_ = ct.state_.release();
  }
}

inline CancellationCallback::~CancellationCallback() {
  if (state_ != nullptr) {
    state_->removeCallback(this);
  }
}

Task<std::unique_ptr<Transport>> ServerSocket::accept() {
  VLOG(5) << "accept() called";
  co_await folly::coro::co_safe_point;

  Baton baton;
  AcceptCallback cb(baton, socket_);  // 有新连接时会唤醒 baton
  socket_->addAcceptCallback(&cb, nullptr);
  socket_->startAccepting();
  auto cancelToken = co_await folly::coro::co_current_cancellation_token;
  // 构造 CancellationCallback,当 cancel 发生时,执行 callback
  CancellationCallback cancellationCallback{cancelToken, [&baton, this] {
                                              this->socket_->stopAccepting();
                                              // 被取消时也会唤醒 baton
                                              baton.post();
                                            }};

  co_await baton;
  co_await folly::coro::co_safe_point;  // 返回 OperationCancelled
  if (cb.error) {
    co_yield co_error(std::move(cb.error));
  }
  co_return std::make_unique<Transport>(
      socket_->getEventBase(),
      AsyncSocket::newSocket(
          socket_->getEventBase(), NetworkSocket::fromFd(cb.acceptFd)));
}

当 Cancel 发生时,cancellationCallback 对象注册的回调函数会被执行,唤醒 baton,继而调用 co_await folly::coro::co_safe_point。该操作会检查当前协程是否已经被取消,若是则抛出 OperationCancelled 异常。

class co_cancelled_t final {
 public:
  /* implicit */ operator co_error() const {
    return co_error(OperationCancelled{});
  }
};

FOLLY_INLINE_VARIABLE constexpr co_cancelled_t co_cancelled{};

class TaskPromiseBase {
 protected:
  template <typename Promise>
  variant_awaitable<FinalAwaiter, ready_awaitable<>> do_safe_point(
      Promise& promise) noexcept {
    if (cancelToken_.isCancellationRequested()) {
      // 如果已经被取消,则 yield 一个 OperationCancelled 异常
      return promise.yield_value(co_cancelled);
    }
    return ready_awaitable<>{};
  }

 public:
  auto await_transform(co_safe_point_t) noexcept {
    return do_safe_point(*this);
  }
};

TEST_F(TaskTest, SafePoint) {
  folly::coro::blockingWait([]() -> folly::coro::Task<void> {
    enum class step_type {
      init,
      before_continue_sp,
      after_continue_sp,
      before_cancel_sp,
      after_cancel_sp,
    };
    step_type step = step_type::init;

    folly::CancellationSource cancelSrc;
    auto makeTask = [&]() -> folly::coro::Task<void> {
      step = step_type::before_continue_sp;
      co_await folly::coro::co_safe_point;  // 未取消,直接通过
      step = step_type::after_continue_sp;

      cancelSrc.requestCancellation();  // 触发取消操作

      step = step_type::before_cancel_sp;
      co_await folly::coro::co_safe_point;  // 检查到取消,抛出异常
      step = step_type::after_cancel_sp;
    };

    auto result = co_await folly::coro::co_awaitTry( //
        folly::coro::co_withCancellation(cancelSrc.getToken(), makeTask()));
    EXPECT_THROW(result.value(), folly::OperationCancelled);
    EXPECT_EQ(step_type::before_cancel_sp, step);
  }());
}

2. co_withCancellation

Folly 的协程对象默认会在 co_await 时透传 cancelToken_ 对象,因此在 Cancel 时可以对深层协程调用进行取消,自底向上传递 OperationCancelled

class TaskPromiseBase {
 public:
  void setCancelToken(folly::CancellationToken&& cancelToken) noexcept {
    if (!hasCancelTokenOverride_) {
      cancelToken_ = std::move(cancelToken);
      hasCancelTokenOverride_ = true;
    }
  }

  template <typename Awaitable>
  auto await_transform(Awaitable&& awaitable) {
    bypassExceptionThrowing_ =
        bypassExceptionThrowing_ == BypassExceptionThrowing::REQUESTED
        ? BypassExceptionThrowing::ACTIVE
        : BypassExceptionThrowing::INACTIVE;

    // cancellable 的协程内 co_await 会继续透传 cancelToken_
    return folly::coro::co_withAsyncStack(folly::coro::co_viaIfAsync(
        executor_.get_alias(),
        folly::coro::co_withCancellation(
            cancelToken_, static_cast<Awaitable&&>(awaitable))));
  }
};

template <typename T>
class FOLLY_NODISCARD Task {
 public:
  friend Task co_withCancellation(
      folly::CancellationToken cancelToken, Task&& task) noexcept {
    DCHECK(task.coro_);
    task.coro_.promise().setCancelToken(std::move(cancelToken));
    return std::move(task);
  }
};

3. CancellationToken

CancellationToken 是一个可以传递给函数或操作的对象,允许调用者稍后请求取消操作。该对象可以通过 CancellationSource.getToken() 来获取,支持复制。从同一个原始的 CancellationSource 对象获取的 CancellationToken 对象使用引用计数指向相同的底层状态 CancellationState,在 CancellationSource.requestCancellation() 发生时会被一起取消。

class CancellationState;

struct CancellationStateTokenDeleter {
  void operator()(CancellationState*) noexcept {
    state->removeTokenReference();
  }
};
using CancellationStateTokenPtr =
    std::unique_ptr<CancellationState, CancellationStateTokenDeleter>;

class CancellationToken {
 public:
  bool canBeCancelled() const noexcept {
    return state_ != nullptr && state_->canBeCancelled();
  }

 private:
  friend class CancellationCallback;
  friend class CancellationSource;

  detail::CancellationStateTokenPtr state_;
};

CancellationSource 对象可以构造 CancellationToken 对象,并且可以通过调用 requestCancellation 取消关联了 CancellationToken 对象的操作。

struct CancellationStateSourceDeleter {
  void operator()(CancellationState*) noexcept {
    state->removeSourceReference();
  }
};
using CancellationStateSourcePtr =
    std::unique_ptr<CancellationState, CancellationStateSourceDeleter>;

class CancellationSource {
 public:
  // Construct to a new, independent cancellation source.
  CancellationSource(): state_(detail::CancellationState::create()) {}

  CancellationToken getToken() const noexcept {
    if (state_ != nullptr) {
      return CancellationToken{state_->addTokenReference()};
    }
    return CancellationToken{};
  }

  bool requestCancellation() const noexcept{
    if (state_ != nullptr) {
      return state_->requestCancellation();
    }
    return false;
  }

 private:
  detail::CancellationStateSourcePtr state_;
};

CancellationState 的实现原理并不复杂,核心原理是引用计数和 CAS。

class CancellationState {
 public:
  FOLLY_NODISCARD static CancellationStateSourcePtr create();

 protected:
  // Constructed initially with a CancellationSource reference count of 1.
  CancellationState() noexcept;
  // Constructed initially with a CancellationToken reference count of 1.
  explicit CancellationState(FixedMergingCancellationStateTag) noexcept;

  virtual ~CancellationState();

  friend struct CancellationStateTokenDeleter;
  friend struct CancellationStateSourceDeleter;

  void removeTokenReference() noexcept;
  void removeSourceReference() noexcept;

 public:
  FOLLY_NODISCARD CancellationStateTokenPtr addTokenReference() noexcept;

  FOLLY_NODISCARD CancellationStateSourcePtr addSourceReference() noexcept;

  bool tryAddCallback(
      CancellationCallback* callback,
      bool incrementRefCountIfSuccessful) noexcept;

  void removeCallback(CancellationCallback* callback) noexcept;

  bool isCancellationRequested() const noexcept;
  bool canBeCancelled() const noexcept;

  // Request cancellation.
  // Return 'true' if cancellation had already been requested.
  // Return 'false' if this was the first thread to request
  // cancellation.
  bool requestCancellation() noexcept;

 private:
  void lock() noexcept;
  void unlock() noexcept;
  void unlockAndIncrementTokenCount() noexcept;
  void unlockAndDecrementTokenCount() noexcept;
  bool tryLockAndCancelUnlessCancelled() noexcept;

  template <typename Predicate>
  bool tryLock(Predicate predicate) noexcept;

  static bool canBeCancelled(std::uint64_t state) noexcept;
  static bool isCancellationRequested(std::uint64_t state) noexcept;
  static bool isLocked(std::uint64_t state) noexcept;

  static constexpr std::uint64_t kCancellationRequestedFlag = 1;
  static constexpr std::uint64_t kLockedFlag = 2;
  static constexpr std::uint64_t kMergingFlag = 4;
  static constexpr std::uint64_t kTokenReferenceCountIncrement = 8;
  static constexpr std::uint64_t kSourceReferenceCountIncrement =
      std::uint64_t(1) << 34u;
  static constexpr std::uint64_t kTokenReferenceCountMask =
      (kSourceReferenceCountIncrement - 1u) -
      (kTokenReferenceCountIncrement - 1u);
  static constexpr std::uint64_t kSourceReferenceCountMask =
      std::numeric_limits<std::uint64_t>::max() -
      (kSourceReferenceCountIncrement - 1u);

  // Bit 0 - Cancellation Requested
  // Bit 1 - Locked Flag
  // Bit 2 - MergingCancellationState Flag
  // Bits 3-33  - Token reference count (max ~2 billion)
  // Bits 34-63 - Source reference count (max ~1 billion)
  std::atomic<std::uint64_t> state_;
  CancellationCallback* head_{nullptr};
  std::thread::id signallingThreadId_;
};

inline CancellationState::CancellationState() noexcept
    : state_(kSourceReferenceCountIncrement) {}

inline CancellationStateTokenPtr
CancellationState::addTokenReference() noexcept {
  state_.fetch_add(kTokenReferenceCountIncrement, std::memory_order_relaxed);
  return CancellationStateTokenPtr{this};
}

inline void CancellationState::removeTokenReference() noexcept {
  const auto oldState = state_.fetch_sub(
      kTokenReferenceCountIncrement, std::memory_order_acq_rel);
  DCHECK(
      (oldState & kTokenReferenceCountMask) >= kTokenReferenceCountIncrement);
  if (oldState < (2 * kTokenReferenceCountIncrement)) {
    delete this;
  }
}

inline CancellationStateSourcePtr
CancellationState::addSourceReference() noexcept {
  state_.fetch_add(kSourceReferenceCountIncrement, std::memory_order_relaxed);
  return CancellationStateSourcePtr{this};
}

inline void CancellationState::removeSourceReference() noexcept {
  const auto oldState = state_.fetch_sub(
      kSourceReferenceCountIncrement, std::memory_order_acq_rel);
  DCHECK(
      (oldState & kSourceReferenceCountMask) >= kSourceReferenceCountIncrement);
  if (oldState <
      (kSourceReferenceCountIncrement + kTokenReferenceCountIncrement)) {
    delete this;
  }
}

CancellationCallback 对象会调用 tryAddCallback 接口增加 Cancel 时的回调。

bool CancellationState::tryAddCallback(
    CancellationCallback* callback,
    bool incrementRefCountIfSuccessful) noexcept {
  // Try to acquire the lock, but abandon trying to acquire the lock if
  // cancellation has already been requested (we can just immediately invoke
  // the callback) or if cancellation can never be requested (we can just
  // skip registration).
  if (!tryLock([callback](std::uint64_t oldState) noexcept {
        if (isCancellationRequested(oldState)) {
          callback->invokeCallback();
          return false;
        }
        return canBeCancelled(oldState);
      })) {
    return false;
  }

  // We've acquired the lock and cancellation has not yet been requested.
  // Push this callback onto the head of the list.
  if (head_ != nullptr) {
    head_->prevNext_ = &callback->next_;
  }
  callback->next_ = head_;
  callback->prevNext_ = &head_;
  head_ = callback;

  if (incrementRefCountIfSuccessful) {
    // Combine multiple atomic operations into a single atomic operation.
    unlockAndIncrementTokenCount();
  } else {
    unlock();
  }

  // Successfully added the callback.
  return true;
}

template <typename Predicate>
bool CancellationState::tryLock(Predicate predicate) noexcept {
  folly::detail::Sleeper sleeper;
  std::uint64_t oldState = state_.load(std::memory_order_acquire);
  while (true) {
    if (!predicate(oldState)) {
      return false;
    } else if (isLocked(oldState)) {
      sleeper.wait();
      oldState = state_.load(std::memory_order_acquire);
    } else if (state_.compare_exchange_weak(
                   oldState,
                   oldState | kLockedFlag,
                   std::memory_order_acquire)) {
      return true;
    }
  }
}

调用 requestCancellation 接口时,从链表中依次取出 CancellationCallback 对象,调用回调函数。这里需要处理可能存在的 removeCallback 时的竞争关系。

inline void CancellationCallback::invokeCallback() noexcept {
  // Invoke within a noexcept context so that we std::terminate() if it throws.
  callback_();
}

bool CancellationState::requestCancellation() noexcept {
  if (!tryLockAndCancelUnlessCancelled()) {
    // Was already marked as cancelled
    return true;
  }

  // This thread marked as cancelled and acquired the lock

  signallingThreadId_ = std::this_thread::get_id();

  while (head_ != nullptr) {
    // Dequeue the first item on the queue.
    CancellationCallback* callback = head_;
    head_ = callback->next_;
    const bool anyMore = head_ != nullptr;
    if (anyMore) {
      head_->prevNext_ = &head_;
    }
    // Mark this item as removed from the list.
    callback->prevNext_ = nullptr;

    // Don't hold the lock while executing the callback
    // as we don't want to block other threads from
    // deregistering callbacks.
    unlock();

    // TRICKY: Need to store a flag on the stack here that the callback
    // can use to signal that the destructor was executed inline
    // during the call.
    // If the destructor was executed inline then it's not safe to
    // dereference 'callback' after 'invokeCallback()' returns.
    // If the destructor runs on some other thread then the other
    // thread will block waiting for this thread to signal that the
    // callback has finished executing.
    bool destructorHasRunInsideCallback = false;
    callback->destructorHasRunInsideCallback_ = &destructorHasRunInsideCallback;

    callback->invokeCallback();

    if (!destructorHasRunInsideCallback) {
      callback->destructorHasRunInsideCallback_ = nullptr;
      callback->callbackCompleted_.store(true, std::memory_order_release);
    }

    if (!anyMore) {
      // This was the last item in the queue when we dequeued it.
      // No more items should be added to the queue after we have
      // marked the state as cancelled, only removed from the queue.
      // Avoid acquiring/releasing the lock in this case.
      return false;
    }

    lock();
  }

  unlock();

  return false;
}

void CancellationState::removeCallback(
    CancellationCallback* callback) noexcept {
  DCHECK(callback != nullptr);

  lock();

  if (callback->prevNext_ != nullptr) {
    // Still registered in the list => not yet executed.
    // Just remove it from the list.
    *callback->prevNext_ = callback->next_;
    if (callback->next_ != nullptr) {
      callback->next_->prevNext_ = callback->prevNext_;
    }

    unlockAndDecrementTokenCount();
    return;
  }

  unlock();

  // Callback has either already executed or is executing concurrently on
  // another thread.

  if (signallingThreadId_ == std::this_thread::get_id()) {
    // Callback executed on this thread or is still currently executing
    // and is deregistering itself from within the callback.
    if (callback->destructorHasRunInsideCallback_ != nullptr) {
      // Currently inside the callback, let the requestCancellation() method
      // know the object is about to be destructed and that it should
      // not try to access the object when the callback returns.
      *callback->destructorHasRunInsideCallback_ = true;
    }
  } else {
    // Callback is currently executing on another thread, block until it
    // finishes executing.
    folly::detail::Sleeper sleeper;
    while (!callback->callbackCompleted_.load(std::memory_order_acquire)) {
      sleeper.wait();
    }
  }

  removeTokenReference();
}

References

  1. folly::coro, Meta