未验证 提交 707d838b 编写于 作者: W Wilber 提交者: GitHub

cherry-pick 46152 (#46183)

上级 adab3c59
...@@ -36,6 +36,10 @@ ...@@ -36,6 +36,10 @@
#include "paddle/fluid/inference/api/paddle_pass_builder.h" #include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/inference/utils/io_utils.h" #include "paddle/fluid/inference/utils/io_utils.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/core/cuda_stream.h"
#endif
#ifdef PADDLE_WITH_ONNXRUNTIME #ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h" #include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif #endif
...@@ -542,7 +546,13 @@ void BindPaddlePredictor(py::module *m) { ...@@ -542,7 +546,13 @@ void BindPaddlePredictor(py::module *m) {
.def("get_input_names", &PaddlePredictor::GetInputNames) .def("get_input_names", &PaddlePredictor::GetInputNames)
.def("get_output_names", &PaddlePredictor::GetOutputNames) .def("get_output_names", &PaddlePredictor::GetOutputNames)
.def("zero_copy_run", &PaddlePredictor::ZeroCopyRun) .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
.def("clone", &PaddlePredictor::Clone) .def("clone", [](PaddlePredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](PaddlePredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("get_serialized_program", &PaddlePredictor::GetSerializedProgram); .def("get_serialized_program", &PaddlePredictor::GetSerializedProgram);
auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config"); auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
...@@ -583,7 +593,13 @@ void BindNativePredictor(py::module *m) { ...@@ -583,7 +593,13 @@ void BindNativePredictor(py::module *m) {
.def("get_input_tensor", &NativePaddlePredictor::GetInputTensor) .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
.def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor) .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
.def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun) .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
.def("clone", &NativePaddlePredictor::Clone) .def("clone", [](NativePaddlePredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](NativePaddlePredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("scope", .def("scope",
&NativePaddlePredictor::scope, &NativePaddlePredictor::scope,
py::return_value_policy::reference); py::return_value_policy::reference);
...@@ -626,6 +642,12 @@ void BindAnalysisConfig(py::module *m) { ...@@ -626,6 +642,12 @@ void BindAnalysisConfig(py::module *m) {
&AnalysisConfig::EnableUseGpu, &AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"), py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0) py::arg("device_id") = 0)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
self.SetExecStream(stream.raw_stream());
})
#endif
.def("enable_xpu", .def("enable_xpu",
&AnalysisConfig::EnableXpu, &AnalysisConfig::EnableXpu,
py::arg("l3_workspace_size") = 16 * 1024 * 1024, py::arg("l3_workspace_size") = 16 * 1024 * 1024,
...@@ -874,7 +896,13 @@ void BindAnalysisPredictor(py::module *m) { ...@@ -874,7 +896,13 @@ void BindAnalysisPredictor(py::module *m) {
.def("analysis_argument", .def("analysis_argument",
&AnalysisPredictor::analysis_argument, &AnalysisPredictor::analysis_argument,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("clone", &AnalysisPredictor::Clone) .def("clone", [](AnalysisPredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](AnalysisPredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("scope", .def("scope",
&AnalysisPredictor::scope, &AnalysisPredictor::scope,
py::return_value_policy::reference) py::return_value_policy::reference)
...@@ -901,7 +929,13 @@ void BindPaddleInferPredictor(py::module *m) { ...@@ -901,7 +929,13 @@ void BindPaddleInferPredictor(py::module *m) {
#endif #endif
self.Run(); self.Run();
}) })
.def("clone", &paddle_infer::Predictor::Clone) .def("clone", [](paddle_infer::Predictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](paddle_infer::Predictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory) .def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory)
.def("clear_intermediate_tensor", .def("clear_intermediate_tensor",
&paddle_infer::Predictor::ClearIntermediateTensor); &paddle_infer::Predictor::ClearIntermediateTensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册