Tokio 源码分析「一、事件驱动 IO」

2020.11.08
SF-Zhou

Tokio 是 Rust 语言下的一个异步运行时,基于非阻塞 IO 和事件驱动。Tokio 中应用了协程及 Work Stealing 策略,实现类似 Goroutine 的 M:N 线程效果。本系列逐步分析 Tokio 的代码实现,版本为 v0.3.3,本篇关注事件驱动 IO 封装库 Mio

1. Mio 概览

Mio 是对系统事件驱动 IO API 的封装,Linux 环境下封装的是 epoll,其他系统环境下封装的是 kqueue 和 IOCP,对外提供一致的 API 接口。本文分析 Mio v0.7.5 的源码,先看 src 目录下的文件结构:

src
├── event          # 对事件的封装
│   ├── event.rs
│   ├── events.rs
│   ├── mod.rs
│   └── source.rs
├── interest.rs    # 感兴趣的事件类型
├── io_source.rs   # 感兴趣的 IO 对象基类
├── lib.rs
├── macros
│   └── mod.rs
├── net            # 不同类型网络的接口
│   ├── mod.rs
│   ├── tcp
│   ├── udp.rs
│   └── uds
├── poll.rs        # 对外统一的 Poll 接口
├── sys            # 不同系统下的底层实现
│   ├── mod.rs
│   ├── shell
│   ├── unix
│   └── windows
├── token.rs       # 用以区分 Poll 得到的 event
└── waker.rs       # 跨线程唤醒 Poll

再来看官方提供的样例:

// You can run this example from the root of the mio repo:
// cargo run --example tcp_server --features="os-poll tcp"
use mio::event::Event;
use mio::net::{TcpListener, TcpStream};
use mio::{Events, Interest, Poll, Registry, Token};
use std::collections::HashMap;
use std::io::{self, Read, Write};
use std::str::from_utf8;

// 使用 Token 用以区分不同的 Socket 连接
const SERVER: Token = Token(0);

const DATA: &[u8] = b"Hello world!\n";

fn main() -> io::Result<()> {
    env_logger::init();

    // 创建 Poll 实例
    let mut poll = Poll::new()?;
    // 创建一个长度为 128 的空事件列表
    let mut events = Events::with_capacity(128);

    // 创建 TcpListener 实例 server,监听 9000 端口
    let addr = "127.0.0.1:9000".parse().unwrap();
    let mut server = TcpListener::bind(addr)?;

    // 将 server 注册到 poll 对象中,监听可读事件
    poll.registry()
        .register(&mut server, SERVER, Interest::READABLE)?;

    // 存储 <Token, TcpStream> 映射
    let mut connections = HashMap::new();
    // 自增的唯一 Token 
    let mut unique_token = Token(SERVER.0 + 1);

    println!("You can connect to the server using `nc`:");
    println!(" $ nc 127.0.0.1 9000");
    println!("You'll see our welcome message and anything you type we'll be printed here.");

    loop {
        // 无限循环,poll 等待事件,None 表示不超时
        poll.poll(&mut events, None)?;

        // 迭代事件
        for event in events.iter() {
            // 根据事件对应的 token 区分事件并做相应处理
            match event.token() {
                // 如果是 SERVER,则说明是 CLIENT 请求连接
                SERVER => loop {
                    // accept 获取连接 TcpStream 对象及其地址
                    let (mut connection, address) = match server.accept() {
                        Ok((connection, address)) => (connection, address),
                        Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
                            // `WouldBlock` 表示误报,实际上并没有连接
                            break;
                        }
                        Err(e) => {
                            return Err(e);
                        }
                    };

                    println!("Accepted connection from: {}", address);

                    // 使用唯一的 token 注册
                    let token = next(&mut unique_token);
                    poll.registry().register(
                        &mut connection,
                        token,
                        Interest::READABLE.add(Interest::WRITABLE),
                    )?;

                    connections.insert(token, connection);
                },
                // 如果是其他连接的事件
                token => {
                    // 从映射中获取对应的连接 connection
                    let done = if let Some(connection) = connections.get_mut(&token) {
                        // 调用函数处理读写
                        handle_connection_event(poll.registry(), connection, event)?
                    } else {
                        false
                    };
                    if done {
                        // 及时删除无效连接
                        connections.remove(&token);
                    }
                }
            }
        }
    }
}

fn next(current: &mut Token) -> Token {
    let next = current.0;
    current.0 += 1;
    Token(next)
}

/// 如果连接结束返回 true
fn handle_connection_event(
    registry: &Registry,
    connection: &mut TcpStream,
    event: &Event,
) -> io::Result<bool> {
    if event.is_writable() {
        // 如果是可写事件,则写入 DATA
        match connection.write(DATA) {
            // 这里期望一次写完,如果没有写完会返回错误
            Ok(n) if n < DATA.len() => return Err(io::ErrorKind::WriteZero.into()),
            Ok(_) => {
                // 完整写完后,重新注册该连接,只关注可读事件
                registry.reregister(connection, event.token(), Interest::READABLE)?
            }
            // WouldBlock 表示仍没有准备好,直接跳过
            Err(ref err) if would_block(err) => {}
            // 中断则直接递归再重试一次
            Err(ref err) if interrupted(err) => {
                return handle_connection_event(registry, connection, event)
            }
            // 其他错误直接返回
            Err(err) => return Err(err),
        }
    }

    if event.is_readable() {
        // 如果是可读事件
        let mut connection_closed = false;
        let mut received_data = vec![0; 4096];
        let mut bytes_read = 0;
        // 循环读取
        loop {
            match connection.read(&mut received_data[bytes_read..]) {
                Ok(0) => {
                    // 返回 0 表示连接已关闭
                    connection_closed = true;
                    break;
                }
                Ok(n) => {
                    // 正常读取到 n 字节
                    bytes_read += n;
                    if bytes_read == received_data.len() {
                        received_data.resize(received_data.len() + 1024, 0);
                    }
                }
                // 错误处理与上方一致
                Err(ref err) if would_block(err) => break,
                Err(ref err) if interrupted(err) => continue,
                Err(err) => return Err(err),
            }
        }

      	// 打印读取到的字节流
        if bytes_read != 0 {
            let received_data = &received_data[..bytes_read];
            if let Ok(str_buf) = from_utf8(received_data) {
                println!("Received data: {}", str_buf.trim_end());
            } else {
                println!("Received (none UTF-8) data: {:?}", received_data);
            }
        }

        // 如果连接关闭,返回 true
        if connection_closed {
            println!("Connection closed");
            return Ok(true);
        }
    }

    Ok(false)
}

fn would_block(err: &io::Error) -> bool {
    err.kind() == io::ErrorKind::WouldBlock
}

fn interrupted(err: &io::Error) -> bool {
    err.kind() == io::ErrorKind::Interrupted
}

如果之前接触过 epoll,会发现调用接口和使用模式是类似的。Mio 仅在系统 API 层面上做了一致的封装,依靠 Rust 的零成本抽象并不会带来额外的开销。

2. Epoll 封装

Mio 对外的 Poll 接口是在 poll.rs 中定义的。Poll 对象中包含一个 Registry 对象,Registry 对象中包含一个 sys::Selector

pub struct Poll {
    registry: Registry,
}

/// Registers I/O resources.
pub struct Registry {
    selector: sys::Selector,
}

impl Poll {
    pub fn registry(&self) -> &Registry {
        &self.registry
    }

  	/// 等待事件
    pub fn poll(&mut self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
        self.registry.selector.select(events.sys(), timeout)
    }
}

impl Registry {
    /// 注册事件源 event::Source
    pub fn register<S>(&self, source: &mut S, token: Token, interests: Interest) -> io::Result<()>
    where
        S: event::Source + ?Sized,
    {
        trace!(
            "registering event source with poller: token={:?}, interests={:?}",
            token,
            interests
        );
        source.register(self, token, interests)
    }

    /// 注册唤醒器
    #[cfg(debug_assertions)]
    pub(crate) fn register_waker(&self) {
        if self.selector.register_waker() {
            panic!("Only a single `Waker` can be active per `Poll` instance");
        }
    }
}

这里出现了一些新的类,sys::Selector / Events / Token / Interest / event::Source。首先看 token.rs,它的实现非常简单,一个 usize 的结构体封装:

#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Token(pub usize);

Token 对象会在 register 时指定,当 event::Source 发生感兴趣的事件时,返回的事件 event 中会包含对应的 token 信息。接着看 interest.rs

#[derive(Copy, PartialEq, Eq, Clone, PartialOrd, Ord)]
pub struct Interest(NonZeroU8);

// 定义可读/可写的值
const READABLE: u8 = 0b0_001;
const WRITABLE: u8 = 0b0_010;

impl Interest {
    /// 创建可读/可写对应的 Interest 常量
    pub const READABLE: Interest = Interest(unsafe { NonZeroU8::new_unchecked(READABLE) });
    pub const WRITABLE: Interest = Interest(unsafe { NonZeroU8::new_unchecked(WRITABLE) });

    // Interest 的加法(按位或)
    #[allow(clippy::should_implement_trait)]
    pub const fn add(self, other: Interest) -> Interest {
        Interest(unsafe { NonZeroU8::new_unchecked(self.0.get() | other.0.get()) })
    }

    // Interest 的减法
    pub fn remove(self, other: Interest) -> Option<Interest> {
        NonZeroU8::new(self.0.get() & !other.0.get()).map(Interest)
    }

    // 判断是否可读/可写
    pub const fn is_readable(self) -> bool {
        (self.0.get() & READABLE) != 0
    }
    pub const fn is_writable(self) -> bool {
        (self.0.get() & WRITABLE) != 0
    }
}

接下来继续看 event.rsEvent 是对内部结构体的透明封装:

use crate::{sys, Token};

// 透明封装,保持一致的内存布局
#[repr(transparent)]
pub struct Event {
    inner: sys::Event,
}

impl Event {
    /// 返回事件对应的 Token
    pub fn token(&self) -> Token {
        sys::event::token(&self.inner)
    }

    /// 在 sys::event 之上的接口封装
    pub fn is_readable(&self) -> bool {
        sys::event::is_readable(&self.inner)
    }
    pub fn is_writable(&self) -> bool {
        sys::event::is_writable(&self.inner)
    }
    pub fn is_error(&self) -> bool {
        sys::event::is_error(&self.inner)
    }

  	...
  
    /// 将 sys::Event 转为 Event 对象
    pub(crate) fn from_sys_event_ref(sys_event: &sys::Event) -> &Event {
        unsafe {
            // 由于内存布局一致,直接强转
            &*(sys_event as *const sys::Event as *const Event)
        }
    }
}

继续看 events.rs 的实现。与 Event 类似,同样是对内部结构体的封装。这里学习一下迭代器的实现,其中 'a 是生命周期标志。

use crate::event::Event;
use crate::sys;

pub struct Events {
    inner: sys::Events,
}

#[derive(Debug, Clone)]
pub struct Iter<'a> {
    inner: &'a Events,
    pos: usize,
}

impl Events {
    /// 创建一个指定容量的事件列表
    pub fn with_capacity(capacity: usize) -> Events {
        Events {
            inner: sys::Events::with_capacity(capacity),
        }
    }

    /// 返回迭代器
    pub fn iter(&self) -> Iter<'_> {
        Iter {
            inner: self,
            pos: 0,
        }
    }

    /// 清理事件列表,可供下次 Poll 使用
    pub fn clear(&mut self) {
        self.inner.clear();
    }

    /// Returns the inner `sys::Events`.
    pub(crate) fn sys(&mut self) -> &mut sys::Events {
        &mut self.inner
    }
}

impl<'a> IntoIterator for &'a Events {
    type Item = &'a Event;
    type IntoIter = Iter<'a>;

    fn into_iter(self) -> Self::IntoIter {
        self.iter()
    }
}

impl<'a> Iterator for Iter<'a> {
    type Item = &'a Event;

    // 迭代器的下一个,就是事件列表中的下一项
    fn next(&mut self) -> Option<Self::Item> {
        let ret = self
            .inner
            .inner
            .get(self.pos)
            .map(Event::from_sys_event_ref);
        self.pos += 1;
        ret
    }
}

现在可以继续看一下内部究竟是如何实现的。首先看 Linux 系统下对 epoll 的封装 epoll.rs

#[derive(Debug)]
pub struct Selector {
    #[cfg(debug_assertions)]
    id: usize,
    ep: RawFd,
    #[cfg(debug_assertions)]
    has_waker: AtomicBool,
}

impl Selector {
    pub fn new() -> io::Result<Selector> {
        syscall!(epoll_create1(flag)).map(|ep| Selector {
            #[cfg(debug_assertions)]
            id: NEXT_ID.fetch_add(1, Ordering::Relaxed),
            ep,
            #[cfg(debug_assertions)]
            has_waker: AtomicBool::new(false),
        })
    }

    pub fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
        #[cfg(target_pointer_width = "32")]
        const MAX_SAFE_TIMEOUT: u128 = 1789569;
        #[cfg(not(target_pointer_width = "32"))]
        const MAX_SAFE_TIMEOUT: u128 = libc::c_int::max_value() as u128;

        let timeout = timeout
            .map(|to| cmp::min(to.as_millis(), MAX_SAFE_TIMEOUT) as libc::c_int)
            .unwrap_or(-1);

        events.clear();
        syscall!(epoll_wait(
            self.ep,
            events.as_mut_ptr(),
            events.capacity() as i32,
            timeout,
        ))
        .map(|n_events| {
            // This is safe because `epoll_wait` ensures that `n_events` are
            // assigned.
            unsafe { events.set_len(n_events as usize) };
        })
    }

    pub fn register(&self, fd: RawFd, token: Token, interests: Interest) -> io::Result<()> {
        // 注册时将 Interest 和 Token 转为 epoll_event 对象
        let mut event = libc::epoll_event {
            events: interests_to_epoll(interests),
            u64: usize::from(token) as u64,
        };

        syscall!(epoll_ctl(self.ep, libc::EPOLL_CTL_ADD, fd, &mut event)).map(|_| ())
    }
}

这里的 Selector 就是对 epoll 调用的封装。接着看事件的实现,先前提到的 sys.Event 直接使用了 libc::epoll_eventsys.Events 则是 Vec<sys.Event>

fn interests_to_epoll(interests: Interest) -> u32 {
    let mut kind = EPOLLET;

    if interests.is_readable() {
        kind = kind | EPOLLIN | EPOLLRDHUP;
    }

    if interests.is_writable() {
        kind |= EPOLLOUT;
    }

    kind as u32
}

pub type Event = libc::epoll_event;
pub type Events = Vec<Event>;

pub mod event {
    use std::fmt;

    use crate::sys::Event;
    use crate::Token;

    pub fn token(event: &Event) -> Token {
        Token(event.u64 as usize)
    }

    pub fn is_readable(event: &Event) -> bool {
        (event.events as libc::c_int & libc::EPOLLIN) != 0
            || (event.events as libc::c_int & libc::EPOLLPRI) != 0
    }

    pub fn is_writable(event: &Event) -> bool {
        (event.events as libc::c_int & libc::EPOLLOUT) != 0
    }

    pub fn is_error(event: &Event) -> bool {
        (event.events as libc::c_int & libc::EPOLLERR) != 0
    }
}

3. Socket 封装

Poll 的逻辑看完了,接着来看 Socket 相关的封装,这里主要看 TCP 的实现。src/net/tcp 下方按照功能划分了三个类 TcpSocket/ TcpListener/ TcpStream,首先看 socket.rs

/// 非阻塞 TCP 连接,用以创建 `TcpListener` 或者 `TcpStream`
#[derive(Debug)]
pub struct TcpSocket {
    sys: sys::tcp::TcpSocket,
}

impl TcpSocket {
    /// Create a new IPv4 TCP socket.
    pub fn new_v4() -> io::Result<TcpSocket> {
        sys::tcp::new_v4_socket().map(|sys| TcpSocket {
            sys
        })
    }

    /// Connect the socket to `addr`.
    pub fn connect(self, addr: SocketAddr) -> io::Result<TcpStream> {
        let stream = sys::tcp::connect(self.sys, addr)?;

        // Don't close the socket
        mem::forget(self);
        Ok(TcpStream::from_std(stream))
    }

    /// Listen for inbound connections, converting the socket to a
    /// `TcpListener`.
    pub fn listen(self, backlog: u32) -> io::Result<TcpListener> {
        let listener = sys::tcp::listen(self.sys, backlog)?;

        // Don't close the socket
        mem::forget(self);
        Ok(TcpListener::from_std(listener))
    }
}

impl Drop for TcpSocket {
    fn drop(&mut self) {
        sys::tcp::close(self.sys);
    }
}

这里需要注意 mem::forget 的使用,connectlisten 的内部实现中会将 self.sys 的所有权转移到 TcpStreamTpcListener 上,以保证在 TpcSocket 析构时不会关闭对应的连接。接着看 TcpListenerTcpStream 的实现:

// 对标准库中的 TcpListener 的封装
pub struct TcpListener {
    inner: IoSource<net::TcpListener>,
}

impl TcpListener {
    /// 绑定指定地址,返回 TcpListener 对象
    pub fn bind(addr: SocketAddr) -> io::Result<TcpListener> {
        let socket = TcpSocket::new_for_addr(addr)?;
        #[cfg(not(windows))]
        socket.set_reuseaddr(true)?;
        socket.bind(addr)?;
        socket.listen(1024)
    }

    /// 将 net::TcpListener 封装为 TcpListener
    pub fn from_std(listener: net::TcpListener) -> TcpListener {
        TcpListener {
            inner: IoSource::new(listener),
        }
    }

    // 接收来自 Client 的连接,返回 TcpStream 对象
    pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
        self.inner.do_io(|inner| {
            sys::tcp::accept(inner).map(|(stream, addr)| (TcpStream::from_std(stream), addr))
        })
    }
}

// 实现 event::Source 接口
impl event::Source for TcpListener {
    fn register(
        &mut self,
        registry: &Registry,
        token: Token,
        interests: Interest,
    ) -> io::Result<()> {
        self.inner.register(registry, token, interests)
    }

    fn reregister(
        &mut self,
        registry: &Registry,
        token: Token,
        interests: Interest,
    ) -> io::Result<()> {
        self.inner.reregister(registry, token, interests)
    }

    fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
        self.inner.deregister(registry)
    }
}

#[cfg(unix)]
impl FromRawFd for TcpListener {
    // 从 fd 转为 net::TcpListener 再转为 TcpListener
    unsafe fn from_raw_fd(fd: RawFd) -> TcpListener {
        TcpListener::from_std(FromRawFd::from_raw_fd(fd))
    }
}


/// A non-blocking TCP stream between a local socket and a remote socket.
pub struct TcpStream {
    inner: IoSource<net::TcpStream>,
}

impl TcpStream {
    /// 连接指定地址,返回 TcpStream 对象
    pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
        let socket = TcpSocket::new_for_addr(addr)?;
        socket.connect(addr)
    }

  	/// 将标准库中的 net::TcpStream 封装为 TcpStream
  	pub fn from_std(stream: net::TcpStream) -> TcpStream {
        TcpStream {
            inner: IoSource::new(stream),
        }
    }
}

/// 实现读取接口
impl Read for TcpStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.inner.do_io(|inner| (&*inner).read(buf))
    }

    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
        self.inner.do_io(|inner| (&*inner).read_vectored(bufs))
    }
}

...

IoSourceevent::Source 的定义位于 io_source.rs

pub struct IoSource<T> {
    state: IoSourceState,
    inner: T,
    #[cfg(debug_assertions)]
    selector_id: SelectorId,
}

impl<T> IoSource<T> {
    /// Create a new `IoSource`.
    pub fn new(io: T) -> IoSource<T> {
        IoSource {
            state: IoSourceState::new(),
            inner: io,
            #[cfg(debug_assertions)]
            selector_id: SelectorId::new(),
        }
    }

    /// Windows 下需要多做一次转换,其他系统等价 f(&self.inner)
    pub fn do_io<F, R>(&self, f: F) -> io::Result<R>
    where
        F: FnOnce(&T) -> io::Result<R>,
    {
        self.state.do_io(f, &self.inner)
    }

    /// Returns the I/O source, dropping the state.
    pub fn into_inner(self) -> T {
        self.inner
    }
}

#[cfg(unix)]
impl<T> event::Source for IoSource<T>
where
    T: AsRawFd,
{
    fn register(
        &mut self,
        registry: &Registry,
        token: Token,
        interests: Interest,
    ) -> io::Result<()> {
        #[cfg(debug_assertions)]
        self.selector_id.associate(registry)?;
        // 通过 RawFd 注册
        poll::selector(registry).register(self.inner.as_raw_fd(), token, interests)
    }
    ...
}

Linux TCP 网络的具体实现放在 src/sys/unixtcp.rs,可以看到是对 Socket API 的简单封装:

pub type TcpSocket = libc::c_int;

pub(crate) fn bind(socket: TcpSocket, addr: SocketAddr) -> io::Result<()> {
    let (raw_addr, raw_addr_length) = socket_addr(&addr);
    syscall!(bind(socket, raw_addr, raw_addr_length))?;
    Ok(())
}

pub(crate) fn connect(socket: TcpSocket, addr: SocketAddr) -> io::Result<net::TcpStream> {
    let (raw_addr, raw_addr_length) = socket_addr(&addr);

    match syscall!(connect(socket, raw_addr, raw_addr_length)) {
        Err(err) if err.raw_os_error() != Some(libc::EINPROGRESS) => {
            Err(err)
        }
        _ => {
            Ok(unsafe { net::TcpStream::from_raw_fd(socket) })
        }
    }
}

pub(crate) fn listen(socket: TcpSocket, backlog: u32) -> io::Result<net::TcpListener> {
    use std::convert::TryInto;

    let backlog = backlog.try_into().unwrap_or(i32::max_value());
    syscall!(listen(socket, backlog))?;
    Ok(unsafe { net::TcpListener::from_raw_fd(socket) })
}

pub fn accept(listener: &net::TcpListener) -> io::Result<(net::TcpStream, SocketAddr)> {
    let mut addr: MaybeUninit<libc::sockaddr_storage> = MaybeUninit::uninit();
    let mut length = size_of::<libc::sockaddr_storage>() as libc::socklen_t;

    let stream = {
        syscall!(accept4(
            listener.as_raw_fd(),
            addr.as_mut_ptr() as *mut _,
            &mut length,
            libc::SOCK_CLOEXEC | libc::SOCK_NONBLOCK,
        ))
        .map(|socket| unsafe { net::TcpStream::from_raw_fd(socket) })
    }?;

    // This is safe because `accept` calls above ensures the address
    // initialised.
    unsafe { to_socket_addr(addr.as_ptr()) }.map(|addr| (stream, addr))
}

4. 总结

最近工作比较忙,很久没看代码、写博客,以至于朋友都来吐槽我不“勤劳”。Mio 的封装实现还是比较简单的,比较适合笔者这样的 Rust 初学者来学习。希望接下来把 Tokio 分模块学习完,过年的时候用 C++ 20 或者 Rust 实现一套 1:N 的协程运行时。