未验证 提交 ab3b87a6 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for atan2 kernel (#51312)

* fix atan2

* fix

* fix

* fix

* fix error

* fix error

* fix
上级 7dbd8b8c
...@@ -55,7 +55,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = { ...@@ -55,7 +55,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"adamw", "adamw",
"any_raw", "any_raw",
"arg_sort", "arg_sort",
"atan2",
"clip_by_norm", "clip_by_norm",
"eig_grad", "eig_grad",
"eigh", "eigh",
......
...@@ -166,6 +166,8 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { ...@@ -166,6 +166,8 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 || if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 ||
y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) { y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) {
out->set_dtype(DataType::FLOAT64); out->set_dtype(DataType::FLOAT64);
} else {
out->set_dtype(x.dtype());
} }
} }
......
...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2, ...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2,
double, double,
phi::dtype::float16, phi::dtype::float16,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2, ...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2,
double, double,
phi::dtype::float16, phi::dtype::float16,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册