未验证 提交 bd4dc3be 编写于 作者: L Lijunhui 提交者: GitHub

solve unexecuted UT (#40397)

上级 befa78ea
...@@ -247,6 +247,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins, ...@@ -247,6 +247,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif #endif
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
expected_kernel_key.place_ = platform::XPUPlace();
bool use_xpu_kp_kernel_rt = bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel && FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key); paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
......
...@@ -111,6 +111,22 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) { ...@@ -111,6 +111,22 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) {
} }
#endif #endif
#ifdef PADDLE_WITH_XPU_KP
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version) {
std::vector<vartype::Type> res;
auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 ? get_kl1_ops()
: get_kp_ops();
if (ops.find(op_name) != ops.end()) {
XPUKernelSet& type_set = ops[op_name];
for (auto& item : type_set) {
res.push_back(item.data_type_);
}
}
return res;
}
#endif
std::vector<vartype::Type> get_xpu_op_support_type( std::vector<vartype::Type> get_xpu_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version) { const std::string& op_name, phi::backends::xpu::XPUVersion version) {
std::vector<vartype::Type> res; std::vector<vartype::Type> res;
......
...@@ -31,6 +31,8 @@ bool is_in_xpu_black_list(const std::string& op_name); ...@@ -31,6 +31,8 @@ bool is_in_xpu_black_list(const std::string& op_name);
bool is_xpu_kp_support_op(const std::string& op_name, bool is_xpu_kp_support_op(const std::string& op_name,
const pOpKernelType& type); const pOpKernelType& type);
bool is_in_xpu_kpwhite_list(const std::string& op_name); bool is_in_xpu_kpwhite_list(const std::string& op_name);
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version);
#endif #endif
std::vector<vartype::Type> get_xpu_op_support_type( std::vector<vartype::Type> get_xpu_op_support_type(
......
...@@ -1957,10 +1957,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1957,10 +1957,17 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_xpu_device_count", platform::GetXPUDeviceCount); m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version", m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); }); [](int device_id) { return platform::get_xpu_version(device_id); });
#ifdef PADDLE_WITH_XPU_KP
m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_kp_op_support_type(op_name, version);
});
#else
m.def("get_xpu_device_op_support_types", m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, phi::backends::xpu::XPUVersion version) { [](const std::string &op_name, phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_op_support_type(op_name, version); return platform::get_xpu_op_support_type(op_name, version);
}); });
#endif
m.def("get_xpu_device_op_list", [](phi::backends::xpu::XPUVersion version) { m.def("get_xpu_device_op_list", [](phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_op_list(version); return platform::get_xpu_op_list(version);
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册