diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index f7363d46afc4031acff437425181cdb7d3b61e55..3567fbdaef75adf6dbf759056c6b4c6d062d1ca9 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -58,13 +59,12 @@ class PredictorClient { const std::vector& int_feed_name, const std::vector& fetch_name); - std::vector>> predict_for_batch( + std::vector>> batch_predict( const std::vector>>& float_feed_batch, const std::vector& float_feed_name, const std::vector>>& int_feed_batch, const std::vector& int_feed_name, - const std::vector& fetch_name, - const int64_t& batch_size); + const std::vector& fetch_name); std::vector> predict_with_profile( const std::vector>& float_feed, diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index 4c22fa48c05adb0b72216f6e5133e4be0927ca5f..a593117db992a76f9a223cc15a768c92601dc879 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -171,13 +171,13 @@ std::vector> PredictorClient::predict( return fetch_result; } -std::vector>> PredictorClient::predict_for_batch( +std::vector>> PredictorClient::batch_predict( const std::vector>> &float_feed_batch, const std::vector &float_feed_name, const std::vector>> &int_feed_batch, const std::vector &int_feed_name, - const std::vector &fetch_name, - const int64_t &batch_size) { + const std::vector &fetch_name) { + int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); std::vector>> fetch_result_batch; if (fetch_name.size() == 0) { return fetch_result_batch; @@ -229,6 +229,8 @@ std::vector>> PredictorClient::predict_for_batch( tensor->add_shape(_shape[idx][j]); } tensor->set_elem_type(0); + VLOG(3) << "feed var name " << name << " index " << vec_idx + << "first data " << int_feed[vec_idx][0]; for (int j = 0; j < int_feed[vec_idx].size(); ++j) { tensor->add_data(const_cast(reinterpret_cast( &(int_feed[vec_idx][j]))), @@ -248,10 +250,13 @@ std::vector>> PredictorClient::predict_for_batch( for (int bi = 0; bi < batch_size; bi++) { for (auto &name : fetch_name) { int idx = _fetch_name_to_idx[name]; - int len = res.insts(0).tensor_array(idx).data_size(); + int len = res.insts(bi).tensor_array(idx).data_size(); VLOG(3) << "fetch name: " << name; VLOG(3) << "tensor data size: " << len; fetch_result_batch[bi][idx].resize(len); + VLOG(3) + << "fetch name " << name << " index " << idx << " first data " + << *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str(); for (int i = 0; i < len; ++i) { fetch_result_batch[bi][idx][i] = *(const float *)res.insts(bi).tensor_array(idx).data(i).c_str(); diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index ac82d91984027d912774244358105dafec30301c..caa88acbcdc514bdcf94fbea2ee9458105d7bbd7 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -57,7 +57,7 @@ PYBIND11_MODULE(serving_client, m) { fetch_name); }) - .def("predict_for_batch", + .def("batch_predict", [](PredictorClient &self, const std::vector>> &float_feed_batch, @@ -65,14 +65,12 @@ PYBIND11_MODULE(serving_client, m) { const std::vector>> &int_feed_batch, const std::vector &int_feed_name, - const std::vector &fetch_name, - const int64_t &batch_size) { - return self.predict_for_batch(float_feed_batch, - float_feed_name, - int_feed_batch, - int_feed_name, - fetch_name, - batch_size); + const std::vector &fetch_name) { + return self.batch_predict(float_feed_batch, + float_feed_name, + int_feed_batch, + int_feed_name, + fetch_name); }); } diff --git a/python/examples/imdb/test_client_batch.py b/python/examples/imdb/test_client_batch.py index 6686df272d62fe62315417470d2ce747a4b3c9ff..913d736d8cfe71ac394c6da1cf3ddbe712bf68bf 100644 --- a/python/examples/imdb/test_client_batch.py +++ b/python/examples/imdb/test_client_batch.py @@ -19,7 +19,7 @@ from multiprocessing import Pool import time -def predict_for_batch(batch_size=4): +def batch_predict(batch_size=4): client = Client() client.load_client_config(conf_file) client.connect(["127.0.0.1:8010"]) @@ -33,7 +33,7 @@ def predict_for_batch(batch_size=4): fetch = ["acc", "cost", "prediction"] feed_batch.append(feed) if len(feed_batch) == batch_size: - fetch_batch = client.predict_for_batch( + fetch_batch = client.batch_predict( feed_batch=feed_batch, fetch=fetch) for i in range(batch_size): print("{} {}".format(fetch_batch[i]["prediction"][1], @@ -47,4 +47,4 @@ def predict_for_batch(batch_size=4): if __name__ == '__main__': conf_file = sys.argv[1] batch_size = int(sys.argv[2]) - predict_for_batch(batch_size) + batch_predict(batch_size) diff --git a/python/paddle_serving/serving_client/__init__.py b/python/paddle_serving/serving_client/__init__.py index 200b99e428768806b5d0fcbe6da608d912218158..e21b5c0bdd74883d050f50275e16b2cbedf712f0 100644 --- a/python/paddle_serving/serving_client/__init__.py +++ b/python/paddle_serving/serving_client/__init__.py @@ -154,8 +154,7 @@ class Client(object): return result_map - def predict_for_batch(self, feed_batch=[], fetch=[]): - batch_size = len(feed_batch) + def batch_predict(self, feed_batch=[], fetch=[]): int_slot_batch = [] float_slot_batch = [] int_feed_names = [] @@ -184,9 +183,9 @@ class Client(object): if key in self.fetch_names_: fetch_names.append(key) - result_batch = self.client_handle_.predict_for_batch( + result_batch = self.client_handle_.batch_predict( float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, - fetch_names, batch_size) + fetch_names) result_map_batch = [] for result in result_batch: