未验证 提交 ab3b87a6 编写于 作者: I Infinity_lee 提交者: GitHub

add output defs for atan2 kernel (#51312)

* fix atan2

* fix

* fix

* fix

* fix error

* fix error

* fix
上级 7dbd8b8c
......@@ -55,7 +55,6 @@ static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"adamw",
"any_raw",
"arg_sort",
"atan2",
"clip_by_norm",
"eig_grad",
"eigh",
......
......@@ -166,6 +166,8 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
if (x.dtype() == DataType::INT32 || x.dtype() == DataType::INT64 ||
y.dtype() == DataType::INT32 || y.dtype() == DataType::INT64) {
out->set_dtype(DataType::FLOAT64);
} else {
out->set_dtype(x.dtype());
}
}
......
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(atan2,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册