bRPC 中同步的 RPC 发出请求后,会等待请求对应的 bthread_id
等待回复。bthread_id
是一个 64 位的标识符,可以附带一个指针数据,支持加锁、等待、范围检查,并且通过 ResourcePool 的版本规避了 ABA 问题。来看下具体实现:
// 创建一个 bthread_id,附带数据为 data,错误处理为 on_error
int bthread_id_create(bthread_id_t* id, void* data,
int (*on_error)(bthread_id_t, void*, int)) {
return bthread::id_create_impl(
id, data, (on_error ? on_error : bthread::default_bthread_id_on_error),
NULL);
}
static int id_create_impl(bthread_id_t* id, void* data,
int (*on_error)(bthread_id_t, void*, int),
int (*on_error2)(bthread_id_t, void*, int,
const std::string&)) {
IdResourceId slot;
Id* const meta = get_resource(&slot); // resource pool 获取对象
if (meta) {
meta->data = data;
meta->on_error = on_error;
meta->on_error2 = on_error2;
CHECK(meta->pending_q.empty());
uint32_t* butex = meta->butex;
if (0 == *butex || *butex + ID_MAX_RANGE + 2 < *butex) {
// Skip 0 so that bthread_id_t is never 0
// avoid overflow to make comparisons simpler.
// butex 的值规避 0
*butex = 1;
}
*meta->join_butex = *butex;
meta->first_ver = *butex;
meta->locked_ver = *butex + 1;
*id = make_id(*butex, slot);
return 0;
}
return ENOMEM;
}
inline bthread_id_t make_id(uint32_t version, IdResourceId slot) {
// 版本位于低 32 位,resource_id 使用高 32 位
const bthread_id_t tmp = {(((uint64_t)slot.value) << 32) | (uint64_t)version};
return tmp;
}
// bthread_id_t 实际对应的数据结构
struct BAIDU_CACHELINE_ALIGNMENT Id {
// first_ver ~ locked_ver - 1: unlocked versions
// locked_ver: locked
// unlockable_ver: locked and about to be destroyed
// contended_ver: locked and contended
uint32_t first_ver; // 起始合法版本,first_ver ~ locked_ver - 1 为合法范围
uint32_t locked_ver; // 锁定状态的版本
internal::FastPthreadMutex mutex; // 该结构元数据的锁
void* data;
int (*on_error)(bthread_id_t, void*, int);
int (*on_error2)(bthread_id_t, void*, int, const std::string&);
const char* lock_location;
uint32_t* butex;
uint32_t* join_butex;
SmallQueue<PendingError, 2> pending_q;
Id() {
// Although value of the butex(as version part of bthread_id_t)
// does not matter, we set it to 0 to make program more deterministic.
butex = bthread::butex_create_checked<uint32_t>();
join_butex = bthread::butex_create_checked<uint32_t>();
*butex = 0;
*join_butex = 0;
}
~Id() {
bthread::butex_destroy(butex);
bthread::butex_destroy(join_butex);
}
inline bool has_version(uint32_t id_ver) const {
// 范围内的版本意味着合法。每个 slot 的版本空间是独立的
return id_ver >= first_ver && id_ver < locked_ver;
}
inline uint32_t contended_ver() const { return locked_ver + 1; } // 锁定冲突时的状态
inline uint32_t unlockable_ver() const { return locked_ver + 2; } // 准备析构的状态
inline uint32_t last_ver() const { return unlockable_ver(); } // 销毁状态的版本
// also the next "first_ver"
inline uint32_t end_ver() const { return last_ver() + 1; }
};
// 等待 bthread_id 销毁
int bthread_id_join(bthread_id_t id) {
const bthread::IdResourceId slot = bthread::get_slot(id);
bthread::Id* const meta = address_resource(slot);
if (!meta) {
// The id is not created yet, this join is definitely wrong.
return EINVAL;
}
const uint32_t id_ver = bthread::get_version(id); // 获取 id 低 32 位的版本
uint32_t* join_butex = meta->join_butex;
while (1) {
meta->mutex.lock();
const bool has_ver = meta->has_version(id_ver); // 检查当前 id 是否合法
const uint32_t expected_ver = *join_butex;
meta->mutex.unlock();
if (!has_ver) {
break;
}
// 在 join_butex 上等待唤醒
if (bthread::butex_wait(join_butex, expected_ver, NULL) < 0 &&
errno != EWOULDBLOCK && errno != EINTR) {
return errno;
}
}
return 0;
}
// 解锁并销毁 bthread_id
int bthread_id_unlock_and_destroy(bthread_id_t id) {
bthread::Id* const meta = address_resource(bthread::get_slot(id));
if (!meta) {
return EINVAL;
}
uint32_t* butex = meta->butex;
uint32_t* join_butex = meta->join_butex;
const uint32_t id_ver = bthread::get_version(id);
meta->mutex.lock();
if (!meta->has_version(id_ver)) {
meta->mutex.unlock();
LOG(FATAL) << "Invalid bthread_id=" << id.value;
return EINVAL;
}
if (*butex == meta->first_ver) {
meta->mutex.unlock();
LOG(FATAL) << "bthread_id=" << id.value << " is not locked!";
return EPERM;
}
const uint32_t next_ver = meta->end_ver();
*butex = next_ver;
*join_butex = next_ver;
meta->first_ver = next_ver; // 赋值后后续的 join 要么版本不合法,要么 join_butex 不一致
meta->locked_ver = next_ver;
meta->pending_q.clear();
meta->mutex.unlock();
// Notice that butex_wake* returns # of woken-up, not successful or not.
bthread::butex_wake_except(butex, 0);
bthread::butex_wake_all(join_butex); // 唤醒等待的 join 函数
return_resource(bthread::get_slot(id)); // 释放资源,原先的 bthread_id 会因为版本原因不再合法
return 0;
}
// 声明 bthread_id 发生错误
int bthread_id_error2_verbose(bthread_id_t id, int error_code,
const std::string& error_text,
const char* location) {
bthread::Id* const meta = address_resource(bthread::get_slot(id));
if (!meta) {
return EINVAL;
}
const uint32_t id_ver = bthread::get_version(id);
uint32_t* butex = meta->butex;
meta->mutex.lock();
if (!meta->has_version(id_ver)) {
meta->mutex.unlock();
return EINVAL;
}
if (*butex == meta->first_ver) {
// unlock 的状态,则上锁
*butex = meta->locked_ver;
meta->lock_location = location;
meta->mutex.unlock();
// 调用错误处理函数
if (meta->on_error) {
return meta->on_error(id, meta->data, error_code);
} else {
return meta->on_error2(id, meta->data, error_code, error_text);
}
} else {
// lock 的状态,将错误信息加入队列
bthread::PendingError e;
e.id = id;
e.error_code = error_code;
e.error_text = error_text;
e.location = location;
meta->pending_q.push(e);
meta->mutex.unlock();
return 0;
}
}
// 锁定 bthread_id
int bthread_id_lock_and_reset_range_verbose(bthread_id_t id, void** pdata,
int range, const char* location) {
bthread::Id* const meta = address_resource(bthread::get_slot(id));
if (!meta) {
return EINVAL;
}
const uint32_t id_ver = bthread::get_version(id);
uint32_t* butex = meta->butex;
bool ever_contended = false;
meta->mutex.lock();
while (meta->has_version(id_ver)) {
// 元数据上锁,且版本合法的情况下
if (*butex == meta->first_ver) {
// contended locker always wakes up the butex at unlock.
// 如果处于 unlock 的状态
meta->lock_location = location;
if (range == 0) {
// fast path
} else if (range < 0 || range > bthread::ID_MAX_RANGE ||
range + meta->first_ver <= meta->locked_ver) {
LOG_IF(FATAL, range < 0)
<< "range must be positive, actually " << range;
LOG_IF(FATAL, range > bthread::ID_MAX_RANGE)
<< "max range is " << bthread::ID_MAX_RANGE << ", actually "
<< range;
} else {
// 如果附带了版本修改操作,则修改对应的 locked_ver
meta->locked_ver = meta->first_ver + range;
}
// 如果之前冲突了,则修改为 contended_ver(),这样 unlock 的时候会唤醒其他 bthread_id_lock
*butex = (ever_contended ? meta->contended_ver() : meta->locked_ver);
meta->mutex.unlock();
if (pdata) {
*pdata = meta->data;
}
return 0;
} else if (*butex != meta->unlockable_ver()) {
// 如果不是 unlock,也不是准备析构的状态,则表示现在处于 locked 或者 contended,赋值为 contended 状态
*butex = meta->contended_ver();
uint32_t expected_ver = *butex;
meta->mutex.unlock();
ever_contended = true;
if (bthread::butex_wait(butex, expected_ver, NULL) < 0 &&
errno != EWOULDBLOCK && errno != EINTR) {
// 等待
return errno;
}
meta->mutex.lock();
} else { // bthread_id_about_to_destroy was called. 准备销毁
meta->mutex.unlock();
return EPERM;
}
}
meta->mutex.unlock();
return EINVAL;
}
// 解锁 bthread_id
int bthread_id_unlock(bthread_id_t id) {
bthread::Id* const meta = address_resource(bthread::get_slot(id));
if (!meta) {
return EINVAL;
}
uint32_t* butex = meta->butex;
// Release fence makes sure all changes made before signal visible to
// woken-up waiters.
const uint32_t id_ver = bthread::get_version(id);
meta->mutex.lock();
if (!meta->has_version(id_ver)) {
meta->mutex.unlock();
LOG(FATAL) << "Invalid bthread_id=" << id.value;
return EINVAL;
}
if (*butex == meta->first_ver) {
meta->mutex.unlock();
LOG(FATAL) << "bthread_id=" << id.value << " is not locked!";
return EPERM;
}
bthread::PendingError front;
if (meta->pending_q.pop(&front)) {
// 如果已经出错了,直接调用错误处理函数
meta->lock_location = front.location;
meta->mutex.unlock();
if (meta->on_error) {
return meta->on_error(front.id, meta->data, front.error_code);
} else {
return meta->on_error2(front.id, meta->data, front.error_code,
front.error_text);
}
} else {
// 否则唤醒等待的锁定操作
const bool contended = (*butex == meta->contended_ver());
*butex = meta->first_ver;
meta->mutex.unlock();
if (contended) {
// We may wake up already-reused id, but that's OK.
bthread::butex_wake(butex);
}
return 0;
}
}
如官方文档所述,使用 bthread_id
可以解决以下问题:
具体实现上,每个 Controller 拥有一个 bthread_id
对象 _correlation_id
,创建时将 data
绑定为 Controller 自身的指针,错误处理使用 HandleSocketFailed
函数:
// controller.cpp
CallId Controller::call_id() {
butil::atomic<uint64_t>* target =
(butil::atomic<uint64_t>*)&_correlation_id.value;
uint64_t loaded = target->load(butil::memory_order_relaxed);
if (loaded) {
const CallId id = { loaded };
return id;
}
// Optimistic locking.
CallId cid = { 0 };
// The range of this id will be reset in Channel::CallMethod
CHECK_EQ(0, bthread_id_create2(&cid, this, HandleSocketFailed));
if (!target->compare_exchange_strong(loaded, cid.value,
butil::memory_order_relaxed)) {
bthread_id_cancel(cid);
cid.value = loaded;
}
return cid;
}
调用前会将 correlation_id
对象的范围改为 max_retry() + 2
,每个版本对应的解释如代码中的注释所示:
// channel.cpp
const CallId correlation_id = cntl->call_id();
const int rc = bthread_id_lock_and_reset_range(
correlation_id, NULL, 2 + cntl->max_retry());
// Make versioned correlation_id.
// call_id : unversioned, mainly for ECANCELED and ERPCTIMEDOUT
// call_id + 1 : first try.
// call_id + 2 : retry 1
// ...
// call_id + N + 1 : retry N
// All ids except call_id are versioned. Say if we've sent retry 1 and
// a failed response of first try comes back, it will be ignored.
当调用 Socket
将请求写入后,Controller 会在 correlation_id
上执行等待:
// channel.cpp
void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
google::protobuf::RpcController* controller_base,
const google::protobuf::Message* request,
google::protobuf::Message* response,
google::protobuf::Closure* done) {
...
cntl->IssueRPC(start_send_real_us);
if (done == NULL) {
// MUST wait for response when sending synchronous RPC. It will
// be woken up by callback when RPC finishes (succeeds or still
// fails after retry)
Join(correlation_id); // 等待 correlation_id 销毁
if (cntl->_span) {
cntl->SubmitSpan();
}
cntl->OnRPCEnd(butil::gettimeofday_us());
}
}
当请求成功收到回复时,会按照协议调用对应的处理函数,例如 baidu_std
会调用 ProcessRpcResponse
:
// baidu_rpc_protocol.cpp
void ProcessRpcResponse(InputMessageBase* msg_base) {
const int64_t start_parse_us = butil::cpuwide_time_us();
DestroyingPtr<MostCommonMessage> msg(static_cast<MostCommonMessage*>(msg_base));
RpcMeta meta;
// 解析 RPC 元信息
if (!ParsePbFromIOBuf(&meta, msg->meta)) {
LOG(WARNING) << "Fail to parse from response meta";
return;
}
// 读取 correlation_id
const bthread_id_t cid = { static_cast<uint64_t>(meta.correlation_id()) };
Controller* cntl = NULL;
// 对 correlation_id 加锁
const int rc = bthread_id_lock(cid, (void**)&cntl);
...
const int saved_error = cntl->ErrorCode();
accessor.OnResponse(cid, saved_error);
}
// controller_private_accessor.h
class ControllerPrivateAccessor {
public:
void OnResponse(CallId id, int saved_error) {
const Controller::CompletionInfo info = { id, true };
_cntl->OnVersionedRPCReturned(info, false, saved_error);
}
}
// controller.cpp
void Controller::OnVersionedRPCReturned(const CompletionInfo& info,
bool new_bthread, int saved_error) {
...
bthread_id_about_to_destroy(info.id); // 唤醒 join
...
}
当 RPC 过程中发生错误时,比如 Socket
的 KeepWrite
写入失败:
// socket.cpp
void Socket::ReturnFailedWriteRequest(Socket::WriteRequest* p, int error_code,
const std::string& error_text) {
if (!p->reset_pipelined_count_and_user_message()) {
CancelUnwrittenBytes(p->data.size());
}
p->data.clear(); // data is probably not written.
const bthread_id_t id_wait = p->id_wait;
butil::return_object(p);
if (id_wait != INVALID_BTHREAD_ID) {
// id_wait 也就是上面的 correlation_id,bthread_id_error2 会进行上锁,并调用对应的错误处理函数,也就是初始化赋值的 HandleSocketFailed
bthread_id_error2(id_wait, error_code, error_text);
}
}
// controller.cpp
int Controller::HandleSocketFailed(bthread_id_t id, void* data, int error_code,
const std::string& error_text) {
// 从 data 中回复 controller 指针
Controller* cntl = static_cast<Controller*>(data);
if (!cntl->is_used_by_rpc()) {
// Cannot destroy the call_id before RPC otherwise an async RPC
// using the controller cannot be joined and related resources may be
// destroyed before done->Run() running in another bthread.
// The error set will be detected in Channel::CallMethod and fail
// the RPC.
cntl->SetFailed(error_code, "Cancel call_id=%" PRId64
" before CallMethod()", id.value);
return bthread_id_unlock(id);
}
const int saved_error = cntl->ErrorCode();
if (error_code == ERPCTIMEDOUT) {
cntl->SetFailed(error_code, "Reached timeout=%" PRId64 "ms @%s",
cntl->timeout_ms(),
butil::endpoint2str(cntl->remote_side()).c_str());
} else if (error_code == EBACKUPREQUEST) {
cntl->SetFailed(error_code, "Reached backup timeout=%" PRId64 "ms @%s",
cntl->backup_request_ms(),
butil::endpoint2str(cntl->remote_side()).c_str());
} else if (!error_text.empty()) {
cntl->SetFailed(error_code, "%s", error_text.c_str());
} else {
cntl->SetFailed(error_code, "%s @%s", berror(error_code),
butil::endpoint2str(cntl->remote_side()).c_str());
}
CompletionInfo info = { id, false };
cntl->OnVersionedRPCReturned(info, true, saved_error); // 结束
return 0;
}
correlation_id
也会用作超时的处理:
// channel.cpp
void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
google::protobuf::RpcController* controller_base,
const google::protobuf::Message* request,
google::protobuf::Message* response,
google::protobuf::Closure* done) {
...
// Setup timer for RPC timetout
// _deadline_us is for truncating _connect_timeout_ms
cntl->_deadline_us = cntl->timeout_ms() * 1000L + start_send_real_us;
const int rc = bthread_timer_add(
&cntl->_timeout_id,
butil::microseconds_to_timespec(cntl->_deadline_us),
HandleTimeout, (void*)correlation_id.value); // 参数使用 correlation_id 的值
...
}
// 超时时调用 HandleTimeout
static void HandleTimeout(void* arg) {
bthread_id_t correlation_id = { (uint64_t)arg };
bthread_id_error(correlation_id, ERPCTIMEDOUT); // 继而调用 HandleSocketFailed
}
// RPC 结束时会删除对应的定时器
void Controller::EndRPC(const CompletionInfo& info) {
if (_timeout_id != 0) {
bthread_timer_del(_timeout_id);
_timeout_id = 0;
}
...
}