未验证 提交 63212541 编写于 作者: W Wilber 提交者: GitHub

Refine python inference api (#26958)

上级 eb65877c
......@@ -73,7 +73,7 @@ class PD_INFER_DECL Tensor {
class PD_INFER_DECL Predictor {
public:
Predictor() = default;
Predictor() = delete;
~Predictor() {}
// Use for clone
explicit Predictor(std::unique_ptr<paddle::PaddlePredictor>&& pred)
......
......@@ -60,6 +60,9 @@ void BindAnalysisConfig(py::module *m);
void BindAnalysisPredictor(py::module *m);
void BindZeroCopyTensor(py::module *m);
void BindPaddlePassBuilder(py::module *m);
void BindPaddleInferPredictor(py::module *m);
void BindPaddleInferTensor(py::module *m);
void BindPredictorPool(py::module *m);
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m);
......@@ -139,6 +142,15 @@ void ZeroCopyTensorCreate(ZeroCopyTensor &tensor, // NOLINT
tensor.copy_from_cpu(static_cast<const T *>(data.data()));
}
template <typename T>
void PaddleInferTensorCreate(paddle_infer::Tensor &tensor, // NOLINT
py::array_t<T> data) {
std::vector<int> shape;
std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
tensor.Reshape(std::move(shape));
tensor.CopyFromCpu(static_cast<const T *>(data.data()));
}
size_t PaddleGetDTypeSize(PaddleDType dt) {
size_t size{0};
switch (dt) {
......@@ -183,6 +195,30 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
return array;
}
py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) { // NOLINT
py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
auto tensor_shape = tensor.shape();
py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
py::array array(dt, std::move(shape));
switch (tensor.type()) {
case PaddleDType::INT32:
tensor.CopyToCpu(static_cast<int32_t *>(array.mutable_data()));
break;
case PaddleDType::INT64:
tensor.CopyToCpu(static_cast<int64_t *>(array.mutable_data()));
break;
case PaddleDType::FLOAT32:
tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and "
"FLOAT32."));
}
return array;
}
py::bytes SerializePDTensorToBytes(PaddleTensor &tensor) { // NOLINT
std::stringstream ss;
paddle::inference::SerializePDTensorToStream(&ss, tensor);
......@@ -200,8 +236,11 @@ void BindInferenceApi(py::module *m) {
BindNativePredictor(m);
BindAnalysisConfig(m);
BindAnalysisPredictor(m);
BindPaddleInferPredictor(m);
BindZeroCopyTensor(m);
BindPaddleInferTensor(m);
BindPaddlePassBuilder(m);
BindPredictorPool(m);
#ifdef PADDLE_WITH_MKLDNN
BindMkldnnQuantizerConfig(m);
#endif
......@@ -209,8 +248,17 @@ void BindInferenceApi(py::module *m) {
&paddle::CreatePaddlePredictor<AnalysisConfig>, py::arg("config"));
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<NativeConfig>, py::arg("config"));
m->def("create_predictor", [](const paddle_infer::Config &config)
-> std::unique_ptr<paddle_infer::Predictor> {
auto pred =
std::unique_ptr<paddle_infer::Predictor>(
new paddle_infer::Predictor(config));
return std::move(pred);
});
m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
m->def("get_version", &paddle_infer::GetVersion);
m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
}
namespace {
......@@ -525,6 +573,19 @@ void BindAnalysisPredictor(py::module *m) {
py::arg("dir"));
}
void BindPaddleInferPredictor(py::module *m) {
py::class_<paddle_infer::Predictor>(*m, "PaddleInferPredictor")
.def(py::init<const paddle_infer::Config &>())
.def("get_input_names", &paddle_infer::Predictor::GetInputNames)
.def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
.def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
.def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
.def("run", &paddle_infer::Predictor::Run)
.def("clone", &paddle_infer::Predictor::Clone)
.def("clear_intermediate_tensor",
&paddle_infer::Predictor::ClearIntermediateTensor);
}
void BindZeroCopyTensor(py::module *m) {
py::class_<ZeroCopyTensor>(*m, "ZeroCopyTensor")
.def("reshape", &ZeroCopyTensor::Reshape)
......@@ -538,6 +599,26 @@ void BindZeroCopyTensor(py::module *m) {
.def("type", &ZeroCopyTensor::type);
}
void BindPaddleInferTensor(py::module *m) {
py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
.def("reshape", &paddle_infer::Tensor::Reshape)
.def("copy_from_cpu", &PaddleInferTensorCreate<int32_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<int64_t>)
.def("copy_from_cpu", &PaddleInferTensorCreate<float>)
.def("copy_to_cpu", &PaddleInferTensorToNumpy)
.def("shape", &paddle_infer::Tensor::shape)
.def("set_lod", &paddle_infer::Tensor::SetLoD)
.def("lod", &paddle_infer::Tensor::lod)
.def("type", &paddle_infer::Tensor::type);
}
void BindPredictorPool(py::module *m) {
py::class_<paddle_infer::services::PredictorPool>(*m, "PredictorPool")
.def(py::init<const paddle_infer::Config &, size_t>())
.def("retrive", &paddle_infer::services::PredictorPool::Retrive,
py::return_value_policy::reference);
}
void BindPaddlePassBuilder(py::module *m) {
py::class_<PaddlePassBuilder>(*m, "PaddlePassBuilder")
.def(py::init<const std::vector<std::string> &>())
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .wrapper import Config, DataType, PlaceType, PrecisionType, Tensor, Predictor
from ..core import create_predictor, get_version, get_num_bytes_of_data_type, PredictorPool
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..core import AnalysisConfig, PaddleDType, PaddlePlace
from ..core import PaddleInferPredictor, PaddleInferTensor
DataType = PaddleDType
PlaceType = PaddlePlace
PrecisionType = AnalysisConfig.Precision
Config = AnalysisConfig
Tensor = PaddleInferTensor
Predictor = PaddleInferPredictor
......@@ -156,6 +156,7 @@ packages=['paddle',
'paddle.framework',
'paddle.jit',
'paddle.fluid',
'paddle.fluid.inference',
'paddle.fluid.dygraph',
'paddle.fluid.dygraph.dygraph_to_static',
'paddle.fluid.dygraph.amp',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册