From bd4dc3be34584f9b273ecec07297fb05e1cf4c52 Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Thu, 10 Mar 2022 20:10:16 +0800 Subject: [PATCH] solve unexecuted UT (#40397) --- paddle/fluid/imperative/prepared_operator.cc | 1 + paddle/fluid/platform/device/xpu/xpu_op_list.cc | 16 ++++++++++++++++ paddle/fluid/platform/device/xpu/xpu_op_list.h | 2 ++ paddle/fluid/pybind/pybind.cc | 7 +++++++ 4 files changed, 26 insertions(+) diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 2317bfdd7c..bae49fb381 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -247,6 +247,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif #ifdef PADDLE_WITH_XPU_KP + expected_kernel_key.place_ = platform::XPUPlace(); bool use_xpu_kp_kernel_rt = FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.cc b/paddle/fluid/platform/device/xpu/xpu_op_list.cc index b20e8ac978..0738514336 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.cc +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.cc @@ -111,6 +111,22 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) { } #endif +#ifdef PADDLE_WITH_XPU_KP +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version) { + std::vector res; + auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 ? get_kl1_ops() + : get_kp_ops(); + if (ops.find(op_name) != ops.end()) { + XPUKernelSet& type_set = ops[op_name]; + for (auto& item : type_set) { + res.push_back(item.data_type_); + } + } + return res; +} +#endif + std::vector get_xpu_op_support_type( const std::string& op_name, phi::backends::xpu::XPUVersion version) { std::vector res; diff --git a/paddle/fluid/platform/device/xpu/xpu_op_list.h b/paddle/fluid/platform/device/xpu/xpu_op_list.h index 455a38e36f..60926dd9a5 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_list.h @@ -31,6 +31,8 @@ bool is_in_xpu_black_list(const std::string& op_name); bool is_xpu_kp_support_op(const std::string& op_name, const pOpKernelType& type); bool is_in_xpu_kpwhite_list(const std::string& op_name); +std::vector get_xpu_kp_op_support_type( + const std::string& op_name, phi::backends::xpu::XPUVersion version); #endif std::vector get_xpu_op_support_type( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 566e38b7a2..1c5b30fe08 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1957,10 +1957,17 @@ All parameter, weight, gradient are variables in Paddle. m.def("get_xpu_device_count", platform::GetXPUDeviceCount); m.def("get_xpu_device_version", [](int device_id) { return platform::get_xpu_version(device_id); }); +#ifdef PADDLE_WITH_XPU_KP + m.def("get_xpu_device_op_support_types", + [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { + return platform::get_xpu_kp_op_support_type(op_name, version); + }); +#else m.def("get_xpu_device_op_support_types", [](const std::string &op_name, phi::backends::xpu::XPUVersion version) { return platform::get_xpu_op_support_type(op_name, version); }); +#endif m.def("get_xpu_device_op_list", [](phi::backends::xpu::XPUVersion version) { return platform::get_xpu_op_list(version); }); -- GitLab