From 561f901393817d52af3d31f5cbe2ed40c52ccbbe Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Fri, 13 Jan 2023 15:54:46 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Support=20elementwise=20r?= =?UTF-8?q?elated=20VJP=20with=20primitives=20(#49784)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * support elementwise base func * fix compiling error and add test * remove additional param * support vjp for div using comp * remove additional change * fix dy2st error with magic num * fix dy magic num * another magic * another magic * add more test * fix windows problem * another magic * fix windows compile * invoke ci * add skip rename strategy * support add vjp * fix test_tanh * support add with new axis cal * fix resnet and some test * add composite log * support sub vjp --- paddle/fluid/operators/CMakeLists.txt | 4 +- .../elementwise/elementwise_add_op.cc | 37 ++++- .../elementwise/elementwise_div_op.cc | 30 ++++- .../elementwise/elementwise_sub_op.cc | 28 +++- .../generator/templates/operator_utils.c.j2 | 1 + .../manual/backward/composite_backward_api.h | 109 +++++++++++++++ .../api/manual/prim_api/eager_prim_api.cc | 24 +++- .../fluid/prim/api/manual/prim_api/prim_api.h | 21 ++- .../api/manual/prim_api/static_prim_api.cc | 98 ++++++++++++++ .../prim/api/manual/utils/eager_utils.cc | 5 +- .../prim/api/manual/utils/static_utils.cc | 18 +++ paddle/fluid/prim/api/manual/utils/utils.h | 44 +++++- paddle/fluid/prim/tests/CMakeLists.txt | 4 +- paddle/fluid/prim/utils/static/desc_tensor.h | 3 + paddle/phi/api/yaml/legacy_backward.yaml | 3 + paddle/phi/core/extended_tensor.h | 2 +- .../unittests/dygraph_to_static/test_bert.py | 13 ++ .../dygraph_to_static/test_resnet.py | 16 +++ .../dygraph_to_static/test_resnet_amp.py | 15 +++ .../test_resnet_pure_fp16.py | 18 +++ .../dygraph_to_static/test_resnet_v2.py | 15 +++ .../prim/prim/vjp/eager/CMakeLists.txt | 5 +- .../vjp/eager/test_comp_eager_add_grad.py | 101 ++++++++++++++ .../vjp/eager/test_comp_eager_div_grad.py | 101 ++++++++++++++ .../vjp/eager/test_comp_eager_sub_grad.py | 101 ++++++++++++++ ...d_comp.py => test_comp_eager_tanh_grad.py} | 0 .../prim/prim/vjp/static/CMakeLists.txt | 6 +- .../prim/vjp/static/test_comp_add_grad.py | 127 ++++++++++++++++++ .../vjp/static/test_comp_add_tanh_grad.py | 125 +++++++++++++++++ .../prim/vjp/static/test_comp_div_grad.py | 125 +++++++++++++++++ .../prim/vjp/static/test_comp_sub_grad.py | 125 +++++++++++++++++ ...nh_grad_comp.py => test_comp_tanh_grad.py} | 2 +- ...st_comp_get_grad_op_desc_prim_disabled.py} | 0 ...est_comp_get_grad_op_desc_prim_enabled.py} | 0 34 files changed, 1310 insertions(+), 16 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py rename python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/{test_eager_tanh_grad_comp.py => test_comp_eager_tanh_grad.py} (100%) create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py create mode 100644 python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py rename python/paddle/fluid/tests/unittests/prim/prim/vjp/static/{test_tanh_grad_comp.py => test_comp_tanh_grad.py} (97%) rename python/paddle/fluid/tests/unittests/prim/{test_get_grad_op_desc_prim_disabled.py => test_comp_get_grad_op_desc_prim_disabled.py} (100%) rename python/paddle/fluid/tests/unittests/prim/{test_get_grad_op_desc_prim_enabled.py => test_comp_get_grad_op_desc_prim_enabled.py} (100%) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 6e3bdda7fc..a287298ef0 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -148,7 +148,7 @@ cc_library(ops_extra_info SRCS ops_extra_info.cc DEPS attribute cudnn_workspace_ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows_utils lapack_function lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor generator static_prim_api) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc static_prim_api static_utils static_global_utils prim_utils) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc_functor matrix_inverse matrix_solve) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper ps_gpu_wrapper) @@ -216,7 +216,7 @@ endif() set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") add_subdirectory(benchmark) -cc_test(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_op) +cc_test_old(op_debug_string_test SRCS op_debug_string_test.cc DEPS elementwise_add_op ${COMMON_OP_DEPS}) if (WITH_ASCEND_CL) cc_test(transpose_op_npu_test SRCS transpose_op_npu_test.cc DEPS op_registry transpose_op scope device_context enforce executor) endif() diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index b4164846aa..11e0fa7dd1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" - +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace framework { class OpDesc; @@ -49,6 +51,29 @@ class ElementwiseAddOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseAddGradCompositeOpMaker + : public prim::GradCompositeOpMakerBase { + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + + public: + void Apply() override { + paddle::experimental::Tensor x = this->GetSingleForwardInput("X"); + paddle::experimental::Tensor y = this->GetSingleForwardInput("Y"); + paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); + auto dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y"); + auto dy_ptr = this->GetOutputPtr(&dy); + std::string dy_name = this->GetOutputName(dy); + int axis = static_cast(this->Attr("axis")); + VLOG(3) << "Runing add_grad composite func"; + prim::add_grad(x, y, out_grad, axis, dx_ptr, dy_ptr); + this->RecoverOutputName(dx, dx_name); + this->RecoverOutputName(dy, dy_name); + } +}; + template class ElementwiseAddDoubleGradMaker : public framework::SingleGradOpMaker { public: @@ -91,9 +116,17 @@ class ElementwiseAddTripleGradMaker : public framework::SingleGradOpMaker { } // namespace paddle REGISTER_ELEMWISE_GRAD_MAKER(elementwise_add, Add); -REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add); +REGISTER_OPERATOR(elementwise_add, + ::paddle::operators::ElementwiseOp, + ::paddle::operators::ElementwiseAddOpMaker, + ::paddle::operators::ElementwiseOpInferVarType, + elementwise_addGradMaker<::paddle::framework::OpDesc>, + elementwise_addGradMaker<::paddle::imperative::OpBase>, + ::paddle::operators::ElementwiseAddGradCompositeOpMaker, + ::paddle::operators::ElementwiseOpInplaceInferer); namespace ops = paddle::operators; + REGISTER_OPERATOR( elementwise_add_grad, ops::ElementwiseOpGrad, diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index f7a9b993c0..3d62792d85 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -19,7 +19,9 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" - +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace operators { @@ -65,6 +67,31 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker { } }; +class ElementwiseDivGradCompositeOpMaker + : public prim::GradCompositeOpMakerBase { + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + + public: + void Apply() override { + paddle::experimental::Tensor x = this->GetSingleForwardInput("X"); + paddle::experimental::Tensor y = this->GetSingleForwardInput("Y"); + paddle::experimental::Tensor out = this->GetSingleForwardOutput("Out"); + paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); + auto dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y"); + auto dy_ptr = this->GetOutputPtr(&dy); + std::string dy_name = this->GetOutputName(dy); + int axis = static_cast(this->Attr("axis")); + VLOG(3) << "Runing div_grad composite func"; + prim::divide_grad( + x, y, out, out_grad, axis, dx_ptr, dy_ptr); + this->RecoverOutputName(dx, dx_name); + this->RecoverOutputName(dy, dy_name); + } +}; + template class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker { public: @@ -96,6 +123,7 @@ REGISTER_OPERATOR(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseDivGradCompositeOpMaker, ops::ElementwiseDivGradOpMaker, ops::ElementwiseDivGradOpMaker); diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index c73192ae79..be839f123a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" - +#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { namespace framework { class OpDesc; @@ -52,6 +54,29 @@ class ElementwiseSubOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseSubGradCompositeOpMaker + : public prim::GradCompositeOpMakerBase { + using prim::GradCompositeOpMakerBase::GradCompositeOpMakerBase; + + public: + void Apply() override { + paddle::experimental::Tensor x = this->GetSingleForwardInput("X"); + paddle::experimental::Tensor y = this->GetSingleForwardInput("Y"); + paddle::experimental::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::experimental::Tensor dx = this->GetSingleInputGrad("X"); + auto dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + paddle::experimental::Tensor dy = this->GetSingleInputGrad("Y"); + auto dy_ptr = this->GetOutputPtr(&dy); + std::string dy_name = this->GetOutputName(dy); + int axis = static_cast(this->Attr("axis")); + VLOG(3) << "Runing sub_grad composite func"; + prim::subtract_grad(x, y, out_grad, axis, dx_ptr, dy_ptr); + this->RecoverOutputName(dx, dx_name); + this->RecoverOutputName(dy, dy_name); + } +}; + template class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker { public: @@ -84,6 +109,7 @@ REGISTER_OPERATOR(elementwise_sub, ::paddle::operators::ElementwiseOpInferVarType, elementwise_subGradMaker<::paddle::framework::OpDesc>, elementwise_subGradMaker<::paddle::imperative::OpBase>, + ::paddle::operators::ElementwiseSubGradCompositeOpMaker, ::paddle::operators::ElementwiseOpInplaceInferer); REGISTER_OPERATOR( diff --git a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 index 742ebcc249..37b5075235 100644 --- a/paddle/fluid/operators/generator/templates/operator_utils.c.j2 +++ b/paddle/fluid/operators/generator/templates/operator_utils.c.j2 @@ -635,6 +635,7 @@ class {{op_name | to_composite_grad_opmaker_name}} : public prim::GradCompositeO {%- endmacro %} {% macro call_composite_backward_api(composite_op_dict) %} + VLOG(3) << "Runing {{composite_op_dict["composite"]["func_name"]}} composite func"; prim::{{composite_op_dict["composite"]["func_name"]}}({{composite_op_dict["composite"]["func_args"]}}); {%- endmacro %} diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h index 93afa69940..fa0d3d640d 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/manual/backward/composite_backward_api.h @@ -27,5 +27,114 @@ void tanh_grad(const Tensor& out, const Tensor& grad_out, Tensor* grad_x) { auto grad_x_tmp = multiply(grad_out, tmp); grad_x->set_impl(grad_x_tmp.impl()); } +template +void subtract_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); + if (phi::product(x.dims()) > phi::product(y.dims())) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto dy_reduce_res = + sum(scale_out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } else { + by_pass(scale_out_grad, dy); + } + } + if (dx) { + if (phi::product(y.dims()) > phi::product(x.dims())) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } else { + by_pass(out_grad, dx); + } + } +} + +template +void add_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + if (phi::product(x.dims()) > phi::product(y.dims())) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto dy_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } else { + by_pass(out_grad, dy); + } + } + if (dx) { + if (phi::product(y.dims()) > phi::product(x.dims())) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); + auto dx_reduce_res = + sum(out_grad, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } else { + by_pass(out_grad, dx); + } + } +} + +template +void divide_grad(const Tensor& x, + const Tensor& y, + const Tensor& out, + const Tensor& out_grad, + int axis, + Tensor* dx, + Tensor* dy) { + if (dy) { + // dy = -(x/y^2) * dout + auto tmp0 = pow(y, 2.0); + auto tmp1 = divide(x, tmp0); + auto tmp2 = scale(tmp1, -1.0, 0.0, true); + auto dy_res = multiply(tmp2, out_grad); + if (phi::product(x.dims()) > phi::product(y.dims())) { + // Maybe need reduce here + phi::DDim reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto dy_reduce_res = + sum(dy_res, phi::vectorize(reduce_dim), y.dtype(), false); + auto dy_tmp = reshape(dy_reduce_res, phi::vectorize(y.dims())); + dy->set_impl(dy_tmp.impl()); + } else { + dy->set_impl(dy_res.impl()); + } + } // indicate we will compute dy + if (dx) { + // dx = (1/y) * dout + auto one_tensor = full(phi::vectorize(y.dims()), 1.0); + auto tmp0 = divide(one_tensor, y); + auto dx_res = multiply(tmp0, out_grad); + if (phi::product(y.dims()) > phi::product(x.dims())) { + // Maybe need reduce here + auto reduce_dim = get_reduce_dims(y.dims(), x.dims()); + auto dx_reduce_res = + sum(dx_res, phi::vectorize(reduce_dim), x.dtype(), false); + auto dx_tmp = reshape(dx_reduce_res, phi::vectorize(x.dims())); + dx->set_impl(dx_tmp.impl()); + } else { + dx->set_impl(dx_res.impl()); + } + } // indicate we will compute dx +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc index e9b666974a..be123ecde7 100644 --- a/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/eager_prim_api.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" - +#include "paddle/phi/capi/include/wrapper_base.h" namespace paddle { namespace prim { template <> @@ -35,5 +35,27 @@ template <> Tensor multiply(const Tensor& x, const Tensor& y) { return ::multiply_ad_func(x, y); } + +template <> +Tensor divide(const Tensor& x, const Tensor& y) { + return ::divide_ad_func(x, y); +} + +template <> +Tensor full(paddle::experimental::IntArray shape, + paddle::experimental::Scalar value, + paddle::experimental::DataType dtype, + paddle::platform::Place place) { + return ::full_ad_func(shape, value, dtype, place); +} +template <> +Tensor sum(Tensor x, IntArray axis, DataType dtype, bool keepdim) { + return ::sum_ad_func(x, axis, dtype, keepdim); +} + +template <> +Tensor reshape(Tensor x, IntArray shape) { + return ::reshape_ad_func(x, shape); +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/prim_api.h b/paddle/fluid/prim/api/manual/prim_api/prim_api.h index 5809c15730..8de90919c5 100644 --- a/paddle/fluid/prim/api/manual/prim_api/prim_api.h +++ b/paddle/fluid/prim/api/manual/prim_api/prim_api.h @@ -13,12 +13,14 @@ // limitations under the License. #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/optional.h" namespace paddle { namespace prim { using Tensor = paddle::experimental::Tensor; - +using IntArray = paddle::experimental::IntArray; +using Scalar = paddle::experimental::Scalar; template Tensor pow(const Tensor& x, const paddle::experimental::Scalar& y); @@ -31,5 +33,22 @@ Tensor scale(const Tensor& X, template Tensor multiply(const Tensor& x, const Tensor& y); +template +Tensor divide(const Tensor& x, const Tensor& y); + +template +Tensor full(IntArray shape, + Scalar value, + DataType dtype = DataType::FLOAT32, + Place place = CPUPlace()); + +template +Tensor sum(Tensor x, + IntArray axis = {}, + DataType dtype = DataType::UNDEFINED, + bool keepdim = false); + +template +Tensor reshape(Tensor x, IntArray shape); } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc index 80e07e9177..bd06b1f503 100644 --- a/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual/prim_api/static_prim_api.cc @@ -30,6 +30,9 @@ #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" namespace paddle { namespace prim { @@ -91,5 +94,100 @@ Tensor multiply(const Tensor& x, const Tensor& y) { return out; } +template <> +Tensor divide(const Tensor& x, const Tensor& y) { + // Grad infershape + Tensor out = empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("elementwise_div"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetInput("Y", + {std::static_pointer_cast(y.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} + +template <> +Tensor full(paddle::experimental::IntArray shape, + paddle::experimental::Scalar value, + paddle::experimental::DataType dtype, + paddle::platform::Place place) { + // Grad infershape + Tensor out = empty({}, dtype, place); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("fill_constant"); + op->SetAttr("shape", shape.GetData()); + PADDLE_ENFORCE_EQ( + ((dtype == paddle::experimental::DataType::FLOAT32) || + (dtype == paddle::experimental::DataType::FLOAT16)), + true, + phi::errors::InvalidArgument( + "We only support float32/float16 for full, but we got data type: %s", + phi::DataTypeToString(dtype))); + op->SetAttr("value", value.to()); + op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype)); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + return out; +} +template <> +Tensor sum(Tensor x, + paddle::experimental::IntArray axis, + paddle::experimental::DataType dtype, + bool keepdim) { + // Grad infershape + Tensor out = empty({}, dtype, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("reduce_sum"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + std::vector res; + for (auto value : axis.GetData()) { + res.push_back(static_cast(value)); + } + op->SetAttr("dim", res); + op->SetAttr("keep_dim", keepdim); + op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype)); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + // TODO(jiabin): This may have runtime shape skip infershape for now. + return out; +} + +template <> +Tensor reshape(Tensor x, paddle::experimental::IntArray shape) { + // Grad infershape + Tensor out = empty({}, x.dtype(), paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("reshape"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + std::vector res; + for (auto value : shape.GetData()) { + // TODO(jiabin): This cast is not safe for now, find a way to handle this. + res.push_back(static_cast(value)); + } + op->SetAttr("shape", res); + op->SetOutput( + "Out", {std::static_pointer_cast(out.impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + // TODO(jiabin): This may have runtime shape skip infershape for now. + return out; +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/eager_utils.cc b/paddle/fluid/prim/api/manual/utils/eager_utils.cc index 99c2ae4821..96d0b4ea1f 100644 --- a/paddle/fluid/prim/api/manual/utils/eager_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/eager_utils.cc @@ -38,6 +38,9 @@ Tensor empty_like(const paddle::experimental::Tensor& x, } return empty_like_ad_func(x, dtype, place); } - +template <> +void by_pass(const paddle::experimental::Tensor& x, Tensor* out) { + out->set_impl(x.impl()); +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/static_utils.cc b/paddle/fluid/prim/api/manual/utils/static_utils.cc index 98495cf1e3..c90cfdb34f 100644 --- a/paddle/fluid/prim/api/manual/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual/utils/static_utils.cc @@ -47,5 +47,23 @@ Tensor empty_like(const Tensor& x, paddle::experimental::IntArray(x.shape()), x.dtype(), paddle::Place()); } +template <> +void by_pass(const paddle::experimental::Tensor& x, + paddle::experimental::Tensor* out) { + Tensor new_out = + empty({}, phi::DataType::FLOAT32, paddle::Place()); + framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock(); + framework::OpDesc* op = block->AppendOp(); + op->SetType("assign"); + op->SetInput("X", + {std::static_pointer_cast(x.impl())->Name()}); + op->SetOutput( + "Out", {std::static_pointer_cast(out->impl())->Name()}); + op->CheckAttrs(); + op->InferVarType(block); + op->InferShape(*block); + out->set_impl(new_out.impl()); +} + } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual/utils/utils.h index e69fee8634..2a77d0cffd 100644 --- a/paddle/fluid/prim/api/manual/utils/utils.h +++ b/paddle/fluid/prim/api/manual/utils/utils.h @@ -19,6 +19,8 @@ #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/ddim.h" +using IntArray = paddle::experimental::IntArray; namespace paddle { namespace prim { // We put some api like utils here @@ -31,6 +33,46 @@ template paddle::experimental::Tensor empty_like(const paddle::experimental::Tensor& x, paddle::experimental::DataType dtype, const paddle::Place& place); - +template +void by_pass(const paddle::experimental::Tensor& x, + paddle::experimental::Tensor* out); +// These method don't need to be specified +static phi::DDim get_reduce_dims(const phi::DDim& x_dims, + const phi::DDim& y_dims) { + std::vector result; + PADDLE_ENFORCE_GE(phi::product(x_dims), + phi::product(y_dims), + phi::errors::InvalidArgument( + "Only x_dims >= y_dims is accepted for " + "get_reduce_dims, but we got x_dims: %s, y_dims: %s", + x_dims, + y_dims)); + int bat = x_dims.size() - y_dims.size(); + for (int i = 0; i < bat; ++i) { + result.push_back(i); + } + for (int i = 0; i < y_dims.size(); ++i) { + if (y_dims[i] == 1) { + result.push_back(i + bat); + } else { + PADDLE_ENFORCE_EQ( + y_dims[i], + x_dims[i + bat], + platform::errors::InvalidArgument( + "ReduceDims dimension mismatch. Operands could " + "not be broadcast together with the shape of x_dims = [%s] and " + "the shape of y_dims = [%s]. Received [%d] in X is not equal to " + "[%d] in Y at i:%d.", + x_dims, + y_dims, + x_dims[i + bat], + y_dims[i], + i)); + } + } + auto res_dims = phi::make_ddim(result); + VLOG(4) << "Reduce Dims is: " << res_dims; + return res_dims; +} } // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/tests/CMakeLists.txt b/paddle/fluid/prim/tests/CMakeLists.txt index 3b2e24358c..92b34352a7 100644 --- a/paddle/fluid/prim/tests/CMakeLists.txt +++ b/paddle/fluid/prim/tests/CMakeLists.txt @@ -19,7 +19,7 @@ set(prim_generated_deps final_dygraph_function final_dygraph_node dygraph_function dygraph_node) cc_test_old( - test_static_prim + test_comp_static SRCS test_static_prim.cc DEPS @@ -37,7 +37,7 @@ cc_test_old( if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_test_old( - test_eager_prim + test_comp_eager SRCS test_eager_prim.cc DEPS diff --git a/paddle/fluid/prim/utils/static/desc_tensor.h b/paddle/fluid/prim/utils/static/desc_tensor.h index 746228ae3d..0b43ae9c60 100644 --- a/paddle/fluid/prim/utils/static/desc_tensor.h +++ b/paddle/fluid/prim/utils/static/desc_tensor.h @@ -45,6 +45,8 @@ class DescTensor : public phi::ExtendedTensor, framework::VarDesc* get_ptr() { return desc_ptr_; } + const phi::Place& place() const override { return place_; } + // TODO(jiabin): override more operators here. private: @@ -55,6 +57,7 @@ class DescTensor : public phi::ExtendedTensor, // we can inherient from ExtendedTensor Rmove this when we make VarDesc's as // same as Tensor, or make Tensor's dims more lightly. mutable phi::DDim dims_; + phi::Place place_; }; } // namespace prim diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index c57e7d434e..2050653904 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -42,6 +42,7 @@ kernel : func : add_grad no_need_buffer : x, y + composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis) backward : add_double_grad inplace : (out_grad -> x_grad) @@ -375,6 +376,7 @@ param : [x, y] kernel : func : divide_grad + composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1) backward : divide_double_grad - backward_op : dropout_grad @@ -1325,6 +1327,7 @@ kernel : func : subtract_grad no_need_buffer : x, y + composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis) backward : subtract_double_grad inplace : (out_grad -> x_grad) diff --git a/paddle/phi/core/extended_tensor.h b/paddle/phi/core/extended_tensor.h index 66c4987fb4..d02dbabde1 100644 --- a/paddle/phi/core/extended_tensor.h +++ b/paddle/phi/core/extended_tensor.h @@ -14,10 +14,10 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/place.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" - namespace phi { /// \brief The ExtendedTensor is a interface for custom designed class. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py index 88054e689f..4ac7a3dfe4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_bert.py @@ -24,6 +24,7 @@ from predictor_utils import PredictorTools import paddle import paddle.fluid as fluid +from paddle.fluid import core from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX place = ( @@ -234,6 +235,18 @@ class TestBert(unittest.TestCase): self.verify_predict() + def test_train_composite(self): + core.set_prim_enabled(True) + static_loss, static_ppl = self.train_static( + self.bert_config, self.data_reader + ) + core.set_prim_enabled(False) + dygraph_loss, dygraph_ppl = self.train_dygraph( + self.bert_config, self.data_reader + ) + np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05) + np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05) + def verify_predict(self): for data in self.data_reader.data_generator()(): dygraph_pred_res = self.predict_dygraph(self.bert_config, data) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py index 7972904d80..40919edbce 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet.py @@ -23,6 +23,7 @@ from predictor_utils import PredictorTools import paddle import paddle.fluid as fluid +from paddle.fluid import core from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.nn import BatchNorm @@ -425,6 +426,21 @@ class TestResnet(unittest.TestCase): ) self.verify_predict() + def test_resnet_composite(self): + core.set_prim_enabled(True) + static_loss = self.train(to_static=True) + core.set_prim_enabled(False) + dygraph_loss = self.train(to_static=True) + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-05, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) + core.set_prim_enabled(False) + def test_in_static_mode_mkldnn(self): fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py index 8b91b41895..8e6872c079 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_amp.py @@ -20,6 +20,7 @@ from test_resnet import SEED, ResNet, optimizer_setting import paddle import paddle.fluid as fluid +from paddle.fluid import core # NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout. batch_size = 2 @@ -128,6 +129,20 @@ class TestResnet(unittest.TestCase): ), ) + def test_resnet_composite(self): + core.set_prim_enabled(True) + static_loss = self.train(to_static=True) + core.set_prim_enabled(False) + dygraph_loss = self.train(to_static=False) + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-05, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py index 1f9c1d1104..6213f6fae2 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_pure_fp16.py @@ -20,6 +20,7 @@ from test_resnet import SEED, ResNet, optimizer_setting import paddle import paddle.fluid as fluid +from paddle.fluid import core # NOTE: Reduce batch_size from 8 to 2 to avoid unittest timeout. batch_size = 2 @@ -134,6 +135,23 @@ class TestResnet(unittest.TestCase): ), ) + def test_resnet_composite(self): + if fluid.is_compiled_with_cuda(): + core.set_prim_enabled(True) + static_loss = self.train(to_static=True) + core.set_prim_enabled(False) + dygraph_loss = self.train(to_static=False) + # NOTE: In pure fp16 training, loss is not stable, so we enlarge atol here. + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-05, + atol=0.001, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py index 45004d42e2..1b4d01114f 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py @@ -22,6 +22,7 @@ import numpy as np from predictor_utils import PredictorTools import paddle +from paddle.fluid import core SEED = 2020 IMAGENET1000 = 1281167 @@ -424,6 +425,20 @@ class TestResnet(unittest.TestCase): ) self.verify_predict() + def test_resnet_composite(self): + core.set_prim_enabled(True) + static_loss = self.train(to_static=True) + core.set_prim_enabled(False) + dygraph_loss = self.train(to_static=False) + np.testing.assert_allclose( + static_loss, + dygraph_loss, + rtol=1e-05, + err_msg='static_loss: {} \n dygraph_loss: {}'.format( + static_loss, dygraph_loss + ), + ) + def test_in_static_mode_mkldnn(self): paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt index 8a6063c08b..fc8d0234d5 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/CMakeLists.txt @@ -9,4 +9,7 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -set_tests_properties(test_eager_tanh_grad_comp PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_tanh_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_div_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_add_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_eager_sub_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py new file mode 100644 index 0000000000..6894e9058c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_add_grad.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + +core.set_prim_enabled(True) + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestTanhGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.add(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + def desired(primal0, primal1): + core.set_prim_enabled(False) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.add(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py new file mode 100644 index 0000000000..5452d2bfcb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_div_grad.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + +core.set_prim_enabled(True) + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestTanhGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.divide(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + def desired(primal0, primal1): + core.set_prim_enabled(False) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.divide(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py new file mode 100644 index 0000000000..3c273afbe3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_sub_grad.py @@ -0,0 +1,101 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + +core.set_prim_enabled(True) + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestTanhGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.subtract(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + def desired(primal0, primal1): + core.set_prim_enabled(False) + paddle.disable_static() + x = paddle.to_tensor(primal0, dtype='float32', stop_gradient=False) + y = paddle.to_tensor(primal1, dtype='float32', stop_gradient=False) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.subtract(x, y) + res = paddle.grad(out, [x, y], create_graph=True, retain_graph=True) + return res[0].numpy(), res[1].numpy() + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_eager_tanh_grad_comp.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py similarity index 100% rename from python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_eager_tanh_grad_comp.py rename to python/paddle/fluid/tests/unittests/prim/prim/vjp/eager/test_comp_eager_tanh_grad.py diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt index e6094bb8af..58375c1696 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/CMakeLists.txt @@ -9,4 +9,8 @@ foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP} ENVS ${GC_ENVS}) endforeach() -set_tests_properties(test_tanh_grad_comp PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_tanh_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_div_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_add_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_sub_grad PROPERTIES TIMEOUT 60) +set_tests_properties(test_comp_add_tanh_grad PROPERTIES TIMEOUT 60) diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py new file mode 100644 index 0000000000..783560c2e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_grad.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestDivGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + y = paddle.static.data('primal1', primal1.shape, primal1.dtype) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.add(x, y) + out = paddle.tanh(z) + res = paddle.static.gradients([out], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'primal1': primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + def desired(primal0, primal1): + core.set_prim_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data( + 'primal0', self.primal0.shape, self.primal0.dtype + ) + y = paddle.static.data( + 'primal1', self.primal1.shape, self.primal1.dtype + ) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.add(x, y) + out = paddle.tanh(z) + res = paddle.static.gradients([out], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': self.primal0, + 'primal1': self.primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py new file mode 100644 index 0000000000..1e75348451 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_add_tanh_grad.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestDivGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + y = paddle.static.data('primal1', primal1.shape, primal1.dtype) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.add(x, y) + res = paddle.static.gradients([z], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'primal1': primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + def desired(primal0, primal1): + core.set_prim_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data( + 'primal0', self.primal0.shape, self.primal0.dtype + ) + y = paddle.static.data( + 'primal1', self.primal1.shape, self.primal1.dtype + ) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.add(x, y) + res = paddle.static.gradients([z], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': self.primal0, + 'primal1': self.primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py new file mode 100644 index 0000000000..94b0742207 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_div_grad.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestDivGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + y = paddle.static.data('primal1', primal1.shape, primal1.dtype) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.divide(x, y) + res = paddle.static.gradients([z], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'primal1': primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + def desired(primal0, primal1): + core.set_prim_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data( + 'primal0', self.primal0.shape, self.primal0.dtype + ) + y = paddle.static.data( + 'primal1', self.primal1.shape, self.primal1.dtype + ) + x.stop_gradient = False + y.stop_gradient = False + z = paddle.divide(x, y) + res = paddle.static.gradients([z], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': self.primal0, + 'primal1': self.primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py new file mode 100644 index 0000000000..3f46b5315f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_sub_grad.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterized as param + +import paddle +from paddle.fluid import core + + +@param.parameterized_class( + ('primal0', 'primal1', 'dtype'), + [ + ( + np.random.rand(2, 3, 4), + np.random.rand(2, 3, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(3, 1, 4), + np.float32, + ), + ( + np.random.rand(2, 3, 3, 4), + np.random.rand(2, 3, 1, 4), + np.float32, + ), + ], +) +class TestDivGradComp(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.primal0 = cls.primal0.astype(cls.dtype) + cls.primal1 = cls.primal1.astype(cls.dtype) + + def setUp(self): + paddle.enable_static() + + def tearDown(self): + paddle.disable_static() + + def test_tanh_grad_comp(self): + def actual(primal0, primal1): + core.set_prim_enabled(True) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data('primal0', primal0.shape, primal0.dtype) + y = paddle.static.data('primal1', primal1.shape, primal1.dtype) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.subtract(x, y) + res = paddle.static.gradients([out], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': primal0, + 'primal1': primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + def desired(primal0, primal1): + core.set_prim_enabled(False) + mp, sp = paddle.static.Program(), paddle.static.Program() + with paddle.static.program_guard(mp, sp): + x = paddle.static.data( + 'primal0', self.primal0.shape, self.primal0.dtype + ) + y = paddle.static.data( + 'primal1', self.primal1.shape, self.primal1.dtype + ) + x.stop_gradient = False + y.stop_gradient = False + out = paddle.subtract(x, y) + res = paddle.static.gradients([out], [x, y]) + exe = paddle.static.Executor() + exe.run(sp) + out = exe.run( + program=mp, + feed={ + 'primal0': self.primal0, + 'primal1': self.primal1, + }, + fetch_list=[res[0].name, res[1].name], + ) + return out[0], out[1] + + dx, dy = actual(self.primal0, self.primal1) + + ddx, ddy = desired(self.primal0, self.primal1) + + np.testing.assert_allclose( + actual=dx, + desired=ddx, + rtol=1e-6, + atol=0, + ) + np.testing.assert_allclose( + actual=dy, + desired=ddy, + rtol=1e-6, + atol=0, + ) + core.set_prim_enabled(False) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_tanh_grad_comp.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py similarity index 97% rename from python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_tanh_grad_comp.py rename to python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py index 3cfa6b876f..445b371b0a 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_tanh_grad_comp.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_tanh_grad.py @@ -60,7 +60,7 @@ class TestTanhGradComp(unittest.TestCase): return exe.run( program=mp, feed={'primal': primal, 'cotangent': cotangent}, - fetch_list=mp.blocks[0].ops[-1].output('Out')[0], + fetch_list=[x_cotangent[0].name], )[0] def desired(primal, cotangent): diff --git a/python/paddle/fluid/tests/unittests/prim/test_get_grad_op_desc_prim_disabled.py b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py similarity index 100% rename from python/paddle/fluid/tests/unittests/prim/test_get_grad_op_desc_prim_disabled.py rename to python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_disabled.py diff --git a/python/paddle/fluid/tests/unittests/prim/test_get_grad_op_desc_prim_enabled.py b/python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py similarity index 100% rename from python/paddle/fluid/tests/unittests/prim/test_get_grad_op_desc_prim_enabled.py rename to python/paddle/fluid/tests/unittests/prim/test_comp_get_grad_op_desc_prim_enabled.py -- GitLab