diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 80f316aed7c8b814739e997c8808539b5322117f..6b3b65857341ba7fe811730c4f77fde86f91d181 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pybind/inference_api.h" +#include #include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" @@ -51,6 +53,81 @@ void BindAnalysisPredictor(py::module *m); #ifdef PADDLE_WITH_MKLDNN void BindMkldnnQuantizerConfig(py::module *m); #endif + +template +PaddleBuf PaddleBufCreate(py::array_t data) { + PaddleBuf buf(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); + return buf; +} + +template +void PaddleBufReset(PaddleBuf &buf, py::array_t data) { // NOLINT + buf.Resize(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); +} + +template +PaddleDType PaddleTensorGetDType(); + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::INT32; +} + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::INT64; +} + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::FLOAT32; +} + +template +PaddleTensor PaddleTensorCreate( + py::array_t data, const std::string name = "", + const std::vector> &lod = {}, bool copy = false) { + PaddleTensor tensor; + + if (copy) { + PaddleBuf buf(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); + tensor.data = std::move(buf); + } else { + tensor.data = PaddleBuf(data.mutable_data(), data.size() * sizeof(T)); + } + + tensor.dtype = PaddleTensorGetDType(); + tensor.name = name; + tensor.lod = lod; + tensor.shape.resize(data.ndim()); + std::copy_n(data.shape(), data.ndim(), tensor.shape.begin()); + + return tensor; +} + +py::array PaddleTensorGetData(PaddleTensor &tensor) { // NOLINT + py::dtype dt; + switch (tensor.dtype) { + case PaddleDType::INT32: + dt = py::dtype::of(); + break; + case PaddleDType::INT64: + dt = py::dtype::of(); + break; + case PaddleDType::FLOAT32: + dt = py::dtype::of(); + break; + default: + LOG(FATAL) << "unsupported dtype"; + } + return py::array(dt, {tensor.shape}, tensor.data.data()); +} } // namespace void BindInferenceApi(py::module *m) { @@ -89,23 +166,39 @@ void BindPaddleBuf(py::module *m) { std::memcpy(buf.data(), static_cast(data.data()), buf.length()); return buf; })) - .def(py::init([](std::vector &data) { - auto buf = PaddleBuf(data.size() * sizeof(int64_t)); - std::memcpy(buf.data(), static_cast(data.data()), buf.length()); - return buf; - })) + .def(py::init(&PaddleBufCreate)) + .def(py::init(&PaddleBufCreate)) + .def(py::init(&PaddleBufCreate)) .def("resize", &PaddleBuf::Resize) .def("reset", [](PaddleBuf &self, std::vector &data) { self.Resize(data.size() * sizeof(float)); std::memcpy(self.data(), data.data(), self.length()); }) - .def("reset", - [](PaddleBuf &self, std::vector &data) { - self.Resize(data.size() * sizeof(int64_t)); - std::memcpy(self.data(), data.data(), self.length()); - }) + .def("reset", &PaddleBufReset) + .def("reset", &PaddleBufReset) + .def("reset", &PaddleBufReset) .def("empty", &PaddleBuf::empty) + .def("tolist", + [](PaddleBuf &self, const std::string &dtype) -> py::list { + py::list l; + if (dtype == "int32") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(int32_t); + l = py::cast(std::vector(data, data + size)); + } else if (dtype == "int64") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(int64_t); + l = py::cast(std::vector(data, data + size)); + } else if (dtype == "float32") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(float); + l = py::cast(std::vector(data, data + size)); + } else { + LOG(FATAL) << "unsupported dtype"; + } + return l; + }) .def("float_data", [](PaddleBuf &self) -> std::vector { auto *data = static_cast(self.data()); @@ -127,6 +220,19 @@ void BindPaddleBuf(py::module *m) { void BindPaddleTensor(py::module *m) { py::class_(*m, "PaddleTensor") .def(py::init<>()) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = false) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = false) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = false) + .def("as_ndarray", &PaddleTensorGetData) .def_readwrite("name", &PaddleTensor::name) .def_readwrite("shape", &PaddleTensor::shape) .def_readwrite("data", &PaddleTensor::data)