diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index 0e042d1880acd15151d3476d6061fb759f92e11b..5a941c54d614ef3859e9bb3fe293ae343a7a3f27 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -84,19 +84,14 @@ class PredictorClient { PredictorRes& predict_res, // NOLINT const int& pid); - std::vector> predict( - const std::vector>& float_feed, - const std::vector& float_feed_name, - const std::vector>& int_feed, - const std::vector& int_feed_name, - const std::vector& fetch_name); - - std::vector>> batch_predict( + int batch_predict( const std::vector>>& float_feed_batch, const std::vector& float_feed_name, const std::vector>>& int_feed_batch, const std::vector& int_feed_name, - const std::vector& fetch_name); + const std::vector& fetch_name, + PredictorRes& predict_res_batch, // NOLINT + const int& pid); private: PredictorApi _api; diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 9fafe0abc59ce543cfc1783003296d68102d0c8d..d1ad58d462025c205eb669b3aa50864051d414ba 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -264,26 +264,22 @@ int PredictorClient::predict(const std::vector> &float_feed, return 0; } -std::vector>> PredictorClient::batch_predict( +int PredictorClient::batch_predict( const std::vector>> &float_feed_batch, const std::vector &float_feed_name, const std::vector>> &int_feed_batch, const std::vector &int_feed_name, - const std::vector &fetch_name) { + const std::vector &fetch_name, + PredictorRes &predict_res_batch, + const int &pid) { int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); - std::vector>> fetch_result_batch; - if (fetch_name.size() == 0) { - return fetch_result_batch; - } + predict_res_batch._int64_map.clear(); + predict_res_batch._float_map.clear(); Timer timeline; int64_t preprocess_start = timeline.TimeStampUS(); - fetch_result_batch.resize(batch_size); int fetch_name_num = fetch_name.size(); - for (int bi = 0; bi < batch_size; bi++) { - fetch_result_batch[bi].resize(fetch_name_num); - } _api.thrd_clear(); _predictor = _api.fetch_predictor("general_model"); @@ -371,20 +367,36 @@ std::vector>> PredictorClient::batch_predict( } else { client_infer_end = timeline.TimeStampUS(); postprocess_start = client_infer_end; - + for (auto &name : fetch_name) { + predict_res_batch._int64_map[name].resize(batch_size); + predict_res_batch._float_map[name].resize(batch_size); + } for (int bi = 0; bi < batch_size; bi++) { for (auto &name : fetch_name) { int idx = _fetch_name_to_idx[name]; int len = res.insts(bi).tensor_array(idx).data_size(); - VLOG(2) << "fetch name: " << name; - VLOG(2) << "tensor data size: " << len; - fetch_result_batch[bi][idx].resize(len); - VLOG(2) - << "fetch name " << name << " index " << idx << " first data " - << *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str(); - /* - TBA - */ + if (_fetch_name_to_type[name] == 0) { + int len = res.insts(bi).tensor_array(idx).int64_data_size(); + VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len; + predict_res_batch._int64_map[name][bi].resize(len); + VLOG(2) << "fetch name " << name << " index " << idx << " first data " + << res.insts(bi).tensor_array(idx).int64_data(0); + for (int i = 0; i < len; ++i) { + predict_res_batch._int64_map[name][bi][i] = + res.insts(bi).tensor_array(idx).int64_data(i); + } + } else if (_fetch_name_to_type[name] == 1) { + int len = res.insts(bi).tensor_array(idx).float_data_size(); + VLOG(2) << "fetch tensor : " << name + << " type: float32 len : " << len; + predict_res_batch._float_map[name][bi].resize(len); + VLOG(2) << "fetch name " << name << " index " << idx << " first data " + << res.insts(bi).tensor_array(idx).float_data(0); + for (int i = 0; i < len; ++i) { + predict_res_batch._float_map[name][bi][i] = + res.insts(bi).tensor_array(idx).float_data(i); + } + } } } postprocess_end = timeline.TimeStampUS(); @@ -393,6 +405,7 @@ std::vector>> PredictorClient::batch_predict( if (FLAGS_profile_client) { std::ostringstream oss; oss << "PROFILE\t" + << "pid:" << pid << "\t" << "prepro_0:" << preprocess_start << " " << "prepro_1:" << preprocess_end << " " << "client_infer_0:" << client_infer_start << " " @@ -411,7 +424,7 @@ std::vector>> PredictorClient::batch_predict( fprintf(stderr, "%s\n", oss.str().c_str()); } - return fetch_result_batch; + return 0; } } // namespace general_model diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index 287fa2dcb55344a6e71bf8e76171de5f94e89de5..0d0ca7bd8fab7785e96ad11e43b109bf118d5c52 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -90,12 +90,16 @@ PYBIND11_MODULE(serving_client, m) { const std::vector>> &int_feed_batch, const std::vector &int_feed_name, - const std::vector &fetch_name) { + const std::vector &fetch_name, + PredictorRes &predict_res_batch, + const int &pid) { return self.batch_predict(float_feed_batch, float_feed_name, int_feed_batch, int_feed_name, - fetch_name); + fetch_name, + predict_res_batch, + pid); }); } diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index bea80f84bc9d29cabe4f31af612c694980b71d09..c77034b5ce9c811d3a5b0c42701b191870ffe45c 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -199,6 +199,7 @@ class Client(object): float_feed_names = [] fetch_names = [] counter = 0 + batch_size = len(feed_batch) for feed in feed_batch: int_slot = [] float_slot = [] @@ -221,15 +222,21 @@ class Client(object): if key in self.fetch_names_: fetch_names.append(key) - result_batch = self.client_handle_.batch_predict( + result_batch = self.result_handle_ + res = self.client_handle_.batch_predict( float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, - fetch_names) + fetch_names, result_batch, self.pid) result_map_batch = [] - for result in result_batch: + for index in range(batch_size): result_map = {} for i, name in enumerate(fetch_names): - result_map[name] = result[i] + if self.fetch_names_to_type_[name] == int_type: + result_map[name] = result_batch.get_int64_by_name(name)[ + index] + elif self.fetch_names_to_type_[name] == float_type: + result_map[name] = result_batch.get_float_by_name(name)[ + index] result_map_batch.append(result_map) return result_map_batch