From 05499c71355f0f83391d824870da909bba37988a Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 28 Apr 2023 09:19:06 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91comp=5Felementwise=5Fdoub?= =?UTF-8?q?le=5Fgrad=20(first=20part)=20(#53385)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add mul doubel grad * add sub_double_grad * add add sub high test * add mutiply test * modify other unsqueeze * delete api.yaml * only for make ci run * midify unsqueeze * modify unsqueeze * tmp * modify operants gen --- cmake/external/cinn.cmake | 2 +- .../generator/eager_gen.py | 3 + .../elementwise/elementwise_add_op.cc | 39 +- .../elementwise/elementwise_mul_op.cc | 53 ++- .../elementwise/elementwise_sub_op.cc | 40 ++- paddle/fluid/prim/api/api.yaml | 1 - .../composite_backward_api.h | 143 +++++++- .../fluid/prim/api/manual_prim/utils/utils.h | 18 + paddle/phi/api/yaml/backward.yaml | 13 - paddle/phi/api/yaml/legacy_backward.yaml | 27 +- paddle/phi/api/yaml/op_compat.yaml | 2 +- paddle/phi/api/yaml/tensor_operants.yaml | 1 - python/paddle/fluid/backward.py | 12 +- .../unittests/test_activation_nn_grad.py | 8 + test/prim/prim/vjp/CMakeLists.txt | 2 + test/prim/prim/vjp/test_comp_high_grad.py | 334 ++++++++++++++++++ 16 files changed, 640 insertions(+), 58 deletions(-) create mode 100644 test/prim/prim/vjp/test_comp_high_grad.py diff --git a/cmake/external/cinn.cmake b/cmake/external/cinn.cmake index c41e094a25e..7d494ef516c 100644 --- a/cmake/external/cinn.cmake +++ b/cmake/external/cinn.cmake @@ -20,7 +20,7 @@ if(NOT CINN_GIT_TAG) set(CINN_GIT_TAG develop) endif() -message(STATUS "CINN version: " ${CINN_GIT_TAG}) +message(STATUS "CINN version: " ${CINN_GIT_TAG}) # TODO(zhhsplendid): CINN has lots of warnings during early development. # They will be treated as errors under paddle. We set no-error now and we will diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index c5db4aee73d..f69b5944c56 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -67,6 +67,9 @@ black_ops_list = [ prim_white_list = [ "matmul_double_grad", "tanh_double_grad", + "add_double_grad", + "multiply_double_grad", + "subtract_double_grad", ] # dict of special api that forward api's output will affect bacward api's output diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 872977b6161..0b522c1d6b3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -99,6 +99,42 @@ class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseAddCompositeDoubleGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // get input + paddle::Tensor y = this->GetSingleForwardInput("Y"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::optional ddx = + this->GetOptionalSingleOutputGrad(framework::GradVarName("X")); + paddle::optional ddy = + this->GetOptionalSingleOutputGrad(framework::GradVarName("Y")); + // get output + paddle::Tensor grad_out_grad_t = + this->GetSingleInputGrad(framework::GradVarName("Out")); + + // get attr + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument("We only support axis = -1 in composite " + "add_doubel_grad but we got: ", + axis)); + + paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t); + std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t); + + VLOG(6) << "Runing add_double_grad composite func"; + prim::add_double_grad( + y, out_grad, ddx, ddy, axis, grad_out_grad); + this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name); + } +}; + template class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker { public: @@ -139,7 +175,8 @@ REGISTER_OPERATOR( ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer, ops::ElementwiseAddDoubleGradMaker, - ops::ElementwiseAddDoubleGradMaker); + ops::ElementwiseAddDoubleGradMaker, + ops::ElementwiseAddCompositeDoubleGradOpMaker); REGISTER_OPERATOR( elementwise_add_grad_grad, diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index c4a1060497e..0fc1cab7391 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -118,6 +118,56 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseMulCompositeDoubleGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // get input + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor y = this->GetSingleForwardInput("Y"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::optional ddx = + this->GetOptionalSingleOutputGrad(framework::GradVarName("X")); + paddle::optional ddy = + this->GetOptionalSingleOutputGrad(framework::GradVarName("Y")); + + // get attr + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument("We only support axis = -1 in composite " + "add_doubel_grad but we got: ", + axis)); + + // get output + paddle::Tensor x_grad_t = this->GetSingleInputGrad("X"); + paddle::Tensor y_grad_t = this->GetSingleInputGrad("Y"); + paddle::Tensor grad_out_grad_t = + this->GetSingleInputGrad(framework::GradVarName("Out")); + + // get output ptr + paddle::Tensor* x_grad = this->GetOutputPtr(&x_grad_t); + paddle::Tensor* y_grad = this->GetOutputPtr(&y_grad_t); + paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t); + // get output orginal name + std::string x_grad_name = this->GetOutputName(x_grad_t); + std::string y_grad_name = this->GetOutputName(y_grad_t); + std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t); + + VLOG(6) << "Runing multiply_double_grad composite func"; + prim::multiply_double_grad( + x, y, out_grad, ddx, ddy, axis, x_grad, y_grad, grad_out_grad); + + // recover output name + this->RecoverOutputName(x_grad_t, x_grad_name); + this->RecoverOutputName(y_grad_t, y_grad_name); + this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name); + } +}; + template class ElementwiseMulTripleGradMaker : public framework::SingleGradOpMaker { public: @@ -162,7 +212,8 @@ REGISTER_OPERATOR( elementwise_mul_grad, ops::ElementwiseOpGrad, ops::ElementwiseMulDoubleGradMaker, - ops::ElementwiseMulDoubleGradMaker); + ops::ElementwiseMulDoubleGradMaker, + ops::ElementwiseMulCompositeDoubleGradOpMaker); REGISTER_OPERATOR( elementwise_mul_grad_grad, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index f1df6442006..0c8c9bc6e80 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -102,6 +102,42 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseSubCompositeDoubleGradOpMaker + : public prim::CompositeGradOpMakerBase { + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + public: + void Apply() override { + // get input + paddle::Tensor y = this->GetSingleForwardInput("Y"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::optional ddx = + this->GetOptionalSingleOutputGrad(framework::GradVarName("X")); + paddle::optional ddy = + this->GetOptionalSingleOutputGrad(framework::GradVarName("Y")); + // get output + paddle::Tensor grad_out_grad_t = + this->GetSingleInputGrad(framework::GradVarName("Out")); + + // get attr + int axis = static_cast(this->Attr("axis")); + PADDLE_ENFORCE_EQ( + axis, + -1, + phi::errors::InvalidArgument("We only support axis = -1 in composite " + "subtract_doubel_grad but we got: ", + axis)); + + paddle::Tensor* grad_out_grad = this->GetOutputPtr(&grad_out_grad_t); + std::string grad_out_grad_name = this->GetOutputName(grad_out_grad_t); + + VLOG(6) << "Runing subtract_double_grad composite func"; + prim::subtract_double_grad( + y, out_grad, ddx, ddy, axis, grad_out_grad); + this->RecoverOutputName(grad_out_grad_t, grad_out_grad_name); + } +}; + } // namespace operators } // namespace paddle @@ -124,7 +160,9 @@ REGISTER_OPERATOR( ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer, ops::ElementwiseSubDoubleGradMaker, - ops::ElementwiseSubDoubleGradMaker); + ops::ElementwiseSubDoubleGradMaker, + ops::ElementwiseSubCompositeDoubleGradOpMaker); + REGISTER_OPERATOR(elementwise_sub_grad_grad, ops::ElementwiseOpDoubleGradWithoutDXDY, ops::ElementwiseDoubleGradOpInplaceInferer, diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index c626d39cf7b..ec3bd574137 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -12,7 +12,6 @@ - bitwise_not - bitwise_or - bitwise_xor -- unsqueeze - exp - scale - matmul diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 4784f2fb617..be825f52362 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -249,6 +249,30 @@ void subtract_grad(const Tensor& x, } } +template +void subtract_double_grad(const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* grad_out_grad) { + if (grad_out_grad) { + // ddout = ddx - ddy + if (!grad_x_grad && !grad_y_grad) { + grad_out_grad = nullptr; + } else { + Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); + if (grad_x_grad) { + ddout = ddout + grad_x_grad.get(); + } + if (grad_y_grad) { + ddout = ddout - grad_y_grad.get(); + } + set_output(ddout, grad_out_grad); + } + } +} + template void add_grad(const Tensor& x, const Tensor& y, @@ -291,6 +315,30 @@ void add_grad(const Tensor& x, } } +template +void add_double_grad(const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* grad_out_grad) { + if (grad_out_grad) { + // ddout = ddx + ddy + if (!grad_x_grad && !grad_y_grad) { + grad_out_grad = nullptr; + } else { + Tensor ddout = full(phi::vectorize(grad_out.dims()), 0.0, y.dtype()); + if (grad_x_grad) { + ddout = ddout + grad_x_grad.get(); + } + if (grad_y_grad) { + ddout = ddout + grad_y_grad.get(); + } + set_output(ddout, grad_out_grad); + } + } +} + template void sum_grad(const Tensor& x, const Tensor& out_grad, @@ -328,7 +376,8 @@ void sum_grad(const Tensor& x, } } } - auto out_grad_ = unsqueeze(out_grad, axis_); + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); @@ -521,6 +570,75 @@ void multiply_grad(const Tensor& x, } } +template +void multiply_double_grad(const Tensor& x, + const Tensor& y, + const Tensor& grad_out, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad, + Tensor* grad_out_grad) { + if (x_grad) { + if (grad_y_grad) { + auto dx = grad_y_grad.get() * grad_out; + if (dx.dims() != x.dims()) { + auto axes = get_reduce_dims_from_out(dx.dims(), x.dims()); + if (!axes.size()) { + set_output(dx, x_grad); + } else { + auto dx_reduce = dx.sum(phi::vectorize(axes), dx.dtype(), false); + if (dx_reduce.dims().size() != x.dims().size()) { + dx_reduce = reshape(dx_reduce, x.shape()); + } + set_output(dx_reduce, x_grad); + } + } else { + set_output(dx, x_grad); + } + + } else { + x_grad = nullptr; + } + } + if (y_grad) { + if (grad_x_grad) { + auto dy = grad_x_grad.get() * grad_out; + if (dy.dims() != y.dims()) { + auto axes = get_reduce_dims_from_out(dy.dims(), y.dims()); + if (!axes.size()) { + set_output(dy, y_grad); + } else { + auto dy_reduce = dy.sum(phi::vectorize(axes), dy.dtype(), false); + if (dy_reduce.dims().size() != y.dims().size()) { + dy_reduce = reshape(dy_reduce, y.shape()); + } + set_output(dy_reduce, y_grad); + } + } else { + set_output(dy, y_grad); + } + } else { + y_grad = nullptr; + } + } + if (grad_out_grad) { + if (grad_x_grad && grad_y_grad) { + auto ddout = grad_x_grad.get() * y + grad_y_grad.get() * x; + set_output(ddout, grad_out_grad); + } else if (grad_x_grad) { + auto ddout = grad_x_grad.get() * y; + set_output(ddout, grad_out_grad); + } else if (grad_y_grad) { + auto ddout = grad_y_grad.get() * x; + set_output(ddout, grad_out_grad); + } else { + grad_out_grad = nullptr; + } + } +} + template void expand_grad(const Tensor& x, const Tensor& out_grad, @@ -1063,9 +1181,11 @@ void group_norm_grad(const Tensor& x, auto p2 = (d2 * mean - d1) * (inv_std_mul_s * inv_std * inv_std); auto p3 = -p2 * mean - d2 * inv_std_mul_s; - p1 = unsqueeze(p1, std::vector({3})); - p2 = unsqueeze(p2, std::vector({2, 3})); - p3 = unsqueeze(p3, std::vector({2, 3})); + auto first_shape = get_unsqueeze_dims(p1, std::vector({3})); + auto second_shape = get_unsqueeze_dims(p2, std::vector({2, 3})); + p1 = reshape(p1, first_shape); + p2 = reshape(p2, second_shape); + p3 = reshape(p3, second_shape); auto tmp_1 = reshape(out_grad_data, whole_group_shape) * p1; auto tmp_2 = reshape(x_data, whole_group_shape) * p2 + p3; auto x_grad_data = tmp_1 + tmp_2; @@ -1078,10 +1198,11 @@ void group_norm_grad(const Tensor& x, } if (scale_grad) { if (scale_ptr) { + auto third_shape = get_unsqueeze_dims(mean, std::vector({2})); auto tmp1 = (reshape(sum_y_grad_mul_x, shape_group) - reshape(sum_y_grad, shape_group) * - unsqueeze(mean, std::vector({2}))) * - unsqueeze(inv_std, std::vector({2})); + reshape(mean, third_shape)) * + reshape(inv_std, third_shape); auto scale_grad_tmp = reshape(tmp1.sum(std::vector({0}), dtype, false), IntArray(std::vector({C}))); @@ -1291,9 +1412,10 @@ void prod_grad(const Tensor& x, } } } - auto out_grad_ = unsqueeze(out_grad, axis_); + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); x_grad_tmp = out_grad_.expand(IntArray(x_dim)); - auto out_ = unsqueeze(out, axis_); + auto out_ = reshape(out, out_grad_shape); out_tmp = out_.expand(IntArray(x_dim)); } else { x_grad_tmp = out_grad.expand(IntArray(x_dim)); @@ -1346,8 +1468,9 @@ void max_grad(const Tensor& x, } } } - auto out_grad_ = unsqueeze(out_grad, axis_); - auto out_ = unsqueeze(out, axis_); + auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_); + auto out_grad_ = reshape(out_grad, out_grad_shape); + auto out_ = reshape(out, out_grad_shape); auto out_grad_tmp = out_grad_.expand(IntArray(x_dim)); auto out_tmp = out_.expand(IntArray(x_dim)); auto mask = equal(x, out_tmp); diff --git a/paddle/fluid/prim/api/manual_prim/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h index 338960d4484..d72da6461da 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/utils.h +++ b/paddle/fluid/prim/api/manual_prim/utils/utils.h @@ -114,5 +114,23 @@ static std::vector unsafe_vector_cast(const std::vector& src) { return dst; } +// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +static std::vector get_unsqueeze_dims(const Tensor& origin, + const IntArray& axis) { + auto origin_dims = origin.shape(); + auto total_shape_size = origin_dims.size() + axis.size(); + std::vector result; + int j = 0, k = 0; + for (size_t i = 0; i < total_shape_size; ++i) { + if (axis[j] == int64_t(i)) { + result.push_back(1); + j++; + } else { + result.push_back(origin_dims[k]); + k++; + } + } + return result; +} } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index a2189b7084b..97a0968f171 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1865,7 +1865,6 @@ kernel : func : tanh_double_grad composite : tanh_double_grad(out, grad_out, grad_x_grad, out_grad, grad_out_grad) - backward : tanh_triple_grad inplace : (grad_x_grad -> grad_out_grad) - backward_op : tanh_grad @@ -1892,18 +1891,6 @@ func : tanh_shrink_grad inplace : (out_grad -> x_grad) -- backward_op : tanh_triple_grad - forward : tanh_double_grad (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward) -> Tensor(grad_out_new), Tensor(grad_out_grad) - args : (Tensor out, Tensor grad_out_forward, Tensor grad_x_grad_forward, Tensor grad_out_new_grad, Tensor grad_out_grad_grad) - output : Tensor(out_grad), Tensor(grad_out_forward_grad), Tensor(grad_x_grad_forward_grad) - infer_meta : - func : GeneralTernaryGradInferMeta - param : [out, out, grad_x_grad_forward] - kernel : - func : tanh_triple_grad - inplace : (grad_x_grad_forward -> grad_out_forward_grad) - optional : grad_out_new_grad, grad_out_grad_grad - - backward_op : temporal_shift_grad forward : temporal_shift(Tensor x, int seg_num, float shift_ratio = 0.25f, str data_format = "NCHW") -> Tensor(out) args : (Tensor out_grad, int seg_num, float shift_ratio, str data_format) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 30bfa1f3847..5f971d98853 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -30,8 +30,8 @@ kernel : func : add_double_grad optional : grad_x_grad, grad_y_grad - backward : add_triple_grad inplace : (grad_x_grad -> grad_out_grad) + composite : add_double_grad(y, grad_out, grad_x_grad, grad_y_grad, axis, grad_out_grad) - backward_op : add_grad forward : add (Tensor x, Tensor y) -> Tensor(out) @@ -47,17 +47,6 @@ backward : add_double_grad inplace : (out_grad -> x_grad) -- backward_op : add_triple_grad - forward : add_double_grad (Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, int axis = -1) -> Tensor(grad_grad_out) - args : (Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_grad_out_grad, int axis = -1) - output : Tensor(grad_grad_x_grad), Tensor(grad_grad_y_grad) - infer_meta : - func : GeneralBinaryGradInferMeta - param : [grad_grad_x, grad_grad_y] - kernel : - func : add_triple_grad - inplace : (grad_grad_out_grad -> grad_grad_x_grad) - - backward_op : amax_grad forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis={}, bool keepdim=false, bool reduce_all=false) @@ -627,8 +616,8 @@ kernel : func : multiply_double_grad optional : grad_x_grad, grad_y_grad - backward : multiply_triple_grad inplace : (grad_x_grad -> grad_out_grad) + composite : multiply_double_grad(x, y, grad_out, grad_x_grad, grad_y_grad, axis, x_grad, y_grad, grad_out_grad) - backward_op : multiply_grad forward : multiply (Tensor x, Tensor y) -> Tensor(out) @@ -642,17 +631,6 @@ composite: multiply_grad(x, y, out_grad, axis, x_grad, y_grad) backward : multiply_double_grad -- backward_op : multiply_triple_grad - forward : multiply_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, int aixs = -1) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out) - args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, int axis = -1) - output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad) - infer_meta : - func : GeneralQuinaryGradInferMeta - param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y] - kernel : - func : multiply_triple_grad - optional : fwd_grad_grad_x, fwd_grad_grad_y, grad_x_grad, grad_y_grad, grad_grad_out_grad - - backward_op : norm_grad forward : norm (Tensor x, int axis, float epsilon, bool is_test) -> Tensor(out), Tensor(norm) args : (Tensor x, Tensor norm, Tensor out_grad, int axis, float epsilon, bool is_test) @@ -940,6 +918,7 @@ optional : grad_x_grad, grad_y_grad no_need_buffer : y, grad_out inplace : (grad_x_grad -> grad_out_grad) + composite : subtract_double_grad(y, grad_out, grad_x_grad, grad_y_grad, axis, grad_out_grad) - backward_op : subtract_grad forward : subtract (Tensor x, Tensor y) -> Tensor(out) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 46245aa4c92..c3d326eafbf 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -2274,7 +2274,7 @@ attrs : [bool use_mkldnn = false, bool use_cudnn = false] - op : tanh - backward : tanh_grad, tanh_double_grad (tanh_grad_grad), tanh_triple_grad + backward : tanh_grad, tanh_double_grad (tanh_grad_grad) inputs : x : X outputs : diff --git a/paddle/phi/api/yaml/tensor_operants.yaml b/paddle/phi/api/yaml/tensor_operants.yaml index 8c0b59fcdc4..c29cf1a933f 100644 --- a/paddle/phi/api/yaml/tensor_operants.yaml +++ b/paddle/phi/api/yaml/tensor_operants.yaml @@ -14,7 +14,6 @@ - bitwise_not - bitwise_or - bitwise_xor -- unsqueeze - exp - scale - matmul diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index fa7ae57ff9d..e6c01f83685 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -1492,17 +1492,21 @@ def _append_backward_ops_( or name in input_grad_names_set ) is_append_grad = False + input_grad_names = [] for op_desc in grad_op_desc: - input_grad_names = [ + input_grad_names += [ name for name in op_desc.input_arg_names() if is_grad_name(name) ] + if len(input_grad_names) == 0: + is_append_grad = True + break + + for op_desc in grad_op_desc: + # some code of gradient ops, like increment, are not very # standard, there is no @GRAD in these ops' inputs. - if len(input_grad_names) == 0: - is_append_grad = True - break if _some_in_set_(input_grad_names, input_grad_names_set): grad_op_descs.append(op_desc) diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index 882d5446540..cddeb3daeed 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -92,12 +92,16 @@ class TestTanhTripleGradCheck(unittest.TestCase): y = paddle.tanh(x) x_arr = np.random.random(shape).astype(dtype) x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) gradient_checker.triple_grad_check( [x], y, x_init=x_arr, place=place, eps=eps ) gradient_checker.triple_grad_check_for_dygraph( self.tanh_wrapper, [x], y, x_init=x_arr, place=place ) + core._set_prim_backward_enabled(False) def test_grad(self): paddle.enable_static() @@ -122,12 +126,16 @@ class TestTanhDoubleGradCheck(unittest.TestCase): y = paddle.tanh(x) x_arr = np.random.uniform(-1, 1, shape).astype(dtype) x_arr[np.abs(x_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) gradient_checker.double_grad_check( [x], y, x_init=x_arr, place=place, eps=eps ) gradient_checker.double_grad_check_for_dygraph( self.tanh_wrapper, [x], y, x_init=x_arr, place=place ) + core._set_prim_backward_enabled(False) def test_grad(self): paddle.enable_static() diff --git a/test/prim/prim/vjp/CMakeLists.txt b/test/prim/prim/vjp/CMakeLists.txt index d71096db0a1..c7cae170629 100644 --- a/test/prim/prim/vjp/CMakeLists.txt +++ b/test/prim/prim/vjp/CMakeLists.txt @@ -8,5 +8,7 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() +set_tests_properties(test_comp_high_grad PROPERTIES TIMEOUT 50) + add_subdirectory(eager) add_subdirectory(static) diff --git a/test/prim/prim/vjp/test_comp_high_grad.py b/test/prim/prim/vjp/test_comp_high_grad.py new file mode 100644 index 00000000000..77fa38684d9 --- /dev/null +++ b/test/prim/prim/vjp/test_comp_high_grad.py @@ -0,0 +1,334 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import unittest + +sys.path.append("../../../../python/paddle/fluid/tests/unittests") + +import gradient_checker +import numpy as np +import parameterized as param +from decorator_helper import prog_scope + +import paddle +from paddle import fluid +from paddle.fluid import core + + +@param.parameterized_class( + ('shape1', 'shape2'), + [ + ( + [2, 3, 4], + [2, 3, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 1], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 4], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 1], + ), + ], +) +class TestAddHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + cls.shape2 = cls.shape2 + + def add_wrapper(self, x): + return paddle.add(x[0], x[1]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.add(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-2, 2, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.add_wrapper, [x, y], y=out, x_init=[x_arr, y_arr], place=place + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.add(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.add_wrapper, [x, y], y=out, x_init=[x_arr, y_arr], place=place + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + +@param.parameterized_class( + ('shape1', 'shape2'), + [ + ( + [2, 3, 4], + [2, 3, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 1], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 4], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 1], + ), + ], +) +class TestSubtractHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + cls.shape2 = cls.shape2 + + def subtract_wrapper(self, x): + return paddle.subtract(x[0], x[1]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.subtract(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-2, 2, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.subtract_wrapper, + [x, y], + y=out, + x_init=[x_arr, y_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.subtract(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-2, 2, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.subtract_wrapper, + [x, y], + y=out, + x_init=[x_arr, y_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + +@param.parameterized_class( + ('shape1', 'shape2'), + [ + ( + [2, 3, 4], + [2, 3, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 4], + ), + ( + [2, 3, 3, 4], + [3, 1, 1], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 4], + ), + ( + [2, 3, 3, 4], + [2, 3, 1, 1], + ), + ], +) +class TestMultiplyHighGradCheck(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.shape1 = cls.shape1 + cls.shape2 = cls.shape2 + + def multiply_wrapper(self, x): + return paddle.multiply(x[0], x[1]) + + @prog_scope() + def func_double(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.multiply(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-2, 2, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.double_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.multiply_wrapper, + [x, y], + y=out, + x_init=[x_arr, y_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + @prog_scope() + def func_triple(self, place): + shape1 = self.shape1 + shape2 = self.shape2 + eps = 0.0005 + dtype = np.float64 + x = paddle.static.data('x', shape1, dtype=dtype) + y = paddle.static.data('y', shape2, dtype=dtype) + x.persistable = True + y.persistable = True + out = paddle.multiply(x, y) + x_arr = np.random.uniform(-1, 1, shape1).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape2).astype(dtype) + x_arr[np.abs(x_arr) < 0.005] = 0.002 + y_arr[np.abs(y_arr) < 0.005] = 0.002 + from paddle.fluid import core + + core._set_prim_backward_enabled(True) + gradient_checker.triple_grad_check( + [x, y], y=out, x_init=[x_arr, y_arr], place=place, eps=eps + ) + gradient_checker.triple_grad_check_for_dygraph( + self.multiply_wrapper, + [x, y], + y=out, + x_init=[x_arr, y_arr], + place=place, + ) + core._set_prim_backward_enabled(False) + + def test_high_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func_double(p) + self.func_triple(p) + + +if __name__ == '__main__': + unittest.main() -- GitLab