未验证 提交 bd4dc3be 编写于 作者: L Lijunhui 提交者: GitHub

solve unexecuted UT (#40397)

上级 befa78ea
......@@ -247,6 +247,7 @@ PreparedOp PrepareImpl(const NameVarMap<VarType>& ins,
#endif
#ifdef PADDLE_WITH_XPU_KP
expected_kernel_key.place_ = platform::XPUPlace();
bool use_xpu_kp_kernel_rt =
FLAGS_run_kp_kernel &&
paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
......
......@@ -111,6 +111,22 @@ bool is_in_xpu_kpwhite_list(const std::string& op_name) {
}
#endif
#ifdef PADDLE_WITH_XPU_KP
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version) {
std::vector<vartype::Type> res;
auto& ops = version == phi::backends::xpu::XPUVersion::XPU1 ? get_kl1_ops()
: get_kp_ops();
if (ops.find(op_name) != ops.end()) {
XPUKernelSet& type_set = ops[op_name];
for (auto& item : type_set) {
res.push_back(item.data_type_);
}
}
return res;
}
#endif
std::vector<vartype::Type> get_xpu_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version) {
std::vector<vartype::Type> res;
......
......@@ -31,6 +31,8 @@ bool is_in_xpu_black_list(const std::string& op_name);
bool is_xpu_kp_support_op(const std::string& op_name,
const pOpKernelType& type);
bool is_in_xpu_kpwhite_list(const std::string& op_name);
std::vector<vartype::Type> get_xpu_kp_op_support_type(
const std::string& op_name, phi::backends::xpu::XPUVersion version);
#endif
std::vector<vartype::Type> get_xpu_op_support_type(
......
......@@ -1957,10 +1957,17 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_xpu_device_count", platform::GetXPUDeviceCount);
m.def("get_xpu_device_version",
[](int device_id) { return platform::get_xpu_version(device_id); });
#ifdef PADDLE_WITH_XPU_KP
m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_kp_op_support_type(op_name, version);
});
#else
m.def("get_xpu_device_op_support_types",
[](const std::string &op_name, phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_op_support_type(op_name, version);
});
#endif
m.def("get_xpu_device_op_list", [](phi::backends::xpu::XPUVersion version) {
return platform::get_xpu_op_list(version);
});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册