bRPC 源码分析「五、请求处理」

2021.03.19
SF-Zhou

1. Request and Wait

bRPC 中同步的 RPC 发出请求后,会等待请求对应的 bthread_id 等待回复。bthread_id 是一个 64 位的标识符,可以附带一个指针数据,支持加锁、等待、范围检查,并且通过 ResourcePool 的版本规避了 ABA 问题。来看下具体实现:

// 创建一个 bthread_id,附带数据为 data,错误处理为 on_error
int bthread_id_create(bthread_id_t* id, void* data,
                      int (*on_error)(bthread_id_t, void*, int)) {
  return bthread::id_create_impl(
      id, data, (on_error ? on_error : bthread::default_bthread_id_on_error),
      NULL);
}

static int id_create_impl(bthread_id_t* id, void* data,
                          int (*on_error)(bthread_id_t, void*, int),
                          int (*on_error2)(bthread_id_t, void*, int,
                                           const std::string&)) {
  IdResourceId slot;
  Id* const meta = get_resource(&slot);  // resource pool 获取对象
  if (meta) {
    meta->data = data;
    meta->on_error = on_error;
    meta->on_error2 = on_error2;
    CHECK(meta->pending_q.empty());
    uint32_t* butex = meta->butex;
    if (0 == *butex || *butex + ID_MAX_RANGE + 2 < *butex) {
      // Skip 0 so that bthread_id_t is never 0
      // avoid overflow to make comparisons simpler.
      // butex 的值规避 0
      *butex = 1;
    }
    *meta->join_butex = *butex;
    meta->first_ver = *butex;
    meta->locked_ver = *butex + 1;
    *id = make_id(*butex, slot);
    return 0;
  }
  return ENOMEM;
}

inline bthread_id_t make_id(uint32_t version, IdResourceId slot) {
  // 版本位于低 32 位,resource_id 使用高 32 位
  const bthread_id_t tmp = {(((uint64_t)slot.value) << 32) | (uint64_t)version};
  return tmp;
}

// bthread_id_t 实际对应的数据结构
struct BAIDU_CACHELINE_ALIGNMENT Id {
  // first_ver ~ locked_ver - 1: unlocked versions
  // locked_ver: locked
  // unlockable_ver: locked and about to be destroyed
  // contended_ver: locked and contended
  uint32_t first_ver;  // 起始合法版本,first_ver ~ locked_ver - 1 为合法范围
  uint32_t locked_ver;  // 锁定状态的版本
  internal::FastPthreadMutex mutex;  // 该结构元数据的锁
  void* data;
  int (*on_error)(bthread_id_t, void*, int);
  int (*on_error2)(bthread_id_t, void*, int, const std::string&);
  const char* lock_location;
  uint32_t* butex;
  uint32_t* join_butex;
  SmallQueue<PendingError, 2> pending_q;

  Id() {
    // Although value of the butex(as version part of bthread_id_t)
    // does not matter, we set it to 0 to make program more deterministic.
    butex = bthread::butex_create_checked<uint32_t>();
    join_butex = bthread::butex_create_checked<uint32_t>();
    *butex = 0;
    *join_butex = 0;
  }

  ~Id() {
    bthread::butex_destroy(butex);
    bthread::butex_destroy(join_butex);
  }

  inline bool has_version(uint32_t id_ver) const {
    // 范围内的版本意味着合法。每个 slot 的版本空间是独立的
    return id_ver >= first_ver && id_ver < locked_ver;
  }
  inline uint32_t contended_ver() const { return locked_ver + 1; }  // 锁定冲突时的状态
  inline uint32_t unlockable_ver() const { return locked_ver + 2; }  // 准备析构的状态
  inline uint32_t last_ver() const { return unlockable_ver(); }  // 销毁状态的版本

  // also the next "first_ver"
  inline uint32_t end_ver() const { return last_ver() + 1; }
};

// 等待 bthread_id 销毁
int bthread_id_join(bthread_id_t id) {
  const bthread::IdResourceId slot = bthread::get_slot(id);
  bthread::Id* const meta = address_resource(slot);
  if (!meta) {
    // The id is not created yet, this join is definitely wrong.
    return EINVAL;
  }
  const uint32_t id_ver = bthread::get_version(id);  // 获取 id 低 32 位的版本
  uint32_t* join_butex = meta->join_butex;
  while (1) {
    meta->mutex.lock();
    const bool has_ver = meta->has_version(id_ver);  // 检查当前 id 是否合法
    const uint32_t expected_ver = *join_butex;
    meta->mutex.unlock();
    if (!has_ver) {
      break;
    }
    // 在 join_butex 上等待唤醒
    if (bthread::butex_wait(join_butex, expected_ver, NULL) < 0 &&
        errno != EWOULDBLOCK && errno != EINTR) {
      return errno;
    }
  }
  return 0;
}

// 解锁并销毁 bthread_id
int bthread_id_unlock_and_destroy(bthread_id_t id) {
  bthread::Id* const meta = address_resource(bthread::get_slot(id));
  if (!meta) {
    return EINVAL;
  }
  uint32_t* butex = meta->butex;
  uint32_t* join_butex = meta->join_butex;
  const uint32_t id_ver = bthread::get_version(id);
  meta->mutex.lock();
  if (!meta->has_version(id_ver)) {
    meta->mutex.unlock();
    LOG(FATAL) << "Invalid bthread_id=" << id.value;
    return EINVAL;
  }
  if (*butex == meta->first_ver) {
    meta->mutex.unlock();
    LOG(FATAL) << "bthread_id=" << id.value << " is not locked!";
    return EPERM;
  }
  const uint32_t next_ver = meta->end_ver();
  *butex = next_ver;
  *join_butex = next_ver;
  meta->first_ver = next_ver;  // 赋值后后续的 join 要么版本不合法,要么 join_butex 不一致
  meta->locked_ver = next_ver;
  meta->pending_q.clear();
  meta->mutex.unlock();
  // Notice that butex_wake* returns # of woken-up, not successful or not.
  bthread::butex_wake_except(butex, 0);
  bthread::butex_wake_all(join_butex);  // 唤醒等待的 join 函数
  return_resource(bthread::get_slot(id));  // 释放资源,原先的 bthread_id 会因为版本原因不再合法
  return 0;
}

// 声明 bthread_id 发生错误
int bthread_id_error2_verbose(bthread_id_t id, int error_code,
                              const std::string& error_text,
                              const char* location) {
  bthread::Id* const meta = address_resource(bthread::get_slot(id));
  if (!meta) {
    return EINVAL;
  }
  const uint32_t id_ver = bthread::get_version(id);
  uint32_t* butex = meta->butex;
  meta->mutex.lock();
  if (!meta->has_version(id_ver)) {
    meta->mutex.unlock();
    return EINVAL;
  }
  if (*butex == meta->first_ver) {
    // unlock 的状态,则上锁
    *butex = meta->locked_ver;
    meta->lock_location = location;
    meta->mutex.unlock();
    // 调用错误处理函数
    if (meta->on_error) {
      return meta->on_error(id, meta->data, error_code);
    } else {
      return meta->on_error2(id, meta->data, error_code, error_text);
    }
  } else {
    // lock 的状态,将错误信息加入队列
    bthread::PendingError e;
    e.id = id;
    e.error_code = error_code;
    e.error_text = error_text;
    e.location = location;
    meta->pending_q.push(e);
    meta->mutex.unlock();
    return 0;
  }
}

// 锁定 bthread_id
int bthread_id_lock_and_reset_range_verbose(bthread_id_t id, void** pdata,
                                            int range, const char* location) {
  bthread::Id* const meta = address_resource(bthread::get_slot(id));
  if (!meta) {
    return EINVAL;
  }
  const uint32_t id_ver = bthread::get_version(id);
  uint32_t* butex = meta->butex;
  bool ever_contended = false;
  meta->mutex.lock();
  while (meta->has_version(id_ver)) {
    // 元数据上锁,且版本合法的情况下
    if (*butex == meta->first_ver) {
      // contended locker always wakes up the butex at unlock.
      // 如果处于 unlock 的状态
      meta->lock_location = location;
      if (range == 0) {
        // fast path
      } else if (range < 0 || range > bthread::ID_MAX_RANGE ||
                 range + meta->first_ver <= meta->locked_ver) {
        LOG_IF(FATAL, range < 0)
            << "range must be positive, actually " << range;
        LOG_IF(FATAL, range > bthread::ID_MAX_RANGE)
            << "max range is " << bthread::ID_MAX_RANGE << ", actually "
            << range;
      } else {
        // 如果附带了版本修改操作,则修改对应的 locked_ver
        meta->locked_ver = meta->first_ver + range;
      }
      // 如果之前冲突了,则修改为 contended_ver(),这样 unlock 的时候会唤醒其他 bthread_id_lock
      *butex = (ever_contended ? meta->contended_ver() : meta->locked_ver);
      meta->mutex.unlock();
      if (pdata) {
        *pdata = meta->data;
      }
      return 0;
    } else if (*butex != meta->unlockable_ver()) {
      // 如果不是 unlock,也不是准备析构的状态,则表示现在处于 locked 或者 contended,赋值为 contended 状态
      *butex = meta->contended_ver();
      uint32_t expected_ver = *butex;
      meta->mutex.unlock();
      ever_contended = true;
      if (bthread::butex_wait(butex, expected_ver, NULL) < 0 &&
          errno != EWOULDBLOCK && errno != EINTR) {
        // 等待
        return errno;
      }
      meta->mutex.lock();
    } else {  // bthread_id_about_to_destroy was called. 准备销毁
      meta->mutex.unlock();
      return EPERM;
    }
  }
  meta->mutex.unlock();
  return EINVAL;
}

// 解锁 bthread_id
int bthread_id_unlock(bthread_id_t id) {
  bthread::Id* const meta = address_resource(bthread::get_slot(id));
  if (!meta) {
    return EINVAL;
  }
  uint32_t* butex = meta->butex;
  // Release fence makes sure all changes made before signal visible to
  // woken-up waiters.
  const uint32_t id_ver = bthread::get_version(id);
  meta->mutex.lock();
  if (!meta->has_version(id_ver)) {
    meta->mutex.unlock();
    LOG(FATAL) << "Invalid bthread_id=" << id.value;
    return EINVAL;
  }
  if (*butex == meta->first_ver) {
    meta->mutex.unlock();
    LOG(FATAL) << "bthread_id=" << id.value << " is not locked!";
    return EPERM;
  }
  bthread::PendingError front;
  if (meta->pending_q.pop(&front)) {
    // 如果已经出错了,直接调用错误处理函数
    meta->lock_location = front.location;
    meta->mutex.unlock();
    if (meta->on_error) {
      return meta->on_error(front.id, meta->data, front.error_code);
    } else {
      return meta->on_error2(front.id, meta->data, front.error_code,
                             front.error_text);
    }
  } else {
    // 否则唤醒等待的锁定操作
    const bool contended = (*butex == meta->contended_ver());
    *butex = meta->first_ver;
    meta->mutex.unlock();
    if (contended) {
      // We may wake up already-reused id, but that's OK.
      bthread::butex_wake(butex);
    }
    return 0;
  }
}

官方文档所述,使用 bthread_id 可以解决以下问题:

  1. 在发送 RPC 过程中 response 回来了,处理 response 的代码和发送代码产生竞争。
  2. 设置 timer 后很快触发了,超时处理代码和发送代码产生竞争。
  3. 重试产生的多个 response 同时回来产生的竞争。
  4. 通过 correlation_id 在 O(1) 时间内找到对应的 RPC 上下文,而无需建立从 correlation_id 到 RPC 上下文的全局哈希表。
  5. 取消 RPC。

具体实现上,每个 Controller 拥有一个 bthread_id 对象 _correlation_id,创建时将 data 绑定为 Controller 自身的指针,错误处理使用 HandleSocketFailed 函数:

// controller.cpp
CallId Controller::call_id() {
    butil::atomic<uint64_t>* target =
        (butil::atomic<uint64_t>*)&_correlation_id.value;
    uint64_t loaded = target->load(butil::memory_order_relaxed);
    if (loaded) {
        const CallId id = { loaded };
        return id;
    }
    // Optimistic locking.
    CallId cid = { 0 };
    // The range of this id will be reset in Channel::CallMethod
    CHECK_EQ(0, bthread_id_create2(&cid, this, HandleSocketFailed));
    if (!target->compare_exchange_strong(loaded, cid.value,
                                         butil::memory_order_relaxed)) {
        bthread_id_cancel(cid);
        cid.value = loaded;
    }
    return cid;
}

调用前会将 correlation_id 对象的范围改为 max_retry() + 2,每个版本对应的解释如代码中的注释所示:

// channel.cpp
const CallId correlation_id = cntl->call_id();
const int rc = bthread_id_lock_and_reset_range(
                    correlation_id, NULL, 2 + cntl->max_retry());

// Make versioned correlation_id.
// call_id         : unversioned, mainly for ECANCELED and ERPCTIMEDOUT
// call_id + 1     : first try.
// call_id + 2     : retry 1
// ...
// call_id + N + 1 : retry N
// All ids except call_id are versioned. Say if we've sent retry 1 and
// a failed response of first try comes back, it will be ignored.

当调用 Socket 将请求写入后,Controller 会在 correlation_id 上执行等待:

// channel.cpp
void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
                         google::protobuf::RpcController* controller_base,
                         const google::protobuf::Message* request,
                         google::protobuf::Message* response,
                         google::protobuf::Closure* done) {
  ...
  cntl->IssueRPC(start_send_real_us);
  if (done == NULL) {
    // MUST wait for response when sending synchronous RPC. It will
    // be woken up by callback when RPC finishes (succeeds or still
    // fails after retry)
    Join(correlation_id);  // 等待 correlation_id 销毁
    if (cntl->_span) {
      cntl->SubmitSpan();
    }
    cntl->OnRPCEnd(butil::gettimeofday_us());
  }
}

当请求成功收到回复时,会按照协议调用对应的处理函数,例如 baidu_std 会调用 ProcessRpcResponse

// baidu_rpc_protocol.cpp
void ProcessRpcResponse(InputMessageBase* msg_base) {
  const int64_t start_parse_us = butil::cpuwide_time_us();
  DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base));
  RpcMeta meta;
  // 解析 RPC 元信息
  if (!ParsePbFromIOBuf(&meta, msg->meta)) {
    LOG(WARNING) << "Fail to parse from response meta";
    return;
  }

  // 读取 correlation_id
  const bthread_id_t cid = { static_cast<uint64_t>(meta.correlation_id()) };
  Controller* cntl = NULL;
  // 对 correlation_id 加锁
  const int rc = bthread_id_lock(cid, (void**)&cntl);
  ...
  const int saved_error = cntl->ErrorCode();
  accessor.OnResponse(cid, saved_error);
}

// controller_private_accessor.h
class ControllerPrivateAccessor {
public:
    void OnResponse(CallId id, int saved_error) {
        const Controller::CompletionInfo info = { id, true };
        _cntl->OnVersionedRPCReturned(info, false, saved_error);
    }
}

// controller.cpp
void Controller::OnVersionedRPCReturned(const CompletionInfo& info,
                                        bool new_bthread, int saved_error) {
  ...
  bthread_id_about_to_destroy(info.id);  // 唤醒 join
  ...
}

当 RPC 过程中发生错误时,比如 SocketKeepWrite 写入失败:

// socket.cpp
void Socket::ReturnFailedWriteRequest(Socket::WriteRequest* p, int error_code,
                                      const std::string& error_text) {
  if (!p->reset_pipelined_count_and_user_message()) {
    CancelUnwrittenBytes(p->data.size());
  }
  p->data.clear();  // data is probably not written.
  const bthread_id_t id_wait = p->id_wait;
  butil::return_object(p);
  if (id_wait != INVALID_BTHREAD_ID) {
    // id_wait 也就是上面的 correlation_id,bthread_id_error2 会进行上锁,并调用对应的错误处理函数,也就是初始化赋值的 HandleSocketFailed
    bthread_id_error2(id_wait, error_code, error_text);
  }
}

// controller.cpp
int Controller::HandleSocketFailed(bthread_id_t id, void* data, int error_code,
                                   const std::string& error_text) {
  // 从 data 中回复 controller 指针
  Controller* cntl = static_cast<Controller*>(data);
  if (!cntl->is_used_by_rpc()) {
    // Cannot destroy the call_id before RPC otherwise an async RPC
    // using the controller cannot be joined and related resources may be
    // destroyed before done->Run() running in another bthread.
    // The error set will be detected in Channel::CallMethod and fail
    // the RPC.
    cntl->SetFailed(error_code, "Cancel call_id=%" PRId64
                    " before CallMethod()", id.value);
    return bthread_id_unlock(id);
  }
  const int saved_error = cntl->ErrorCode();
  if (error_code == ERPCTIMEDOUT) {
    cntl->SetFailed(error_code, "Reached timeout=%" PRId64 "ms @%s",
                    cntl->timeout_ms(),
                    butil::endpoint2str(cntl->remote_side()).c_str());
  } else if (error_code == EBACKUPREQUEST) {
    cntl->SetFailed(error_code, "Reached backup timeout=%" PRId64 "ms @%s",
                    cntl->backup_request_ms(),
                    butil::endpoint2str(cntl->remote_side()).c_str());
  } else if (!error_text.empty()) {
    cntl->SetFailed(error_code, "%s", error_text.c_str());
  } else {
    cntl->SetFailed(error_code, "%s @%s", berror(error_code),
                    butil::endpoint2str(cntl->remote_side()).c_str());
  }
  CompletionInfo info = { id, false };
  cntl->OnVersionedRPCReturned(info, true, saved_error);  // 结束
  return 0;
}

correlation_id 也会用作超时的处理:

// channel.cpp
void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
                         google::protobuf::RpcController* controller_base,
                         const google::protobuf::Message* request,
                         google::protobuf::Message* response,
                         google::protobuf::Closure* done) {
  ...
  // Setup timer for RPC timetout

  // _deadline_us is for truncating _connect_timeout_ms
  cntl->_deadline_us = cntl->timeout_ms() * 1000L + start_send_real_us;
  const int rc = bthread_timer_add(
    &cntl->_timeout_id,
    butil::microseconds_to_timespec(cntl->_deadline_us),
    HandleTimeout, (void*)correlation_id.value);  // 参数使用 correlation_id 的值
  ...
}

// 超时时调用 HandleTimeout
static void HandleTimeout(void* arg) {
    bthread_id_t correlation_id = { (uint64_t)arg };
    bthread_id_error(correlation_id, ERPCTIMEDOUT);  // 继而调用 HandleSocketFailed
}

// RPC 结束时会删除对应的定时器
void Controller::EndRPC(const CompletionInfo& info) {
  if (_timeout_id != 0) {
    bthread_timer_del(_timeout_id);
    _timeout_id = 0;
  }
  ...
}

References

  1. "bRPC bthread_id", incubator-brpc