From 5df0991f2b28e16ba37001022af6976d6161d840 Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Wed, 18 Sep 2019 10:29:44 +0800 Subject: [PATCH] Cherry-pick #18817 and #19353. Python inference api update and add unittest (#19831) * python inference enable_memory_optim(#18817) python inference API support enable_memory_optim * Python infer api update and add unit test (#19353) * python inference api supports numpy and add unit test, fix unit test fail in test_slim_int8_googlenet and test_slim_int8_mobilenet --- paddle/fluid/pybind/inference_api.cc | 153 +++++++++++++++--- .../mkldnn_post_training_strategy.py | 17 +- .../tests/unittests/test_inference_api.py | 74 +++++++++ 3 files changed, 212 insertions(+), 32 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_inference_api.py diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index ae7fcad7847..812fa9db1af 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pybind/inference_api.h" +#include #include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include "paddle/fluid/inference/api/analysis_predictor.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" @@ -37,20 +39,97 @@ using paddle::NativeConfig; using paddle::NativePaddlePredictor; using paddle::AnalysisPredictor; -static void BindPaddleDType(py::module *m); -static void BindPaddleBuf(py::module *m); -static void BindPaddleTensor(py::module *m); -static void BindPaddlePlace(py::module *m); -static void BindPaddlePredictor(py::module *m); -static void BindNativeConfig(py::module *m); -static void BindNativePredictor(py::module *m); -static void BindAnalysisConfig(py::module *m); -static void BindAnalysisPredictor(py::module *m); +namespace { +void BindPaddleDType(py::module *m); +void BindPaddleBuf(py::module *m); +void BindPaddleTensor(py::module *m); +void BindPaddlePlace(py::module *m); +void BindPaddlePredictor(py::module *m); +void BindNativeConfig(py::module *m); +void BindNativePredictor(py::module *m); +void BindAnalysisConfig(py::module *m); +void BindAnalysisPredictor(py::module *m); #ifdef PADDLE_WITH_MKLDNN -static void BindMkldnnQuantizerConfig(py::module *m); +void BindMkldnnQuantizerConfig(py::module *m); #endif +template +PaddleBuf PaddleBufCreate(py::array_t data) { + PaddleBuf buf(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); + return buf; +} + +template +void PaddleBufReset(PaddleBuf &buf, py::array_t data) { // NOLINT + buf.Resize(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); +} + +template +PaddleDType PaddleTensorGetDType(); + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::INT32; +} + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::INT64; +} + +template <> +PaddleDType PaddleTensorGetDType() { + return PaddleDType::FLOAT32; +} + +template +PaddleTensor PaddleTensorCreate( + py::array_t data, const std::string name = "", + const std::vector> &lod = {}, bool copy = true) { + PaddleTensor tensor; + + if (copy) { + PaddleBuf buf(data.size() * sizeof(T)); + std::copy_n(static_cast(data.mutable_data()), data.size(), + static_cast(buf.data())); + tensor.data = std::move(buf); + } else { + tensor.data = PaddleBuf(data.mutable_data(), data.size() * sizeof(T)); + } + + tensor.dtype = PaddleTensorGetDType(); + tensor.name = name; + tensor.lod = lod; + tensor.shape.resize(data.ndim()); + std::copy_n(data.shape(), data.ndim(), tensor.shape.begin()); + + return tensor; +} + +py::array PaddleTensorGetData(PaddleTensor &tensor) { // NOLINT + py::dtype dt; + switch (tensor.dtype) { + case PaddleDType::INT32: + dt = py::dtype::of(); + break; + case PaddleDType::INT64: + dt = py::dtype::of(); + break; + case PaddleDType::FLOAT32: + dt = py::dtype::of(); + break; + default: + LOG(FATAL) << "unsupported dtype"; + } + return py::array(dt, {tensor.shape}, tensor.data.data()); +} +} // namespace + void BindInferenceApi(py::module *m) { BindPaddleDType(m); BindPaddleBuf(m); @@ -71,6 +150,7 @@ void BindInferenceApi(py::module *m) { m->def("paddle_dtype_size", &paddle::PaddleDtypeSize); } +namespace { void BindPaddleDType(py::module *m) { py::enum_(*m, "PaddleDType") .value("FLOAT32", PaddleDType::FLOAT32) @@ -86,23 +166,39 @@ void BindPaddleBuf(py::module *m) { std::memcpy(buf.data(), static_cast(data.data()), buf.length()); return buf; })) - .def(py::init([](std::vector &data) { - auto buf = PaddleBuf(data.size() * sizeof(int64_t)); - std::memcpy(buf.data(), static_cast(data.data()), buf.length()); - return buf; - })) + .def(py::init(&PaddleBufCreate)) + .def(py::init(&PaddleBufCreate)) + .def(py::init(&PaddleBufCreate)) .def("resize", &PaddleBuf::Resize) .def("reset", [](PaddleBuf &self, std::vector &data) { self.Resize(data.size() * sizeof(float)); std::memcpy(self.data(), data.data(), self.length()); }) - .def("reset", - [](PaddleBuf &self, std::vector &data) { - self.Resize(data.size() * sizeof(int64_t)); - std::memcpy(self.data(), data.data(), self.length()); - }) + .def("reset", &PaddleBufReset) + .def("reset", &PaddleBufReset) + .def("reset", &PaddleBufReset) .def("empty", &PaddleBuf::empty) + .def("tolist", + [](PaddleBuf &self, const std::string &dtype) -> py::list { + py::list l; + if (dtype == "int32") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(int32_t); + l = py::cast(std::vector(data, data + size)); + } else if (dtype == "int64") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(int64_t); + l = py::cast(std::vector(data, data + size)); + } else if (dtype == "float32") { + auto *data = static_cast(self.data()); + auto size = self.length() / sizeof(float); + l = py::cast(std::vector(data, data + size)); + } else { + LOG(FATAL) << "unsupported dtype"; + } + return l; + }) .def("float_data", [](PaddleBuf &self) -> std::vector { auto *data = static_cast(self.data()); @@ -124,6 +220,19 @@ void BindPaddleBuf(py::module *m) { void BindPaddleTensor(py::module *m) { py::class_(*m, "PaddleTensor") .def(py::init<>()) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = true) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = true) + .def(py::init(&PaddleTensorCreate), py::arg("data"), + py::arg("name") = "", + py::arg("lod") = std::vector>(), + py::arg("copy") = true) + .def("as_ndarray", &PaddleTensorGetData) .def_readwrite("name", &PaddleTensor::name) .def_readwrite("shape", &PaddleTensor::shape) .def_readwrite("data", &PaddleTensor::data) @@ -227,6 +336,8 @@ void BindAnalysisConfig(py::module *m) { .def("switch_ir_optim", &AnalysisConfig::SwitchIrOptim, py::arg("x") = true) .def("ir_optim", &AnalysisConfig::ir_optim) + .def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim) + .def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir) .def("switch_use_feed_fetch_ops", &AnalysisConfig::SwitchUseFeedFetchOps, py::arg("x") = true) .def("use_feed_fetch_ops_enabled", @@ -312,6 +423,6 @@ void BindAnalysisPredictor(py::module *m) { .def("SaveOptimModel", &AnalysisPredictor::SaveOptimModel, py::arg("dir")); } - +} // namespace } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/contrib/slim/quantization/mkldnn_post_training_strategy.py b/python/paddle/fluid/contrib/slim/quantization/mkldnn_post_training_strategy.py index dcaabfadedf..1b34983001e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/mkldnn_post_training_strategy.py +++ b/python/paddle/fluid/contrib/slim/quantization/mkldnn_post_training_strategy.py @@ -86,21 +86,16 @@ class MKLDNNPostTrainingQuantStrategy(Strategy): # TODO (Intel) Remove limits that MKLDNNPostTrainingQuantStrategy # only support image classification num_images = len(data) - images = core.PaddleTensor() - images.name = "x" - images.shape = [num_images, ] + list(data[0][0].shape) - images.dtype = core.PaddleDType.FLOAT32 image_data = [img.tolist() for (img, _) in data] - image_data = np.array(image_data).astype("float32") + image_data = np.array(image_data).astype("float32").reshape( + [num_images, ] + list(data[0][0].shape)) image_data = image_data.ravel() - images.data = core.PaddleBuf(image_data.tolist()) + images = core.PaddleTensor(image_data, "x") + images.shape = [num_images, ] + list(data[0][0].shape) - labels = core.PaddleTensor() - labels.name = "y" - labels.shape = [num_images, 1] - labels.dtype = core.PaddleDType.INT64 label_data = [label for (_, label) in data] - labels.data = core.PaddleBuf(label_data) + labels = core.PaddleTensor( + np.array(label_data).astype("int64").reshape([num_images, 1]), "y") warmup_data = [images, labels] diff --git a/python/paddle/fluid/tests/unittests/test_inference_api.py b/python/paddle/fluid/tests/unittests/test_inference_api.py new file mode 100644 index 00000000000..c6491b719a3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_inference_api.py @@ -0,0 +1,74 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os, shutil +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.core import PaddleTensor +from paddle.fluid.core import PaddleDType + + +class TestInferenceApi(unittest.TestCase): + def test_inference_api(self): + tensor32 = np.random.randint(10, 20, size=[20, 2]).astype('int32') + paddletensor32 = PaddleTensor(tensor32) + value32 = np.array(paddletensor32.data.int32_data()).reshape(*[20, 2]) + dtype32 = paddletensor32.dtype + self.assertEqual(value32.all(), tensor32.all()) + self.assertEqual(dtype32, PaddleDType.INT32) + self.assertEqual( + type(paddletensor32.data.tolist('int32')), type(tensor32.tolist())) + self.assertEqual( + paddletensor32.data.tolist('int32'), tensor32.ravel().tolist()) + self.assertEqual(type(paddletensor32.as_ndarray()), type(tensor32)) + paddletensor32.data.reset(tensor32) + self.assertEqual(paddletensor32.as_ndarray().all(), tensor32.all()) + + tensor64 = np.random.randint(10, 20, size=[20, 2]).astype('int64') + paddletensor64 = PaddleTensor(tensor64) + value64 = np.array(paddletensor64.data.int64_data()).reshape(*[20, 2]) + dtype64 = paddletensor64.dtype + self.assertEqual(value64.all(), tensor64.all()) + self.assertEqual(dtype64, PaddleDType.INT64) + self.assertEqual( + type(paddletensor64.data.tolist('int64')), type(tensor64.tolist())) + self.assertEqual( + paddletensor64.data.tolist('int64'), tensor64.ravel().tolist()) + self.assertEqual(type(paddletensor64.as_ndarray()), type(tensor64)) + paddletensor64.data.reset(tensor64) + self.assertEqual(paddletensor64.as_ndarray().all(), tensor64.all()) + + tensor_float = np.random.randn(20, 2).astype('float32') + paddletensor_float = PaddleTensor(tensor_float) + value_float = np.array(paddletensor_float.data.float_data()).reshape( + *[20, 2]) + dtype_float = paddletensor_float.dtype + self.assertEqual(value_float.all(), tensor_float.all()) + self.assertEqual(dtype_float, PaddleDType.FLOAT32) + self.assertEqual( + type(paddletensor_float.data.tolist('float32')), + type(tensor_float.tolist())) + self.assertEqual( + paddletensor_float.data.tolist('float32'), + tensor_float.ravel().tolist()) + self.assertEqual( + type(paddletensor_float.as_ndarray()), type(tensor_float)) + paddletensor_float.data.reset(tensor_float) + self.assertEqual(paddletensor_float.as_ndarray().all(), + tensor_float.all()) + + +if __name__ == '__main__': + unittest.main() -- GitLab