[C++] 基于 C++20 协程编写 gRPC 客户端与服务端

基于 C++20 协程编写 gRPC 客户端与服务端

完整代码

gRPCC++ 异步接口十分不友好,尤其是对于需要支持并发的服务端来说,官方的例子是实现一个小型状态机进行请求处理。是否可以将 gRPC 与 C++20 的协程结合,编写出简单易懂的代码呢?

实现细节请阅读博客,在此仅展示最终代码。

协议:

syntax = "proto3";

package sample;

service SampleService {
  // Unary RPC:最简单的请求-响应模式
  rpc Echo(EchoRequest) returns (EchoResponse);

  // 服务端流式:服务端持续向客户端推送数据
  rpc GetNumbers(GetNumbersRequest) returns (stream Number);

  // 客户端流式:客户端持续向服务端发送数据,服务端返回一个结果
  rpc Sum(stream Number) returns (SumResponse);

  // 双向流式:双方都可以持续发送和接收数据
  rpc Chat(stream ChatMessage) returns (stream ChatMessage);
}

message EchoRequest  { string message = 1; }
message EchoResponse { string message = 1; int64 timestamp = 2; }
message GetNumbersRequest { int32 value = 1; int32 count = 2; }
message Number       { int32 value = 1; }
message SumResponse  { int32 total = 1; int32 count = 2; }
message ChatMessage  { string user = 1; string content = 2; int64 timestamp = 3; }

客户端:

class Client final : public GenericClient<sample::SampleService> {
public:
    using GenericClient::GenericClient;

    static Client make(const std::string &address) {
        return Client{sample::SampleService::NewStub(grpc::CreateChannel(address, grpc::InsecureChannelCredentials()))};
    }

    asyncio::task::Task<sample::EchoResponse>
    echo(
        sample::EchoRequest request,
        std::unique_ptr<grpc::ClientContext> context = std::make_unique<grpc::ClientContext>()
    ) {
        co_return co_await call(&sample::SampleService::Stub::async::Echo, std::move(context), std::move(request));
    }

    asyncio::task::Task<void>
    getNumbers(
        sample::GetNumbersRequest request,
        asyncio::Sender<sample::Number> sender,
        std::unique_ptr<grpc::ClientContext> context = std::make_unique<grpc::ClientContext>()
    ) {
        co_await call(
            &sample::SampleService::Stub::async::GetNumbers,
            std::move(context),
            std::move(request),
            std::move(sender)
        );
    }

    asyncio::task::Task<sample::SumResponse> sum(
        asyncio::Receiver<sample::Number> receiver,
        std::unique_ptr<grpc::ClientContext> context = std::make_unique<grpc::ClientContext>()
    ) {
        co_return co_await call(&sample::SampleService::Stub::async::Sum, std::move(context), std::move(receiver));
    }

    asyncio::task::Task<void>
    chat(
        asyncio::Receiver<sample::ChatMessage> receiver,
        asyncio::Sender<sample::ChatMessage> sender,
        std::unique_ptr<grpc::ClientContext> context = std::make_unique<grpc::ClientContext>()
    ) {
        co_return co_await call(
            &sample::SampleService::Stub::async::Chat,
            std::move(context),
            std::move(receiver),
            std::move(sender)
        );
    }
};

asyncio::task::Task<void> asyncMain(const int argc, char *argv[]) {
    auto client = Client::make("localhost:50051");

    co_await all(
        // Unary RPC
        asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
            sample::EchoRequest req;
            req.set_message("Hello gRPC!");
            const auto resp = co_await client.echo(req);
            fmt::print("Echo: {}\n", resp.message());
        }),

        // 服务端流 + 客户端流,用 channel 串联
        asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
            sample::GetNumbersRequest req;
            req.set_value(1);
            req.set_count(5);

            auto [sender, receiver] = asyncio::channel<sample::Number>();

            const auto result = co_await all(
                client.getNumbers(req, std::move(sender)),
                client.sum(std::move(receiver))
            );

            const auto &resp = std::get<sample::SumResponse>(result);
            fmt::print("Sum: {}, count: {}\n", resp.total(), resp.count());
        }),

        // 双向流
        asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
            auto [inSender, inReceiver] = asyncio::channel<sample::ChatMessage>();
            auto [outSender, outReceiver] = asyncio::channel<sample::ChatMessage>();

            co_await all(
                client.chat(std::move(outReceiver), std::move(inSender)),
                asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
                    sample::ChatMessage msg;
                    msg.set_content("Hello server!");
                    co_await asyncio::error::guard(outSender.send(std::move(msg)));
                    outSender.close();
                }),
                asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
                    const auto msg = co_await asyncio::error::guard(inReceiver.receive());
                    fmt::print("Chat reply: {}\n", msg.content());
                })
            );
        })
    );
}

服务端:

class Server final : public GenericServer<sample::SampleService> {
public:
    using GenericServer::GenericServer;

    static Server make(const std::string &address) {
        auto service = std::make_unique<sample::SampleService::AsyncService>();

        grpc::ServerBuilder builder;

        builder.AddListeningPort(address, grpc::InsecureServerCredentials());
        builder.RegisterService(service.get());

        auto completionQueue = builder.AddCompletionQueue();
        auto server = builder.BuildAndStart();

        return {std::move(server), std::move(service), std::move(completionQueue)};
    }

private:
    // Unary:直接返回 Response ,错误自动转换为 gRPC 错误状态
    static asyncio::task::Task<sample::EchoResponse> echo(sample::EchoRequest request) {
        sample::EchoResponse response;
        response.set_message(request.message());
        response.set_timestamp(std::time(nullptr));
        co_return response;
    }

    // 服务端流:接受 Writer ,逐个写入
    static asyncio::task::Task<void>
    getNumbers(sample::GetNumbersRequest request, Writer<sample::Number> writer) {
        for (int i = 0; i < request.count(); ++i) {
            sample::Number number;
            number.set_value(request.value() + i);
            co_await writer.write(number);
        }
    }

    // 客户端流:接受 Reader ,读取并聚合
    static asyncio::task::Task<sample::SumResponse> sum(Reader<sample::Number> reader) {
        int total{0}, count{0};
        while (const auto number = co_await reader.read()) {
            total += number->value();
            ++count;
        }
        sample::SumResponse response;
        response.set_total(total);
        response.set_count(count);
        co_return response;
    }

    // 双向流:读一条,回一条
    static asyncio::task::Task<void> chat(Stream<sample::ChatMessage, sample::ChatMessage> stream) {
        while (const auto message = co_await stream.read()) {
            sample::ChatMessage response;
            response.set_user("Server");
            response.set_timestamp(std::time(nullptr));
            response.set_content(fmt::format("Echo: {}", message->content()));
            co_await stream.write(response);
        }
    }

    // 将方法指针和 handler 绑定,启动各 RPC 的监听循环
    asyncio::task::Task<void> dispatch() override {
        co_await all(
            handle(&sample::SampleService::AsyncService::RequestEcho, echo),
            handle(&sample::SampleService::AsyncService::RequestGetNumbers, getNumbers),
            handle(&sample::SampleService::AsyncService::RequestSum, sum),
            handle(&sample::SampleService::AsyncService::RequestChat, chat)
        );
    }
};

asyncio::task::Task<void> asyncMain(const int argc, char *argv[]) {
    auto server = Server::make("0.0.0.0:50051");
    auto signal = asyncio::Signal::make();

    co_await race(
        asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
            asyncio::sync::Event event;

            co_await asyncio::task::Cancellable{
                all(
                    server.run(),
                    asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
                        co_await asyncio::error::guard(event.wait());
                        co_await server.shutdown(); // 通知 gRPC 服务器关闭
                    })
                ),
                [&]() -> std::expected<void, std::error_code> {
                    event.set(); // 触发 shutdown 流程
                    return {};
                }
            };
        }),
        asyncio::task::spawn([&]() -> asyncio::task::Task<void> {
            co_await asyncio::error::guard(signal.on(SIGINT));
        })
    );
}

正常运行:

正常运行截图

连接失败:

连接失败截图

错误信息友好,包含原始的 gRPC 错误信息,以及协程调用栈。