未验证 提交 c8746024 编写于 作者: W Wilber 提交者: GitHub

python support pass external stream (#46152)

* python support pass external stream

* fix compile error
上级 ed7bc2bd
......@@ -36,6 +36,10 @@
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/phi/core/cuda_stream.h"
#endif
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif
......@@ -542,7 +546,13 @@ void BindPaddlePredictor(py::module *m) {
.def("get_input_names", &PaddlePredictor::GetInputNames)
.def("get_output_names", &PaddlePredictor::GetOutputNames)
.def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
.def("clone", &PaddlePredictor::Clone)
.def("clone", [](PaddlePredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](PaddlePredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("get_serialized_program", &PaddlePredictor::GetSerializedProgram);
auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
......@@ -583,7 +593,13 @@ void BindNativePredictor(py::module *m) {
.def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
.def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
.def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
.def("clone", &NativePaddlePredictor::Clone)
.def("clone", [](NativePaddlePredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](NativePaddlePredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("scope",
&NativePaddlePredictor::scope,
py::return_value_policy::reference);
......@@ -626,6 +642,12 @@ void BindAnalysisConfig(py::module *m) {
&AnalysisConfig::EnableUseGpu,
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
self.SetExecStream(stream.raw_stream());
})
#endif
.def("enable_xpu",
&AnalysisConfig::EnableXpu,
py::arg("l3_workspace_size") = 16 * 1024 * 1024,
......@@ -874,7 +896,13 @@ void BindAnalysisPredictor(py::module *m) {
.def("analysis_argument",
&AnalysisPredictor::analysis_argument,
py::return_value_policy::reference)
.def("clone", &AnalysisPredictor::Clone)
.def("clone", [](AnalysisPredictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](AnalysisPredictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("scope",
&AnalysisPredictor::scope,
py::return_value_policy::reference)
......@@ -901,7 +929,13 @@ void BindPaddleInferPredictor(py::module *m) {
#endif
self.Run();
})
.def("clone", &paddle_infer::Predictor::Clone)
.def("clone", [](paddle_infer::Predictor &self) { self.Clone(nullptr); })
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("clone",
[](paddle_infer::Predictor &self, phi::CUDAStream &stream) {
self.Clone(stream.raw_stream());
})
#endif
.def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory)
.def("clear_intermediate_tensor",
&paddle_infer::Predictor::ClearIntermediateTensor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册