PaxosStore 源码分析「一、网络通信」

2020.02.15
SF-Zhou

PaxosStore 是由微信开源的分布式存储系统,基于 Paxos 协议实现存储的强一致。代码开源于 2017 年,同步发表的还有 VLDB 会议上的论文 "PaxosStore: High-availability Storage Made Practical in WeChat"。PaxosStore 广泛应用于微信各项存储服务中,代码久经考验。所以开出一个新系列「PaxosStore 源码分析」,希望从这个工业化 Paxos 协议实现中学习一些新东西。作为系列的第一篇,本文会介绍 PaxosStore 代码的整体结构和网络通信的实现。

1. PaxosStore 源码概览

PaxosStore 项目中开源了两个 Paxos 实现,分别是 Certain 和 PaxosKV。前者是较为通用的 PLog + DB 的实现,可以接入各类数据库;后者基于 PaxosLog-as-Value 的思想,是针对 KV 系统单独做的场景优化。两个实现是完全独立的,本系列会着重分析 Certain 的源码。Certain 库中代码总计约 18000 行,其目录如下所示:

certain
├── example  # 样例
├── include  # 头文件
├── model    # 不知何意
├── network  # 网络基础
├── src      # 具体实现
├── third    # 第三方库
└── utils    # 基础工具

2. 网络通信

网络通信部分与 Paxos 协议的实现无关,相当于基础库,该部分的代码容易上手阅读,由此入坑。先看 network/InetAddr.h

struct InetAddr_t {
  struct sockaddr_in tAddr;

  bool operator==(const InetAddr_t &tOther) const {
    const struct sockaddr_in &tOtherAddr = tOther.tAddr;

    if (tAddr.sin_addr.s_addr != tOtherAddr.sin_addr.s_addr) {
      return false;
    }

    if (tAddr.sin_port != tOtherAddr.sin_port) {
      return false;
    }

    return true;
  }

  bool operator<(const InetAddr_t &tOther) const {
    const struct sockaddr_in &tOtherAddr = tOther.tAddr;

    if (tAddr.sin_addr.s_addr != tOtherAddr.sin_addr.s_addr) {
      return tAddr.sin_addr.s_addr < tOtherAddr.sin_addr.s_addr;
    }

    if (tAddr.sin_port != tOtherAddr.sin_port) {
      return tAddr.sin_port < tOtherAddr.sin_port;
    }

    return true;
  }

  InetAddr_t() { memset(&tAddr, 0, sizeof(tAddr)); }

  InetAddr_t(struct sockaddr_in tSockAddr) { tAddr = tSockAddr; }

  InetAddr_t(const char *sIP, uint16_t iPort) {
    memset(&tAddr, 0, sizeof(tAddr));
    tAddr.sin_family = AF_INET;
    tAddr.sin_addr.s_addr = inet_addr(sIP);
    tAddr.sin_port = htons(iPort);
  }

  InetAddr_t(uint32_t iIP, uint16_t iPort) {
    memset(&tAddr, 0, sizeof(tAddr));
    tAddr.sin_family = AF_INET;
    tAddr.sin_addr.s_addr = iIP;
    tAddr.sin_port = htons(iPort);
  }

  string ToString() const {
    const char *sIP = inet_ntoa(tAddr.sin_addr);
    uint16_t iPort = ntohs(tAddr.sin_port);

    char acBuffer[32];
    snprintf(acBuffer, 32, "%s:%hu", sIP, iPort);

    return acBuffer;
  }

  uint32_t GetNetOrderIP() { return tAddr.sin_addr.s_addr; }
};

笔者已将代码使用 clang-format 的格式化。从该段代码中也能一窥其代码风格,变量使用驼峰命名,前缀标明类型,习惯就好。InetAddr_t 类封装了 Socket 地址结构 sockaddr_in,提供了地址与字符串之间的相互转换。继续看 network/SocketHelper.h

// 连接信息
struct ConnInfo_t {
  int iFD;  // 当前连接 fd

  InetAddr_t tLocalAddr;  // 本地地址
  InetAddr_t tPeerAddr;   // 远端地址

  string ToString() const {
    char acBuffer[128];
    snprintf(acBuffer, 128, "fd %d local %s peer %s", iFD, tLocalAddr.ToString().c_str(),
             tPeerAddr.ToString().c_str());
    return acBuffer;
  }
};

// 创建 Socket 连接,并设定为非阻塞模式
// 如果设定 ptInetAddr,则绑定到该地址
int CreateSocket(const InetAddr_t *ptInetAddr);

// 检查 Socket 是否合法
bool CheckIfValid(int iFD);

// 使用 iFD 连接指定地址
int Connect(int iFD, const InetAddr_t &tInetAddr);

// 连接指定地址,并将连接信息存储到 tConnInfo 中
int Connect(const InetAddr_t &tInetAddr, ConnInfo_t &tConnInfo);

// 设定非阻塞模式
int SetNonBlock(int iFD, bool bNonBlock = true);

// 创建非阻塞管道
int MakeNonBlockPipe(int &iInFD, int &iOutFD);

摘录 CreateSocketConnect 的函数实现:

int CreateSocket(const InetAddr_t *ptInetAddr) {
  int iRet;

  int iFD = socket(AF_INET, SOCK_STREAM, 0);
  if (iFD == -1) {
    CertainLogError("socket ret -1 errno %d", errno);
    return -1;
  }

  int iOptVal = 1;
  iRet = setsockopt(iFD, SOL_SOCKET, SO_REUSEADDR, (char *)&iOptVal, sizeof(int));
  if (iRet == -1) {
    CertainLogError("setsockopt fail fd %d errno %d", iFD, errno);
    return -2;
  }

  // Close TCP negle algorithm.
  iOptVal = 1;
  iRet = setsockopt(iFD, IPPROTO_TCP, TCP_NODELAY, (char *)&iOptVal, sizeof(iOptVal));
  if (iRet == -1) {
    CertainLogError("setsockopt fail fd %d errno %d", iFD, errno);
    return -3;
  }

  int iFlags = fcntl(iFD, F_GETFL, 0);
  iRet = fcntl(iFD, F_SETFL, iFlags | O_NONBLOCK);
  AssertEqual(iRet, 0);

  if (ptInetAddr == NULL) {
    return iFD;
  }

  struct sockaddr *ptAddr = (struct sockaddr *)&ptInetAddr->tAddr;
  socklen_t tLen = sizeof(*ptAddr);

  iRet = bind(iFD, ptAddr, tLen);
  if (iRet == -1) {
    CertainLogError("bind fail fd %d addr %s errno %d", iFD, ptInetAddr->ToString().c_str(), errno);
    return -4;
  }

  return iFD;
}

int Connect(int iFD, const InetAddr_t &tInetAddr) {
  const struct sockaddr_in *ptAddr = &tInetAddr.tAddr;
  socklen_t tLen = sizeof(*ptAddr);

  int iRet = connect(iFD, (struct sockaddr *)ptAddr, tLen);
  if (iRet == -1) {
    if (errno == EINPROGRESS) {
      return 0;
    }

    CertainLogError("connect fail fd %d addr %s errno %d", iFD, tInetAddr.ToString().c_str(),
                    errno);
    return -1;
  }

  return 0;
}

int Connect(const InetAddr_t &tInetAddr, ConnInfo_t &tConnInfo) {
  int iRet;

  int iFD = CreateSocket(NULL);
  if (iFD == -1) {
    CertainLogError("CreateSocket ret %d", iFD);
    return -1;
  }

  iRet = Connect(iFD, tInetAddr);
  if (iRet != 0) {
    close(iFD);
    CertainLogError("Connect ret %d", iRet);
    return -2;
  }

  tConnInfo.iFD = iFD;
  tConnInfo.tPeerAddr = tInetAddr;

  struct sockaddr tAddr = {0};
  socklen_t tLen = sizeof(tAddr);

  iRet = getsockname(iFD, &tAddr, &tLen);
  if (iRet != 0) {
    close(iFD);
    CertainLogError("getsockname fail ret %d errno %d", iRet, errno);
    return -3;
  }

  if (tAddr.sa_family != AF_INET) {
    close(iFD);
    CertainLogError("not spported sa_family %d", tAddr.sa_family);
    return -4;
  }

  tConnInfo.tLocalAddr.tAddr = *((struct sockaddr_in *)&tAddr);

  return 0;
}

CreateSocket 创建 TCP Socket 后,首先将其设定为 SO_REUSEADDR,以保证服务端重启时不会出现 "Address already in use" 的报错;而后将其 TCP 设定为 TCP_NODELAY,缓冲区有数据时立即发包;最后将其设定为 O_NONBLOCK 非阻塞模式,缓冲区不可用时立即返回错误不等待。Connect 建立连接后,使用 getsockname 获取本地地址,并将 fd、远端地址和本地地址保存于 ConnInfo_t 中。

Linux 下高性能的网络通信离不开 Epoll。继续看 network/EpollIO.h

const int kDefaultFDEvents = EPOLLIN | EPOLLOUT | EPOLLET;
const int kDefaultEventSize = 8096;

class clsFDBase;
// IO 处理纯虚基类
// 提供 HandleRead/HandleWrite 两种方法,参数为 FD 对象指针
class clsIOHandlerBase {
 public:
  virtual ~clsIOHandlerBase() {}
  virtual int HandleRead(clsFDBase *poFD) = 0;
  virtual int HandleWrite(clsFDBase *poFD) = 0;
};

// FD 基类,POXIS fd 的封装
// 保存 fd / io_handler / event 等参数
// 保存 readable / writable 两种状态
class clsFDBase {
 private:
  int m_iFD;
  uint32_t m_iFDID;

  clsIOHandlerBase *m_poIOHandler;
  int m_iEvents;

  bool m_bReadable;
  bool m_bWritable;

 public:
  clsFDBase(int iFD, clsIOHandlerBase *poIOHandler = NULL, uint32_t iFDID = -1,
            int iEvents = kDefaultFDEvents)
      : m_iFD(iFD),
        m_iFDID(iFDID),
        m_poIOHandler(poIOHandler),
        m_iEvents(iEvents),
        m_bReadable(false),
        m_bWritable(false) {}

  virtual ~clsFDBase(){};

  int GetFD() { return m_iFD; }
  uint32_t GetFDID() { return m_iFDID; }

  int GetEvents() { return m_iEvents; }

  // 使用宏定义成员变量的 Get/Set 方法
  BOOLEN_IS_SET(Readable);
  BOOLEN_IS_SET(Writable);

  clsIOHandlerBase *GetIOHandler() { return m_poIOHandler; }
  void SetIOHandler(clsIOHandlerBase *poIOHandler) { m_poIOHandler = poIOHandler; }
};

clsFDBase 是文件描述符 fd 的封装,实际上这里的 fd 仅限于 Socket 连接。其保存了关心的事件 iEvents 和 IO 事件处理器 poIOHandler,当事件发生时,可以调用 clsIOHandlerBase::HandleRead/Write 进行处理。注意这里默认的事件中使用了边缘触发模式 EPOLLET。负责管理事件的肯定是 Epoll 了:

class clsEpollIO {
 private:
  int m_iEpollFD;

  uint32_t m_iEventSize;
  epoll_event *m_ptEpollEv;

  struct UniqueFDPtr_t {
    uint32_t iFDID;
    clsFDBase *poFD;
  };
  UniqueFDPtr_t *m_atUniqueFDPtrMap;

 public:
  clsEpollIO(uint32_t iEventSize = kDefaultEventSize);
  ~clsEpollIO();

  // 在 Epoll 中添/删/改 FD 对象
  int Add(clsFDBase *poFD);
  int Remove(clsFDBase *poFD);
  int RemoveAndCloseFD(clsFDBase *poFD);
  int Modify(clsFDBase *poFD);

  clsFDBase *GetFDBase(int iFD, uint32_t iFDID);

  // 执行 epoll_wait 获取可用 fd 并处理
  void RunOnce(int iTimeoutMS);
};

具体的函数实现位于 network/EpollIO.cpp,摘录部分:

// 在 epoll 中添加 FD 对象,并根据 fd 存储 FD 对象及其 ID
int clsEpollIO::Add(clsFDBase *poFD) {
  int iFD = poFD->GetFD();
  AssertLess(iFD, CERTAIN_MAX_FD_NUM);
  CertainLogDebug("fd %d", iFD);

  epoll_event ev = {0};
  ev.events = poFD->GetEvents();
  ev.data.ptr = static_cast<void *>(poFD);

  int iRet = epoll_ctl(m_iEpollFD, EPOLL_CTL_ADD, iFD, &ev);
  if (iRet == -1) {
    CertainLogError("epoll_ctl fail fd %d errno %d", iFD, errno);
    return -1;
  }

  AssertEqual(iRet, 0);

  AssertEqual(m_atUniqueFDPtrMap[iFD].poFD, NULL);
  m_atUniqueFDPtrMap[iFD].poFD = poFD;
  m_atUniqueFDPtrMap[iFD].iFDID = poFD->GetFDID();

  return 0;
}

// 执行一次检查,根据事件执行 FD 对象的 HandleRead / HandleWrite
void clsEpollIO::RunOnce(int iTimeoutMS) {
  int iNum, iRet;

  while (1) {
    iNum = epoll_wait(m_iEpollFD, m_ptEpollEv, m_iEventSize, iTimeoutMS);
    if (iNum == -1) {
      CertainLogError("epoll_wait fail epoll_fd %d errno %d", m_iEpollFD, errno);

      if (errno != EINTR) {
        // You should probably raise "open files" limit.
        AssertEqual(errno, 0);
        assert(false);
      }

      continue;
    }
    break;
  }

  for (int i = 0; i < iNum; ++i) {
    int iEvents = m_ptEpollEv[i].events;
    clsFDBase *poFD = static_cast<clsFDBase *>(m_ptEpollEv[i].data.ptr);
    clsIOHandlerBase *poHandler = poFD->GetIOHandler();
    Assert(poHandler != NULL);

    if ((iEvents & EPOLLIN) || (iEvents & EPOLLERR) || (iEvents & EPOLLHUP)) {
      poFD->SetReadable(true);
      iRet = poHandler->HandleRead(poFD);
      if (iRet != 0) {
        continue;
      }
    }

    if (iEvents & EPOLLOUT) {
      poFD->SetWritable(true);
      iRet = poHandler->HandleWrite(poFD);
      if (iRet != 0) {
        continue;
      }
    }
  }
}

// 根据 fd 获取 FD 对象,加入 ID 校验
clsFDBase *clsEpollIO::GetFDBase(int iFD, uint32_t iFDID) {
  if (m_atUniqueFDPtrMap[iFD].iFDID == iFDID) {
    return m_atUniqueFDPtrMap[iFD].poFD;
  }

  return NULL;
}

RunOnce 获得可用的文件描述符及事件后,读取其对应的 FD 对象并执行 HandleReadHandleWrite,注意这里忽略了处理的返回值。接下来看 network/IOChannel.h

#define CERTAIN_IO_BUFFER_SIZE (40 << 20)  // 40MB == 2 * (MAX_WRITEBATCH_SIZE + 1000)

class clsIOChannel : public clsFDBase {
 private:
  ConnInfo_t m_tConnInfo;

  // 0 <= iStart0 <= iEnd0 <= iStart1 <= iEnd1 <= iSize
  // The valid data is [iStart0, iEnd0) and [iStart1, iEnd1).
  struct CBuffer_t {
    char *pcData;
    uint32_t iSize;

    uint32_t iStart0;
    uint32_t iEnd0;

    uint32_t iStart1;
    uint32_t iEnd1;
  };

  void InitCBuffer(CBuffer_t *ptBuffer) {
    ptBuffer->pcData = (char *)malloc(CERTAIN_IO_BUFFER_SIZE);
    ptBuffer->iSize = CERTAIN_IO_BUFFER_SIZE;

    ptBuffer->iStart0 = 0;
    ptBuffer->iEnd0 = 0;

    ptBuffer->iStart1 = ptBuffer->iSize;
    ptBuffer->iEnd1 = ptBuffer->iSize;
  }

  void UpdateBuffer(CBuffer_t *ptBuffer) {
    AssertNotMore(ptBuffer->iStart0, ptBuffer->iEnd0);
    AssertNotMore(ptBuffer->iEnd0, ptBuffer->iStart1);
    AssertNotMore(ptBuffer->iStart1, ptBuffer->iEnd1);
    AssertNotMore(ptBuffer->iEnd1, ptBuffer->iSize);

    if (ptBuffer->iStart1 < ptBuffer->iEnd1) {
      AssertEqual(ptBuffer->iStart0, 0);
      return;
    }

    AssertEqual(ptBuffer->iStart1, ptBuffer->iEnd1);

    ptBuffer->iStart1 = ptBuffer->iSize;
    ptBuffer->iEnd1 = ptBuffer->iSize;

    if (ptBuffer->iStart0 == ptBuffer->iEnd0) {
      ptBuffer->iStart0 = 0;
      ptBuffer->iEnd0 = 0;
      return;
    }

    if (ptBuffer->iStart1 - ptBuffer->iEnd0 <= ptBuffer->iStart0) {
      ptBuffer->iStart1 = ptBuffer->iStart0;
      ptBuffer->iEnd1 = ptBuffer->iEnd0;

      ptBuffer->iStart0 = 0;
      ptBuffer->iEnd0 = 0;
    }
  }

  uint32_t GetSize(CBuffer_t *ptBuffer) {
    uint32_t iSize = (ptBuffer->iEnd1 - ptBuffer->iStart1) + (ptBuffer->iEnd0 - ptBuffer->iStart0);
    return iSize;
  }

  void DestroyCBuffer(CBuffer_t *ptBuffer) {
    free(ptBuffer->pcData), ptBuffer->pcData = NULL;
    memset(ptBuffer, 0, sizeof(CBuffer_t));
  }

  string m_strReadBytes;
  CBuffer_t m_tWriteBuffer;

  uint64_t m_iTotalReadCnt;
  uint64_t m_iTotalWriteCnt;

  uint64_t m_iTimestampUS;

  uint32_t m_iServerID;

  bool m_bBroken;

 public:
  class clsSerializeCBBase {
   public:
    virtual int Call(char *pcBuffer, uint32_t iSize) = 0;
    virtual ~clsSerializeCBBase() {}
  };

  // 构造,io_handler / conn_info / server_id / fd_id
  clsIOChannel(clsIOHandlerBase *poHandler, const ConnInfo_t &tConnInfo, uint32_t iServerID = -1,
               uint32_t iFDID = -1)
      : clsFDBase(tConnInfo.iFD, poHandler, iFDID),
        m_tConnInfo(tConnInfo),
        m_iTotalReadCnt(0),
        m_iTotalWriteCnt(0),
        m_iTimestampUS(0),
        m_iServerID(iServerID),
        m_bBroken(false) {
    InitCBuffer(&m_tWriteBuffer);
  }

  // 析构,删内存
  virtual ~clsIOChannel() {
    PrintInfo();
    DestroyCBuffer(&m_tWriteBuffer);
  }

  int Read(char *pcBuffer, uint32_t iSize);
  int Write(const char *pcBuffer, uint32_t iSize);

  int FlushWriteBuffer();

  TYPE_GET_SET(uint64_t, TimestampUS, iTimestampUS);

  // Return 0 iff append successfully.
  int AppendReadBytes(const char *pcBuffer, uint32_t iSize);
  int AppendWriteBytes(const char *pcBuffer, uint32_t iSize);
  int AppendWriteBytes(clsSerializeCBBase *poSerializeCB);

  uint32_t GetWriteByteSize() { return GetSize(&m_tWriteBuffer); }

  uint32_t GetServerID() { return m_iServerID; }

  const ConnInfo_t &GetConnInfo() { return m_tConnInfo; }

  bool IsBroken() { return m_bBroken; }

  bool IsBufferBusy() {
    // Check if above 1/8 CERTAIN_IO_BUFFER_SIZE.
    return (GetWriteByteSize() << 3) > CERTAIN_IO_BUFFER_SIZE;
  }

  void PrintInfo();
};

clsIOChannel 则继承了 clsFDBase 类,并封装了 Socket 连接。类的声明中有一大部分是用来定义写入缓存 CBuffer_t,简单来说就是 Flush 剩下的部分存为 [iStart1, iEnd1),新缓存存到 [iStart0, iEnd0),故每次执行 Flush 时,都需要先写 [iStart1, iEnd1)。来看函数的实现:

int clsIOChannel::Read(char *pcBuffer, uint32_t iSize) {
  int iRet;
  int iFD = clsFDBase::GetFD();

  if (m_bBroken) {
    return -1;
  }

  Assert(clsFDBase::IsReadable());

  uint32_t iCurr = m_strReadBytes.size();
  AssertLess(iCurr, iSize);

  memcpy(pcBuffer, m_strReadBytes.c_str(), iCurr);
  m_strReadBytes.clear();

  while (iCurr < iSize) {
    iRet = read(iFD, pcBuffer + iCurr, iSize - iCurr);

    if (iRet > 0) {
      iCurr += iRet;
      m_iTotalReadCnt += iRet;
    } else if (iRet == 0) {
      // read 返回 0,连接中断
      InetAddr_t tPeerAddr = m_tConnInfo.tPeerAddr;
      CertainLogError("closed by peer %s fd %d", tPeerAddr.ToString().c_str(), iFD);

      m_bBroken = true;
    } else if (iRet == -1) {
      CertainLogDebug("read ret -1 fd %d errno %d", iFD, errno);

      // 遇到中断
      if (errno == EINTR) {
        continue;
      }
      // 暂时不可用
      if (errno == EAGAIN) {
        clsFDBase::SetReadable(false);
        break;
      }

      m_bBroken = true;

      CertainLogError("read ret -1 conn: %s errno %d", m_tConnInfo.ToString().c_str(), errno);
    }

    if (m_bBroken) {
      break;
    }
  }

  if (iCurr == 0) {
    Assert(m_bBroken);
    return -2;
  }

  return iCurr;
}

int clsIOChannel::Write(const char *pcBuffer, uint32_t iSize) {
  int iRet;

  iRet = AppendWriteBytes(pcBuffer, iSize);
  if (iRet != 0) {
    CertainLogError("AppendWriteBytes ret %d", iRet);
    return -1;
  }

  iRet = FlushWriteBuffer();
  if (iRet != 0) {
    CertainLogError("FlushWriteBuffer ret %d", iRet);
    return -2;
  }

  return iSize;
}

// 真正地写入
int clsIOChannel::FlushWriteBuffer() {
  int iRet;

  if (m_bBroken) {
    return -1;
  }
  Assert(clsFDBase::IsWritable());

  uint32_t iDataSize = GetSize(&m_tWriteBuffer);
  if (iDataSize == 0) {
    return 0;
  }

  int iFD = clsFDBase::GetFD();
  struct iovec atBuffer[2] = {0};

  atBuffer[0].iov_base = m_tWriteBuffer.pcData + m_tWriteBuffer.iStart1;
  atBuffer[0].iov_len = m_tWriteBuffer.iEnd1 - m_tWriteBuffer.iStart1;

  atBuffer[1].iov_base = m_tWriteBuffer.pcData + m_tWriteBuffer.iStart0;
  atBuffer[1].iov_len = m_tWriteBuffer.iEnd0 - m_tWriteBuffer.iStart0;

  int iBufferIdx = 0;

  while (iDataSize > 0) {
    AssertLess(iBufferIdx, 2);
    iRet = writev(iFD, atBuffer + iBufferIdx, 2 - iBufferIdx);
    AssertNotEqual(iRet, 0);

    if (iRet > 0) {
      m_iTotalWriteCnt += iRet;
      iDataSize -= iRet;

      while (iRet > 0) {
        AssertLess(iBufferIdx, 2);

        if (size_t(iRet) < atBuffer[iBufferIdx].iov_len) {
          atBuffer[iBufferIdx].iov_len -= iRet;
          atBuffer[iBufferIdx].iov_base = (char *)atBuffer[iBufferIdx].iov_base + iRet;
          iRet = 0;
        } else {
          iRet -= atBuffer[iBufferIdx].iov_len;
          atBuffer[iBufferIdx].iov_len = 0;
          iBufferIdx++;
        }
      }
    } else if (iRet == -1) {
      CertainLogDebug("read ret -1 fd %d errno %d", iFD, errno);

      if (errno == EINTR) {
        continue;
      }
      if (errno == EAGAIN) {
        clsFDBase::SetWritable(false);
        break;
      }

      m_bBroken = true;
      CertainLogError("write fail fd %d errno %d", iFD, errno);
      break;
    }
  }

  m_tWriteBuffer.iStart0 = m_tWriteBuffer.iEnd0 - atBuffer[1].iov_len;
  m_tWriteBuffer.iStart1 = m_tWriteBuffer.iEnd1 - atBuffer[0].iov_len;
  UpdateBuffer(&m_tWriteBuffer);

  if (m_bBroken) {
    return -2;
  }

  return 0;
}

// 读取同样写入 buffer 中
int clsIOChannel::AppendReadBytes(const char *pcBuffer, uint32_t iSize) {
  m_strReadBytes.append(pcBuffer, iSize);
  return 0;
}

// 写入 buffer 中
int clsIOChannel::AppendWriteBytes(const char *pcBuffer, uint32_t iSize) {
  UpdateBuffer(&m_tWriteBuffer);

  if (iSize > m_tWriteBuffer.iStart1 - m_tWriteBuffer.iEnd0) {
    return -1;
  }

  memcpy(m_tWriteBuffer.pcData + m_tWriteBuffer.iEnd0, pcBuffer, iSize);
  m_tWriteBuffer.iEnd0 += iSize;

  return 0;
}

// 序列化地写入 buffer 中
int clsIOChannel::AppendWriteBytes(clsSerializeCBBase *poCB) {
  UpdateBuffer(&m_tWriteBuffer);

  int iRet = poCB->Call(m_tWriteBuffer.pcData + m_tWriteBuffer.iEnd0,
                        m_tWriteBuffer.iStart1 - m_tWriteBuffer.iEnd0);
  if (iRet < 0) {
    CertainLogError("poCB->Call max_size %u ret %d", m_tWriteBuffer.iStart1 - m_tWriteBuffer.iEnd0,
                    iRet);
    return -1;
  }
  m_tWriteBuffer.iEnd0 += iRet;

  return 0;
}

Read 时先将缓存的数据读取出来,后执行 read(fd) 读取真实的数据。注意 Readable 状态的判断和使用。Write 时先将数据写入缓存中,而后执行 FlushWriteBuffer 执行真正的 write(fd)。实际使用中需要配合 clsEpollIOclsIOHandlerBase 一起使用,在 HandleRead 执行 Read,在 HandleWrite 时执行 FlushWriteBuffer,这部分将在之后的 clsIOWorker 类中有所体现。

接下来看一下 ConnWorker 的实现。这里调整一下代码的顺序,先看 src/ConnWorker.hclsConnWorker 的定义:

class clsConnWorker : public clsThreadBase {
 private:
  clsConfigure *m_poConf;
  clsEpollIO *m_poEpollIO;

  clsListenHandler *m_poListenHandler;
  clsNegoHandler *m_poNegoHandler;

  int AddListen(const InetAddr_t &tInetAddr);
  int AddAllListen();

  int RecvNegoMsg(clsNegoContext *poNegoCtx);
  int AcceptOneFD(clsListenContext *poContext);

 public:
  virtual ~clsConnWorker();
  clsConnWorker(clsConfigure *poConf);

  int HandleListen(clsListenContext *poContext);
  int HandleNego(clsNegoContext *poNegoCtx);

  void Run();
};

clsConnWorker 继承于 clsThreadBase,封装了 pthread,线程启动时会执行 Run 函数:

void clsConnWorker::Run() {
  uint32_t iLocalServerID = m_poConf->GetLocalServerID();
  SetThreadTitle("conn_%u", iLocalServerID);
  CertainLogInfo("conn_%u run", iLocalServerID);

  int iRet = AddAllListen();
  if (iRet != 0) {
    CertainLogFatal("AddAllListen ret %d", iRet);
    Assert(false);
  }

  while (1) {
    m_poEpollIO->RunOnce(1000);

    if (CheckIfExiting(0)) {
      printf("conn_%u exit\n", iLocalServerID);
      CertainLogInfo("conn_%u exit", iLocalServerID);
      break;
    }
  }
}

Run 函数中首先会执行 AddAllListen 启动服务端的地址监听 AddAllListen

int clsConnWorker::AddAllListen() {
  int iRet;

  uint32_t iLocalServerID = m_poConf->GetLocalServerID();
  vector<InetAddr_t> vecServerAddr = m_poConf->GetServerAddrs();
  AssertLess(iLocalServerID, vecServerAddr.size());

  iRet = AddListen(vecServerAddr[iLocalServerID]);
  if (iRet != 0) {
    CertainLogError("AddListen ret %d", iRet);
    return -1;
  }

  return 0;
}

int clsConnWorker::AddListen(const InetAddr_t &tLocalAddr) {
  int iRet;
  int iBacklog = 8096;

  int iFD = CreateSocket(&tLocalAddr);
  if (iFD < 0) {
    CertainLogError("CreateSocket ret %d", iFD);
    return -1;
  }

  iRet = listen(iFD, iBacklog);
  if (iRet == -1) {
    CertainLogError("listen ret - 1 errno %d", errno);
    return -2;
  }

  CertainLogInfo("Start listen addr %s fd %d", tLocalAddr.ToString().c_str(), iFD);

  clsListenContext *poContext = new clsListenContext(iFD, m_poListenHandler, tLocalAddr);

  iRet = m_poEpollIO->Add(poContext);
  if (iRet != 0) {
    CertainLogError("m_poEpollIO->Add ret %d", iRet);
    AssertEqual(close(iFD), 0);
    delete poContext, poContext = NULL;

    return -3;
  }

  return 0;
}

AddListen 中创建了 Socket 连接,并且监听了服务端设定的地址,而后创建了一 clsListenContext 对象并将其加入 Epoll 对象中。当该连接可读时,说明有新的客户端尝试连接 connect,服务端需执行 accept 接受连接,这里看 clsListenContext 及其 Handler 的实现:

class clsListenContext : public clsFDBase {
 private:
  InetAddr_t m_tLocalAddr;

 public:
  clsListenContext(int iFD, clsIOHandlerBase *poIOHandler, const InetAddr_t &tLocalAddr)
      : clsFDBase(iFD, poIOHandler, -1, (EPOLLIN | EPOLLET)), m_tLocalAddr(tLocalAddr) {}

  virtual ~clsListenContext() {}

  InetAddr_t GetLocalAddr() { return m_tLocalAddr; }
};

class clsListenHandler : public clsIOHandlerBase {
 private:
  clsConnWorker *m_poConnWorker;

 public:
  clsListenHandler(clsConnWorker *poConnWorker) : m_poConnWorker(poConnWorker) {}
  virtual ~clsListenHandler() {}

  virtual int HandleRead(clsFDBase *poFD);
  virtual int HandleWrite(clsFDBase *poFD);
};

int clsListenHandler::HandleRead(clsFDBase *poFD) {
  return m_poConnWorker->HandleListen(dynamic_cast<clsListenContext *>(poFD));
}

int clsListenHandler::HandleWrite(clsFDBase *poFD) { CERTAIN_EPOLLIO_UNREACHABLE; }

// 处理 Listen,要么建立连接,要么重新 Listen
int clsConnWorker::HandleListen(clsListenContext *poContext) {
  while (1) {
    int iRet = AcceptOneFD(poContext);
    if (iRet == 0) {
      continue;
    } else if (iRet < 0) {
      // Relisten when the fd is invalid, Check if Bug.

      clsAutoDelete<clsListenContext> oAuto(poContext);
      CertainLogFatal("AcceptOnce ret %d", iRet);

      int iFD = poContext->GetFD();
      iRet = close(iFD);
      AssertEqual(iRet, 0);

      iRet = AddListen(poContext->GetLocalAddr());
      AssertEqual(iRet, 0);

      return -1;
    }

    AssertEqual(iRet, 1);
    break;
  }

  return 0;
}

// accept client 端的连接,并将 NegoContext 加入 epoll
int clsConnWorker::AcceptOneFD(clsListenContext *poContext) {
  int iRet;
  int iFD, iListenFD = poContext->GetFD();
  ConnInfo_t tConnInfo;

  struct sockaddr_in tSockAddr;
  socklen_t tLen = sizeof(tSockAddr);

  while (1) {
    iFD = accept(iListenFD, (struct sockaddr *)(&tSockAddr), &tLen);
    if (iFD == -1) {
      if (errno == EINTR) {
        continue;
      } else if (errno == EAGAIN) {
        return 1;
      }

      CertainLogError("accept ret -1 errno %d", errno);
      return -1;
    }
    break;
  }

  AssertEqual(SetNonBlock(iFD, true), 0);

  tConnInfo.tLocalAddr = poContext->GetLocalAddr();
  tConnInfo.tPeerAddr = InetAddr_t(tSockAddr);
  tConnInfo.iFD = iFD;

  CertainLogInfo("accept conn %s", tConnInfo.ToString().c_str());

  clsNegoContext *poNegoCtx = new clsNegoContext(m_poNegoHandler, tConnInfo);
  iRet = m_poEpollIO->Add(poNegoCtx);
  if (iRet != 0) {
    CertainLogError("m_poEpollIO->Add ret %d", iRet);
    AssertEqual(close(tConnInfo.iFD), 0);
    delete poNegoCtx, poNegoCtx = NULL;
  }

  return 0;
}

AcceptOneFD 中,服务端接受了新连接,将新连接设定为非阻塞模式,并构造了 clsNegoContext 加入 Epoll 对象中,当客户端写入数据、该连接可读时:

// 协商上下文
class clsNegoContext : public clsFDBase {
 private:
  ConnInfo_t m_tConnInfo;
  uint32_t m_iServerID;

 public:
  clsNegoContext(clsIOHandlerBase *poHandler, const ConnInfo_t &tConnInfo)
      : clsFDBase(tConnInfo.iFD, poHandler, -1, (EPOLLIN | EPOLLET)),
        m_tConnInfo(tConnInfo),
        m_iServerID(INVALID_SERVER_ID) {}

  virtual ~clsNegoContext() {}

  const ConnInfo_t &GetConnInfo() { return m_tConnInfo; }
  TYPE_GET_SET(uint32_t, ServerID, iServerID);
};

class clsNegoHandler : public clsIOHandlerBase {
 private:
  clsConnWorker *m_poConnWorker;

 public:
  clsNegoHandler(clsConnWorker *poConnWorker) : m_poConnWorker(poConnWorker) {}
  virtual ~clsNegoHandler() {}

  virtual int HandleRead(clsFDBase *poFD);
  virtual int HandleWrite(clsFDBase *poFD);
};

int clsNegoHandler::HandleRead(clsFDBase *poFD) {
  // 当 client 发送数据时
  return m_poConnWorker->HandleNego(dynamic_cast<clsNegoContext *>(poFD));
}

int clsNegoHandler::HandleWrite(clsFDBase *poFD) { CERTAIN_EPOLLIO_UNREACHABLE; }

// 校验 client 连接,并从 epoll 中删除 NegoContext,加入连接池
int clsConnWorker::HandleNego(clsNegoContext *poNegoCtx) {
  int iRet;
  clsAutoDelete<clsNegoContext> oAuto(poNegoCtx);

  iRet = RecvNegoMsg(poNegoCtx);
  if (iRet != 0) {
    CertainLogError("RecvNegoMsg ret %d", iRet);
    m_poEpollIO->RemoveAndCloseFD(dynamic_cast<clsFDBase *>(poNegoCtx));
    return -1;
  }

  m_poEpollIO->Remove(dynamic_cast<clsFDBase *>(poNegoCtx));

  // For check only.
  ConnInfo_t tConnInfo = poNegoCtx->GetConnInfo();
  bool bInnerServer = false;
  vector<InetAddr_t> vecAddr = m_poConf->GetServerAddrs();
  for (uint32_t i = 0; i < vecAddr.size(); ++i) {
    if (vecAddr[i].GetNetOrderIP() == tConnInfo.tPeerAddr.GetNetOrderIP()) {
      if (i != poNegoCtx->GetServerID()) {
        CertainLogFatal("%u -> %u Check if replace machine conn: %s", poNegoCtx->GetServerID(), i,
                        tConnInfo.ToString().c_str());
        poNegoCtx->SetServerID(i);
      }
      bInnerServer = true;
    }
  }
  if (!bInnerServer) {
    CertainLogFatal("conn %s not from Inner Servers, check it", tConnInfo.ToString().c_str());
    close(poNegoCtx->GetFD());
    return -2;
  }

  CertainLogInfo("srvid %u conn %s from Inner Servers", poNegoCtx->GetServerID(),
                 tConnInfo.ToString().c_str());

  iRet = clsConnInfoMng::GetInstance()->PutByOneThread(poNegoCtx->GetServerID(),
                                                       poNegoCtx->GetConnInfo());
  if (iRet != 0) {
    CertainLogError("clsConnInfoMng PutByOneThread ret %d", iRet);
    close(poNegoCtx->GetFD());
  }

  return 0;
}

// 接收 MagicNumber 并设定 ServerID
int clsConnWorker::RecvNegoMsg(clsNegoContext *poNegoCtx) {
  int iRet;

  const ConnInfo_t &tConnInfo = poNegoCtx->GetConnInfo();
  int iFD = tConnInfo.iFD;

  uint8_t acNego[3];

  while (1) {
    iRet = read(iFD, acNego, 3);
    if (iRet == -1) {
      if (errno == EAGAIN || errno == EINTR) {
        continue;
      }

      CertainLogError("conn %s errno %d", tConnInfo.ToString().c_str(), errno);
      return -1;
    } else if (iRet == 0) {
      CertainLogError("conn %s closed by peer", tConnInfo.ToString().c_str());
      return -2;
    } else if (iRet < 3) {
      CertainLogError("simple close conn %s", tConnInfo.ToString().c_str());
      return -3;
    }

    AssertEqual(iRet, 3);
    break;
  }

  uint16_t hMagicNum = ntohs(*(uint16_t *)acNego);
  if (hMagicNum != RP_MAGIC_NUM) {
    CertainLogFatal("BUG conn %s no RP_MAGIC_NUM found", tConnInfo.ToString().c_str());
    return -4;
  }

  uint32_t iServerID = uint32_t(acNego[2]);
  poNegoCtx->SetServerID(iServerID);

  CertainLogDebug("iServerID %u conn %s", iServerID, tConnInfo.ToString().c_str());

  return 0;
}

RecvNegoMsg 函数中会尝试读取客户端发送过来的三个字节,校验 Magic Number;RecvNegoMsg 函数中会判断是否为内部服务器;均成功则将键值对 <ServerID, ConnInfo> 加入连接池 clsConnInfoMng 中:

// 连接池单例
class clsConnInfoMng : public clsSingleton<clsConnInfoMng> {
 private:
  clsConfigure *m_poConf;

  uint32_t m_iServerNum;
  uint32_t m_iLocalServerID;

  clsMutex m_oMutex;

  vector<queue<ConnInfo_t> > m_vecIntConnQueue;

  friend class clsSingleton<clsConnInfoMng>;
  clsConnInfoMng() {}

  void RemoveInvalidConn();

 public:
  int Init(clsConfigure *poConf);
  void Destroy();

  int TakeByMultiThread(uint32_t iIOWorkerID, const vector<vector<clsIOChannel *> > vecIntChannel,
                        ConnInfo_t &tConnInfo, uint32_t &iServerID);

  int PutByOneThread(uint32_t iServerID, const ConnInfo_t &tConnInfo);
};

int clsConnInfoMng::Init(clsConfigure *poConf) {
  m_poConf = poConf;

  m_iServerNum = m_poConf->GetServerNum();
  m_iLocalServerID = m_poConf->GetLocalServerID();

  Assert(m_vecIntConnQueue.size() == 0);
  m_vecIntConnQueue.resize(m_iServerNum);

  return 0;
}

void clsConnInfoMng::Destroy() { m_vecIntConnQueue.clear(); }

int clsConnInfoMng::PutByOneThread(uint32_t iServerID, const ConnInfo_t &tConnInfo) {
  if (iServerID == m_iLocalServerID ||
      (iServerID >= m_iServerNum && iServerID != INVALID_SERVER_ID)) {
    CertainLogError("Unexpected server id %u", iServerID);
    return -1;
  }

  CertainLogDebug("iServerID %u conn %s", iServerID, tConnInfo.ToString().c_str());

  clsThreadLock oLock(&m_oMutex);
  AssertLess(iServerID, m_iServerNum);
  m_vecIntConnQueue[iServerID].push(tConnInfo);

  return 0;
}

void clsConnInfoMng::RemoveInvalidConn() {
  uint32_t iInvalidCnt = 0;
  uint32_t iConnNotToken = 0;

  for (uint32_t i = 0; i < m_vecIntConnQueue.size(); ++i) {
    while (!m_vecIntConnQueue[i].empty()) {
      iConnNotToken++;
      ConnInfo_t tConnInfo = m_vecIntConnQueue[i].front();

      if (CheckIfValid(tConnInfo.iFD)) {
        break;
      }

      iInvalidCnt++;
      CertainLogError("Invalid conn: %s", tConnInfo.ToString().c_str());

      int iRet = close(tConnInfo.iFD);
      if (iRet != 0) {
        CertainLogError("close fail fd %d errno %d", tConnInfo.iFD, errno);
      }
      m_vecIntConnQueue[i].pop();
    }
  }

  if (iInvalidCnt > 0) {
    CertainLogError("iInvalidCnt %u", iInvalidCnt);
  } else if (iConnNotToken > 0) {
    CertainLogImpt("iConnNotToken %u", iConnNotToken);
  }
}

int clsConnInfoMng::TakeByMultiThread(uint32_t iIOWorkerID,
                                      const vector<vector<clsIOChannel *> > vecIntChannel,
                                      ConnInfo_t &tConnInfo, uint32_t &iServerID) {
  clsThreadLock oLock(&m_oMutex);

  RemoveInvalidConn();

  for (uint32_t i = 0; i < m_vecIntConnQueue.size(); ++i) {
    if (i == m_iLocalServerID) {
      Assert(m_vecIntConnQueue[i].empty());
      continue;
    }

    if (m_vecIntConnQueue[i].empty()) {
      continue;
    }

    // Use min first to make every IOWorker has equal IO channels.

    int32_t iMinChannelCnt = INT32_MAX;
    for (uint32_t iWorkerID = 0; iWorkerID < m_poConf->GetIOWorkerNum(); ++iWorkerID) {
      int32_t iTemp = clsIOWorkerRouter::GetInstance()->GetAndAddIntChannelCnt(iWorkerID, i, 0);
      if (iMinChannelCnt > iTemp) {
        iMinChannelCnt = iTemp;
      }
    }

    AssertNotMore(iMinChannelCnt, vecIntChannel[i].size());
    if (vecIntChannel[i].size() == uint32_t(iMinChannelCnt)) {
      tConnInfo = m_vecIntConnQueue[i].front();
      m_vecIntConnQueue[i].pop();

      iServerID = i;

      return 0;
    }
  }

  return -1;
}

后调用 TakeByMultiThread 则可以取出校验后可用的连接、读写数据实现网络通信了。

3. 总结

Paxos 协议执行过程中包含大量的消息收发,网络通信的性能就十分重要。PaxosStore 中使用了 Epoll 和边缘触发获得高性能,并实现了一定程度的封装提高易用性。下一篇将继续分析 PaxosStore 模块间消息的传递。

References

  1. "Tencent/PaxosStore", GitHub
  2. "PaxosStore: High-availability Storage Made Practical in WeChat", Jianjun Zheng