提交 6ea67de5 编写于 作者: D Dong Daxiang 提交者: GitHub

Merge pull request #256 from MRXLT/general-server-batch

fix batch predict
...@@ -84,19 +84,14 @@ class PredictorClient { ...@@ -84,19 +84,14 @@ class PredictorClient {
PredictorRes& predict_res, // NOLINT PredictorRes& predict_res, // NOLINT
const int& pid); const int& pid);
std::vector<std::vector<float>> predict( int batch_predict(
const std::vector<std::vector<float>>& float_feed,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<int64_t>>& int_feed,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name);
std::vector<std::vector<std::vector<float>>> batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch, const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name, const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch, const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name, const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name); const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT
const int& pid);
private: private:
PredictorApi _api; PredictorApi _api;
......
...@@ -264,26 +264,22 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -264,26 +264,22 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
return 0; return 0;
} }
std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( int PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch, const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name, const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch, const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) { const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
std::vector<std::vector<std::vector<float>>> fetch_result_batch;
if (fetch_name.size() == 0) {
return fetch_result_batch;
}
predict_res_batch._int64_map.clear();
predict_res_batch._float_map.clear();
Timer timeline; Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
fetch_result_batch.resize(batch_size);
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
for (int bi = 0; bi < batch_size; bi++) {
fetch_result_batch[bi].resize(fetch_name_num);
}
_api.thrd_clear(); _api.thrd_clear();
_predictor = _api.fetch_predictor("general_model"); _predictor = _api.fetch_predictor("general_model");
...@@ -371,20 +367,36 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -371,20 +367,36 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
} else { } else {
client_infer_end = timeline.TimeStampUS(); client_infer_end = timeline.TimeStampUS();
postprocess_start = client_infer_end; postprocess_start = client_infer_end;
for (auto &name : fetch_name) {
predict_res_batch._int64_map[name].resize(batch_size);
predict_res_batch._float_map[name].resize(batch_size);
}
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; int idx = _fetch_name_to_idx[name];
int len = res.insts(bi).tensor_array(idx).data_size(); int len = res.insts(bi).tensor_array(idx).data_size();
VLOG(2) << "fetch name: " << name; if (_fetch_name_to_type[name] == 0) {
VLOG(2) << "tensor data size: " << len; int len = res.insts(bi).tensor_array(idx).int64_data_size();
fetch_result_batch[bi][idx].resize(len); VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
VLOG(2) predict_res_batch._int64_map[name][bi].resize(len);
<< "fetch name " << name << " index " << idx << " first data " VLOG(2) << "fetch name " << name << " index " << idx << " first data "
<< *(const float *)res.insts(bi).tensor_array(idx).data(0).c_str(); << res.insts(bi).tensor_array(idx).int64_data(0);
/* for (int i = 0; i < len; ++i) {
TBA predict_res_batch._int64_map[name][bi][i] =
*/ res.insts(bi).tensor_array(idx).int64_data(i);
}
} else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(bi).tensor_array(idx).float_data_size();
VLOG(2) << "fetch tensor : " << name
<< " type: float32 len : " << len;
predict_res_batch._float_map[name][bi].resize(len);
VLOG(2) << "fetch name " << name << " index " << idx << " first data "
<< res.insts(bi).tensor_array(idx).float_data(0);
for (int i = 0; i < len; ++i) {
predict_res_batch._float_map[name][bi][i] =
res.insts(bi).tensor_array(idx).float_data(i);
}
}
} }
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
...@@ -393,6 +405,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -393,6 +405,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
if (FLAGS_profile_client) { if (FLAGS_profile_client) {
std::ostringstream oss; std::ostringstream oss;
oss << "PROFILE\t" oss << "PROFILE\t"
<< "pid:" << pid << "\t"
<< "prepro_0:" << preprocess_start << " " << "prepro_0:" << preprocess_start << " "
<< "prepro_1:" << preprocess_end << " " << "prepro_1:" << preprocess_end << " "
<< "client_infer_0:" << client_infer_start << " " << "client_infer_0:" << client_infer_start << " "
...@@ -411,7 +424,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict( ...@@ -411,7 +424,7 @@ std::vector<std::vector<std::vector<float>>> PredictorClient::batch_predict(
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
return fetch_result_batch; return 0;
} }
} // namespace general_model } // namespace general_model
......
...@@ -90,12 +90,16 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -90,12 +90,16 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<std::vector<int64_t>>> const std::vector<std::vector<std::vector<int64_t>>>
&int_feed_batch, &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name) { const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch,
const int &pid) {
return self.batch_predict(float_feed_batch, return self.batch_predict(float_feed_batch,
float_feed_name, float_feed_name,
int_feed_batch, int_feed_batch,
int_feed_name, int_feed_name,
fetch_name); fetch_name,
predict_res_batch,
pid);
}); });
} }
......
...@@ -199,6 +199,7 @@ class Client(object): ...@@ -199,6 +199,7 @@ class Client(object):
float_feed_names = [] float_feed_names = []
fetch_names = [] fetch_names = []
counter = 0 counter = 0
batch_size = len(feed_batch)
for feed in feed_batch: for feed in feed_batch:
int_slot = [] int_slot = []
float_slot = [] float_slot = []
...@@ -221,15 +222,21 @@ class Client(object): ...@@ -221,15 +222,21 @@ class Client(object):
if key in self.fetch_names_: if key in self.fetch_names_:
fetch_names.append(key) fetch_names.append(key)
result_batch = self.client_handle_.batch_predict( result_batch = self.result_handle_
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
fetch_names) fetch_names, result_batch, self.pid)
result_map_batch = [] result_map_batch = []
for result in result_batch: for index in range(batch_size):
result_map = {} result_map = {}
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
result_map[name] = result[i] if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result_batch.get_int64_by_name(name)[
index]
elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result_batch.get_float_by_name(name)[
index]
result_map_batch.append(result_map) result_map_batch.append(result_map)
return result_map_batch return result_map_batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册