未验证 提交 1bf2ab48 编写于 作者: R ronnywang 提交者: GitHub

fix atan2 grad (#56067)

上级 826eaf72
......@@ -32,8 +32,12 @@ struct Atan2GradFunctor {
float x1 = static_cast<float>(x1_[idx]);
float x2 = static_cast<float>(x2_[idx]);
float x = x1 * x1 + x2 * x2;
dx1_[idx] = static_cast<T>(static_cast<float>(dout_[idx]) * x2 / x);
dx2_[idx] = static_cast<T>(-static_cast<float>(dout_[idx]) * x1 / x);
if (dx1_) {
dx1_[idx] = static_cast<T>(static_cast<float>(dout_[idx]) * x2 / x);
}
if (dx2_) {
dx2_[idx] = static_cast<T>(-static_cast<float>(dout_[idx]) * x1 / x);
}
}
const T* x1_;
......@@ -56,8 +60,12 @@ struct Atan2GradFunctor<double> {
HOSTDEVICE void operator()(int64_t idx) const {
auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx];
dx1_[idx] = dout_[idx] * x2_[idx] / x;
dx2_[idx] = -dout_[idx] * x1_[idx] / x;
if (dx1_) {
dx1_[idx] = dout_[idx] * x2_[idx] / x;
}
if (dx2_) {
dx2_[idx] = -dout_[idx] * x1_[idx] / x;
}
}
const double* x1_;
......@@ -81,9 +89,11 @@ void Atan2GradKernel(const Context& ctx,
auto out_grad_data = out_grad.data<T>();
auto* x_grad_data =
ctx.template Alloc<T>(x_grad, size_t(x.numel() * sizeof(T)));
x_grad ? ctx.template Alloc<T>(x_grad, size_t(x.numel() * sizeof(T)))
: nullptr;
auto* y_grad_data =
ctx.template Alloc<T>(y_grad, size_t(y.numel() * sizeof(T)));
y_grad ? ctx.template Alloc<T>(y_grad, size_t(y.numel() * sizeof(T)))
: nullptr;
phi::funcs::ForRange<Context> for_range(ctx, numel);
phi::Atan2GradFunctor<T> functor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册