提交 355ce782 编写于 作者: B barrierye

release gil in func batch_predict && get lod and shape from c++ without copy

上级 573983e3
...@@ -78,12 +78,18 @@ class ModelRes { ...@@ -78,12 +78,18 @@ class ModelRes {
std::vector<float>&& get_float_by_name_with_rv(const std::string& name) { std::vector<float>&& get_float_by_name_with_rv(const std::string& name) {
return std::move(_float_value_map[name]); return std::move(_float_value_map[name]);
} }
const std::vector<int>& get_shape(const std::string& name) { const std::vector<int>& get_shape_by_name(const std::string& name) {
return _shape_map[name]; return _shape_map[name];
} }
const std::vector<int>& get_lod(const std::string& name) { std::vector<int>&& get_shape_by_name_with_rv(const std::string& name) {
return std::move(_shape_map[name]);
}
const std::vector<int>& get_lod_by_name(const std::string& name) {
return _lod_map[name]; return _lod_map[name];
} }
std::vector<int>&& get_lod_by_name_with_rv(const std::string& name) {
return std::move(_lod_map[name]);
}
void set_engine_name(const std::string& engine_name) { void set_engine_name(const std::string& engine_name) {
_engine_name = engine_name; _engine_name = engine_name;
} }
...@@ -139,13 +145,21 @@ class PredictorRes { ...@@ -139,13 +145,21 @@ class PredictorRes {
const std::string& name) { const std::string& name) {
return std::move(_models[model_idx].get_float_by_name_with_rv(name)); return std::move(_models[model_idx].get_float_by_name_with_rv(name));
} }
const std::vector<int>& get_shape(const int model_idx, const std::vector<int>& get_shape_by_name(const int model_idx,
const std::string& name) { const std::string& name) {
return _models[model_idx].get_shape(name); return _models[model_idx].get_shape_by_name(name);
}
const std::vector<int>&& get_shape_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_shape_by_name_with_rv(name));
} }
const std::vector<int>& get_lod(const int model_idx, const std::vector<int>& get_lod_by_name(const int model_idx,
const std::string& name) { const std::string& name) {
return _models[model_idx].get_lod(name); return _models[model_idx].get_lod_by_name(name);
}
const std::vector<int>&& get_lod_by_name_with_rv(const int model_idx,
const std::string& name) {
return std::move(_models[model_idx].get_lod_by_name_with_rv(name));
} }
void add_model_res(ModelRes&& res) { void add_model_res(ModelRes&& res) {
_engine_names.push_back(res.engine_name()); _engine_names.push_back(res.engine_name());
......
...@@ -51,14 +51,22 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -51,14 +51,22 @@ PYBIND11_MODULE(serving_client, m) {
}) })
.def("get_shape", .def("get_shape",
[](PredictorRes &self, int model_idx, std::string &name) { [](PredictorRes &self, int model_idx, std::string &name) {
return self.get_shape(model_idx, name); std::vector<int> *ptr = new std::vector<int>(
}, std::move(self.get_shape_by_name_with_rv(model_idx, name)));
py::return_value_policy::reference) auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<int> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("get_lod", .def("get_lod",
[](PredictorRes &self, int model_idx, std::string &name) { [](PredictorRes &self, int model_idx, std::string &name) {
return self.get_lod(model_idx, name); std::vector<int> *ptr = new std::vector<int>(
}, std::move(self.get_lod_by_name_with_rv(model_idx, name)));
py::return_value_policy::reference) auto capsule = py::capsule(ptr, [](void *p) {
delete reinterpret_cast<std::vector<int> *>(p);
});
return py::array(ptr->size(), ptr->data(), capsule);
})
.def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); }) .def("variant_tag", [](PredictorRes &self) { return self.variant_tag(); })
.def("get_engine_names", .def("get_engine_names",
[](PredictorRes &self) { return self.get_engine_names(); }); [](PredictorRes &self) { return self.get_engine_names(); });
...@@ -109,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -109,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) {
fetch_name, fetch_name,
predict_res_batch, predict_res_batch,
pid); pid);
}) },
py::call_guard<py::gil_scoped_release>())
.def("numpy_predict", .def("numpy_predict",
[](PredictorClient &self, [](PredictorClient &self,
const std::vector<std::vector<py::array_t<float>>> const std::vector<std::vector<py::array_t<float>>>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册