Rust RPC Async 接口探索

2024.06.17
SF-Zhou

最近在做 Rust RDMA 的通信库,RDMA 建立通信时需要两端交换信息,常规的方式是依赖 Socket 或者 RMDA CM。我打算用前者,另外因为 RDMA 通信库最终暴露给用户的接口类似于 RPC 框架,所以打算自己实现一套 Rust RPC ,同时支持 TCP 和 RDMA。在设计 RPC 接口时遇到了一些问题,这里整理记录下来。

1. 背景

团队原先项目使用的技术栈是 C++ 和 Folly Coroutines,我之前在上面搭建了一套 RPC 框架,有如下特性:

  1. 基于 C++ 反射的二进制序列化
  2. 基于宏的 RPC 接口定义,接口协程化
  3. 单个 Server 支持多个 Service

对于新的这套 RPC 框架,我仍然希望保留这些特性,使用上大概的设想是:

// 1. 定义 service,依赖过程宏生成辅助代码
#[rpc::service]
pub trait Demo {
  async fn echo(&self, req: String) -> Result<String>;
  async fn plus_one(&self, req: u32) -> Result<u32>;
}

// 2. 实现 service
struct DemoImpl { .. }
impl DemoService for DemoImpl {
  async fn echo(&self, _ctx: Context, req: String) -> Result<String> {
    Ok(req)
  }

  async fn plus_one(&self, _ctx: Context, req: u32) -> Result<u32> {
    Ok(req + 1)
  }
}

// 3. server 加入 service
let server = Server::new();
server.add_service(DemoImpl::new());

// 4. client 访问
let client_ctx = ClientContext::new();
let client = DemoClient::new(client_ctx);
let result = client.echo("hello".to_string());

async fn 在 2023 年底已经可以在 trait 中直接使用了,参考文献 1,想象中一切都是很美好的。

2. 探索

定义 Service 接口时,我们需要依赖过程宏生成 server 端 dispatch 的代码,即根据请求数据判断具体调用哪个接口。假设我们的请求数据是一段字节流 bytes,该段字节流存储了接口名的字符串以及请求 req 的序列化结果,我们需要生成的代码应该类似于:

use derse::{DownwardBytes, Serialization};
use std::io::Result;

#[allow(async_fn_in_trait)]
pub trait Demo {
    async fn echo(&self, req: String) -> Result<String>;
    async fn plus_one(&self, req: u32) -> Result<u32>;

    // generated by proc macro.
    async fn invoke(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
        let mut buf = bytes.as_ref();
        let name = String::deserialize_from(&mut buf).unwrap();
        match name.as_str() {
            "echo" => {
                let req = Serialization::deserialize_from(&mut buf).unwrap();
                let rsp = self.echo(req).await?;
                let out = rsp.serialize::<DownwardBytes>().unwrap();
                Ok(Vec::from(out.as_slice()))
            }
            "plus_one" => {
                let req = Serialization::deserialize_from(&mut buf).unwrap();
                let rsp = self.plus_one(req).await?;
                let out = rsp.serialize::<DownwardBytes>().unwrap();
                Ok(Vec::from(out.as_slice()))
            }
            _ => panic!("method name is invalid"),
        }
    }
}

struct DemoImpl;
impl Demo for DemoImpl {
    async fn echo(&self, req: String) -> Result<String> {
        Ok(req)
    }

    async fn plus_one(&self, req: u32) -> Result<u32> {
        Ok(req + 1)
    }
}

#[tokio::main]
async fn main() {
    // 1. invoke directly.
    let service = DemoImpl;
    let result = service.echo("hello".to_string()).await;
    assert_eq!(result.unwrap(), "hello");

    // 2. invoke by bytes.
    let mut bytes = "hello".serialize::<DownwardBytes>().unwrap();
    "echo".serialize_to(&mut bytes).unwrap();
    let result = service.invoke(Vec::from(bytes.as_slice())).await.unwrap();
    let string = <&str>::deserialize(result.as_slice()).unwrap();
    assert_eq!(string, "hello");
}

derse 是我自己做的一套简单的二进制序列化工具。上面的错误处理比较粗犷,看个意思就行。核心内容是生成一个统一的接口函数 async fn invoke(&self, req: Vec<u8>) -> Result<Vec<u8>>。接下来只需要将 service impl 对象加入 server 中,server 收到新的二进制消息时调用 impl.invoke(bytes) 方法就可以。

如何存储该对象呢?直觉的做法是定义一个 Service 的 trait,包含 invoke 方法,然后使用过程宏为接口类 impl Service trait,最后保存 Box<dyn Service> 对象。大概如下:

// 1. RPC 框架内定义 Service trait
#[allow(async_fn_in_trait)]
pub trait Service {
    async fn invoke(&self, bytes: Vec<u8>) -> Result<Vec<u8>>;
}

// 2. 过程宏生成桥接代码
impl<T: Demo> Service for T {
    async fn invoke(&self, bytes: Vec<u8>) -> Result<Vec<u8>> {
        Demo::invoke(self, bytes).await
    }
}

// 3. 保存 dyn Service 对象
let service: Box<dyn Service> = Box::new(DemoImpl);

但实际上这样做有两个问题:

  1. impl Service for T 违反了孤儿原则,参考文献 2。这个倒是可以使用文中提到的 NewType Pattern 来解决;
  2. 无法转为 dyn Service 对象,会提示 "the trait Service cannot be made into an object consider moving invoke to another trait"

翻阅文档可以得知,目前包含 async fn 或者返回 impl Trait 的 trait 无法转为 trait object。该特性称之为 Dyn async trait,目前官方还在做。原因是什么呢?本质上 async fn 是一个语法糖:

async fn foo() -> i32 {
    42
}

// 去糖后类似于如下代码
use std::pin::Pin;
use std::task::{Context, Poll};

// 定义一个匿名类型来表示状态机
struct FooFuture {
    state: i32,
}

impl std::future::Future for FooFuture {
    type Output = i32;

    fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
        // 状态机的具体实现
        Poll::Ready(42)
    }
}

// 返回一个 Future
fn foo() -> impl std::future::Future<Output = i32> {
    FooFuture { state: 0 }
}
  1. FooFuture 结构体是由编译器生成的状态机,管理异步操作的状态和调度。
  2. impl Future for FooFuture 为状态机实现了 Future 特征。
  3. foo 函数返回一个 impl Future<Output = i32>,即 FooFuture 实例。

因为返回的对象类型是不确定的,所以 dyn Service 无法确定返回的类型,也就无法直接转为 dyn Service 对象了。但我们可以在定义 Service trait 时就进行返回类型的去糖:

use derse::{DownwardBytes, Serialization};
use std::{io::Result, pin::Pin, sync::Arc};

#[allow(async_fn_in_trait)]
pub trait Demo {
    async fn echo(self: Arc<Self>, req: String) -> Result<String>;
    async fn plus_one(self: Arc<Self>, req: u32) -> Result<u32>;

    // generated by proc macro.
    async fn invoke(self: Arc<Self>, bytes: Vec<u8>) -> Result<Vec<u8>> {
        let mut buf = bytes.as_ref();
        let name = String::deserialize_from(&mut buf).unwrap();
        match name.as_str() {
            "echo" => {
                let req = Serialization::deserialize_from(&mut buf).unwrap();
                let rsp = self.echo(req).await?;
                let out = rsp.serialize::<DownwardBytes>().unwrap();
                Ok(Vec::from(out.as_slice()))
            }
            "plus_one" => {
                let req = Serialization::deserialize_from(&mut buf).unwrap();
                let rsp = self.plus_one(req).await?;
                let out = rsp.serialize::<DownwardBytes>().unwrap();
                Ok(Vec::from(out.as_slice()))
            }
            _ => panic!("parse error"),
        }
    }
}

#[allow(async_fn_in_trait)]
pub trait Service {
    fn invoke(
        self: Arc<Self>,
        bytes: Vec<u8>,
    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>>>>;
}

impl<T: Demo + 'static> Service for T {
    fn invoke(
        self: Arc<Self>,
        bytes: Vec<u8>,
    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>>>> {
        Box::pin(Demo::invoke(self, bytes))
    }
}

struct DemoImpl;
impl Demo for DemoImpl {
    async fn echo(self: Arc<Self>, req: String) -> Result<String> {
        Ok(req)
    }

    async fn plus_one(self: Arc<Self>, req: u32) -> Result<u32> {
        Ok(req + 1)
    }
}

#[tokio::main]
async fn main() {
    let service: Arc<dyn Service> = Arc::new(DemoImpl);
    let mut bytes = "hello".serialize::<DownwardBytes>().unwrap();
    "echo".serialize_to(&mut bytes).unwrap();
    let result = service
        .clone()
        .invoke(Vec::from(bytes.as_slice()))
        .await
        .unwrap();
    let string = <&str>::deserialize(result.as_slice()).unwrap();
    assert_eq!(string, "hello");

    // let result = tokio::spawn(service.clone().invoke(Vec::from(bytes.as_slice())));
}

将返回类型定义为 Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>>>>,上述代码就可以跑起来了,一切看起来都美好起来了,是吗?

当我们尝试 tokio::spawn(service.clone().invoke(Vec::from(bytes.as_slice()))) 时,会提示:"dyn Future<Output = Result<Vec<u8>, std::io::Error>> cannot be sent between threads safely, the trait Send is not implemented for dyn Future<Output = Result<Vec<u8>, std::io::Error>>"。所以我们尝试给 Service::invoke 的返回值加上 Send 约束:

#[allow(async_fn_in_trait)]
pub trait Service {
    fn invoke(
        self: Arc<Self>,
        bytes: Vec<u8>,
    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>;
}

impl<T: Demo + 'static> Service for T {
    fn invoke(
        self: Arc<Self>,
        bytes: Vec<u8>,
    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>> {
        Box::pin(Demo::invoke(self, bytes))
    }
}

// `impl Future<Output = Result<Vec<u8>, std::io::Error>>` cannot be sent between threads safely.
// the trait `Send` is not implemented for `impl Future<Output = Result<Vec<u8>, std::io::Error>>` required for the cast from `Pin<Box<impl Future<Output = Result<Vec<u8>, std::io::Error>>>>` to `Pin<Box<(dyn Future<Output = Result<Vec<u8>, std::io::Error>> + Send + 'static)>>`

提示 Demo::invoke 返回的 Future 没有实现 Send。async fn 是否 Send 取决于实现它的实现,实现 invoke 函数时还无法知道 Demo::echoDemo::plus_one 是否是 Send 的,参考文献 3。所以我们还需要在定义 async fn 接口时给它加上 Send 约束,另外使用闭包进行类型擦除:

use derse::{DownwardBytes, Serialization};
use std::{io::Result, pin::Pin, sync::Arc};

#[allow(async_fn_in_trait)]
pub trait Demo {
    fn echo(
        self: Arc<Self>,
        req: String,
    ) -> impl std::future::Future<Output = Result<String>> + Send;
    fn plus_one(self: Arc<Self>, req: u32)
        -> impl std::future::Future<Output = Result<u32>> + Send;

    // generated by proc macro.
    fn export(
        self: Arc<Self>,
    ) -> impl Fn(Vec<u8>) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<u8>>> + Send>>
    where
        Self: Send + Sync + 'static,
    {
        move |bytes| {
            let clone = self.clone();
            Box::pin(async move {
                let mut buf = bytes.as_ref();
                let name = String::deserialize_from(&mut buf).unwrap();
                match name.as_str() {
                    "echo" => {
                        let req = Serialization::deserialize_from(&mut buf).unwrap();
                        let rsp = clone.echo(req).await?;
                        let out = rsp.serialize::<DownwardBytes>().unwrap();
                        Ok(Vec::from(out.as_slice()))
                    }
                    "plus_one" => {
                        let req = Serialization::deserialize_from(&mut buf).unwrap();
                        let rsp = clone.plus_one(req).await?;
                        let out = rsp.serialize::<DownwardBytes>().unwrap();
                        Ok(Vec::from(out.as_slice()))
                    }
                    _ => panic!("parse error"),
                }
            })
        }
    }
}

struct DemoImpl;
impl Demo for DemoImpl {
    async fn echo(self: Arc<Self>, req: String) -> Result<String> {
        Ok(req)
    }

    async fn plus_one(self: Arc<Self>, req: u32) -> Result<u32> {
        Ok(req + 1)
    }
}

#[tokio::main]
async fn main() {
    let service = Arc::new(DemoImpl).export();

    let mut bytes = "hello".serialize::<DownwardBytes>().unwrap();
    "echo".serialize_to(&mut bytes).unwrap();
    let result = tokio::spawn(service(Vec::from(bytes.as_slice())))
        .await
        .unwrap()
        .unwrap();
    let string = <&str>::deserialize(result.as_slice()).unwrap();
    assert_eq!(string, "hello");

    let mut bytes = 233i32.serialize::<DownwardBytes>().unwrap();
    "plus_one".serialize_to(&mut bytes).unwrap();
    let result = tokio::spawn(service(Vec::from(bytes.as_slice())))
        .await
        .unwrap()
        .unwrap();
    let string = i32::deserialize(result.as_slice()).unwrap();
    assert_eq!(string, 234);
}

这样就基本实现了我们需要的特性了,service 也被抽象为统一的函数闭包类型,可以方便地进行存储,我称之为当前的版本答案:

pub type Service = Box<dyn Fn(Vec<u8>) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>>;

References

  1. Announcing async fn and return-position impl Trait in traits
  2. Coherence and Orphan Rules in Rust
  3. How to force an async function return type to be Send?