提交 cf5c2436 编写于 作者: M MRXLT

fix bert demo

上级 6312429d
...@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() { ...@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() {
_api.thrd_initialize(); _api.thrd_initialize();
} }
int PredictorClient::predict( int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
const std::vector<std::vector<float>>& float_feed, const std::vector<std::string> &float_feed_name,
const std::vector<std::string>& float_feed_name, const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::vector<int64_t>>& int_feed, const std::vector<std::string> &int_feed_name,
const std::vector<std::string>& int_feed_name, const std::vector<std::string> &fetch_name,
const std::vector<std::string>& fetch_name, PredictorRes &predict_res) { // NOLINT
PredictorRes & predict_res) { // NOLINT
predict_res._int64_map.clear(); predict_res._int64_map.clear();
predict_res._float_map.clear(); predict_res._float_map.clear();
Timer timeline; Timer timeline;
...@@ -218,6 +217,7 @@ int PredictorClient::predict( ...@@ -218,6 +217,7 @@ int PredictorClient::predict(
VLOG(2) << "fetch name: " << name; VLOG(2) << "fetch name: " << name;
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
int len = res.insts(0).tensor_array(idx).int64_data_size(); int len = res.insts(0).tensor_array(idx).int64_data_size();
VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
predict_res._int64_map[name].resize(1); predict_res._int64_map[name].resize(1);
predict_res._int64_map[name][0].resize(len); predict_res._int64_map[name][0].resize(len);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
...@@ -226,6 +226,7 @@ int PredictorClient::predict( ...@@ -226,6 +226,7 @@ int PredictorClient::predict(
} }
} else if (_fetch_name_to_type[name] == 1) { } else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(0).tensor_array(idx).float_data_size(); int len = res.insts(0).tensor_array(idx).float_data_size();
VLOG(2) << "fetch tensor : " << name << " type: float32 len : " << len;
predict_res._float_map[name].resize(1); predict_res._float_map[name].resize(1);
predict_res._float_map[name][0].resize(len); predict_res._float_map[name][0].resize(len);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
...@@ -244,7 +245,7 @@ int PredictorClient::predict( ...@@ -244,7 +245,7 @@ int PredictorClient::predict(
<< "prepro_1:" << preprocess_end << " " << "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " " << "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " "; << "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) { if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2; int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) { for (int i = 0; i < op_num; ++i) {
...@@ -252,10 +253,10 @@ int PredictorClient::predict( ...@@ -252,10 +253,10 @@ int PredictorClient::predict(
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " "; oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
} }
} }
oss << "postpro_0:" << postprocess_start << " "; oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end; oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
return 0; return 0;
...@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} }
VLOG(2) << "batch [" << bi << "] " VLOG(2) << "batch [" << bi << "] "
<< "itn feed value prepared"; << "int feed value prepared";
} }
int64_t preprocess_end = timeline.TimeStampUS(); int64_t preprocess_end = timeline.TimeStampUS();
......
...@@ -120,7 +120,6 @@ class BertService(): ...@@ -120,7 +120,6 @@ class BertService():
def test(): def test():
bc = BertService( bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12', model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20, max_seq_len=20,
...@@ -130,16 +129,25 @@ def test(): ...@@ -130,16 +129,25 @@ def test():
config_file = './serving_client_conf/serving_client_conf.prototxt' config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"] fetch = ["pooled_output"]
bc.load_client(config_file, server_addr) bc.load_client(config_file, server_addr)
batch_size = 4 batch_size = 1
batch = [] batch = []
for line in sys.stdin: for line in sys.stdin:
if len(batch) < batch_size: if batch_size == 1:
batch.append([line.strip()]) result = bc.run_general([[line.strip()]], fetch)
print(result)
else: else:
result = bc.run_batch_general(batch, fetch) if len(batch) < batch_size:
batch = [] batch.append([line.strip()])
for r in result: else:
print(r) result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if len(batch) > 0:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op) ...@@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op)
server = Server() server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4) server.set_num_threads(4)
server.set_local_bin(
"~/github/Serving/build_server/core/general-server/serving")
server.load_model_config(sys.argv[1]) server.load_model_config(sys.argv[1])
port = int(sys.argv[2]) port = int(sys.argv[2])
......
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册