From 9f7b027dce88e1925d0c0cccd41d05bccb54e840 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 9 Apr 2019 21:33:57 -0500 Subject: [PATCH] fix activation grad op desc maker (#16715) test=develop --- paddle/fluid/framework/details/op_registry.h | 6 + .../fluid/op_use_default_grad_op_maker.spec | 23 -- .../fluid/operators/activation_cudnn_op.cu.cc | 16 +- paddle/fluid/operators/activation_op.cc | 161 ++++++-------- paddle/fluid/operators/activation_op.cu | 5 +- paddle/fluid/operators/activation_op.h | 201 +++++++++++++----- 6 files changed, 240 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index a9a4fb08a2..18de595983 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -233,6 +233,12 @@ struct OpInfoFiller { } }; +// A fake OpInfoFiller of void +template <> +struct OpInfoFiller { + void operator()(const char* op_type, OpInfo* info) const {} +}; + } // namespace details } // namespace framework diff --git a/paddle/fluid/op_use_default_grad_op_maker.spec b/paddle/fluid/op_use_default_grad_op_maker.spec index 4e833dd144..21a25ce7d5 100644 --- a/paddle/fluid/op_use_default_grad_op_maker.spec +++ b/paddle/fluid/op_use_default_grad_op_maker.spec @@ -1,14 +1,7 @@ -abs -acos -asin -atan attention_lstm -brelu conv_shift -cos cos_sim dequantize -elu fc flatten fsp @@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu fusion_seqexpand_concat_fc fusion_seqpool_concat fusion_squared_mat_sub -gelu gru -hard_shrink hierarchical_sigmoid -leaky_relu -log -logsigmoid lrn lstm_unit lstmp @@ -38,7 +26,6 @@ modified_huber_loss nce pool2d pool3d -pow prelu quantize rank_loss @@ -50,20 +37,10 @@ reduce_sum requantize reshape rnn_memory_helper -round sequence_softmax -sin -softplus -softshrink -softsign spp -square squeeze -stanh -swish -tanh_shrink tensor_array_to_tensor -thresholded_relu transpose unpool unsqueeze diff --git a/paddle/fluid/operators/activation_cudnn_op.cu.cc b/paddle/fluid/operators/activation_cudnn_op.cu.cc index a382414d5c..f03355eb44 100644 --- a/paddle/fluid/operators/activation_cudnn_op.cu.cc +++ b/paddle/fluid/operators/activation_cudnn_op.cu.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/platform/cudnn_desc.h" @@ -82,6 +85,8 @@ template struct CudnnReluGradFunctor : public CudnnActivationGradFunctor { explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor { explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -105,6 +112,8 @@ template struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor { explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -116,6 +125,8 @@ template struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor { explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) : CudnnActivationGradFunctor(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -140,10 +151,13 @@ class CudnnActivationGradKernel public: using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { + static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out."); + const framework::Tensor *X, *Out, *dOut; X = Out = dOut = nullptr; framework::Tensor* dX = nullptr; - ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); + ExtractActivationGradTensor(context, &X, &Out, &dOut, + &dX); dX->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); Functor functor(dev_ctx); diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index c87e4b22b3..1e5d63fc11 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -15,7 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/activation_op.h" #include #include +#include #include +#include #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/platform/port.h" #ifdef PADDLE_WITH_CUDA @@ -27,6 +29,25 @@ namespace operators { using paddle::framework::Tensor; +template +static constexpr bool CanInplaceAct() { + return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps; +} + +std::unique_ptr> GetInplaceOpSet() { + std::unique_ptr> ret( + new std::unordered_set()); +#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \ + bwd_functor) \ + if (CanInplaceAct>()) { \ + ret->insert(#op_type); \ + } + + FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET); +#undef INSERT_INTO_INPLACE_OP_SET + return ret; +} + #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ class OP_NAME##OpMaker \ : public ::paddle::framework::OpProtoAndCheckerMaker { \ @@ -50,26 +71,32 @@ using paddle::framework::Tensor; } \ } -#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ - class OP_NAME##GradMaker \ - : public ::paddle::framework::SingleGradOpDescMaker { \ - public: \ - using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \ - \ - protected: \ - std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ - auto* op = new ::paddle::framework::OpDesc(); \ - op->SetType(#KERNEL_TYPE "_grad"); \ - op->SetInput("Out", Output("Out")); \ - op->SetInput(::paddle::framework::GradVarName("Out"), \ - OutputGrad("Out")); \ - \ - op->SetAttrMap(Attrs()); \ - \ - op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \ - return std::unique_ptr<::paddle::framework::OpDesc>(op); \ - } \ +template +class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType(ForwardOpType() + "_grad"); + op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepX)) { + op->SetInput("X", Input("X")); + } + + if (static_cast(kDepValue) & + static_cast(ActBwdOpFwdDeps::kDepOut)) { + op->SetInput("Out", Output("Out")); + } + + return op; } +}; framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, const framework::OperatorWithKernel& oper, @@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - ctx->ShareDim("Out", framework::GradVarName("X")); - ctx->ShareLoD("Out", framework::GradVarName("X")); + auto out_grad_name = framework::GradVarName("Out"); + ctx->ShareDim(out_grad_name, framework::GradVarName("X")); + ctx->ShareLoD(out_grad_name, framework::GradVarName("X")); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return GetKernelType(ctx, *this, "Out"); + return GetKernelType(ctx, *this, framework::GradVarName("Out")); } }; @@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); - -REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt); -REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6); -REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal); -REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid); } // namespace operators } // namespace paddle namespace ops = paddle::operators; -#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ - __macro(Sigmoid, sigmoid); \ - __macro(Relu, relu); \ - __macro(Exp, exp); \ - __macro(Tanh, tanh); \ - __macro(Ceil, ceil); \ - __macro(Floor, floor); \ - __macro(Sqrt, sqrt); \ - __macro(SoftRelu, soft_relu); \ - __macro(Relu6, relu6); \ - __macro(Reciprocal, reciprocal); \ - __macro(HardSigmoid, hard_sigmoid); - -#define FOR_EACH_OP_FUNCTOR(__macro) \ - __macro(LogSigmoid, logsigmoid); \ - __macro(SoftShrink, softshrink); \ - __macro(Abs, abs); \ - __macro(Cos, cos); \ - __macro(Acos, acos); \ - __macro(Sin, sin); \ - __macro(Asin, asin); \ - __macro(Atan, atan); \ - __macro(Round, round); \ - __macro(Log, log); \ - __macro(Square, square); \ - __macro(Gelu, gelu); \ - __macro(BRelu, brelu); \ - __macro(Pow, pow); \ - __macro(STanh, stanh); \ - __macro(Softplus, softplus); \ - __macro(Softsign, softsign); \ - __macro(LeakyRelu, leaky_relu); \ - __macro(TanhShrink, tanh_shrink); \ - __macro(ELU, elu); \ - __macro(HardShrink, hard_shrink); \ - __macro(Swish, swish); \ - __macro(ThresholdedRelu, thresholded_relu); - -#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::operators::ActivationOpInferVarType, \ - ::paddle::operators::OP_NAME##GradMaker, \ - ::paddle::framework::SingleOpInplaceInToOut); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \ - ::paddle::framework::SingleOpInplaceInToOut) - -#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ - REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ - ::paddle::operators::OP_NAME##OpMaker, \ - ::paddle::operators::ActivationOpInferVarType, \ - ::paddle::framework::DefaultGradOpDescMaker); \ - REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) - -#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ +#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \ + REGISTER_OPERATOR( \ + KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \ + ops::ActivationOpInferVarType, \ + ops::ActivationGradOpDescMaker::FwdDeps()>, \ + std::conditional>(), \ + ::paddle::framework::SingleOpInplaceInToOut, \ + void>::type); \ + REGISTER_OPERATOR( \ + KERNEL_TYPE##_grad, ops::ActivationOpGrad, \ + std::conditional>(), \ + ::paddle::framework::SingleOpInplaceInToOut, \ + void>::type) + +#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \ + grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, ops::ActivationKernel>, \ @@ -643,6 +619,5 @@ namespace ops = paddle::operators; ops::ActivationGradKernel>); -FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); -FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); -FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); +FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); +FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL); diff --git a/paddle/fluid/operators/activation_op.cu b/paddle/fluid/operators/activation_op.cu index d3a7ceed46..9c7a8d8971 100644 --- a/paddle/fluid/operators/activation_op.cu +++ b/paddle/fluid/operators/activation_op.cu @@ -15,7 +15,8 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; -#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \ +#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \ + grad_functor) \ REGISTER_OP_CUDA_KERNEL( \ act_type, \ ops::ActivationKernel>, \ @@ -30,4 +31,4 @@ namespace plat = paddle::platform; ops::ActivationGradKernel>); -FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); +FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index ff7e623f6f..915632a328 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -12,6 +12,7 @@ limitations under the License. */ #pragma once #include #include +#include #include #include #include @@ -35,21 +36,29 @@ limitations under the License. */ namespace paddle { namespace operators { -/* Use ugly global variable, for the using in python layer side - Please refer to the layer_helper.py and get the details. - */ -static std::unordered_set InplaceOpSet = { - "sigmoid", "exp", "relu", "tanh", "sqrt", "ceil", - "floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid"}; +enum ActBwdOpFwdDeps { + kNoDeps = 0x00, // Do not need any forward input/output + kDepX = 0x01, // Only need forward input X + kDepOut = 0x02, // Only need forward output Out + + // Never add kDepXOut, because Out can be always calculated + // by forward input X in backward part. + // FIXME(zjl): but in MKLDNN abs, X and Out are all needed... + // Developers should not rely on this enum value! + kDepXOut = 0x03 +}; + +std::unique_ptr> GetInplaceOpSet(); static bool IsInplace(const std::string& op) { - bool inplace = InplaceOpSet.count(op); + static auto InplaceOpSet = GetInplaceOpSet(); + bool inplace = InplaceOpSet->count(op); // for op_grad const int kGradSuffixLen = 4; if (op.size() > kGradSuffixLen && op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) { inplace = - InplaceOpSet.count(op.substr(0, op.size() - (kGradSuffixLen + 1))); + InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1))); } return inplace; } @@ -85,16 +94,21 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context, context.op().Output("Out")); } +template inline void ExtractActivationGradTensor( const framework::ExecutionContext& context, const framework::Tensor** X, const framework::Tensor** Out, const framework::Tensor** dOut, framework::Tensor** dX) { - auto out_var = context.InputVar("Out"); auto out_grad_var = context.InputVar(framework::GradVarName("Out")); auto x_grad_var = context.OutputVar(framework::GradVarName("X")); - PADDLE_ENFORCE(out_var != nullptr, - "Cannot get input Variable Out, variable name = %s", - context.op().Input("Out")); + const framework::Variable* out_var = nullptr; + + if (static_cast(kDepValue) & static_cast(kDepOut)) { + out_var = context.InputVar("Out"); + PADDLE_ENFORCE(out_var != nullptr, + "Cannot get input Variable Out, variable name = %s", + context.op().Input("Out")); + } PADDLE_ENFORCE(out_grad_var != nullptr, "Cannot get input Variable %s, variable name = %s", framework::GradVarName("Out"), @@ -105,23 +119,36 @@ inline void ExtractActivationGradTensor( context.op().Output(framework::GradVarName("X"))); if (CanBeUsedBySelectedRows.count(context.op().Type())) { - *Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *out_grad_var); *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( x_grad_var); + + if (out_var) { + *Out = + paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var); + } else { + *Out = *dOut; // fake out + } + } else { *Out = context.Input("Out"); *dOut = context.Input(framework::GradVarName("Out")); *dX = context.Output(framework::GradVarName("X")); + + if (out_var) { + *Out = &(out_var->Get()); + } else { + *Out = *dOut; // fake out + } } + PADDLE_ENFORCE(*dX != nullptr, "Cannot get output tensor %s, variable name = %s", framework::GradVarName("X"), context.op().Output(framework::GradVarName("X"))); - bool inplace = IsInplace(context.op().Type()); - if (!inplace) { + if (static_cast(kDepValue) & static_cast(kDepX)) { auto x_var = context.InputVar("X"); PADDLE_ENFORCE(x_var != nullptr, "Cannot get input tensor X, variable name = %s", @@ -172,7 +199,8 @@ class ActivationGradKernel const framework::Tensor *X, *Out, *dOut; framework::Tensor* dX = nullptr; X = Out = dOut = nullptr; - ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); + ExtractActivationGradTensor(context, &X, &Out, &dOut, + &dX); dX->mutable_data(context.GetPlace()); auto dout = framework::EigenVector::Flatten(detail::Ref(dOut)); auto out = framework::EigenVector::Flatten(detail::Ref(Out)); @@ -222,6 +250,8 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out * (static_cast(1) - out); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // Originally: logsigmoid(x) = -log (1 + exp(-x)) @@ -258,6 +288,8 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // exp(x) = e^x @@ -276,6 +308,8 @@ struct ExpGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * out; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // relu(x) = max(x, 0) @@ -294,6 +328,8 @@ struct ReluGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (out > static_cast(0)).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) @@ -338,6 +374,8 @@ struct GeluGradFunctor : BaseActivationFunctor { (-static_cast(0.5) * x.square()).exp(); dx.device(d) = dout * (first + second); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) @@ -356,6 +394,8 @@ struct TanhGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) - out * out); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // tanhshrink(x) = x - tanh(x) @@ -375,6 +415,8 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x.tanh() * x.tanh()); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // tanhshrink(x) = x - tanh(x) @@ -409,6 +451,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { auto temp2 = (x > static_cast(threshold)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 @@ -443,6 +487,8 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor { auto temp2 = (x < -lambdaT).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sqrt(x) = x^(1/2) @@ -461,6 +507,8 @@ struct SqrtGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0.5) * dout / out; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // ceil(x) = ceiling(x) @@ -479,6 +527,8 @@ struct ZeroGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) / out; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; } }; // floor(x) = flooring(x) @@ -522,6 +572,8 @@ struct CosGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = -dout * x.unaryExpr(Sine()); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // cosine(x) = cos(x) @@ -541,6 +593,8 @@ struct SinGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.unaryExpr(Cosine()); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // sine(x) = sin(x) @@ -582,6 +636,8 @@ struct AcosGradFunctor : public BaseActivationFunctor { dx.device(d) = -dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -614,6 +670,8 @@ struct AsinGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(1) / (static_cast(1) - x.square()).sqrt(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -645,6 +703,8 @@ struct AtanGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(1) / (static_cast(1) + x.square()); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // round(x) = [x] @@ -672,6 +732,8 @@ struct AbsGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * x.sign(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; } }; // reciprocal(x) = 1 / x @@ -690,6 +752,8 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(-1) * out * out; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // log(x) = natural logarithm of x @@ -708,6 +772,8 @@ struct LogGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (static_cast(1) / x); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // square(x) = x^2 @@ -726,6 +792,8 @@ struct SquareGradFunctor : public BaseActivationFunctor { void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * static_cast(2) * x; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -760,6 +828,8 @@ struct BReluGradFunctor : public BaseActivationFunctor { ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // relu6(x) = min(max(0, x), 6) @@ -792,6 +862,8 @@ struct Relu6GradFunctor : public BaseActivationFunctor { ((out > static_cast(0)) * (out < static_cast(threshold))) .template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; // softplus(x) = log(1 + exp(x)) @@ -821,6 +893,8 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // softsign(x) = x / (1 + |x|) @@ -842,6 +916,8 @@ struct SoftsignGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -872,6 +948,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { auto temp = ((out > -tmp) * (out < tmp)).template cast().eval(); dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -901,6 +979,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { auto temp2 = (x >= static_cast(0)).template cast().eval(); dx.device(d) = dout * (temp1 + temp2).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -928,9 +1008,11 @@ struct ELUGradFunctor : public BaseActivationFunctor { typename dX> void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = dout * (x > static_cast(0)).template cast() + - dout * (out + static_cast(alpha)) * + dout * static_cast(alpha) * x.exp() * (x < static_cast(0)).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 @@ -958,6 +1040,8 @@ struct PowGradFunctor : public BaseActivationFunctor { dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor) - static_cast(1)); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -991,6 +1075,8 @@ struct STanhGradFunctor : public BaseActivationFunctor { auto temp = (a * x).tanh() * (a * x).tanh(); dx.device(d) = dout * a * b * (static_cast(1) - temp); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -1020,6 +1106,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { auto th = static_cast(threshold); dx.device(d) = dout * (x > th).template cast(); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; template @@ -1053,6 +1141,8 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { .template cast() * static_cast(slope); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; } }; template @@ -1077,49 +1167,54 @@ struct SwishGradFunctor : public BaseActivationFunctor { template - void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); + auto out = x * temp1; auto temp2 = temp1 * (static_cast(1) - (static_cast(beta) * out)); dx.device(d) = dout * ((static_cast(beta) * out) + temp2); } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } }; } // namespace operators } // namespace paddle -#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ - __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ - __macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ - __macro(exp, ExpFunctor, ExpGradFunctor); \ - __macro(relu, ReluFunctor, ReluGradFunctor); \ - __macro(gelu, GeluFunctor, GeluGradFunctor); \ - __macro(tanh, TanhFunctor, TanhGradFunctor); \ - __macro(atan, AtanFunctor, AtanGradFunctor); \ - __macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ - __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ - __macro(abs, AbsFunctor, AbsGradFunctor); \ - __macro(ceil, CeilFunctor, ZeroGradFunctor); \ - __macro(floor, FloorFunctor, ZeroGradFunctor); \ - __macro(cos, CosFunctor, CosGradFunctor); \ - __macro(acos, AcosFunctor, AcosGradFunctor); \ - __macro(sin, SinFunctor, SinGradFunctor); \ - __macro(asin, AsinFunctor, AsinGradFunctor); \ - __macro(round, RoundFunctor, ZeroGradFunctor); \ - __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ - __macro(log, LogFunctor, LogGradFunctor); \ - __macro(square, SquareFunctor, SquareGradFunctor); \ - __macro(brelu, BReluFunctor, BReluGradFunctor); \ - __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ - __macro(pow, PowFunctor, PowGradFunctor); \ - __macro(stanh, STanhFunctor, STanhGradFunctor); \ - __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ - __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(relu6, Relu6Functor, Relu6GradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ - __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ - __macro(elu, ELUFunctor, ELUGradFunctor); \ - __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ - __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ - __macro(swish, SwishFunctor, SwishGradFunctor); \ - __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); +#define FOR_EACH_ACTIVATION_OP(__macro) \ + __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ + __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ + __macro(exp, Exp, ExpFunctor, ExpGradFunctor); \ + __macro(relu, Relu, ReluFunctor, ReluGradFunctor); \ + __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \ + __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \ + __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \ + __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ + __macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \ + __macro(abs, Abs, AbsFunctor, AbsGradFunctor); \ + __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ + __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ + __macro(cos, Cos, CosFunctor, CosGradFunctor); \ + __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \ + __macro(sin, Sin, SinFunctor, SinGradFunctor); \ + __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \ + __macro(round, Round, RoundFunctor, ZeroGradFunctor); \ + __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ + __macro(log, Log, LogFunctor, LogGradFunctor); \ + __macro(square, Square, SquareFunctor, SquareGradFunctor); \ + __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \ + __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ + __macro(pow, Pow, PowFunctor, PowGradFunctor); \ + __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ + __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ + __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ + __macro(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ + __macro(elu, ELU, ELUFunctor, ELUGradFunctor); \ + __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \ + __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ + HardSigmoidGradFunctor); \ + __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ + __macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \ + ThresholdedReluGradFunctor); -- GitLab