Robin Hood Hashing 源码分析

2021.05.21
SF-Zhou

从 C++11 开始,STL 会提供哈希表 std::unordered_map 的实现,用起来确实很方便,不过性能上就差强人意了。robin_hood::unordered_map 作为 std::unordered_map 的替代品,提供了与标准库中一致的接口,同时带来 2 到 3 倍的性能提升,着实让人心动。笔者年前尝试使用该哈希表,但由于其内部的 Bug 导致低概率的抛出异常,不得已又退回使用标准库。今年 3 月底的时候其作者修复了该 Bug,笔者也第一时间测试使用,并上线到现网环境,截止目前无任何故障。安全起见,笔者分析了该哈希表的具体实现,分析的代码版本为 3.11.1,目前也没有发现潜在的安全隐患。依笔者之见,Robin Hood 高性能的秘诀是开放寻址、平坦化和限制冲突。

1. 开放寻址

目前主流的 STL 实现均使用闭式寻址(Closed Addressing),当发生冲突时,需要使用额外的数据结构处理冲突。例如 GCC 中使用的是链表,查询时会先对 key 进行哈希确定桶的位置,再比对桶对应的链表中的元素。闭式寻址的优势是删除简单,相同负载系数下对比开放寻址性能更好。但冲突剧烈时,查询的复杂度也会从 O(1)\mathcal O(1) 退化到 O(n)\mathcal O(n),此时也依赖 Rehash 减少冲突。

Bucket Collision Chain
0
1
2 ②②②
3
4
5
6
7 ⑦⑦

而 Robin Hood 中使用的是开放寻址(Open Addressing),发生冲突时会尝试找下一个空桶的位置,每个桶至多存放一个元素,这也就限制了其负载系数至多为 1。其优势是有更好的缓存局部性,负载系数低时性能优异,劣势是删除时复杂度更高,负载系数高时冲突剧烈。单纯使用开放寻址无法应对复杂的现实 需求,为了提高性能还需要额外的优化策略。

Bucket Open Addressing
0
1
2
3
4
5
6
7

2. 平坦化

平坦化(Flatten)是指将哈希表中的元素直接存储在哈希桶数组中。非平坦化的实现会在哈希桶数组中存放元素的指针,查询时先读桶中的数据,再访问对应的元素,会产生一次间接寻址。平坦化则可以减少一次寻址操作,确定桶的位置后就可以直接访问元素。其优势自然是获得更好的性能和缓存局部性,劣势是需要使用更多的内存空间,以 80% 的负载系数为例,Rehash 后 60% 的内存空间存放的是空桶。另外平坦化要求键值对支持移动构造和移动复制,Robin Hood 对符合该条件并且键值对总大小小于 6 个 size_t 的会启用平坦化的实现:

template <typename Key, typename T, typename Hash = hash<Key>,
          typename KeyEqual = std::equal_to<Key>, size_t MaxLoadFactor100 = 80>
using unordered_map = detail::Table<
    sizeof(robin_hood::pair<Key, T>) <= sizeof(size_t) * 6 &&
        std::is_nothrow_move_constructible<robin_hood::pair<Key, T>>::value &&
        std::is_nothrow_move_assignable<robin_hood::pair<Key, T>>::value,
    MaxLoadFactor100, Key, T, Hash, KeyEqual>;

对于不符合条件的键值对,Robin Hood 中也提供了非平坦化实现。简单压测可以发现,相同的键值对类型平坦化相较于非平坦化可以提升一倍多的性能。

3. 限制冲突

Robin Hood 中使用了 uint8_t 类型的 Info 字段记录 key 的目标桶与实际存放桶之间的距离,使用该字段实现:

  1. 检查桶是否为空桶;
  2. 限制目标桶与实际桶之间的距离小于 256,使查询的复杂度收敛;
  3. 保证桶中实际存放的键值对的顺序始终与键值对目标桶的顺序一致。
Bucket Info=0 Info=1 Info=2 Info=3 Info=4
0
1
2
3
4
5
6
7
8 (Buffer)
9 (Buffer)

如上表所示,使用 Info=0 表示空桶,非空桶时 Info 记录键值对与哈希目标桶的距离,当超过限制时执行扩容。插入时根据距离判断键值对的目标桶位置并以此排序执行插入,删除时也根据该距离判断是否需要将键值对前移。申请哈希数组时至多会额外申请 0xFF 个空间存储尾端冲突的键值对。

执行插入的代码如下:

template <typename... Args>
std::pair<iterator, bool> emplace(Args&&... args) {
  ROBIN_HOOD_TRACE(this)
  // 构造键值对节点
  Node n{*this, std::forward<Args>(args)...};
  // 查询插入位置
  auto idxAndState = insertKeyPrepareEmptySpot(getFirstConst(n));
  switch (idxAndState.second) {
    case InsertionState::key_found:
      n.destroy(*this);
      break;

    case InsertionState::new_node:
      ::new (static_cast<void*>(&mKeyVals[idxAndState.first]))
        Node(*this, std::move(n));
      break;

    case InsertionState::overwrite_node:
      mKeyVals[idxAndState.first] = std::move(n);
      break;

    case InsertionState::overflow_error:
      n.destroy(*this);
      throwOverflowError();
      break;
  }

  // 返回迭代器
  return std::make_pair(
    iterator(mKeyVals + idxAndState.first, mInfo + idxAndState.first),
    InsertionState::key_found != idxAndState.second);
}

template <typename OtherKey>
std::pair<size_t, InsertionState> insertKeyPrepareEmptySpot(OtherKey&& key) {
  for (int i = 0; i < 256; ++i) {
    size_t idx{};
    InfoType info{};
    // 查询哈希后的位置,计算 info 值
    keyToIdx(key, &idx, &info);
    // 跳过目标桶非当前位置的节点
    nextWhileLess(&info, &idx);

    // while we potentially have a match
    while (info == mInfo[idx]) {
      // 若找到相同的 key,则提前返回
      if (WKeyEqual::operator()(key, mKeyVals[idx].getFirst())) {
        // key already exists, do NOT insert.
        // see http://en.cppreference.com/w/cpp/container/unordered_map/insert
        return std::make_pair(idx, InsertionState::key_found);
      }
      // info 一致但 key 不一致,则继续寻找下个节点
      next(&info, &idx);
    }

    // unlikely that this evaluates to true
    if (ROBIN_HOOD_UNLIKELY(mNumElements >= mMaxNumElementsAllowed)) {
      // 元素数量超过允许的值后,执行 Rehash 扩容
      if (!increase_size()) {
        return std::make_pair(size_t(0), InsertionState::overflow_error);
      }
      continue;
    }

    // key not found, so we are now exactly where we want to insert it.
    // 当前位置 info > mInfo[idx],准备在该位置插入
    auto const insertion_idx = idx;
    auto const insertion_info = info;
    if (ROBIN_HOOD_UNLIKELY(insertion_info + mInfoInc > 0xFF)) {
      // 如果 info 的值即将超过 0xFF,那么下一次插入前先执行扩容
      mMaxNumElementsAllowed = 0;
    }

    // find an empty spot
    // 在插入位置继续寻找下一个空桶的位置
    while (0 != mInfo[idx]) {
      next(&info, &idx);
    }

    // 如果插入位置与空桶位置不一致
    if (idx != insertion_idx) {
      // 则将插入位置到空桶前的所有元素向后移动,空出插入位置来
      shiftUp(idx, insertion_idx);
    }
    // put at empty spot
    // 在插入位置更新 info
    mInfo[insertion_idx] = static_cast<uint8_t>(insertion_info);
    ++mNumElements;
    // 返回结果
    return std::make_pair(
      insertion_idx, idx == insertion_idx ? InsertionState::new_node
      : InsertionState::overwrite_node);
  }

  // enough attempts failed, so finally give up.
  return std::make_pair(size_t(0), InsertionState::overflow_error);
}

// highly performance relevant code.
// Lower bits are used for indexing into the array (2^n size)
// The upper 1-5 bits need to be a reasonable good hash, to save comparisons.
template <typename HashKey>
void keyToIdx(HashKey&& key, size_t* idx, InfoType* info) const {
  // In addition to whatever hash is used, add another mul & shift so we get
  // better hashing. This serves as a bad hash prevention, if the given data
  // is badly mixed.
  auto h = static_cast<uint64_t>(WHash::operator()(key));

  // 执行完用户提供的哈希后,再执行一次可变参数的哈希
  h *= mHashMultiplier;
  h ^= h >> 33U;

  // the lower InitialInfoNumBits are reserved for info.
  // 计算 info 和插入位置 idx
  *info = mInfoInc + static_cast<InfoType>((h & InfoMask) >> mInfoHashShift);
  *idx = (static_cast<size_t>(h) >> InitialInfoNumBits) & mMask;
}

// forwards the index by one, wrapping around at the end
void next(InfoType* info, size_t* idx) const noexcept {
  // 跳到下一个位置
  *idx = *idx + 1;
  // 距离需要叠加上对应的系数
  *info += mInfoInc;
}

void nextWhileLess(InfoType* info, size_t* idx) const noexcept {
  // unrolling this by hand did not bring any speedups.
  while (*info < mInfo[*idx]) {
    next(info, idx);
  }
}

执行删除的代码:

// Erases element at pos, returns iterator to the next element.
iterator erase(iterator pos) {
  ROBIN_HOOD_TRACE(this)
  // we assume that pos always points to a valid entry, and not end().
  auto const idx = static_cast<size_t>(pos.mKeyVals - mKeyVals);

  shiftDown(idx);
  --mNumElements;

  if (*pos.mInfo) {
    // we've backward shifted, return this again
    return pos;
  }

  // no backward shift, return next element
  return ++pos;
}

void shiftDown(size_t idx) noexcept(
  std::is_nothrow_move_assignable<Node>::value) {
  // until we find one that is either empty or has zero offset.
  // TODO(martinus) we don't need to move everything, just the last one for
  // the same bucket.
  // 析构需要删除的键值对
  mKeyVals[idx].destroy(*this);

  // until we find one that is either empty or has zero offset.
  // 根据距离判断是否需要前移
  while (mInfo[idx + 1] >= 2 * mInfoInc) {
    ROBIN_HOOD_COUNT(shiftDown)
    // 距离减一
    mInfo[idx] = static_cast<uint8_t>(mInfo[idx + 1] - mInfoInc);
    mKeyVals[idx] = std::move(mKeyVals[idx + 1]);
    ++idx;
  }

  mInfo[idx] = 0;
  // don't destroy, we've moved it
  // mKeyVals[idx].destroy(*this);
  mKeyVals[idx].~Node();
}

执行扩容的代码如下:

// True if resize was possible, false otherwise
bool increase_size() {
  // nothing allocated yet? just allocate InitialNumElements
  if (0 == mMask) {
    initData(InitialNumElements);
    return true;
  }

  auto const maxNumElementsAllowed = calcMaxNumElementsAllowed(mMask + 1);
  if (mNumElements < maxNumElementsAllowed && try_increase_info()) {
    return true;
  }

  ROBIN_HOOD_LOG("mNumElements="
                 << mNumElements << ", maxNumElementsAllowed="
                 << maxNumElementsAllowed << ", load="
                 << (static_cast<double>(mNumElements) * 100.0 /
                     (static_cast<double>(mMask) + 1)))

  nextHashMultiplier();
  if (mNumElements * 2 < calcMaxNumElementsAllowed(mMask + 1)) {
    // we have to resize, even though there would still be plenty of space
    // left! Try to rehash instead. Delete freed memory so we don't steadyily
    // increase mem in case we have to rehash a few times
    rehashPowerOfTwo(mMask + 1, true);
  } else {
    // Each resize use a different hash so we don't so easily overflow.
    // Make sure we only have odd numbers, so that the multiplication is
    // reversible!
    rehashPowerOfTwo((mMask + 1) * 2, false);
  }
  return true;
}

void nextHashMultiplier() {
  // adding an *even* number, so that the multiplier will always stay odd.
  // This is necessary so that the hash stays a mixing function (and thus
  // doesn't have any information loss).
  // 修改哈希常数避免始终陷入高冲突状态
  mHashMultiplier += UINT64_C(0xc4ceb9fe1a85ec54);
}

// reserves space for at least the specified number of elements.
// only works if numBuckets if power of two
// True on success, false otherwise
void rehashPowerOfTwo(size_t numBuckets, bool forceFree) {
  ROBIN_HOOD_TRACE(this)

  Node* const oldKeyVals = mKeyVals;
  uint8_t const* const oldInfo = mInfo;

  const size_t oldMaxElementsWithBuffer =
    calcNumElementsWithBuffer(mMask + 1);

  // resize operation: move stuff
  initData(numBuckets);
  if (oldMaxElementsWithBuffer > 1) {
    for (size_t i = 0; i < oldMaxElementsWithBuffer; ++i) {
      if (oldInfo[i] != 0) {
        // might throw an exception, which is really bad since we are in the
        // middle of moving stuff.
        insert_move(std::move(oldKeyVals[i]));
        // destroy the node but DON'T destroy the data.
        oldKeyVals[i].~Node();
      }
    }

    // this check is not necessary as it's guarded by the previous if, but it
    // helps silence g++'s overeager "attempt to free a non-heap object 'map'
    // [-Werror=free-nonheap-object]" warning.
    if (oldKeyVals != reinterpret_cast_no_cast_align_warning<Node*>(&mMask)) {
      // don't destroy old data: put it into the pool instead
      if (forceFree) {
        std::free(oldKeyVals);
      } else {
        DataPool::addOrFree(oldKeyVals,
                            calcNumBytesTotal(oldMaxElementsWithBuffer));
      }
    }
  }
}

实际上 info 字段有限的位中还存储了部分哈希的信息用于加速找 key,随着扩容其位数会逐渐降低,这里就不再详述了。

4. 使用建议

并不存在某一种哈希表可以适用所有场景,不过大部分场景下 Robin Hood 的性能都不错,推荐尝试。笔者之前优化的项目中哈希表占用整个服务 8% 左右的 CPU,替换为 Robin Hood 后哈希表部分的 CPU 占用降低到 3%。

References

  1. "Hash table", Wikipedia
  2. "martinus/robin-hood-hashing", GitHub