From 6074c50afa05950ed65b9ba7f4f080f9cc331f94 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 7 Nov 2022 10:23:17 +0800 Subject: [PATCH] [cusotm device] add python inference api, test=develop (#46460) --- paddle/fluid/inference/api/paddle_analysis_config.h | 2 +- paddle/fluid/pybind/inference_api.cc | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 0ed5380e67..f9587c0995 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -367,7 +367,7 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \param device_id device_id the custom device to use (default is 0). /// - void EnableCustomDevice(const std::string& device_type, int device_id); + void EnableCustomDevice(const std::string& device_type, int device_id = 0); /// /// \brief Turn on ONNXRuntime. /// diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 9b99cad869..60f1bfd921 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -662,6 +662,10 @@ void BindAnalysisConfig(py::module *m) { .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, py::arg("device_id") = 0) + .def("enable_custom_device", + &AnalysisConfig::EnableCustomDevice, + py::arg("device_type"), + py::arg("device_id") = 0) .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0) .def("enable_ipu", &AnalysisConfig::EnableIpu, -- GitLab