diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 2317bfdd7c0d5ee94e91e081da47177625f5bfd8..bae49fb381a475dd8227d1dc855a6db28c9cd273 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -247,6 +247,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif #ifdef PADDLE_WITH_XPU_KP + expected_kernel_key.place_ = platform::XPUPlace(); bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc index b20e8ac9785cafea7e4f85fbfb9570d3cde5d1f5..073851433620130a3c3c6d256a4d6ca3b3f74555 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -111,6 +111,22 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) { } #endif +#ifdef PADDLE_WITH_XPU_KP +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version) { + std::vector res; + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 ? get_kl1_ops() + : get_kp_ops(); + if (ops.find(op_name) != ops.end()) { + XPUKernelSet& type_set = ops[op_name]; + for (auto& item : type_set) { + res.push_back(item.data_type_); + } + } + return res; +} +#endif + std::vector get_xpu_op_support_type( const std::string& op_name, phi::backends::xpu::XPUVersion version) { std::vector res; diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.h b/paddle/fluid/platform/device/xpu/xpu_op_list.h index 455a38e36fe0ad756021eb5ac23c012f65cc0c6a..60926dd9a5660ee13be7d61eb453740207994029 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.h @@ -31,6 +31,8 @@ bool is_in_xpu_black_list(const std::string& op_name); bool is_xpu_kp_support_op(const std::string& op_name, const pOpKernelType& type); bool is_in_xpu_kpwhite_list(const std::string& op_name); +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version); #endif std::vector get_xpu_op_support_type( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 566e38b7a21edb94ce3accf93b642cad690e9a5e..1c5b30fe087f3636a6a10579651d2c6a77a42343 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1957,10 +1957,17 @@ All parameter, weight, gradient are variables in Paddle. m.def("get_xpu_device_count", platform::GetXPUDeviceCount); m.def("get_xpu_device_version", [](int device_id) { return platform::get_xpu_version(device_id); }); +#ifdef PADDLE_WITH_XPU_KP + m.def("get_xpu_device_op_support_types", + [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { + return platform::get_xpu_kp_op_support_type(op_name, version); + }); +#else m.def("get_xpu_device_op_support_types", [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { return platform::get_xpu_op_support_type(op_name, version); }); +#endif m.def("get_xpu_device_op_list", [](phi::backends::xpu::XPUVersion version) { return platform::get_xpu_op_list(version); });