未验证 提交 af3a0675 编写于 作者: H Huang Jiyi 提交者: GitHub

Add output defs for logical_xxx kernel (#51331)

* add output defs

* add output defs for kps
上级 2e3c6803
...@@ -90,10 +90,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -90,10 +90,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"layer_norm_grad", "layer_norm_grad",
"less_equal", "less_equal",
"less_than", "less_than",
"logical_and",
"logical_not",
"logical_or",
"logical_xor",
"lstsq", "lstsq",
"lu", "lu",
"matrix_nms", "matrix_nms",
......
...@@ -64,7 +64,9 @@ void LogicalNotKernel(const Context& dev_ctx, ...@@ -64,7 +64,9 @@ void LogicalNotKernel(const Context& dev_ctx,
int64_t, \ int64_t, \
int, \ int, \
int8_t, \ int8_t, \
int16_t) {} int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
REGISTER_LOGICAL_CPU_KERNEL(logical_and, And) REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or) REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
......
...@@ -64,10 +64,18 @@ void LogicalNotKernel(const Context& dev_ctx, ...@@ -64,10 +64,18 @@ void LogicalNotKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {} PD_REGISTER_KERNEL(logical_and, KPS, ALL_LAYOUT, phi::LogicalAndKernel, int) {
PD_REGISTER_KERNEL(logical_or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {} kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
PD_REGISTER_KERNEL(logical_not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {} }
PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {} PD_REGISTER_KERNEL(logical_or, KPS, ALL_LAYOUT, phi::LogicalOrKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(logical_not, KPS, ALL_LAYOUT, phi::LogicalNotKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#else #else
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \ #define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \ PD_REGISTER_KERNEL(logical_and, \
...@@ -80,7 +88,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {} ...@@ -80,7 +88,9 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {}
int64_t, \ int64_t, \
int, \ int, \
int8_t, \ int8_t, \
int16_t) {} int16_t) { \
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); \
}
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And) REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or) REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册