From 127d2664cdc887316dba4588b4c05bd2f0c8b55e Mon Sep 17 00:00:00 2001 From: wangruting Date: Wed, 18 Jan 2023 07:02:34 +0000 Subject: [PATCH] merge develop --- .../manual/backward/composite_backward_api.h | 199 ++++++++++++++---- 1 file changed, 154 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 31e09b34f16..c148fca37bf 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" #include "paddle/fluid/prim/api/manual/utils/utils.h" #include "paddle/phi/common/int_array.h" @@ -23,16 +24,17 @@ namespace prim { using Tensor = paddle::experimental::Tensor; using IntArray = paddle::experimental::IntArrayBase; -// using IntArray = paddle::experimental::IntArray; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { + if (!grad_x) return; auto tmp = pow(out, 2.0); tmp = scale(tmp, -1.0, 1.0, true); auto grad_x_tmp = multiply(grad_out, tmp); - set_output(grad_x_tmp, grad_x); + set_output(grad_x_tmp.impl(), grad_x); } + template void subtract_grad(const Tensor& x, const Tensor& y, @@ -42,26 +44,33 @@ void subtract_grad(const Tensor& x, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - set_output(dy_tmp, dy); - + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + by_pass(scale_out_grad, dy); + } else { + auto dy_reduce_res = sum( + scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp.impl(), dy); + } } else { by_pass(scale_out_grad, dy); } } if (dx) { - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - set_output(dx_tmp, dx); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dx); + } else { + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp.impl(), dx); + } } else { by_pass(out_grad, dx); } @@ -76,25 +85,34 @@ void add_grad(const Tensor& x, Tensor* dx, Tensor* dy) { if (dy) { - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - set_output(dy_tmp, dy); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dy); + } else { + auto dy_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp.impl(), dy); + } + } else { by_pass(out_grad, dy); } } if (dx) { - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - set_output(dx_tmp, dx); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dx); + } else { + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp.impl(), dx); + } } else { by_pass(out_grad, dx); } @@ -131,11 +149,12 @@ void sum_grad(const Tensor& x, axis_ = axis.GetData(); } auto out_grad_ = unsqueeze(out_grad, axis_); - x_grad_tmp = expand(out_grad_, x_dim); + x_grad_tmp = expand(out_grad_, IntArray(x_dim)); } else { - x_grad_tmp = expand(out_grad, x_dim); + x_grad_tmp = expand(out_grad, IntArray(x_dim)); } - set_output(x_grad_tmp, x_grad); + + set_output(x_grad_tmp.impl(), x_grad); } template @@ -152,15 +171,19 @@ void divide_grad(const Tensor& x, auto tmp1 = divide(x, tmp0); auto tmp2 = scale(tmp1, -1.0, 0.0, true); auto dy_res = multiply(tmp2, out_grad); - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - set_output(dy_tmp, dy); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + set_output(dy_res.impl(), dy); + } else { + auto dy_reduce_res = + sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + set_output(dy_tmp.impl(), dy); + } } else { - set_output(dy_res, dy); + set_output(dy_res.impl(), dy); } } // indicate we will compute dy if (dx) { @@ -168,15 +191,20 @@ void divide_grad(const Tensor& x, auto one_tensor = full(phi::vectorize(y.dims()), 1.0); auto tmp0 = divide(one_tensor, y); auto dx_res = multiply(tmp0, out_grad); - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - set_output(dx_tmp, dx); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + set_output(dx_res.impl(), dx); + } else { + auto dx_reduce_res = + sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + set_output(dx_tmp.impl(), dx); + } + } else { - set_output(dx_res, dx); + set_output(dx_res.impl(), dx); } } // indicate we will compute dx } @@ -187,8 +215,89 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { auto div_x = full(phi::vectorize(out.dims()), 0.5); auto tmp = divide(div_x, out); auto x_grad_tmp = multiply(out_grad, tmp); - set_output(x_grad_tmp, x_grad); + set_output(x_grad_tmp.impl(), x_grad); } } + +template +void multiply_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + auto x_grad_unreduce = multiply(out_grad, y); + if (x.dims() != y.dims()) { + auto axes = get_reduce_dims(x.dims(), y.dims()); + if (!axes.size()) { + set_output(x_grad_unreduce.impl(), x_grad); + } else { + auto x_grad_reduced = sum(x_grad_unreduce, + phi::vectorize(axes), + x_grad_unreduce.dtype(), + false); + if (x_grad_reduced.dims().size() != x.dims().size()) { + x_grad_reduced = reshape(x_grad_reduced, x.shape()); + } + set_output(x_grad_reduced.impl(), x_grad); + } + } else { + set_output(x_grad_unreduce.impl(), x_grad); + } + } + if (y_grad) { + auto y_grad_unreduce = multiply(out_grad, x); + if (y.dims() != x.dims()) { + auto axes = get_reduce_dims(y.dims(), x.dims()); + if (!axes.size()) { + set_output(y_grad_unreduce.impl(), y_grad); + } else { + auto y_grad_reduced = sum(y_grad_unreduce, + phi::vectorize(axes), + y_grad_unreduce.dtype(), + false); + if (y_grad_reduced.dims().size() != y.dims().size()) { + y_grad_reduced = reshape(y_grad_reduced, y.shape()); + } + set_output(y_grad_reduced.impl(), y_grad); + } + } else { + set_output(y_grad_unreduce.impl(), y_grad); + } + } +} + +template +void expand_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shape, + Tensor* x_grad) { + if (x_grad) { + auto out_dims = phi::make_ddim(shape.GetData()); + if (out_dims != x.dims()) { + auto axes = get_reduce_dims(x.dims(), out_dims); + if (!axes.size()) { + by_pass(out_grad, x_grad); + } else { + auto reduced = sum(out_grad, phi::vectorize(axes), x.dtype(), false); + if (reduced.dims().size() != x.dims().size()) { + reduced = reshape(reduced, x.shape()); + } + set_output(reduced.impl(), x_grad); + } + } else { + by_pass(out_grad, x_grad); + } + } +} + +template +void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + set_output(multiply(out_grad, out).impl(), x_grad); + } +} + } // namespace prim } // namespace paddle -- GitLab