From 39c6765ad62cfc0d1a3b0cd65a877c56aac967e8 Mon Sep 17 00:00:00 2001 From: Xiaoxu Chen Date: Tue, 17 Jan 2023 10:16:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Add=20multiply,expand,div?= =?UTF-8?q?=20vjp=20rules=20(#49831)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support elementwise base func * fix compiling error and add test * support vjp for div using comp * remove additional change * fix dy2st error with magic num * fix dy magic num * another magic * another magic * another magic * add skip rename strategy * support add vjp * support add with new axis cal * support sub vjp * [prim] add multiply vjp rules * [prim] add multiply vjp rules * [prim] fix no infershape with composite in _append_backward_ops * [prim] add expand vjp rule * [prim] add exp vjp rule * uncomment infer shape for reshape/sum static prim api * [prim] fix tanh nullptr error * remove some print message * fix magic number in run_program relative tests @JiaBinYang * [prim] add expand,multiply,exp vjp rules * fix only support single direction reduce error * infer reduce dims using out dims Co-authored-by: JiabinYang <360788950@qq.com> --- .../elementwise/elementwise_mul_op.cc | 33 +++- paddle/fluid/operators/expand_v2_op.cc | 21 ++ .../manual/backward/composite_backward_api.h | 186 ++++++++++++++---- .../api/manual/prim_api/eager_prim_api.cc | 10 + .../fluid/prim/api/manual/prim_api/prim_api.h | 6 + .../api/manual/prim_api/static_prim_api.cc | 20 +- paddle/fluid/prim/api/manual/utils/utils.h | 48 ++--- paddle/phi/api/yaml/backward.yaml | 1 + paddle/phi/api/yaml/legacy_backward.yaml | 2 + .../prim/prim/vjp/eager/CMakeLists.txt | 7 - .../vjp/eager/test_comp_eager_exp_grad.py | 77 ++++++++ .../vjp/eager/test_comp_eager_expand_grad.py | 93 +++++++++ .../eager/test_comp_eager_multiply_grad.py | 100 ++++++++++ .../prim/vjp/static/test_comp_add_grad.py | 2 +- .../prim/vjp/static/test_comp_exp_grad.py | 122 ++++++++++++ .../prim/vjp/static/test_comp_expand_grad.py | 112 +++++++++++ .../vjp/static/test_comp_multiply_grad.py | 128 ++++++++++++ .../prim/vjp/static/test_comp_sub_grad.py | 5 +- 18 files changed, 897 insertions(+), 76 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 5048a40ddde..457ea83c0d6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace operators { @@ -63,6 +66,33 @@ class ElementwiseMulOpGradMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseMulGradCompositeOpMaker + : public prim::GradCompositeOpMakerBase { + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + + public: + void Apply() override { + auto x = this->GetSingleForwardInput("X"); + auto y = this->GetSingleForwardInput("Y"); + auto out_grad = this->GetSingleOutputGrad("Out"); + auto x_grad = this->GetSingleInputGrad("X"); + auto x_grad_p = this->GetOutputPtr(&x_grad); + auto x_grad_name = this->GetOutputName(x_grad); + auto y_grad = this->GetSingleInputGrad("Y"); + auto y_grad_p = this->GetOutputPtr(&y_grad); + auto y_grad_name = this->GetOutputName(y_grad); + prim::multiply_grad( + x, + y, + out_grad, + static_cast(this->Attr("axis")), + x_grad_p, + y_grad_p); + this->RecoverOutputName(x_grad, x_grad_name); + this->RecoverOutputName(y_grad, y_grad_name); + } +}; + template class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker { public: @@ -123,7 +153,8 @@ REGISTER_OPERATOR(elementwise_mul, ops::ElementwiseMulOpMaker, ops::ElementwiseOpInferVarType, ops::ElementwiseMulOpGradMaker, - ops::ElementwiseMulOpGradMaker); + ops::ElementwiseMulOpGradMaker, + ops::ElementwiseMulGradCompositeOpMaker); REGISTER_OPERATOR( elementwise_mul_grad, ops::ElementwiseOpGrad, diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index cbd322f3876..7b24c31ff28 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -20,6 +20,9 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" @@ -190,6 +193,23 @@ class ExpandV2GradOpMaker : public framework::SingleGradOpMaker { } }; +class ExpandV2GradCompositeOpMaker : public prim::GradCompositeOpMakerBase { + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + + public: + void Apply() override { + auto x = this->GetSingleForwardInput("X"); + auto out_grad = this->GetSingleOutputGrad("Out"); + auto x_grad = this->GetSingleInputGrad("X"); + auto x_grad_p = this->GetOutputPtr(&x_grad); + auto x_grad_name = this->GetOutputName(x_grad); + auto shape = this->Attr>("shape"); + prim::expand_grad( + x, out_grad, paddle::experimental::IntArray(shape), x_grad_p); + this->RecoverOutputName(x_grad, x_grad_name); + } +}; + template class ExpandV2DoubleGradOpMaker : public framework::SingleGradOpMaker { public: @@ -223,6 +243,7 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(expand_v2, ops::ExpandV2Op, ops::ExpandV2OpMaker, + ops::ExpandV2GradCompositeOpMaker, ops::ExpandV2GradOpMaker, ops::ExpandV2GradOpMaker, ExpandInferShapeFunctor); diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 19898e0c562..4ededb74f38 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -23,16 +23,17 @@ namespace prim { using Tensor = paddle::experimental::Tensor; using IntArray = paddle::experimental::IntArrayBase; -// using IntArray = paddle::experimental::IntArray; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h template void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { + if (!grad_x) return; auto tmp = pow(out, 2.0); tmp = scale(tmp, -1.0, 1.0, true); auto grad_x_tmp = multiply(grad_out, tmp); grad_x->set_impl(grad_x_tmp.impl()); } + template void subtract_grad(const Tensor& x, const Tensor& y, @@ -42,25 +43,33 @@ void subtract_grad(const Tensor& x, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + by_pass(scale_out_grad, dy); + } else { + auto dy_reduce_res = sum( + scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } } else { by_pass(scale_out_grad, dy); } } if (dx) { - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dx); + } else { + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } } else { by_pass(out_grad, dx); } @@ -75,25 +84,34 @@ void add_grad(const Tensor& x, Tensor* dx, Tensor* dy) { if (dy) { - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dy); + } else { + auto dy_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } + } else { by_pass(out_grad, dy); } } if (dx) { - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + by_pass(out_grad, dx); + } else { + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } } else { by_pass(out_grad, dx); } @@ -130,9 +148,9 @@ void sum_grad(const Tensor& x, axis_ = axis.GetData(); } auto out_grad_ = unsqueeze(out_grad, axis_); - x_grad_tmp = expand(out_grad_, x_dim); + x_grad_tmp = expand(out_grad_, IntArray(x_dim)); } else { - x_grad_tmp = expand(out_grad, x_dim); + x_grad_tmp = expand(out_grad, IntArray(x_dim)); } x_grad->set_impl(x_grad_tmp.impl()); @@ -152,13 +170,17 @@ void divide_grad(const Tensor& x, auto tmp1 = divide(x, tmp0); auto tmp2 = scale(tmp1, -1.0, 0.0, true); auto dy_res = multiply(tmp2, out_grad); - if (phi::product(x.dims()) > phi::product(y.dims())) { + if (x.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); - auto dy_reduce_res = - sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); - auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); - dy->set_impl(dy_tmp.impl()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + if (!reduce_dim.size()) { + dy->set_impl(dy_res.impl()); + } else { + auto dy_reduce_res = + sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } } else { dy->set_impl(dy_res.impl()); } @@ -168,13 +190,18 @@ void divide_grad(const Tensor& x, auto one_tensor = full(phi::vectorize(y.dims()), 1.0); auto tmp0 = divide(one_tensor, y); auto dx_res = multiply(tmp0, out_grad); - if (phi::product(y.dims()) > phi::product(x.dims())) { + if (y.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); - auto dx_reduce_res = - sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); - auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); - dx->set_impl(dx_tmp.impl()); + auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + if (!reduce_dim.size()) { + dx->set_impl(dx_res.impl()); + } else { + auto dx_reduce_res = + sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } + } else { dx->set_impl(dx_res.impl()); } @@ -190,5 +217,86 @@ void sqrt_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { x_grad->set_impl(x_grad_tmp.impl()); } } + +template +void multiply_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* x_grad, + Tensor* y_grad) { + if (x_grad) { + auto x_grad_unreduce = multiply(out_grad, y); + if (x.dims() != y.dims()) { + auto axes = get_reduce_dims(x.dims(), y.dims()); + if (!axes.size()) { + x_grad->set_impl(x_grad_unreduce.impl()); + } else { + auto x_grad_reduced = sum(x_grad_unreduce, + phi::vectorize(axes), + x_grad_unreduce.dtype(), + false); + if (x_grad_reduced.dims().size() != x.dims().size()) { + x_grad_reduced = reshape(x_grad_reduced, x.shape()); + } + x_grad->set_impl(x_grad_reduced.impl()); + } + } else { + x_grad->set_impl(x_grad_unreduce.impl()); + } + } + if (y_grad) { + auto y_grad_unreduce = multiply(out_grad, x); + if (y.dims() != x.dims()) { + auto axes = get_reduce_dims(y.dims(), x.dims()); + if (!axes.size()) { + y_grad->set_impl(y_grad_unreduce.impl()); + } else { + auto y_grad_reduced = sum(y_grad_unreduce, + phi::vectorize(axes), + y_grad_unreduce.dtype(), + false); + if (y_grad_reduced.dims().size() != y.dims().size()) { + y_grad_reduced = reshape(y_grad_reduced, y.shape()); + } + y_grad->set_impl(y_grad_reduced.impl()); + } + } else { + y_grad->set_impl(y_grad_unreduce.impl()); + } + } +} + +template +void expand_grad(const Tensor& x, + const Tensor& out_grad, + const IntArray& shape, + Tensor* x_grad) { + if (x_grad) { + auto out_dims = phi::make_ddim(shape.GetData()); + if (out_dims != x.dims()) { + auto axes = get_reduce_dims(x.dims(), out_dims); + if (!axes.size()) { + by_pass(out_grad, x_grad); + } else { + auto reduced = sum(out_grad, phi::vectorize(axes), x.dtype(), false); + if (reduced.dims().size() != x.dims().size()) { + reduced = reshape(reduced, x.shape()); + } + x_grad->set_impl(reduced.impl()); + } + } else { + by_pass(out_grad, x_grad); + } + } +} + +template +void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + x_grad->set_impl(multiply(out_grad, out).impl()); + } +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc index 7dac02ea5b2..fa6e2f42779 100644 --- a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc @@ -67,5 +67,15 @@ template <> Tensor reshape(Tensor x, IntArray shape) { return ::reshape_ad_func(x, shape); } + +template <> +Tensor exp(const Tensor& x) { + return ::exp_ad_func(x); +} + +template +Tensor expand(const Tensor& x, const IntArray& shape) { + return ::expand_ad_func(x, shape); +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/prim_api.h b/paddle/fluid/prim/api/manual/prim_api/prim_api.h index 5465cdb601e..c7edf80a2f3 100644 --- a/paddle/fluid/prim/api/manual/prim_api/prim_api.h +++ b/paddle/fluid/prim/api/manual/prim_api/prim_api.h @@ -57,5 +57,11 @@ Tensor sum(Tensor x, template Tensor reshape(Tensor x, IntArray shape); + +template +Tensor expand(const Tensor& x, const IntArray& shape); + +template +Tensor exp(const Tensor& x); } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc index 0bf14b5955b..62854061ef5 100644 --- a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc @@ -199,7 +199,7 @@ Tensor sum(Tensor x, "Out", {std::static_pointer_cast(out.impl())->Name()}); op->CheckAttrs(); op->InferVarType(block); - // TODO(jiabin): This may have runtime shape skip infershape for now. + // TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. return out; } @@ -222,7 +222,23 @@ Tensor reshape(Tensor x, paddle::experimental::IntArray shape) { "Out", {std::static_pointer_cast(out.impl())->Name()}); op->CheckAttrs(); op->InferVarType(block); - // TODO(jiabin): This may have runtime shape skip infershape for now. + // TODO(jiabin, cxxly): This may have runtime shape skip infershape for now. + return out; +} + +template <> +Tensor exp(const Tensor& x) { + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("exp"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); return out; } } // namespace prim diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual/utils/utils.h index 2a77d0cffdc..22127d30d31 100644 --- a/paddle/fluid/prim/api/manual/utils/utils.h +++ b/paddle/fluid/prim/api/manual/utils/utils.h @@ -16,11 +16,12 @@ #include #include #include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/ddim.h" -using IntArray = paddle::experimental::IntArray; + namespace paddle { namespace prim { // We put some api like utils here @@ -36,43 +37,42 @@ paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, template void by_pass(const paddle::experimental::Tensor& x, paddle::experimental::Tensor* out); + // These method don't need to be specified -static phi::DDim get_reduce_dims(const phi::DDim& x_dims, - const phi::DDim& y_dims) { +static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, + const phi::DDim& in_dims) { std::vector result; - PADDLE_ENFORCE_GE(phi::product(x_dims), - phi::product(y_dims), - phi::errors::InvalidArgument( - "Only x_dims >= y_dims is accepted for " - "get_reduce_dims, but we got x_dims: %s, y_dims: %s", - x_dims, - y_dims)); - int bat = x_dims.size() - y_dims.size(); + int bat = dout_dims.size() - in_dims.size(); for (int i = 0; i < bat; ++i) { result.push_back(i); } - for (int i = 0; i < y_dims.size(); ++i) { - if (y_dims[i] == 1) { + for (int i = 0; i < in_dims.size(); ++i) { + if (in_dims[i] == 1) { result.push_back(i + bat); } else { PADDLE_ENFORCE_EQ( - y_dims[i], - x_dims[i + bat], + in_dims[i], + dout_dims[i + bat], platform::errors::InvalidArgument( "ReduceDims dimension mismatch. Operands could " - "not be broadcast together with the shape of x_dims = [%s] and " - "the shape of y_dims = [%s]. Received [%d] in X is not equal to " + "not be broadcast together with the shape of dout = [%s] and " + "the shape of in_dims = [%s]. Received [%d] in X is not equal to " "[%d] in Y at i:%d.", - x_dims, - y_dims, - x_dims[i + bat], - y_dims[i], + dout_dims, + in_dims, + dout_dims[i + bat], + in_dims[i], i)); } } - auto res_dims = phi::make_ddim(result); - VLOG(4) << "Reduce Dims is: " << res_dims; - return res_dims; + return phi::make_ddim(result); } + +static phi::DDim get_reduce_dims(const phi::DDim& x_dims, + const phi::DDim& y_dims) { + auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims); + return get_reduce_dims_from_out(out_dims, x_dims); +} + } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index bedc73f9683..23158d79401 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -431,6 +431,7 @@ kernel : func : exp_grad inplace : (out_grad -> x_grad) + composite : exp_grad(out, out_grad, x_grad) - backward_op : expm1_grad forward : expm1 (Tensor x) -> Tensor(out) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 16e5244ffa9..10ca2aee865 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -475,6 +475,7 @@ func : expand_grad no_need_buffer : x backward : expand_double_grad + composite: expand_grad(x, out_grad, shape, x_grad_p) - backward_op : exponential__grad forward : exponential_ (Tensor x, float lam) -> Tensor(out) @@ -880,6 +881,7 @@ param : [x, y] kernel : func : multiply_grad + composite: multiply_grad(x, y, out_grad, axis, x_grad, y_grad) backward : multiply_double_grad - backward_op : multiply_triple_grad diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt index 7d5fc1006d1..863a484c466 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt @@ -8,10 +8,3 @@ set(GC_ENVS FLAGS_eager_delete_tensor_gb=0.0) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() - -set_tests_properties(test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60) -set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60) -set_tests_properties(test_comp_eager_sum_grad PROPERTIES TIMEOUT 60) -set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60) -set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60) -set_tests_properties(test_comp_eager_sqrt_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py new file mode 100644 index 00000000000..e81314ba041 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_exp_grad.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import autograd +import autograd.numpy +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal', 'cotangent', 'dtype'), + [ + (np.random.rand(10, 10), np.random.rand(10, 10), np.float32), + ], +) +class TestExpGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + core.set_prim_enabled(True) + cls.primal = cls.primal.astype(cls.dtype) + if cls.cotangent is not None: + cls.cotangent = cls.cotangent.astype(cls.dtype) + + @classmethod + def tearDownClass(cls): + core.set_prim_enabled(False) + + def test_exp_grad_comp(self): + def actual(primal, cotangent): + primal = paddle.to_tensor(primal) + primal.stop_gradient = False + return paddle.grad( + paddle.exp(primal), primal, paddle.to_tensor(cotangent) + )[0] + + def desired(primal, cotangent): + cotangent = ( + np.ones_like(cotangent, dtype=primal.dtype) + if cotangent is None + else cotangent + ) + return autograd.make_vjp(autograd.numpy.exp)(primal)[0](cotangent) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent), + desired=desired(self.primal, self.cotangent), + rtol=1e-6, + atol=0, + ) + + def test_stop_gradients(self): + with self.assertRaises(ValueError): + primal = paddle.to_tensor(self.primal) + primal.stop_gradient = True + return paddle.grad( + paddle.exp(primal), primal, paddle.to_tensor(self.cotangent) + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py new file mode 100644 index 00000000000..c4de565dc50 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_expand_grad.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('name', 'primal', 'cotangent', 'shape', 'dtype'), + ( + ( + 'same_shape', + np.random.rand(10, 10), + np.random.rand(10, 10), + (10, 10), + np.float32, + ), + ( + 'same_rank', + np.random.rand(1, 10), + np.random.rand(10, 10), + (10, 10), + np.float32, + ), + ( + 'same_rank', + np.random.rand(10, 1, 10, 1), + np.random.rand(10, 10, 10, 10), + (10, 10, 10, 10), + np.float32, + ), + ( + 'diff_rank', + np.random.rand(1, 10, 1), + np.random.rand(10, 10, 10, 10), + (10, 10, 10, 10), + np.float32, + ), + ), +) +class TestExpandGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal = cls.primal.astype(cls.dtype) + cls.cotangent = cls.cotangent.astype(cls.dtype) + + @classmethod + def tearDownClass(cls): + core.set_prim_enabled(False) + + def test_comp(self): + def func(primal, cotangent, shape): + primal = paddle.to_tensor(primal) + primal.stop_gradient = False + cotangent = paddle.to_tensor(cotangent) + return paddle.grad(paddle.expand(primal, shape), primal, cotangent)[ + 0 + ] + + def actual(primal, cotangent, shape): + core.set_prim_enabled(True) + return func(primal, cotangent, shape) + + def desired(primal, cotangent, shape): + core.set_prim_enabled(False) + return func(primal, cotangent, shape) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, self.shape), + desired=desired(self.primal, self.cotangent, self.shape), + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py new file mode 100644 index 00000000000..59daf91ab8b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_multiply_grad.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('name', 'primals', 'stop_gradients', 'cotangents', 'dtype'), + ( + ( + 'test_normal_case', + (np.random.rand(2, 3, 4), np.random.rand(2, 3, 4)), + (False, False), + (np.random.rand(2, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_diff_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(3, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_same_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_stop_gradient', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, True), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_reduce_axe_empty', + (np.random.rand(2, 3, 3, 4), np.random.rand(2, 1, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ), +) +class TestMultiplyGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primals = tuple(primal.astype(cls.dtype) for primal in cls.primals) + cls.cotangents = tuple(co.astype(cls.dtype) for co in cls.cotangents) + + def as_tuple(self, x): + return (x,) if isinstance(x, paddle.Tensor) else x + + def vjp(self): + primals, cotangents = self.primals, self.cotangents + primals = tuple(paddle.to_tensor(primal) for primal in primals) + for primal, flag in zip(primals, self.stop_gradients): + primal.stop_gradient = flag + cotangents = tuple(paddle.to_tensor(co) for co in cotangents) + out = self.as_tuple(paddle.multiply(*primals)) + grads = paddle.grad(out, primals, cotangents, allow_unused=True) + return [g for g in grads if g is not None] + + def test_comp(self): + core.set_prim_enabled(True) + actual = self.vjp() + + core.set_prim_enabled(False) + desired = self.vjp() + + for i, j in zip(actual, desired): + np.testing.assert_allclose( + i, + j, + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py index 9c392663be0..a9464470189 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -51,7 +51,7 @@ from paddle.fluid import core ), ], ) -class TestDivGradComp(unittest.TestCase): +class TestAddGradComp(unittest.TestCase): @classmethod def setUpClass(cls): cls.primal0 = cls.primal0.astype(cls.dtype) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py new file mode 100644 index 00000000000..c1c76631232 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_exp_grad.py @@ -0,0 +1,122 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import autograd +import autograd.numpy +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal', 'cotangent', 'dtype'), + [ + (np.random.rand(10, 10), np.random.rand(10, 10), np.float32), + (np.random.rand(10, 10), None, np.float32), + ], +) +class TestExpGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + core.set_prim_enabled(True) + cls.primal = cls.primal.astype(cls.dtype) + if cls.cotangent is not None: + cls.cotangent = cls.cotangent.astype(cls.dtype) + + @classmethod + def tearDownClass(cls): + core.set_prim_enabled(False) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_exp_grad_comp(self): + def actual(primal, cotangent): + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = False + v = ( + None + if cotangent is None + else paddle.static.data( + 'cotangent', cotangent.shape, cotangent.dtype + ) + ) + y = paddle.exp(x) + x_cotangent = paddle.static.gradients(y, x, v) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=x_cotangent, + )[0] + + def desired(primal, cotangent): + cotangent = ( + np.ones_like(cotangent, dtype=primal.dtype) + if cotangent is None + else cotangent + ) + return autograd.make_vjp(autograd.numpy.exp)(primal)[0](cotangent) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent), + desired=desired(self.primal, self.cotangent), + rtol=1e-6, + atol=0, + ) + + def test_stop_gradient(self): + def actual(primal, cotangent): + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = True + v = ( + None + if cotangent is None + else paddle.static.data( + 'cotangent', cotangent.shape, cotangent.dtype + ) + ) + y = paddle.exp(x) + x_cotangent = paddle.static.gradients(y, x, v) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=x_cotangent, + ) + + def desired(primal, cotangent): + return [] + + self.assertEqual( + actual(self.primal, self.cotangent), + desired(self.primal, self.cotangent), + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py new file mode 100644 index 00000000000..c322074d34d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_expand_grad.py @@ -0,0 +1,112 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('name', 'primal', 'cotangent', 'shape', 'dtype'), + ( + ( + 'same_shape', + np.random.rand(10, 10), + np.random.rand(10, 10), + (10, 10), + np.float32, + ), + ( + 'same_rank', + np.random.rand(1, 10), + np.random.rand(10, 10), + (10, 10), + np.float32, + ), + ( + 'same_rank', + np.random.rand(10, 1, 10, 1), + np.random.rand(10, 10, 10, 10), + (10, 10, 10, 10), + np.float32, + ), + ( + 'diff_rank', + np.random.rand(1, 10, 1), + np.random.rand(10, 10, 10, 10), + (10, 10, 10, 10), + np.float32, + ), + ( + 'single_direction_broadcast', + np.random.rand(10, 10, 10, 10), + np.random.rand(1, 10, 1), + (10, 10, 10, 10), + np.float32, + ), + ), +) +class TestExpandGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal = cls.primal.astype(cls.dtype) + cls.cotangent = cls.cotangent.astype(cls.dtype) + paddle.enable_static() + + @classmethod + def tearDownClass(cls): + paddle.disable_static() + core.set_prim_enabled(False) + + def test_comp(self): + def func(primal, cotangent, shape): + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal', primal.shape, primal.dtype) + x.stop_gradient = False + v = paddle.static.data( + 'cotangent', cotangent.shape, cotangent.dtype + ) + y = paddle.expand(x, shape) + x_cotangent = paddle.static.gradients(y, x) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={'primal': primal, 'cotangent': cotangent}, + fetch_list=x_cotangent, + )[0] + + def actual(primal, cotangent, shape): + core.set_prim_enabled(True) + return func(primal, cotangent, shape) + + def desired(primal, cotangent, shape): + core.set_prim_enabled(False) + return func(primal, cotangent, shape) + + np.testing.assert_allclose( + actual=actual(self.primal, self.cotangent, self.shape), + desired=desired(self.primal, self.cotangent, self.shape), + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py new file mode 100644 index 00000000000..63e8a4f1bbf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_multiply_grad.py @@ -0,0 +1,128 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core, framework + + +@param.parameterized_class( + ('name', 'primals', 'stop_gradients', 'cotangents', 'dtype'), + ( + ( + 'test_normal_case', + (np.random.rand(2, 3, 4), np.random.rand(2, 3, 4)), + (False, False), + (np.random.rand(2, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_diff_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(3, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_broadcast_same_rank', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, False), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_stop_gradient', + (np.random.rand(2, 3, 1, 4), np.random.rand(2, 1, 3, 4)), + (False, True), + (np.random.rand(2, 3, 3, 4),), + np.float32, + ), + ( + 'test_reduce_axe_empty', + (np.random.rand(2, 3, 3, 4), np.random.rand(2, 1, 3, 4)), + (False, False), + (np.random.rand(2, 1, 3, 1),), + np.float32, + ), + ), +) +class TestMultiplyGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primals = tuple(primal.astype(cls.dtype) for primal in cls.primals) + cls.cotangents = tuple(co.astype(cls.dtype) for co in cls.cotangents) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def as_tuple(self, x): + return (x,) if isinstance(x, framework.Variable) else x + + def vjp(self): + primals, cotangents = self.primals, self.cotangents + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + primals = tuple( + paddle.static.data(f'primal{i}', primal.shape, primal.dtype) + for i, primal in enumerate(primals) + ) + for primal, flag in zip(primals, self.stop_gradients): + primal.stop_gradient = flag + cotangents = tuple( + paddle.static.data(f'cotangent{i}', co.shape, co.dtype) + for i, co in enumerate(cotangents) + ) + out = self.as_tuple(paddle.multiply(*primals)) + grads = paddle.static.gradients(out, primals) + exe = paddle.static.Executor() + exe.run(sp) + return exe.run( + program=mp, + feed={ + **{ + f'primal{i}': primal + for i, primal in enumerate(self.primals) + }, + **{f'cotangent{i}': co for i, co in enumerate(self.cotangents)}, + }, + fetch_list=[g for g in grads if g is not None], + ) + + def test_comp(self): + + core.set_prim_enabled(True) + actual = self.vjp() + + core.set_prim_enabled(False) + desired = self.vjp() + + self.assertEqual(len(actual), len(desired)) + for i, j in zip(actual, desired): + np.testing.assert_allclose( + i, + j, + rtol=1e-6, + atol=0, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py index d7cab193a99..8baf91ba0dd 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -39,14 +39,15 @@ from paddle.fluid import core np.random.rand(2, 3, 1, 4), np.float32, ), + (np.random.rand(2, 3, 3, 4), np.random.rand(2, 3, 1, 4), np.float32), ( - np.random.rand(2, 3, 3, 4), + np.random.rand(2, 1, 3, 4), np.random.rand(2, 3, 1, 4), np.float32, ), ( np.random.rand(2, 3, 3, 4), - np.random.rand(2, 3, 1, 1), + np.random.rand(2, 1, 1, 4), np.float32, ), ], -- GitLab