未验证 提交 a74d7fb6 编写于 作者: H houj04 提交者: GitHub

add set-xpu-device-id function for inference config. (#35572)

上级 c192127b
......@@ -114,6 +114,14 @@ void AnalysisConfig::EnableXpu(int l3_workspace_size, bool locked,
Update();
}
void AnalysisConfig::SetXpuDeviceId(int device_id) {
PADDLE_ENFORCE_EQ(use_xpu_, true,
platform::errors::PreconditionNotMet(
"Should call EnableXpu before SetXpuDeviceId."));
xpu_device_id_ = device_id;
Update();
}
void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL
use_npu_ = true;
......
......@@ -323,6 +323,17 @@ TEST(AnalysisPredictor, bf16_pass_strategy) {
passStrategy.EnableMkldnnBfloat16();
}
#ifdef PADDLE_WITH_XPU
TEST(AnalysisPredictor, set_xpu_device_id) {
AnalysisConfig config;
config.EnableXpu();
config.SetXpuDeviceId(0);
ASSERT_EQ(config.xpu_device_id(), 0);
config.SetXpuDeviceId(1);
ASSERT_EQ(config.xpu_device_id(), 1);
}
#endif
} // namespace paddle
namespace paddle_infer {
......
......@@ -203,6 +203,12 @@ struct PD_INFER_DECL AnalysisConfig {
const std::string& precision = "int16",
bool adaptive_seqlen = false);
///
/// \brief Set XPU device id.
///
/// \param device_id the XPU card to use (default is 0).
///
void SetXpuDeviceId(int device_id = 0);
///
/// \brief Turn on NPU.
///
/// \param device_id device_id the NPU card to use (default is 0).
......
......@@ -523,6 +523,8 @@ void BindAnalysisConfig(py::module *m) {
py::arg("locked") = false, py::arg("autotune") = true,
py::arg("autotune_file") = "", py::arg("precision") = "int16",
py::arg("adaptive_seqlen") = false)
.def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId,
py::arg("device_id") = 0)
.def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
.def("disable_gpu", &AnalysisConfig::DisableGpu)
.def("use_gpu", &AnalysisConfig::use_gpu)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册