未验证 提交 af3a0675 编写于 作者: H Huang Jiyi 提交者: GitHub

Add output defs for logical_xxx kernel (#51331)

* add output defs

* add output defs for kps
上级 2e3c6803
......@@ -90,10 +90,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"layer_norm_grad",
"less_equal",
"less_than",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"lstsq",
"lu",
"matrix_nms",
......
......@@ -64,7 +64,9 @@ void LogicalNotKernel(const Context& dev_ctx,
int64_t, \
int, \
int8_t, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
......
......@@ -64,10 +64,18 @@ void LogicalNotKernel(const Context& dev_ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {}
PD_REGISTER_KERNEL(logical_or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {}
PD_REGISTER_KERNEL(logical_not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {}
PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {}
PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(logical_or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(logical_not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#else
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \
......@@ -80,7 +88,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {}
int64_t, \
int, \
int8_t, \
int16_t) {}
int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册