使用bRPC和ONNX Runtime把Bert模型服务化

2022-11-20 00:15:58 浏览数 (2)

在上文《实践演练Pytorch Bert模型转ONNX模型及预测》中,我们将Bert的Pytorch模型转换成ONNX模型,并使用onnxruntime-gpu完成了python版的ONNX模型预测。今天我们来把预测搬到C 上,模拟一下模型的部署。

对于C 版本模型预测服务,只需要按部就班完成如下三步即可:

  1. 完成C 版本的中文切词,并向量化的过程
  2. 使用ONNX Runtime的C API,完成模型的推理预测过程
  3. 编写一个简单的bRPC服务,把前面两步集成进去即可。

C 中文文本向量化

FullTokenizer

所谓“他山之石,可以攻玉”。把中文文本向量化,这一步基本是琐细但是模式化的操作,网上肯定有现成的代码,比如这个Github gist:https://gist.github.com/luistung/ace4888cf5fd1bad07844021cb2c7ecf

我们不必把过多精力放到这个实现的源码上,这多少有点本末倒置。我们主要关注它如何使用就可以,看一下main函数:

代码语言:c 复制
int main() {
    auto tokenizer = FullTokenizer("data/chinese_L-12_H-768_A-12/vocab.txt");
    std::string line;
    while (std::getline(std::cin, line)) {
        auto tokens = tokenizer.tokenize(line);
        auto ids = tokenizer.convertTokensToIds(tokens);
        std::cout << "#" << convertFromUnicode(boost::join(tokens, L" ")) << "#" << "t";
        for (size_t i = 0; i < ids.size(); i  ) {
            if (i!=0) std::cout << " ";
            std::cout << ids[i];
        }
        std::cout << std::endl;
    }
    return 0;
}

这是一个可以循环从标准输入获取文本,然后转换成向量输出的程序。我们其实只需要关注如下API即可。

构造函数

代码语言:c 复制
    auto tokenizer = FullTokenizer("data/chinese_L-12_H-768_A-12/vocab.txt");

需要传入一个词汇文件(vocab)作为输入。我们那个Bert项目,也是有自己的词汇文件的。

将文本切词成token

代码语言:c 复制
    auto tokens = tokenizer.tokenize(line);

把token向量化

代码语言:c 复制
    auto ids = tokenizer.convertTokensToIds(tokens);

安装依赖

这是一个完整的可运行的代码片段,编译并运行它需要两个依赖:boost和utf8proc

boost

boost对于C 程序员来说应该都不陌生,这里有个建议,安装boost的时候,不要下载Github上的boost项目的release包,因为里面缺乏submodule。直接去boost的官网 https://www.boost.org/ 下载。

编译方法:

代码语言:shell复制
./bootstrap.sh # 生成可执行文件b2
./b2 headers # 生成一个boost目录,可以复制到其他地方,都是header-only的库

utf8proc

这是一个处理UTF-8字符的C语言库,在Github上:https://github.com/JuliaStrings/utf8proc

编译采用cmake,比较简单:

代码语言:shell复制
mkdir build
cmake -S . -B build
cmake --build build

测试

C 版本向量化

现在我们使用我们Bert项目中词汇文件,来初始化一个FulTokenizer,看看它的向量化结果是否符合预期:

代码语言:c 复制
    auto tokenizer = FullTokenizer("/home/guodongxiaren/vocab.txt");
    auto tokens = tokenizer.tokenize("李稻葵:过去2年抗疫为每人增寿10天");
    auto ids = tokenizer.convertTokensToIds(tokens);
    for (size_t i = 0; i < ids.size(); i  ) {
        if (i != 0) std::cout << " ";
        std::cout << ids[i];
    }
    std::cout<<std::endl;

输出:

代码语言:txt复制
3330 4940 5878 131 6814 1343 123 2399 2834 4554 711 3680 782 1872 2195 8108 1921

Python版本向量化

我们使用之前文章中提到的python脚本中的向量化函数,来对同一段文本进行一下向量化。看一下结果:

代码语言:txt复制
101 3330 4940 5878 131 6814 1343 123 2399 2834 4554 711 3680 782 1872 2195 8108 1921

对比

可以看出python版本第一位多一个101,后面的数字基本相同。这个101就是Bert模型中[CLS]标记对应的向量化后的数字。别忘了,python版本,都是要拼接这个前缀的:

代码语言:python代码运行次数:0复制
    token = config.tokenizer.tokenize(text)
    token = ['[CLS]']   token
    mask = []
    token_ids = config.tokenizer.convert_tokens_to_ids(token)

[CLS]标记在Bert模型中表示的是这是一个分类(Classify)任务。因为Bert模型除了分类,还能执行其他任务。

当然如果你把[CLS]传入我们C 版本的向量化函数,结果是不符合预期,比如输入改成:

代码语言:C 复制
    auto tokens = tokenizer.tokenize("李稻葵:过去2年抗疫为每人增寿10天");
代码语言:txt复制
138 12847 8118 140 3330 4940 5878 131 6814 1343 123 2399 2834 4554 711 3680 782 1872 2195 8108 1921

结果是前面多了4个数字,并不会只是添加一个101,因为[CLS]是在Bert模型中的tokenizer会特殊处理的字符串,和普通文本的向量化方式不同。

解决方案是我们只需要可以把C 向量化结果,手工拼上一个101就可以了。

ONNX Runtime C

ONNX Runtime(以下简称ORT)的C 版本API文档:https://onnxruntime.ai/docs/api/c/namespace_ort.html

Ort::Session初始化

Ort::Session对应ORT的python API中 InferenceSession。

Ort::Session的构造函数有多个重载版本,最常用的是:

代码语言:C 复制
Ort::Session::Session(Env& env,               
                      const char * model_path,      // ONNX模型的路径
                      const SessionOptions & options
                      )

比如可以这构造Session:

代码语言:C 复制
    Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
    Ort::SessionOptions session_options;

    OrtCUDAProviderOptions cuda_options; 
    session_options.AppendExecutionProvider_CUDA(cuda_options);
    
    const char* model_path = "/home/guodongxiaren/model.onnx";

    Ort::Session session(env, model_path, session_options);

session_options是用于设置一些额外的配置参数,比如上面设置成CUDA执行。它还有其他的设置,这里不展开。我们只需要实现一个最简单代码即可。

Ort::Value的构建

Ort::Value是模型输入的类型,也就是ORT C API中表示Tensor(张量)的类型。它通过CreateTensor函数构造,CreateTensor函数也有多个重载,这里介绍其中一个。

CreateTensor() API

代码语言:c 复制
template<typename T >
Value Ort::Value::CreateTensor(const OrtMemoryInfo* info,
                               T*                   p_data,
		               size_t               p_data_element_count,
			       const int64_t*       shape,
                               size_t               shape_len
)

函数参数

描述

info

p_data指向数据存储的内存类型,CPU或者GPU

p_data

核心数据(的地址)

p_data_element_count

p_data指向数据的字节数

shape

p_data形状(的地址)

shape_len

shape参数的维度

模板参数T

模板参数T表示Tensor中数据类型。对于我们这里就是int64_t

p_data 与 p_data_element_count

p_data表示的就是核心的数据,是一段连续存储,可以使用vector来存储,通过data()函数获取其数据的指针。

p_data_element_count 表示的就是这段连续存储中有多少个元素了。

shape 与 shape_len

shape参数用来表示Tensor的形状。因为不管数学意义上的Tensor的形状如何,在ORT C API中p_data都是使用一度连续存储的空间表示,不会像python中一样套上层层的括号表达维度。比如数学意义上的一个2维矩阵:[[1,2,3],[4,5,6]],在这里只需要传入{1,2,3,4,5,6} 然后通过shape参数:{2, 3}表示这是2*3的矩阵。

通过上一篇文章,我们知道我们模型的输入是ids和mask两个Tensor,每个形状都是一个1*32。所以可以这样表示这个shape:

代码语言:c 复制
std::vector<int64_t> shape = {1, 32};

shape.data()即可以获得一个int64_t*的指针,因为我们这里维度是固定的,所以直接用int64的数组也以。

shape_len表示的就是shape中有几个元素(shape的维度),即2。

info

表示的是p_data是存储在CPU还是GPU(CUDA)上。这里我们用CPU来存储输入的Tensor数据即可,因为代码会比较简练:

代码语言:c 复制
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);

如果是GPU存储,则需要调用CUDA的API,稍微繁琐一点。

即使这里是CPU也不影响我们模型在GPU上跑推理的。

合并Tensor

假设我们已经得到了存储模型输入参数ids和mask向量的两个vector对象:input_tensor_values和mask_tensor_values,我们可以先这样获得表示各自Tensor的Ort::Value对象:

代码语言:c 复制
Ort::Value input_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, input_tensor_values.data(), input_tensor_values.size(), input_node_dims.data(), 2);

Ort::Value mask_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, mask_tensor_values.data(), mask_tensor_values.size(), input_node_dims.data(), 2);

接下来,将两个Tensor合并成一个。

代码语言:txt复制
std::vector<Ort::Value> ort_inputs;
ort_inputs.push_back(std::move(input_tensor));
ort_inputs.push_back(std::move(mask_tensor));

Ort::Session::Run() 与推理预测

Session的Run函数就是执行模型推理的过程。

参数梗概

参数如下:

代码语言:c 复制
std::vector<Value> Ort::Session::Run(const RunOptions& 	run_options,
			          const char* const* 	input_names,
			          const Value* 		input_values,
			          size_t 		input_count,
			          const char* const* 	output_names,
			          size_t 		output_count 
)	

参数

描述

run_options

可忽略

input_names

模型输入的名称

input_values

模型输入

input_count

输入的个数

output_names

输出的名称

output_count

输出的个数

调用示例

代码语言:c 复制
std::vector<const char*> input_node_names = {"ids", "mask"};
std::vector<const char*> output_node_names = {"output"};
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, 
                                  input_node_names.data(), 
                                  ort_inputs.data(),
                                  ort_inputs.size(), 
                                  output_node_names.data(), 
                                  1);

Run()的返回值是std::vector<Ort::Value>类型,因为模型可能有多输出,所以是vector表示,但是对于我们的模型来说它的输出只有一个Tensor,所以返回值outout_tensors的size必为1。不放心的话,也可以额外检查一下。

所以outout_tensors0就是输出向量了,和python一样,表示的是输入Tensor对于每种分类下的概率,我们选取概率最高的那个,就表示最终预测的分类结果了。

自定义argmax

C 本身没有argmax的函数,但是利用STL,很容易写出一个:

代码语言:c 复制
template <typename T>
int argmax(T begin, T end) {
    return std::distance(begin, std::max_element(begin, end));
}

最终结果

代码语言:c 复制
    const float* output = output_tensors[0].GetTensorData<float>();

    return argmax(output, output 10);封装Model类有了前面的铺垫,我们把文本向量化和ORT的预测功能整合成一个Model类,提供一个更简单便捷的使用方式。类声明

class Model {

public:

代码语言:txt复制
Model(const std::string& model_path, const std::string& vocab_path);
代码语言:txt复制
~Model() {delete tokenizer_; delete ses_;}
代码语言:txt复制
// 执行文本预测,返回预测的分类名称
代码语言:txt复制
std::string predict(const std::string& text, float* score=nullptr);
代码语言:txt复制
// 执行文本预测,返回预测的分类ID
代码语言:txt复制
int infer(const std::string& text, float* score=nullptr);

protected:

代码语言:txt复制
// 将文本向量化,返回ids和mask两个向量
代码语言:txt复制
std::vector<std::vector<int64_t>> build_input(const std::string& text);

private:

代码语言:txt复制
FullTokenizer* tokenizer_ = nullptr;
代码语言:txt复制
Ort::Session* ses_ = nullptr;
代码语言:txt复制
Ort::Env env_; // 注意

};

代码语言:txt复制
## 构造函数
```c  

Model::Model(const std::string& model_path,

代码语言:txt复制
         const std::string& vocab_path)
代码语言:txt复制
             :env_(ORT_LOGGING_LEVEL_WARNING, "test") {
代码语言:txt复制
tokenizer_ = new FullTokenizer(vocab_path);
代码语言:txt复制
Ort::SessionOptions session_options;
代码语言:txt复制
OrtCUDAProviderOptions cuda_options;
代码语言:txt复制
session_options.AppendExecutionProvider_CUDA(cuda_options);
代码语言:txt复制
ses_ = new Ort::Session(env_, model_path.c_str(), session_options);

}

代码语言:txt复制
### Ort::Env与coredump

通过前面的例子,Ort::Env参数应该只是构造Ort::Session时的临时变量,这里为什么要弄成Model类的成员变量呢?作为Model构造函数中的局部变量不行吗?在我的1.31的ORT版本上还真不行。因为如果env是一个局部变量,在后面infer函数中执行Session::Run()的时候,会coredump。
回看Ort::Session的构造参数定义:
```C  

Ort::Session::Session(Env& env,

代码语言:txt复制
                  const char * model_path,      // ONNX模型的路径
代码语言:txt复制
                  const SessionOptions & options
代码语言:txt复制
                  )
代码语言:txt复制
Env是一个非常量的引用,也就是env如果定义成一个局部变量,那么在Model构造函数结束之后,env引用就失效了,出现引用悬空(可以理解成野指针)。但是在Session::Run()执行的时候,内部还会使用到env中的数据,从而出现非法的内存访问。

其实这属于API设计上的一个BUG,最近看到ORT的Github上已经做了修复。参见这个Pull Request:
[Make 'env' argument to Session const](https://github.com/microsoft/onnxruntime/pull/13362)
```diff

struct Session : detail::SessionImpl<OrtSession> {

  • explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used这里改成常量引用,是可以延长局部变量的声明周期的。 这个PR是2022.10.19 Merge到主干的,并没有包含到当前(2022.11)最新的版本(1.13.1)中。虽然这个版本是2022.10.25发布release的。下个版本应该能体现,到时候就不用再特殊处理Env参数了。 ### build_input() 用来封装文本向量化,以及最终返回ids和mask两个向量的过程。中间包含补`101`,padding的操作。 ```c std::vector<std::vector<int64_t>> Model::build_input(const std::string& text) { auto tokens = tokenizer_->tokenize(text); auto token_ids = tokenizer_->convertTokensToIds(tokens);
  • Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
  • Session(Env& env, const ORTCHAR_T model_path, const SessionOptions& options, OrtPrepackedWeightsContainer prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
  • Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
  • Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
  • explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
  • Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
  • Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
  • OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
  • Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
  • Session(const Env& env, const void model_data, size_t model_data_length, const SessionOptions& options, OrtPrepackedWeightsContainer prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
代码语言:txt复制
std::vector<std::vector<int64_t>> res;
代码语言:txt复制
std::vector<int64_t> input(32);
代码语言:txt复制
std::vector<int64_t> mask(32);
代码语言:txt复制
input[0] = 101; // Bert模型的[CLS]标记的
代码语言:txt复制
mask[0] = 1;
代码语言:txt复制
for (int i = 0; i < token_ids.size() && i < 31;   i) {
代码语言:txt复制
    input[i 1] = token_ids[i];
代码语言:txt复制
    mask[i 1] = token_ids[i] > 0;
代码语言:txt复制
}
代码语言:txt复制
res.push_back(std::move(input));
代码语言:txt复制
res.push_back(std::move(mask));
代码语言:txt复制
return res;

}

代码语言:txt复制
### infer()和predict()
infer用来执行推理,返回文本最接近的分类。
```c  

int Model::infer(const std::string& text, float* score) {

代码语言:txt复制
auto& session = *ses_;
代码语言:txt复制
// 调用前面的build_input
代码语言:txt复制
auto res = build_input(text);
代码语言:txt复制
std::vector<int64_t> shape = {1, 32};
代码语言:txt复制
auto& input_tensor_values = res[0];
代码语言:txt复制
auto& mask_tensor_values = res[1];
代码语言:txt复制
const static auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
代码语言:txt复制
Ort::Value input_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, input_tensor_values.data(),
代码语言:txt复制
                                                        input_tensor_values.size(), shape.data(), 2);
代码语言:txt复制
Ort::Value mask_tensor = Ort::Value::CreateTensor<int64_t>(memory_info, mask_tensor_values.data(),
代码语言:txt复制
                                                        mask_tensor_values.size(), shape.data(), 2);
代码语言:txt复制
std::vector<Ort::Value> ort_inputs;
代码语言:txt复制
ort_inputs.push_back(std::move(input_tensor));
代码语言:txt复制
ort_inputs.push_back(std::move(mask_tensor));
代码语言:txt复制
const static std::vector<const char*> input_node_names = {"ids", "mask"};
代码语言:txt复制
const static std::vector<const char*> output_node_names = {"output"};
代码语言:txt复制
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_node_names.data(), ort_inputs.data(),
代码语言:txt复制
                                ort_inputs.size(), output_node_names.data(), 1);
代码语言:txt复制
if (output_tensors.size() != output_node_names.size()) {
代码语言:txt复制
    return -1;
代码语言:txt复制
}
代码语言:txt复制
const float* output = output_tensors[0].GetTensorData<float>();
代码语言:txt复制
int idx = argmax(output, output 10);
代码语言:txt复制
if (score != nullptr) {
代码语言:txt复制
    *score = output[idx];
代码语言:txt复制
}
代码语言:txt复制
return idx;

}

代码语言:txt复制
predict()函数比infer()函数更进一步,用来返回分类的名称。
首先我们还是借用python中分类名:
```c  

const static std::vector<std::string> kNames = {

代码语言:txt复制
"finance",
代码语言:txt复制
"realty",
代码语言:txt复制
"stocks",
代码语言:txt复制
"education",
代码语言:txt复制
"science",
代码语言:txt复制
"society",
代码语言:txt复制
"politics",
代码语言:txt复制
"sports",
代码语言:txt复制
"game",
代码语言:txt复制
"entertainment"

};

代码语言:txt复制
然后:
```c  

std::string Model::predict(const std::string& text, float* score) {

代码语言:txt复制
int idx = infer(text, score);
代码语言:txt复制
return (idx >= 0 && idx < kNames.size()) ? kNames[idx] : "Unknown";

}}

代码语言:txt复制
# bRPC
终于到了bRPC服务化的环节了,其实这部分已经比较简单了。直接用官方example中的echo_server改改就可以了。关于bRPC的基础,可以参考我之前的这两篇文章:
- [bRPC最新安装上手指南](https://mp.weixin.qq.com/s/UYbTfQRY9JsonOyCBqxyHA)
- [通过echo_server带你入门bRPC!](https://mp.weixin.qq.com/s/nmLruEd_nUkC7Dj5EHyWLw)

## 定义接口proto
```proto

syntax="proto2";

package guodongxiaren;

option cc_generic_services = true;

message NewsClassifyRequest {

代码语言:txt复制
required string title = 1;

};

message NewsClassifyResponse {

代码语言:txt复制
required string result = 1;
代码语言:txt复制
optional float score = 2;

};

service InferService {

代码语言:txt复制
rpc NewsClassify(NewsClassifyRequest) returns (NewsClassifyResponse);

};

代码语言:txt复制
## Server代码
```c  

#include <gflags/gflags.h>

#include <butil/logging.h>

#include <brpc/server.h>

#include "infer.pb.h"

#include "util/model.h"

DEFINE_int32(port, 8000, "TCP Port of this server");

DEFINE_string(listen_addr, "", "Server listen address, may be IPV4/IPV6/UDS."

代码语言:txt复制
        " If this is set, the flag port will be ignored");

DEFINE_int32(idle_timeout_s, -1, "Connection will be closed if there is no "

代码语言:txt复制
         "read/write operations during the last `idle_timeout_s'");

DEFINE_int32(logoff_ms, 2000, "Maximum duration of server's LOGOFF state "

代码语言:txt复制
         "(waiting for client to close connection before server stops)");

namespace guodongxiaren {

class InferServiceImpl : public InferService {

public:

代码语言:txt复制
InferServiceImpl() {}
代码语言:txt复制
virtual ~InferServiceImpl() { delete model; }
代码语言:txt复制
// 接口
代码语言:txt复制
virtual void NewsClassify(google::protobuf::RpcController* cntl_base,
代码语言:txt复制
                  const NewsClassifyRequest* request,
代码语言:txt复制
                  NewsClassifyResponse* response,
代码语言:txt复制
                  google::protobuf::Closure* done) {
代码语言:txt复制
    brpc::ClosureGuard done_guard(done);
代码语言:txt复制
    brpc::Controller* cntl =
代码语言:txt复制
        static_cast<brpc::Controller*>(cntl_base);
代码语言:txt复制
    float score = 0.0f;
代码语言:txt复制
    auto result = model->predict(request->title(), &score);
代码语言:txt复制
    LOG(INFO) << " " << request->title()
代码语言:txt复制
              << " is " << result
代码语言:txt复制
              << " score: " << score;
代码语言:txt复制
    response->set_result(result);
代码语言:txt复制
    response->set_score(score);
代码语言:txt复制
}
代码语言:txt复制
// 初始化函数
代码语言:txt复制
int Init(const std::string& model_path, const std::string& vocab_path) {
代码语言:txt复制
    model = new Model(model_path, vocab_path);
代码语言:txt复制
}
代码语言:txt复制
Model* model = nullptr;

};

} // namespace guodongxiaren

int main(int argc, char* argv[]) {

代码语言:txt复制
gflags::ParseCommandLineFlags(&argc, &argv, true);
代码语言:txt复制
brpc::Server server;
代码语言:txt复制
guodongxiaren::InferServiceImpl service_impl;
代码语言:txt复制
// 初始化
代码语言:txt复制
const char* vocab_path = "/home/guodongxiaren/vocab.txt";
代码语言:txt复制
const char* model_path = "/home/guodongxiaren/model.onnx";
代码语言:txt复制
service_impl.Init(model_path, vocab_path);
代码语言:txt复制
if (server.AddService(&service_impl, 
代码语言:txt复制
                      brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
代码语言:txt复制
    LOG(ERROR) << "Fail to add service";
代码语言:txt复制
    return -1;
代码语言:txt复制
}
代码语言:txt复制
butil::EndPoint point;
代码语言:txt复制
if (!FLAGS_listen_addr.empty()) {
代码语言:txt复制
    if (butil::str2endpoint(FLAGS_listen_addr.c_str(), &point) < 0) {
代码语言:txt复制
        LOG(ERROR) << "Invalid listen address:" << FLAGS_listen_addr;
代码语言:txt复制
        return -1;
代码语言:txt复制
    }
代码语言:txt复制
} else {
代码语言:txt复制
    point = butil::EndPoint(butil::IP_ANY, FLAGS_port);
代码语言:txt复制
}
代码语言:txt复制
brpc::ServerOptions options;
代码语言:txt复制
options.idle_timeout_sec = FLAGS_idle_timeout_s;
代码语言:txt复制
if (server.Start(point, &options) != 0) {
代码语言:txt复制
    LOG(ERROR) << "Fail to start InferServer";
代码语言:txt复制
    return -1;
代码语言:txt复制
}
代码语言:txt复制
server.RunUntilAskedToQuit();
代码语言:txt复制
return 0;

}

代码语言:txt复制
## 测试
bRPC支持单端口多协议,一个bRPC服务默认除了可以提供protobuf类型的请求外,也只支持HTTP JSON请求。所以我们可以直接使用curl来测试:

curl -d '{"title": "衡水中学:破除超限、内卷等现象"}' 127.0.0.1:8000/InferService/NewsClassify

代码语言:txt复制
输出:

{"result":"education","score":9.031564712524414}

代码语言:txt复制

0 人点赞