提交 8970af96 编写于 作者: M MRXLT

reduce memory copy

上级 4c0f6820
...@@ -59,6 +59,18 @@ class PredictorRes { ...@@ -59,6 +59,18 @@ class PredictorRes {
std::map<std::string, std::vector<std::vector<float>>> _float_map; std::map<std::string, std::vector<std::vector<float>>> _float_map;
}; };
class PredictorResBatch {
public:
PredictorResBatch() {}
~PredictorResBatch() {}
public:
const PredictorRes& at(const int index) { return _predictres_vector[index]; }
public:
std::vector<PredictorRes> _predictres_vector;
};
class PredictorClient { class PredictorClient {
public: public:
PredictorClient() {} PredictorClient() {}
...@@ -91,6 +103,15 @@ class PredictorClient { ...@@ -91,6 +103,15 @@ class PredictorClient {
const std::vector<std::string>& int_feed_name, const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name); const std::vector<std::string>& fetch_name);
int batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>>& int_feed_batch,
const std::vector<std::string>& int_feed_name,
const std::vector<std::string>& fetch_name,
PredictorResBatch& predict_res, // NOLINT
const int& pid);
std::vector<PredictorRes> batch_predict( std::vector<PredictorRes> batch_predict(
const std::vector<std::vector<std::vector<float>>>& float_feed_batch, const std::vector<std::vector<std::vector<float>>>& float_feed_batch,
const std::vector<std::string>& float_feed_name, const std::vector<std::string>& float_feed_name,
......
...@@ -264,22 +264,20 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -264,22 +264,20 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
return 0; return 0;
} }
std::vector<PredictorRes> PredictorClient::batch_predict( int PredictorClient::batch_predict(
const std::vector<std::vector<std::vector<float>>> &float_feed_batch, const std::vector<std::vector<std::vector<float>>> &float_feed_batch,
const std::vector<std::string> &float_feed_name, const std::vector<std::string> &float_feed_name,
const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch, const std::vector<std::vector<std::vector<int64_t>>> &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorResBatch &predict_res_batch,
const int &pid) { const int &pid) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
std::vector<std::vector<std::vector<float>>> fetch_result_batch;
std::vector<PredictorRes> predict_res_batch;
Timer timeline; Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
predict_res_batch.resize(batch_size); predict_res_batch._predictres_vector.resize(batch_size);
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
_api.thrd_clear(); _api.thrd_clear();
...@@ -370,8 +368,8 @@ std::vector<PredictorRes> PredictorClient::batch_predict( ...@@ -370,8 +368,8 @@ std::vector<PredictorRes> PredictorClient::batch_predict(
postprocess_start = client_infer_end; postprocess_start = client_infer_end;
for (int bi = 0; bi < batch_size; bi++) { for (int bi = 0; bi < batch_size; bi++) {
predict_res_batch[bi]._int64_map.clear(); predict_res_batch._predictres_vector[bi]._int64_map.clear();
predict_res_batch[bi]._float_map.clear(); predict_res_batch._predictres_vector[bi]._float_map.clear();
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
int idx = _fetch_name_to_idx[name]; int idx = _fetch_name_to_idx[name];
...@@ -379,24 +377,26 @@ std::vector<PredictorRes> PredictorClient::batch_predict( ...@@ -379,24 +377,26 @@ std::vector<PredictorRes> PredictorClient::batch_predict(
if (_fetch_name_to_type[name] == 0) { if (_fetch_name_to_type[name] == 0) {
int len = res.insts(bi).tensor_array(idx).int64_data_size(); int len = res.insts(bi).tensor_array(idx).int64_data_size();
VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len; VLOG(2) << "fetch tensor : " << name << " type: int64 len : " << len;
predict_res_batch[bi]._int64_map[name].resize(1); predict_res_batch._predictres_vector[bi]._int64_map[name].resize(1);
predict_res_batch[bi]._int64_map[name][0].resize(len); predict_res_batch._predictres_vector[bi]._int64_map[name]
[0].resize(len);
VLOG(2) << "fetch name " << name << " index " << idx << " first data " VLOG(2) << "fetch name " << name << " index " << idx << " first data "
<< res.insts(bi).tensor_array(idx).int64_data(0); << res.insts(bi).tensor_array(idx).int64_data(0);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
predict_res_batch[bi]._int64_map[name][0][i] = predict_res_batch._predictres_vector[bi]._int64_map[name][0][i] =
res.insts(bi).tensor_array(idx).int64_data(i); res.insts(bi).tensor_array(idx).int64_data(i);
} }
} else if (_fetch_name_to_type[name] == 1) { } else if (_fetch_name_to_type[name] == 1) {
int len = res.insts(bi).tensor_array(idx).float_data_size(); int len = res.insts(bi).tensor_array(idx).float_data_size();
VLOG(2) << "fetch tensor : " << name VLOG(2) << "fetch tensor : " << name
<< " type: float32 len : " << len; << " type: float32 len : " << len;
predict_res_batch[bi]._float_map[name].resize(1); predict_res_batch._predictres_vector[bi]._float_map[name].resize(1);
predict_res_batch[bi]._float_map[name][0].resize(len); predict_res_batch._predictres_vector[bi]._float_map[name]
[0].resize(len);
VLOG(2) << "fetch name " << name << " index " << idx << " first data " VLOG(2) << "fetch name " << name << " index " << idx << " first data "
<< res.insts(bi).tensor_array(idx).float_data(0); << res.insts(bi).tensor_array(idx).float_data(0);
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
predict_res_batch[bi]._float_map[name][0][i] = predict_res_batch._predictres_vector[bi]._float_map[name][0][i] =
res.insts(bi).tensor_array(idx).float_data(i); res.insts(bi).tensor_array(idx).float_data(i);
} }
} }
...@@ -427,7 +427,7 @@ std::vector<PredictorRes> PredictorClient::batch_predict( ...@@ -427,7 +427,7 @@ std::vector<PredictorRes> PredictorClient::batch_predict(
fprintf(stderr, "%s\n", oss.str().c_str()); fprintf(stderr, "%s\n", oss.str().c_str());
} }
return predict_res_batch; return 0;
} }
} // namespace general_model } // namespace general_model
......
...@@ -41,6 +41,12 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -41,6 +41,12 @@ PYBIND11_MODULE(serving_client, m) {
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<PredictorResBatch>(m, "PredictorResBatch", py::buffer_protocol())
.def(py::init())
.def("at",
[](PredictorResBatch &self, int index) { return self.at(index); },
py::return_value_policy::reference);
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol()) py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init()) .def(py::init())
.def("init_gflags", .def("init_gflags",
...@@ -91,12 +97,14 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -91,12 +97,14 @@ PYBIND11_MODULE(serving_client, m) {
&int_feed_batch, &int_feed_batch,
const std::vector<std::string> &int_feed_name, const std::vector<std::string> &int_feed_name,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorResBatch &predict_res_batch,
const int &pid) { const int &pid) {
return self.batch_predict(float_feed_batch, return self.batch_predict(float_feed_batch,
float_feed_name, float_feed_name,
int_feed_batch, int_feed_batch,
int_feed_name, int_feed_name,
fetch_name, fetch_name,
predict_res_batch,
pid); pid);
}); });
} }
......
...@@ -89,6 +89,7 @@ class Client(object): ...@@ -89,6 +89,7 @@ class Client(object):
def load_client_config(self, path): def load_client_config(self, path):
from .serving_client import PredictorClient from .serving_client import PredictorClient
from .serving_client import PredictorRes from .serving_client import PredictorRes
from .serving_client import PredictorResBatch
model_conf = m_config.GeneralModelConfig() model_conf = m_config.GeneralModelConfig()
f = open(path, 'r') f = open(path, 'r')
model_conf = google.protobuf.text_format.Merge( model_conf = google.protobuf.text_format.Merge(
...@@ -99,6 +100,7 @@ class Client(object): ...@@ -99,6 +100,7 @@ class Client(object):
# get feed shapes, feed types # get feed shapes, feed types
# map feed names to index # map feed names to index
self.result_handle_ = PredictorRes() self.result_handle_ = PredictorRes()
self.result_batch_handle_ = PredictorResBatch()
self.client_handle_ = PredictorClient() self.client_handle_ = PredictorClient()
self.client_handle_.init(path) self.client_handle_.init(path)
read_env_flags = ["profile_client", "profile_server"] read_env_flags = ["profile_client", "profile_server"]
...@@ -180,6 +182,7 @@ class Client(object): ...@@ -180,6 +182,7 @@ class Client(object):
float_feed_names = [] float_feed_names = []
fetch_names = [] fetch_names = []
counter = 0 counter = 0
batch_size = len(feed_batch)
for feed in feed_batch: for feed in feed_batch:
int_slot = [] int_slot = []
float_slot = [] float_slot = []
...@@ -202,19 +205,21 @@ class Client(object): ...@@ -202,19 +205,21 @@ class Client(object):
if key in self.fetch_names_: if key in self.fetch_names_:
fetch_names.append(key) fetch_names.append(key)
result_batch = self.client_handle_.batch_predict( result_batch = self.result_batch_handle_
res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, int_slot_batch, int_feed_names, float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
fetch_names, self.pid) fetch_names, result_batch, self.pid)
result_map_batch = [] result_map_batch = []
for result in result_batch: for index in range(batch_size):
result = result_batch.at(index)
result_map = {} result_map = {}
for i, name in enumerate(fetch_names): for i, name in enumerate(fetch_names):
if self.fetch_names_to_type_[name] == int_type: if self.fetch_names_to_type_[name] == int_type:
result_map[name] = result.get_int64_by_name(name)[0] result_map[name] = result.get_int64_by_name(name)[0]
elif self.fetch_names_to_type_[name] == float_type: elif self.fetch_names_to_type_[name] == float_type:
result_map[name] = result.get_float_by_name(name)[0] result_map[name] = result.get_float_by_name(name)[0]
result_map_batch.appenf(result_map) result_map_batch.append(result_map)
return result_map_batch return result_map_batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册