diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 1dcf261f96769f9c33ddcbe792eda9a6057c8502..b0bc5c6d914a3de2797ce069855d7c62ed58d3c3 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -132,13 +132,12 @@ int PredictorClient::create_predictor() { _api.thrd_initialize(); } -int PredictorClient::predict( - const std::vector>& float_feed, - const std::vector& float_feed_name, - const std::vector>& int_feed, - const std::vector& int_feed_name, - const std::vector& fetch_name, - PredictorRes & predict_res) { // NOLINT +int PredictorClient::predict(const std::vector> &float_feed, + const std::vector &float_feed_name, + const std::vector> &int_feed, + const std::vector &int_feed_name, + const std::vector &fetch_name, + PredictorRes &predict_res) { // NOLINT predict_res._int64_map.clear(); predict_res._float_map.clear(); Timer timeline; @@ -218,6 +217,7 @@ int PredictorClient::predict( VLOG(2) << "fetch name: " << name; if (_fetch_name_to_type[name] == 0) { int len = res.insts(0).tensor_array(idx).int64_data_size(); + VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len; predict_res._int64_map[name].resize(1); predict_res._int64_map[name][0].resize(len); for (int i = 0; i < len; ++i) { @@ -226,6 +226,7 @@ int PredictorClient::predict( } } else if (_fetch_name_to_type[name] == 1) { int len = res.insts(0).tensor_array(idx).float_data_size(); + VLOG(2) << "fetch tensor : " << name << " type: float32 len : " << len; predict_res._float_map[name].resize(1); predict_res._float_map[name][0].resize(len); for (int i = 0; i < len; ++i) { @@ -244,7 +245,7 @@ int PredictorClient::predict( << "prepro_1:" << preprocess_end << " " << "client_infer_0:" << client_infer_start << " " << "client_infer_1:" << client_infer_end << " "; - + if (FLAGS_profile_server) { int op_num = res.profile_time_size() / 2; for (int i = 0; i < op_num; ++i) { @@ -252,10 +253,10 @@ int PredictorClient::predict( oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " "; } } - + oss << "postpro_0:" << postprocess_start << " "; oss << "postpro_1:" << postprocess_end; - + fprintf(stderr, "%s\n", oss.str().c_str()); } return 0; @@ -342,7 +343,7 @@ std::vector>> PredictorClient::batch_predict( } VLOG(2) << "batch [" << bi << "] " - << "itn feed value prepared"; + << "int feed value prepared"; } int64_t preprocess_end = timeline.TimeStampUS(); diff --git a/python/examples/bert/bert_client.py b/python/examples/bert/bert_client.py index 452a18024b69ee330d63bb0dc1d5e8a450e51883..a73340a2244eccc27e2a8b880503dbd2ead2f0ee 100644 --- a/python/examples/bert/bert_client.py +++ b/python/examples/bert/bert_client.py @@ -120,7 +120,6 @@ class BertService(): def test(): - bc = BertService( model_name='bert_chinese_L-12_H-768_A-12', max_seq_len=20, @@ -130,16 +129,25 @@ def test(): config_file = './serving_client_conf/serving_client_conf.prototxt' fetch = ["pooled_output"] bc.load_client(config_file, server_addr) - batch_size = 4 + batch_size = 1 batch = [] for line in sys.stdin: - if len(batch) < batch_size: - batch.append([line.strip()]) + if batch_size == 1: + result = bc.run_general([[line.strip()]], fetch) + print(result) else: - result = bc.run_batch_general(batch, fetch) - batch = [] - for r in result: - print(r) + if len(batch) < batch_size: + batch.append([line.strip()]) + else: + result = bc.run_batch_general(batch, fetch) + batch = [] + for r in result: + print(r) + if len(batch) > 0: + result = bc.run_batch_general(batch, fetch) + batch = [] + for r in result: + print(r) if __name__ == '__main__': diff --git a/python/examples/bert/bert_server.py b/python/examples/bert/bert_server.py index 52b74b4622cfa3add6ad41678339924e3f9c3b0c..35d38be0cac50b899b58085c7f103f32537859c4 100644 --- a/python/examples/bert/bert_server.py +++ b/python/examples/bert/bert_server.py @@ -31,8 +31,6 @@ op_seq_maker.add_op(general_response_op) server = Server() server.set_op_sequence(op_seq_maker.get_op_sequence()) server.set_num_threads(4) -server.set_local_bin( - "~/github/Serving/build_server/core/general-server/serving") server.load_model_config(sys.argv[1]) port = int(sys.argv[2]) diff --git a/python/examples/bert/get_data.sh b/python/examples/bert/get_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..37174e725e22d4ae1ea000723a9e8f1a026b017d --- /dev/null +++ b/python/examples/bert/get_data.sh @@ -0,0 +1 @@ +wget https://paddle-serving.bj.bcebos.com/bert_example/data-c.txt --no-check-certificate