From c642aa1747afd03d961074539d9e6207cace5e2f Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:44:06 +0800 Subject: [PATCH] add_triple_grad rules (#54164) --- .../generator/eager_gen.py | 1 + .../elementwise/elementwise_add_op.cc | 38 ++++ .../composite_backward_api.h | 117 ------------ .../composite_double_backward_api.h | 170 ++++++++++++++++++ paddle/phi/api/yaml/legacy_backward.yaml | 1 + 5 files changed, 210 insertions(+), 117 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 709372dd98e..ec765cd170e 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -67,6 +67,7 @@ black_ops_list = [ prim_white_list = [ "matmul_double_grad", "subtract_double_grad", + "add_triple_grad", "silu_double_grad", ] diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 91ed63b3e2e..f8817faa791 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -154,6 +154,44 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseAddCompositeTripleGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // get input + paddle::Tensor ddx = this->GetSingleForwardInput("DDX"); + paddle::Tensor ddy = this->GetSingleForwardInput("DDY"); + paddle::Tensor d_ddout = this->GetSingleOutputGrad("DDOut"); + + // get output + paddle::Tensor grad_grad_x_t = + this->GetSingleInputGrad(framework::GradVarName("DDX")); + paddle::Tensor grad_grad_y_t = + this->GetSingleInputGrad(framework::GradVarName("DDY")); + // get attr + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument("We only support axis = -1 in composite " + "add_triple_grad but we got: ", + axis)); + + paddle::Tensor* grad_grad_x = this->GetOutputPtr(&grad_grad_x_t); + std::string grad_grad_x_name = this->GetOutputName(grad_grad_x_t); + paddle::Tensor* grad_grad_y = this->GetOutputPtr(&grad_grad_y_t); + std::string grad_grad_y_name = this->GetOutputName(grad_grad_y_t); + + VLOG(6) << "Runing add_triple_grad composite func"; + prim::add_triple_grad( + ddx, ddy, d_ddout, axis, grad_grad_x, grad_grad_y); + this->RecoverOutputName(grad_grad_x_t, grad_grad_x_name); + this->RecoverOutputName(grad_grad_y_t, grad_grad_y_name); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 737ec0ce6e4..4d2e31ebd4f 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -234,30 +234,6 @@ void subtract_grad(const Tensor& x, } } -template -void subtract_double_grad(const Tensor& y, - const Tensor& grad_out, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - int axis, - Tensor* grad_out_grad) { - if (grad_out_grad) { - // ddout = ddx - ddy - if (!grad_x_grad && !grad_y_grad) { - grad_out_grad = nullptr; - } else { - Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); - if (grad_x_grad) { - ddout = ddout + grad_x_grad.get(); - } - if (grad_y_grad) { - ddout = ddout - grad_y_grad.get(); - } - set_output(ddout, grad_out_grad); - } - } -} - template void add_grad(const Tensor& x, const Tensor& y, @@ -300,30 +276,6 @@ void add_grad(const Tensor& x, } } -template -void add_double_grad(const Tensor& y, - const Tensor& grad_out, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - int axis, - Tensor* grad_out_grad) { - if (grad_out_grad) { - // ddout = ddx + ddy - if (!grad_x_grad && !grad_y_grad) { - grad_out_grad = nullptr; - } else { - Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); - if (grad_x_grad) { - ddout = ddout + grad_x_grad.get(); - } - if (grad_y_grad) { - ddout = ddout + grad_y_grad.get(); - } - set_output(ddout, grad_out_grad); - } - } -} - template void sum_grad(const Tensor& x, const Tensor& out_grad, @@ -555,75 +507,6 @@ void multiply_grad(const Tensor& x, } } -template -void multiply_double_grad(const Tensor& x, - const Tensor& y, - const Tensor& grad_out, - const paddle::optional& grad_x_grad, - const paddle::optional& grad_y_grad, - int axis, - Tensor* x_grad, - Tensor* y_grad, - Tensor* grad_out_grad) { - if (x_grad) { - if (grad_y_grad) { - auto dx = grad_y_grad.get() * grad_out; - if (dx.dims() != x.dims()) { - auto axes = get_reduce_dims_from_out(dx.dims(), x.dims()); - if (!axes.size()) { - set_output(dx, x_grad); - } else { - auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false); - if (dx_reduce.dims().size() != x.dims().size()) { - dx_reduce = reshape(dx_reduce, x.shape()); - } - set_output(dx_reduce, x_grad); - } - } else { - set_output(dx, x_grad); - } - - } else { - x_grad = nullptr; - } - } - if (y_grad) { - if (grad_x_grad) { - auto dy = grad_x_grad.get() * grad_out; - if (dy.dims() != y.dims()) { - auto axes = get_reduce_dims_from_out(dy.dims(), y.dims()); - if (!axes.size()) { - set_output(dy, y_grad); - } else { - auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false); - if (dy_reduce.dims().size() != y.dims().size()) { - dy_reduce = reshape(dy_reduce, y.shape()); - } - set_output(dy_reduce, y_grad); - } - } else { - set_output(dy, y_grad); - } - } else { - y_grad = nullptr; - } - } - if (grad_out_grad) { - if (grad_x_grad && grad_y_grad) { - auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; - set_output(ddout, grad_out_grad); - } else if (grad_x_grad) { - auto ddout = grad_x_grad.get() * y; - set_output(ddout, grad_out_grad); - } else if (grad_y_grad) { - auto ddout = grad_y_grad.get() * x; - set_output(ddout, grad_out_grad); - } else { - grad_out_grad = nullptr; - } - } -} - template void expand_grad(const Tensor& x, const Tensor& out_grad, diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index a97d4ff0bc0..a7479bc7166 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -383,5 +383,175 @@ void silu_double_grad(const Tensor& x, } } +template +void multiply_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + if (x_grad) { + if (grad_y_grad) { + auto dx = grad_y_grad.get() * grad_out; + if (dx.dims() != x.dims()) { + auto axes = get_reduce_dims_from_out(dx.dims(), x.dims()); + if (!axes.size()) { + set_output(dx, x_grad); + } else { + auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false); + if (dx_reduce.dims().size() != x.dims().size()) { + dx_reduce = reshape(dx_reduce, x.shape()); + } + set_output(dx_reduce, x_grad); + } + } else { + set_output(dx, x_grad); + } + + } else { + x_grad = nullptr; + } + } + if (y_grad) { + if (grad_x_grad) { + auto dy = grad_x_grad.get() * grad_out; + if (dy.dims() != y.dims()) { + auto axes = get_reduce_dims_from_out(dy.dims(), y.dims()); + if (!axes.size()) { + set_output(dy, y_grad); + } else { + auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false); + if (dy_reduce.dims().size() != y.dims().size()) { + dy_reduce = reshape(dy_reduce, y.shape()); + } + set_output(dy_reduce, y_grad); + } + } else { + set_output(dy, y_grad); + } + } else { + y_grad = nullptr; + } + } + if (grad_out_grad) { + if (grad_x_grad && grad_y_grad) { + auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; + set_output(ddout, grad_out_grad); + } else if (grad_x_grad) { + auto ddout = grad_x_grad.get() * y; + set_output(ddout, grad_out_grad); + } else if (grad_y_grad) { + auto ddout = grad_y_grad.get() * x; + set_output(ddout, grad_out_grad); + } else { + grad_out_grad = nullptr; + } + } +} + +template +void add_double_grad(const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* grad_out_grad) { + if (grad_out_grad) { + // ddout = ddx + ddy + if (!grad_x_grad && !grad_y_grad) { + grad_out_grad = nullptr; + } else { + Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); + if (grad_x_grad) { + ddout = ddout + grad_x_grad.get(); + } + if (grad_y_grad) { + ddout = ddout + grad_y_grad.get(); + } + set_output(ddout, grad_out_grad); + } + } +} + +template +void add_triple_grad(const paddle::optional& grad_grad_x, + const paddle::optional& grad_grad_y, + const Tensor& grad_grad_out_grad, + int axis, + Tensor* grad_grad_x_grad, + Tensor* grad_grad_y_grad) { + if (grad_grad_y_grad) { + if (grad_grad_y) { + if (grad_grad_y.get().dims() != grad_grad_out_grad.dims()) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(grad_grad_y.get().dims(), + grad_grad_out_grad.dims()); + if (!reduce_dim.size()) { + by_pass(grad_grad_out_grad, grad_grad_y_grad); + } else { + auto dddy_reduce_res = grad_grad_out_grad.sum( + phi::vectorize(reduce_dim), grad_grad_y.get().dtype(), false); + auto dddy_tmp = reshape(dddy_reduce_res, + phi::vectorize(grad_grad_y.get().dims())); + set_output(dddy_tmp, grad_grad_y_grad); + } + } else { + by_pass(grad_grad_out_grad, grad_grad_y_grad); + } + } else { + grad_grad_y_grad = nullptr; + } + } + if (grad_grad_x_grad) { + if (grad_grad_x) { + if (grad_grad_x.get().dims() != grad_grad_out_grad.dims()) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(grad_grad_x.get().dims(), + grad_grad_out_grad.dims()); + if (!reduce_dim.size()) { + by_pass(grad_grad_out_grad, grad_grad_x_grad); + } else { + auto dddx_reduce_res = grad_grad_out_grad.sum( + phi::vectorize(reduce_dim), grad_grad_x.get().dtype(), false); + auto dddx_tmp = reshape(dddx_reduce_res, + phi::vectorize(grad_grad_x.get().dims())); + set_output(dddx_tmp, grad_grad_x_grad); + } + } else { + by_pass(grad_grad_out_grad, grad_grad_x_grad); + } + } else { + grad_grad_x_grad = nullptr; + } + } +} + +template +void subtract_double_grad(const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* grad_out_grad) { + if (grad_out_grad) { + // ddout = ddx - ddy + if (!grad_x_grad && !grad_y_grad) { + grad_out_grad = nullptr; + } else { + Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); + if (grad_x_grad) { + ddout = ddout + grad_x_grad.get(); + } + if (grad_y_grad) { + ddout = ddout - grad_y_grad.get(); + } + set_output(ddout, grad_out_grad); + } + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b312fa2658e..07547bbb897 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -36,6 +36,7 @@ kernel : func : add_triple_grad inplace : (grad_grad_out_grad -> grad_grad_x_grad) + composite : add_triple_grad (grad_grad_x, grad_grad_y, grad_grad_out_grad, axis, grad_grad_x_grad, grad_grad_y_grad ) - backward_op : amax_grad forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out) -- GitLab