未验证 提交 47ee6c1a 编写于 作者: M MRXLT 提交者: GitHub

Merge pull request #97 from MRXLT/bert

add demo for bert
......@@ -17,6 +17,7 @@ set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SERVING_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_SERVING_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(PADDLE_SERVING_INSTALL_DIR ${CMAKE_BINARY_DIR}/output)
SET(CMAKE_INSTALL_RPATH "\$ORIGIN" "${CMAKE_INSTALL_RPATH}")
include(system)
......
......@@ -61,6 +61,12 @@ add_executable(ctr_prediction
${CMAKE_CURRENT_LIST_DIR}/src/ctr_prediction.cpp)
target_link_libraries(ctr_prediction -Wl,--whole-archive sdk-cpp
-Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz)
add_executable(bert
${CMAKE_CURRENT_LIST_DIR}/src/bert_service.cpp)
target_link_libraries(bert -Wl,--whole-archive sdk-cpp
-Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz)
# install
install(TARGETS ximage
RUNTIME DESTINATION
......@@ -128,3 +134,11 @@ install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/ctr_prediction/)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/ctr_prediction DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/ctr_prediction/data)
install(TARGETS bert
RUNTIME DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/bert/bin)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/bert/)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/bert DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/bert/data)
......@@ -139,3 +139,17 @@ predictors {
}
}
}
predictors {
name: "bert_service"
service_name: "baidu.paddle_serving.predictor.bert_service.BertService"
endpoint_router: "WeightedRandomRender"
weighted_random_render_conf {
variant_weight_list: "50"
}
variants {
tag: "var1"
naming_conf {
cluster: "list://127.0.0.1:8010"
}
}
}
此差异已折叠。
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <string>
#include <thread> //NOLINT
#include <vector>
#include "sdk-cpp/bert_service.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
extern int batch_size = 1;
extern int max_seq_len = 128;
extern int layer_num = 12;
extern int emb_size = 768;
extern int thread_num = 1;
extern int max_turn = 1000;
std::atomic<int> g_concurrency(0);
std::vector<std::vector<int>> response_time;
std::vector<std::vector<int>> infer_time;
char* data_filename = "./data/bert/demo_wiki_data";
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
BertReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
// avoid out of boundary
int cur_index = data_index + i;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], ";");
std::vector<std::string> token_list = split(feature_list[0], " ");
std::vector<std::string> seg_list = split(feature_list[1], " ");
std::vector<std::string> pos_list = split(feature_list[2], " ");
for (int fi = 0; fi < max_seq_len; fi++) {
if (std::stoi(token_list[fi]) != 0) {
ins->add_token_ids(std::stoi(token_list[fi]));
ins->add_sentence_type_ids(std::stoi(seg_list[fi]));
ins->add_position_ids(std::stoi(pos_list[fi]));
ins->add_input_masks(1.0);
} else {
ins->add_token_ids(0);
ins->add_sentence_type_ids(0);
ins->add_position_ids(0);
ins->add_input_masks(0.0);
}
}
ins->set_max_seq_len(max_seq_len);
ins->set_emb_size(emb_size);
}
return 0;
}
void print_res(const Request& req,
const Response& res,
std::string route_tag,
uint64_t elapse_ms) {
for (uint32_t ri = 0; ri < res.instances_size(); ri++) {
const BertResInstance& res_ins = res.instances(ri);
std::ostringstream oss;
oss << "[";
for (uint32_t bi = 0; bi < res_ins.instances_size(); bi++) {
const EmbeddingValues& emb_ins = res_ins.instances(bi);
oss << "[";
for (uint32_t ei = 0; ei < emb_ins.values_size(); ei++) {
oss << emb_ins.values(ei) << " ";
}
oss << "],";
}
oss << "]\n";
LOG(INFO) << "Receive : " << oss.str();
}
LOG(INFO) << "Succ call predictor[ctr_prediction_service], the tag is: "
<< route_tag << ", elapse_ms: " << elapse_ms;
}
void thread_worker(PredictorApi* api,
int thread_id,
int batch_size,
int server_concurrency,
const std::vector<std::string>& data_list) {
Request req;
Response res;
api->thrd_initialize();
std::string line;
int turns = 0;
while (turns < max_turn) {
timeval start;
gettimeofday(&start, NULL);
api->thrd_clear();
Predictor* predictor = api->fetch_predictor("bert_service");
if (!predictor) {
LOG(ERROR) << "Failed fetch predictor: bert_service";
return;
}
req.Clear();
res.Clear();
while (g_concurrency.load() >= server_concurrency) {
}
g_concurrency++;
LOG(INFO) << "Current concurrency " << g_concurrency.load();
int data_index = turns * batch_size;
if (create_req(&req, data_list, data_index, batch_size) != 0) {
return;
}
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
return;
}
timeval end;
gettimeofday(&end, NULL);
uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) -
(start.tv_sec * 1000 + start.tv_usec / 1000);
response_time[thread_id].push_back(elapse_ms);
print_res(req, res, predictor->tag(), elapse_ms);
g_concurrency--;
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
turns++;
}
api->thrd_finalize();
}
void calc_time(int server_concurrency, int batch_size) {
std::vector<int> time_list;
for (auto a : response_time) {
time_list.insert(time_list.end(), a.begin(), a.end());
}
LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency;
LOG(INFO) << "Max Seq Len : " << max_seq_len;
float total_time = 0;
float max_time = 0;
float min_time = 1000000;
for (int i = 0; i < time_list.size(); ++i) {
total_time += time_list[i];
if (time_list[i] > max_time) max_time = time_list[i];
if (time_list[i] < min_time) min_time = time_list[i];
}
float mean_time = total_time / (time_list.size());
float var_time;
for (int i = 0; i < time_list.size(); ++i) {
var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
}
var_time = var_time / time_list.size();
LOG(INFO) << "Total time : " << total_time / server_concurrency
<< " Variance : " << var_time << " Max time : " << max_time
<< " Min time : " << min_time;
float qps = 0.0;
if (total_time > 0)
qps = (time_list.size() * 1000) / (total_time / server_concurrency);
LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: ";
sort(time_list.begin(), time_list.end());
int percent_pos_50 = time_list.size() * 0.5;
int percent_pos_80 = time_list.size() * 0.8;
int percent_pos_90 = time_list.size() * 0.9;
int percent_pos_99 = time_list.size() * 0.99;
int percent_pos_999 = time_list.size() * 0.999;
if (time_list.size() != 0) {
LOG(INFO) << "Mean time : " << mean_time;
LOG(INFO) << "50 percent ms: " << time_list[percent_pos_50];
LOG(INFO) << "80 percent ms: " << time_list[percent_pos_80];
LOG(INFO) << "90 percent ms: " << time_list[percent_pos_90];
LOG(INFO) << "99 percent ms: " << time_list[percent_pos_99];
LOG(INFO) << "99.9 percent ms: " << time_list[percent_pos_999];
} else {
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
LOG(INFO) << "N/A";
}
}
int main(int argc, char** argv) {
PredictorApi api;
if (argc > 1) {
thread_num = std::stoi(argv[1]);
batch_size = std::stoi(argv[2]);
max_seq_len = std::stoi(argv[3]);
}
response_time.resize(thread_num);
int server_concurrency = thread_num;
// log set
#ifdef BCLOUD
logging::LoggingSettings settings;
settings.logging_dest = logging::LOG_TO_FILE;
std::string log_filename(argv[0]);
log_filename = log_filename.substr(log_filename.find_last_of('/') + 1);
settings.log_file = (std::string("./log/") + log_filename + ".log").c_str();
settings.delete_old = logging::DELETE_OLD_LOG_FILE;
logging::InitLogging(settings);
logging::ComlogSinkOptions cso;
cso.process_name = log_filename;
cso.enable_wf_device = true;
logging::ComlogSink::GetInstance()->Setup(&cso);
#else
struct stat st_buf;
int ret = 0;
if ((ret = stat("./log", &st_buf)) != 0) {
mkdir("./log", 0777);
ret = stat("./log", &st_buf);
if (ret != 0) {
LOG(WARNING) << "Log path ./log not exist, and create fail";
return -1;
}
}
FLAGS_log_dir = "./log";
google::InitGoogleLogging(strdup(argv[0]));
FLAGS_logbufsecs = 0;
FLAGS_logbuflevel = -1;
#endif
// predictor conf
if (api.create("./conf", "predictors.prototxt") != 0) {
LOG(ERROR) << "Failed create predictors api!";
return -1;
}
// read data
std::ifstream data_file(data_filename);
if (!data_file) {
std::cout << "read file error \n" << std::endl;
return -1;
}
std::vector<std::string> data_list;
std::string line;
while (getline(data_file, line)) {
data_list.push_back(line);
}
// create threads
std::vector<std::thread*> thread_pool;
for (int i = 0; i < server_concurrency; ++i) {
thread_pool.push_back(new std::thread(thread_worker,
&api,
i,
batch_size,
server_concurrency,
std::ref(data_list)));
}
for (int i = 0; i < server_concurrency; ++i) {
thread_pool[i]->join();
delete thread_pool[i];
}
calc_time(server_concurrency, batch_size);
api.destroy();
return 0;
}
......@@ -24,6 +24,18 @@ if (NOT EXISTS
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid)
endif()
if (NOT EXISTS
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12)
execute_process(COMMAND wget --no-check-certificate
https://paddle-serving.bj.bcebos.com/data/bert/bert_cased_L-12_H-768_A-12.tar.gz
--output-document
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12.tar.gz)
execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf
"${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid/bert_cased_L-12_H-768_A-12.tar.gz"
WORKING_DIRECTORY
${CMAKE_CURRENT_LIST_DIR}/data/model/paddle/fluid)
endif()
include_directories(SYSTEM ${CMAKE_CURRENT_LIST_DIR}/../kvdb/include)
include(op/CMakeLists.txt)
......@@ -78,7 +90,7 @@ if (${WITH_MKL})
install(FILES
${CMAKE_BINARY_DIR}/third_party/install/Paddle/third_party/install/mklml/lib/libmklml_intel.so
${CMAKE_BINARY_DIR}/third_party/install/Paddle/third_party/install/mklml/lib/libiomp5.so
${CMAKE_BINARY_DIR}/third_party/install/Paddle/third_party/install/mkldnn/lib/libmkldnn.so
${CMAKE_BINARY_DIR}/third_party/install/Paddle/third_party/install/mkldnn/lib/libmkldnn.so.0
DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/serving/bin)
endif()
......@@ -6,7 +6,7 @@ engines {
model_data_path: "./data/model/paddle/fluid/SE_ResNeXt50_32x4d"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
enable_batch_align: 0
enable_memory_optimization: true
static_optimization: false
force_update_static_cache: false
......@@ -35,3 +35,15 @@ engines {
sparse_param_service_type: REMOTE
sparse_param_service_table_name: "test_dict"
}
engines {
name: "bert"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/bert_cased_L-12_H-768_A-12"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
enable_memory_optimization: true
}
......@@ -21,7 +21,6 @@ services {
name: "BuiltinFluidService"
workflows: "workflow5"
}
services {
name: "TextClassificationService"
workflows: "workflow6"
......@@ -36,3 +35,7 @@ services {
name: "CTRPredictionService"
workflows: "workflow8"
}
services {
name: "BertService"
workflows: "workflow9"
}
......@@ -83,3 +83,12 @@ workflows {
type: "CTRPredictionOp"
}
}
workflows {
name: "workflow9"
workflow_type: "Sequence"
nodes {
name: "bert_service_op"
type: "BertServiceOp"
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "demo-serving/op/bert_service_op.h"
#include <cstdio>
#include <string>
#include "predictor/framework/infer.h"
#include "predictor/framework/memory.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::bert_service::BertResInstance;
using baidu::paddle_serving::predictor::bert_service::Response;
using baidu::paddle_serving::predictor::bert_service::BertReqInstance;
using baidu::paddle_serving::predictor::bert_service::Request;
using baidu::paddle_serving::predictor::bert_service::EmbeddingValues;
int BertServiceOp::inference() {
timeval op_start;
gettimeofday(&op_start, NULL);
const Request *req = dynamic_cast<const Request *>(get_request_message());
TensorVector *in = butil::get_object<TensorVector>();
Response *res = mutable_data<Response>();
uint32_t batch_size = req->instances_size();
if (batch_size <= 0) {
LOG(WARNING) << "No instances need to inference!";
return 0;
}
const int64_t MAX_SEQ_LEN = req->instances(0).max_seq_len();
const int64_t EMB_SIZE = req->instances(0).emb_size();
paddle::PaddleTensor src_ids;
paddle::PaddleTensor pos_ids;
paddle::PaddleTensor seg_ids;
paddle::PaddleTensor input_masks;
src_ids.name = std::string("src_ids");
pos_ids.name = std::string("pos_ids");
seg_ids.name = std::string("sent_ids");
input_masks.name = std::string("input_mask");
src_ids.dtype = paddle::PaddleDType::INT64;
src_ids.shape = {batch_size, MAX_SEQ_LEN, 1};
src_ids.data.Resize(batch_size * MAX_SEQ_LEN * sizeof(int64_t));
pos_ids.dtype = paddle::PaddleDType::INT64;
pos_ids.shape = {batch_size, MAX_SEQ_LEN, 1};
pos_ids.data.Resize(batch_size * MAX_SEQ_LEN * sizeof(int64_t));
seg_ids.dtype = paddle::PaddleDType::INT64;
seg_ids.shape = {batch_size, MAX_SEQ_LEN, 1};
seg_ids.data.Resize(batch_size * MAX_SEQ_LEN * sizeof(int64_t));
input_masks.dtype = paddle::PaddleDType::FLOAT32;
input_masks.shape = {batch_size, MAX_SEQ_LEN, 1};
input_masks.data.Resize(batch_size * MAX_SEQ_LEN * sizeof(float));
std::vector<std::vector<size_t>> lod_set;
lod_set.resize(1);
for (uint32_t i = 0; i < batch_size; i++) {
lod_set[0].push_back(i * MAX_SEQ_LEN);
}
// src_ids.lod = lod_set;
// pos_ids.lod = lod_set;
// seg_ids.lod = lod_set;
// input_masks.lod = lod_set;
uint32_t index = 0;
for (uint32_t i = 0; i < batch_size; i++) {
int64_t *src_data = static_cast<int64_t *>(src_ids.data.data()) + index;
int64_t *pos_data = static_cast<int64_t *>(pos_ids.data.data()) + index;
int64_t *seg_data = static_cast<int64_t *>(seg_ids.data.data()) + index;
float *input_masks_data =
static_cast<float *>(input_masks.data.data()) + index;
const BertReqInstance &req_instance = req->instances(i);
memcpy(src_data,
req_instance.token_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(pos_data,
req_instance.position_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(seg_data,
req_instance.sentence_type_ids().data(),
sizeof(int64_t) * MAX_SEQ_LEN);
memcpy(input_masks_data,
req_instance.input_masks().data(),
sizeof(float) * MAX_SEQ_LEN);
index += MAX_SEQ_LEN;
}
in->push_back(src_ids);
in->push_back(pos_ids);
in->push_back(seg_ids);
in->push_back(input_masks);
TensorVector *out = butil::get_object<TensorVector>();
if (!out) {
LOG(ERROR) << "Failed get tls output object";
return -1;
}
#if 0 // print request
std::ostringstream oss;
for (int j = 0; j < 3; j++) {
int64_t* example = reinterpret_cast<int64_t*>((*in)[j].data.data());
for (uint32_t i = 0; i < MAX_SEQ_LEN; i++) {
oss << *(example + i) << " ";
}
oss << ";";
}
float* example = reinterpret_cast<float*>((*in)[3].data.data());
for (int i = 0; i < MAX_SEQ_LEN; i++) {
oss << *(example + i) << " ";
}
LOG(INFO) << "msg: " << oss.str();
#endif
timeval infer_start;
gettimeofday(&infer_start, NULL);
if (predictor::InferManager::instance().infer(
BERT_MODEL_NAME, in, out, batch_size)) {
LOG(ERROR) << "Failed do infer in fluid model: " << BERT_MODEL_NAME;
return -1;
}
timeval infer_end;
gettimeofday(&infer_end, NULL);
uint64_t infer_time =
(infer_end.tv_sec * 1000 + infer_end.tv_usec / 1000 -
(infer_start.tv_sec * 1000 + infer_start.tv_usec / 1000));
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) {
EmbeddingValues *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < EMB_SIZE; ei++) {
uint32_t index = bi * EMB_SIZE + ei;
emb_instance->add_values(out_data[index]);
}
}
}
timeval op_end;
gettimeofday(&op_end, NULL);
uint64_t op_time = (op_end.tv_sec * 1000 + op_end.tv_usec / 1000 -
(op_start.tv_sec * 1000 + op_start.tv_usec / 1000));
res->set_op_time(op_time);
res->set_infer_time(infer_time);
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
in->clear();
butil::return_object<TensorVector>(in);
for (size_t i = 0; i < out->size(); ++i) {
(*out)[i].shape.clear();
}
out->clear();
butil::return_object<TensorVector>(out);
return 0;
}
DEFINE_OP(BertServiceOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle_inference_api.h" // NOLINT
#endif
#include "demo-serving/bert_service.pb.h"
#include <sys/time.h>
namespace baidu {
namespace paddle_serving {
namespace serving {
static const char* BERT_MODEL_NAME = "bert";
class BertServiceOp
: public baidu::paddle_serving::predictor::OpWithChannel<
baidu::paddle_serving::predictor::bert_service::Response> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(BertServiceOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -7,6 +7,7 @@ LIST(APPEND protofiles
${CMAKE_CURRENT_LIST_DIR}/int64tensor_service.proto
${CMAKE_CURRENT_LIST_DIR}/text_classification.proto
${CMAKE_CURRENT_LIST_DIR}/ctr_prediction.proto
${CMAKE_CURRENT_LIST_DIR}/bert_service.proto
)
PROTOBUF_GENERATE_SERVING_CPP(TRUE PROTO_SRCS PROTO_HDRS ${protofiles})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "pds_option.proto";
import "builtin_format.proto";
package baidu.paddle_serving.predictor.bert_service;
option cc_generic_services = true;
message BertReqInstance {
repeated int64 token_ids = 1;
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
};
message Request { repeated BertReqInstance instances = 1; };
message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response {
repeated BertResInstance instances = 1;
optional int64 op_time = 2;
optional int64 infer_time = 3;
};
service BertService {
rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response);
option (pds.options).generate_impl = true;
};
......@@ -17,7 +17,9 @@
#include <pthread.h>
#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "configure/include/configure_parser.h"
#include "configure/inferencer_configure.pb.h"
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "pds_option.proto";
import "builtin_format.proto";
package baidu.paddle_serving.predictor.bert_service;
option cc_generic_services = true;
message BertReqInstance {
repeated int64 token_ids = 1;
repeated int64 sentence_type_ids = 2;
repeated int64 position_ids = 3;
repeated float input_masks = 4;
optional int64 max_seq_len = 5;
optional int64 emb_size = 6;
};
message Request { repeated BertReqInstance instances = 1; };
message EmbeddingValues { repeated float values = 1; };
message BertResInstance { repeated EmbeddingValues instances = 1; };
message Response {
repeated BertResInstance instances = 1;
optional int64 op_time = 2;
optional int64 infer_time = 3;
};
service BertService {
rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response);
option (pds.options).generate_stub = true;
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册