提交 cf5c2436 编写于 作者: M MRXLT

fix bert demo

上级 6312429d
...@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() { ...@@ -132,13 +132,12 @@ int PredictorClient::create_predictor() {
_api.thrd_initialize(); _api.thrd_initialize();
} }
int PredictorClient::predict( int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
const std::vector<std::vector<float>>& float_feed, const std::vector<std::string> &float_feed_name,
const std::vector<std::string>& float_feed_name, const std::vector<std::vector<int64_t>> &int_feed,
const std::vector<std::vector<int64_t>>& int_feed, const std::vector<std::string> &int_feed_name,
const std::vector<std::string>& int_feed_name, const std::vector<std::string> &fetch_name,
const std::vector<std::string>& fetch_name, PredictorRes &predict_res) { // NOLINT
PredictorRes & predict_res) { // NOLINT
predict_res._int64_map.clear(); predict_res._int64_map.clear();
predict_res._float_map.clear(); predict_res._float_map.clear();
Timer timeline; Timer timeline;
...@@ -218,6 +217,7 @@ int PredictorClient::predict( ...@@ -218,6 +217,7 @@ int PredictorClient::predict(
VLOG(2) << "fetch name: " << name; VLOG(2) << "fetch name: " << name;
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
int len = res.insts(0).tensor_array(idx).int64_data_size(); int len = res.insts(0).tensor_array(idx).int64_data_size();
VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
predict_res._int64_map[name].resize(1); predict_res._int64_map[name].resize(1);
predict_res._int64_map[name][0].resize(len); predict_res._int64_map[name][0].resize(len);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
...@@ -226,6 +226,7 @@ int PredictorClient::predict( ...@@ -226,6 +226,7 @@ int PredictorClient::predict(
} }
} else if (_fetch_name_to_type[name] == 1) { } else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(0).tensor_array(idx).float_data_size(); int len = res.insts(0).tensor_array(idx).float_data_size();
VLOG(2) << "fetch tensor : " << name << " type: float32 len : " << len;
predict_res._float_map[name].resize(1); predict_res._float_map[name].resize(1);
predict_res._float_map[name][0].resize(len); predict_res._float_map[name][0].resize(len);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
...@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -342,7 +343,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} }
VLOG(2) << "batch [" << bi << "] " VLOG(2) << "batch [" << bi << "] "
<< "itn feed value prepared"; << "int feed value prepared";
} }
int64_t preprocess_end = timeline.TimeStampUS(); int64_t preprocess_end = timeline.TimeStampUS();
......
...@@ -120,7 +120,6 @@ class BertService(): ...@@ -120,7 +120,6 @@ class BertService():
def test(): def test():
bc = BertService( bc = BertService(
model_name='bert_chinese_L-12_H-768_A-12', model_name='bert_chinese_L-12_H-768_A-12',
max_seq_len=20, max_seq_len=20,
...@@ -130,9 +129,13 @@ def test(): ...@@ -130,9 +129,13 @@ def test():
config_file = './serving_client_conf/serving_client_conf.prototxt' config_file = './serving_client_conf/serving_client_conf.prototxt'
fetch = ["pooled_output"] fetch = ["pooled_output"]
bc.load_client(config_file, server_addr) bc.load_client(config_file, server_addr)
batch_size = 4 batch_size = 1
batch = [] batch = []
for line in sys.stdin: for line in sys.stdin:
if batch_size == 1:
result = bc.run_general([[line.strip()]], fetch)
print(result)
else:
if len(batch) < batch_size: if len(batch) < batch_size:
batch.append([line.strip()]) batch.append([line.strip()])
else: else:
...@@ -140,6 +143,11 @@ def test(): ...@@ -140,6 +143,11 @@ def test():
batch = [] batch = []
for r in result: for r in result:
print(r) print(r)
if len(batch) > 0:
result = bc.run_batch_general(batch, fetch)
batch = []
for r in result:
print(r)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op) ...@@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op)
server = Server() server = Server()
server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_op_sequence(op_seq_maker.get_op_sequence())
server.set_num_threads(4) server.set_num_threads(4)
server.set_local_bin(
"~/github/Serving/build_server/core/general-server/serving")
server.load_model_config(sys.argv[1]) server.load_model_config(sys.argv[1])
port = int(sys.argv[2]) port = int(sys.argv[2])
......
wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册