From 29c28e2f0cc9f796e857aee18a7a1649a182df6a Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Thu, 6 Apr 2023 14:45:41 +0800 Subject: [PATCH] support more custom vjp (#52533) --- paddle/fluid/operators/activation_op.cc | 47 +++++++++++++++++-- .../composite_backward_api.h | 36 ++++++++++++++ .../utils/static/composite_grad_desc_maker.h | 45 +++++------------- paddle/phi/api/yaml/backward.yaml | 2 + 4 files changed, 91 insertions(+), 39 deletions(-) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index c7cea3122fe..23f1bb7bdd3 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -23,10 +23,12 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/common_infer_shape_functions.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" +#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" +#include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/backward.h" - DECLARE_bool(use_mkldnn); namespace paddle { @@ -80,6 +82,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker { } } }; +class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { + public: + using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase; + + protected: + void Apply() override { + paddle::Tensor x = this->GetSingleForwardInput("X"); + paddle::Tensor out_grad = this->GetSingleOutputGrad("Out"); + paddle::Tensor dx = this->GetSingleInputGrad("X"); + auto* dx_ptr = this->GetOutputPtr(&dx); + std::string dx_name = this->GetOutputName(dx); + VLOG(6) << "Runing hardswish_grad composite func"; + prim::hardswish_grad(x, out_grad, dx_ptr); + this->RecoverOutputName(dx, dx_name); + } +}; phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx, const framework::OperatorWithKernel& oper, @@ -400,6 +418,25 @@ namespace plat = paddle::platform; ops::ActivationOpGrad, \ ops::ActivationGradOpInplaceInferer); +#define REGISTER_ACTIVATION_OP_WITH_COMP( \ + KERNEL_TYPE, OP_NAME, functor, grad_functor) \ + REGISTER_OPERATOR( \ + KERNEL_TYPE, \ + ops::ActivationOp, \ + ops::OP_NAME##OpMaker, \ + ops::ActivationOpInferVarType, \ + ops::ActivationGradOpMaker::FwdDeps(), \ + paddle::framework::OpDesc>, \ + ops::ActivationGradOpMaker::FwdDeps(), \ + paddle::imperative::OpBase>, \ + ops::OP_NAME##CompositeGradOpMaker, \ + std::conditional>(), \ + ops::ActFwdInplaceInferer, \ + void>::type); \ + REGISTER_OPERATOR(KERNEL_TYPE##_grad, \ + ops::ActivationOpGrad, \ + ops::ActivationGradOpInplaceInferer); + FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name) \ @@ -416,10 +453,10 @@ REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); -REGISTER_ACTIVATION_OP(hard_swish, - HardSwish, - HardSwishFunctor, - HardSwishGradFunctor); +REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish, + HardSwish, + HardSwishFunctor, + HardSwishGradFunctor); REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); /* ========================== register checkpoint ===========================*/ diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 08fa2d0013f..0ac61cd4a3e 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -32,6 +32,42 @@ using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h +template +void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto offset = full(phi::vectorize(x.dims()), 3.0, x.dtype()); + auto condition = less_equal(x, offset); + auto tmp1 = where(condition, out_grad * ((x / 3.0) + 0.5), out_grad); + auto res = where( + less_than(x, full(phi::vectorize(x.dims()), -3.0, x.dtype())), + full(phi::vectorize(x.dims()), 0.0, x.dtype()), + tmp1); + set_output(res, x_grad); + } +} + +template +void leaky_relu_grad(const Tensor& out, + const Tensor& out_grad, + float negative_slope, + Tensor* x_grad) { + if (x_grad) { + auto condition = greater_than( + out, full(phi::vectorize(out.dims()), 0.0, out.dtype())); + auto res = where(condition, out_grad, out_grad * negative_slope); + set_output(res, x_grad); + } +} + +template +void silu_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { + if (x_grad) { + auto sigmoid = 1.0 / (1.0 + exp(-x)); + auto res = out_grad * sigmoid * (1.0 + x * (1.0 - sigmoid)); + set_output(res, x_grad); + } +} + template void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index bbf8f45b29b..83b18814b19 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -376,7 +376,7 @@ class CompositeGradOpMakerBase { } std::vector MultiInputGrad( - const std::string& name, bool drop_empty_grad = true) const { + const std::string& name) const { std::vector ret_val; std::vector input_grads; auto var_names = this->MultiForwardInputVarName(name); @@ -393,39 +393,7 @@ class CompositeGradOpMakerBase { return framework::kEmptyVarName; } }); - if (!drop_empty_grad) { - for (const auto& name : ret_val) { - if (original_block_->HasVar(name)) { - // Copy Var from original block to active block, or create a new one. - CopyVarFromOrig(name); - input_grads.emplace_back( - StaticCompositeContext::Instance().GetBlock()->FindVar(name)); - } else { - input_grads.emplace_back( - StaticCompositeContext::Instance().GetBlock()->Var(name)); - } - } - return input_grads; - } - PADDLE_ENFORCE_LE( - var_names.size(), - 1UL, - platform::errors::Unavailable( - "BUG from operator developer:" - " for input argument with a list of variables, " - " drop_empty_grad is not allowed because it makes" - " the correspondence bewteen a variable and its gradient" - " ambiguous.")); - - std::vector dropped_ret_val; - dropped_ret_val.reserve(ret_val.size()); - std::copy_if( - ret_val.begin(), - ret_val.end(), - std::back_inserter(dropped_ret_val), - [](const std::string& str) { return str != framework::kEmptyVarName; }); - for (const auto& name : dropped_ret_val) { - // TODO(jiabin): Will this cause fill zeros error? + for (const auto& name : ret_val) { if (original_block_->HasVar(name)) { // Copy Var from original block to active block, or create a new one. CopyVarFromOrig(name); @@ -437,6 +405,15 @@ class CompositeGradOpMakerBase { } } return input_grads; + PADDLE_ENFORCE_LE( + var_names.size(), + 1UL, + platform::errors::Unavailable( + "BUG from operator developer:" + " for input argument with a list of variables, " + " drop_empty_grad is not allowed because it makes" + " the correspondence bewteen a variable and its gradient" + " ambiguous.")); } std::vector MultiOutputGrad( diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 3311d2979e4..3ebb72f4fdb 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -829,6 +829,7 @@ kernel : func : leaky_relu_grad backward : leaky_relu_double_grad + composite: leaky_relu_grad(x, out_grad, negative_slope, x_grad) inplace : (out_grad -> x_grad) - backward_op : lerp_grad @@ -1461,6 +1462,7 @@ param : [x] kernel : func : silu_grad + composite : silu_grad(x, out_grad, x_grad) inplace : (out_grad -> x_grad) - backward_op : sin_double_grad -- GitLab