未验证 提交 37dbbbd1 编写于 作者: R Ryan 提交者: GitHub

w/o pre-commit (#51315)

上级 39a1ab69
...@@ -58,8 +58,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -58,8 +58,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"angle", "angle",
"any_raw", "any_raw",
"arg_sort", "arg_sort",
"argmax",
"argmin",
"as_real", "as_real",
"atan2", "atan2",
"auc", "auc",
......
...@@ -195,7 +195,9 @@ PD_REGISTER_KERNEL(argmin, ...@@ -195,7 +195,9 @@ PD_REGISTER_KERNEL(argmin,
int32_t, int32_t,
int64_t, int64_t,
int16_t, int16_t,
uint8_t) {} uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(argmax, PD_REGISTER_KERNEL(argmax,
CPU, CPU,
...@@ -206,4 +208,6 @@ PD_REGISTER_KERNEL(argmax, ...@@ -206,4 +208,6 @@ PD_REGISTER_KERNEL(argmax,
int32_t, int32_t,
int64_t, int64_t,
int16_t, int16_t,
uint8_t) {} uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -266,7 +266,9 @@ PD_REGISTER_KERNEL(argmin, ...@@ -266,7 +266,9 @@ PD_REGISTER_KERNEL(argmin,
int32_t, int32_t,
int64_t, int64_t,
int16_t, int16_t,
uint8_t) {} uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(argmax, PD_REGISTER_KERNEL(argmax,
GPU, GPU,
...@@ -279,4 +281,6 @@ PD_REGISTER_KERNEL(argmax, ...@@ -279,4 +281,6 @@ PD_REGISTER_KERNEL(argmax,
int32_t, int32_t,
int64_t, int64_t,
int16_t, int16_t,
uint8_t) {} uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -74,4 +74,6 @@ void ArgMaxKernel(const Context& dev_ctx, ...@@ -74,4 +74,6 @@ void ArgMaxKernel(const Context& dev_ctx,
XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {} PD_REGISTER_KERNEL(argmax, XPU, ALL_LAYOUT, phi::ArgMaxKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册