diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index f9c7be9cd4c2758567a05f2b38ce82b9bb22b5f7..ac540c75511ef29c840e258ada2171cb9ee0b262 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -114,6 +114,14 @@ void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked, Update(); } +void AnalysisConfig::SetXpuDeviceId(int device_id) { + PADDLE_ENFORCE_EQ(use_xpu_, true, + platform::errors::PreconditionNotMet( + "Should call EnableXpu before SetXpuDeviceId.")); + xpu_device_id_ = device_id; + Update(); +} + void AnalysisConfig::EnableNpu(int device_id) { #ifdef PADDLE_WITH_ASCEND_CL use_npu_ = true; diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 87af94a88d4b5fa68a6b962dd3722289da9377a4..513f3669a19ce9d760d53d5b1d2c94c4b0c55703 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -323,6 +323,17 @@ TEST(AnalysisPredictor, bf16_pass_strategy) { passStrategy.EnableMkldnnBfloat16(); } +#ifdef PADDLE_WITH_XPU +TEST(AnalysisPredictor, set_xpu_device_id) { + AnalysisConfig config; + config.EnableXpu(); + config.SetXpuDeviceId(0); + ASSERT_EQ(config.xpu_device_id(), 0); + config.SetXpuDeviceId(1); + ASSERT_EQ(config.xpu_device_id(), 1); +} +#endif + } // namespace paddle namespace paddle_infer { diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index dbdd0983b53088f8604d73a1f7699df24781209b..a64377f80f8aad7db5c2a5c4b91160354dde81e7 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -203,6 +203,12 @@ struct PD_INFER_DECL AnalysisConfig { const std::string& precision = "int16", bool adaptive_seqlen = false); /// + /// \brief Set XPU device id. + /// + /// \param device_id the XPU card to use (default is 0). + /// + void SetXpuDeviceId(int device_id = 0); + /// /// \brief Turn on NPU. /// /// \param device_id device_id the NPU card to use (default is 0). diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 87986aebe049f29e3abbd5afd9d7c7fae5aeb591..b1a91cd302187ed6c43d099c21277dc6b5d89214 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -523,6 +523,8 @@ void BindAnalysisConfig(py::module *m) { py::arg("locked") = false, py::arg("autotune") = true, py::arg("autotune_file") = "", py::arg("precision") = "int16", py::arg("adaptive_seqlen") = false) + .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId, + py::arg("device_id") = 0) .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0) .def("disable_gpu", &AnalysisConfig::DisableGpu) .def("use_gpu", &AnalysisConfig::use_gpu)