提交 0be0c9c0 编写于 作者: M MRXLT

fix bert demo

上级 2e75c394
......@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() {
_api.thrd_initialize();
}
int PredictorClient::predict(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
PredictorRes & predict_res) { // NOLINT
int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name,
PredictorRes &predict_res) { // NOLINT
predict_res._int64_map.clear();
predict_res._float_map.clear();
Timer timeline;
......@@ -218,6 +217,7 @@ int PredictorClient::predict(
VLOG(2) << "fetch name: " << name;
if (_fetch_name_to_type[name] == 0) {
int len = res.insts(0).tensor_array(idx).int64_data_size();
VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
predict_res._int64_map[name].resize(1);
predict_res._int64_map[name][0].resize(len);
for (int i = 0; i < len; ++i) {
......@@ -226,6 +226,7 @@ int PredictorClient::predict(
}
} else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(0).tensor_array(idx).float_data_size();
VLOG(2) << "fetch tensor : " << name << " type: float32 len : " << len;
predict_res._float_map[name].resize(1);
predict_res._float_map[name][0].resize(len);
for (int i = 0; i < len; ++i) {
......@@ -244,7 +245,7 @@ int PredictorClient::predict(
<< "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " "
<< "client_infer_1:" << client_infer_end << " ";
if (FLAGS_profile_server) {
int op_num = res.profile_time_size() / 2;
for (int i = 0; i < op_num; ++i) {
......@@ -252,10 +253,10 @@ int PredictorClient::predict(
oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " ";
}
}
oss << "postpro_0:" << postprocess_start << " ";
oss << "postpro_1:" << postprocess_end;
fprintf(stderr, "%s\n", oss.str().c_str());
}
return 0;
......@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
}
VLOG(2) << "batch [" << bi << "] "
<< "itn feed value prepared";
<< "int feed value prepared";
}
int64_t preprocess_end = timeline.TimeStampUS();
......
......@@ -120,7 +120,6 @@ class BertService():
def test():
bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20,
......@@ -130,16 +129,25 @@ def test():
config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"]
bc.load_client(config_file, server_addr)
batch_size = 4
batch_size = 1
batch = []
for line in sys.stdin:
if len(batch) < batch_size:
batch.append([line.strip()])
if batch_size == 1:
result = bc.run_general([[line.strip()]], fetch)
print(result)
else:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if len(batch) < batch_size:
batch.append([line.strip()])
else:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if len(batch) > 0:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if __name__ == '__main__':
......
......@@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op)
server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4)
server.set_local_bin(
"~/github/Serving/build_server/core/general-server/serving")
server.load_model_config(sys.argv[1])
port = int(sys.argv[2])
......
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册