未验证 提交 9f7b027d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix activation grad op desc maker (#16715)

test=develop
上级 9bd44b94
...@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> { ...@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
} }
}; };
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
abs
acos
asin
atan
attention_lstm attention_lstm
brelu
conv_shift conv_shift
cos
cos_sim cos_sim
dequantize dequantize
elu
fc fc
flatten flatten
fsp fsp
...@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu ...@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc fusion_seqexpand_concat_fc
fusion_seqpool_concat fusion_seqpool_concat
fusion_squared_mat_sub fusion_squared_mat_sub
gelu
gru gru
hard_shrink
hierarchical_sigmoid hierarchical_sigmoid
leaky_relu
log
logsigmoid
lrn lrn
lstm_unit lstm_unit
lstmp lstmp
...@@ -38,7 +26,6 @@ modified_huber_loss ...@@ -38,7 +26,6 @@ modified_huber_loss
nce nce
pool2d pool2d
pool3d pool3d
pow
prelu prelu
quantize quantize
rank_loss rank_loss
...@@ -50,20 +37,10 @@ reduce_sum ...@@ -50,20 +37,10 @@ reduce_sum
requantize requantize
reshape reshape
rnn_memory_helper rnn_memory_helper
round
sequence_softmax sequence_softmax
sin
softplus
softshrink
softsign
spp spp
square
squeeze squeeze
stanh
swish
tanh_shrink
tensor_array_to_tensor tensor_array_to_tensor
thresholded_relu
transpose transpose
unpool unpool
unsqueeze unsqueeze
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
...@@ -82,6 +85,8 @@ template <typename T> ...@@ -82,6 +85,8 @@ template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { : CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -105,6 +112,8 @@ template <typename T> ...@@ -105,6 +112,8 @@ template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -116,6 +125,8 @@ template <typename T> ...@@ -116,6 +125,8 @@ template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename Functor> template <typename Functor>
...@@ -140,10 +151,13 @@ class CudnnActivationGradKernel ...@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public: public:
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut; const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr; X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr; framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -27,6 +29,25 @@ namespace operators { ...@@ -27,6 +29,25 @@ namespace operators {
using paddle::framework::Tensor; using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
std::unique_ptr<std::unordered_set<std::string>> ret(
new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
return ret;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \ class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
...@@ -50,27 +71,33 @@ using paddle::framework::Tensor; ...@@ -50,27 +71,33 @@ using paddle::framework::Tensor;
} \ } \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ template <ActBwdOpFwdDeps kDepValue>
class OP_NAME##GradMaker \ class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
: public ::paddle::framework::SingleGradOpDescMaker { \ public:
public: \ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\ protected:
protected: \ std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
auto* op = new ::paddle::framework::OpDesc(); \ op->SetType(ForwardOpType() + "_grad");
op->SetType(#KERNEL_TYPE "_grad"); \ op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Out", Output("Out")); \ op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetAttrMap(Attrs());
OutputGrad("Out")); \
\ if (static_cast<int>(kDepValue) &
op->SetAttrMap(Attrs()); \ static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
\ op->SetInput("X", Input("X"));
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \ }
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \ if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
} }
return op;
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
...@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("Out", framework::GradVarName("X")); auto out_grad_name = framework::GradVarName("Out");
ctx->ShareLoD("Out", framework::GradVarName("X")); ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out"); return GetKernelType(ctx, *this, framework::GradVarName("Out"));
} }
}; };
...@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); ...@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ #define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
__macro(Sigmoid, sigmoid); \ REGISTER_OPERATOR( \
__macro(Relu, relu); \ KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
__macro(Exp, exp); \ ops::ActivationOpInferVarType, \
__macro(Tanh, tanh); \ ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
__macro(Ceil, ceil); \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(Floor, floor); \ ::paddle::framework::SingleOpInplaceInToOut, \
__macro(Sqrt, sqrt); \ void>::type); \
__macro(SoftRelu, soft_relu); \ REGISTER_OPERATOR( \
__macro(Relu6, relu6); \ KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
__macro(Reciprocal, reciprocal); \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(HardSigmoid, hard_sigmoid); ::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
__macro(SoftShrink, softshrink); \ grad_functor) \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \ act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \ ops::functor<float>>, \
...@@ -643,6 +619,5 @@ namespace ops = paddle::operators; ...@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \ ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
act_type, \ act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \ ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
...@@ -30,4 +31,4 @@ namespace plat = paddle::platform; ...@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \ ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>); ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
...@@ -12,6 +12,7 @@ limitations under the License. */ ...@@ -12,6 +12,7 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include <glog/logging.h>
#include <algorithm> #include <algorithm>
#include <memory>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
...@@ -35,21 +36,29 @@ limitations under the License. */ ...@@ -35,21 +36,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
/* Use ugly global variable, for the using in python layer side enum ActBwdOpFwdDeps {
Please refer to the layer_helper.py and get the details. kNoDeps = 0x00, // Do not need any forward input/output
*/ kDepX = 0x01, // Only need forward input X
static std::unordered_set<std::string> InplaceOpSet = { kDepOut = 0x02, // Only need forward output Out
"sigmoid", "exp", "relu", "tanh", "sqrt", "ceil",
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid"}; // Never add kDepXOut, because Out can be always calculated
// by forward input X in backward part.
// FIXME(zjl): but in MKLDNN abs, X and Out are all needed...
// Developers should not rely on this enum value!
kDepXOut = 0x03
};
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet();
static bool IsInplace(const std::string& op) { static bool IsInplace(const std::string& op) {
bool inplace = InplaceOpSet.count(op); static auto InplaceOpSet = GetInplaceOpSet();
bool inplace = InplaceOpSet->count(op);
// for op_grad // for op_grad
const int kGradSuffixLen = 4; const int kGradSuffixLen = 4;
if (op.size() > kGradSuffixLen && if (op.size() > kGradSuffixLen &&
op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) { op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) {
inplace = inplace =
InplaceOpSet.count(op.substr(0, op.size() - (kGradSuffixLen + 1))); InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1)));
} }
return inplace; return inplace;
} }
...@@ -85,16 +94,21 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context, ...@@ -85,16 +94,21 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context,
context.op().Output("Out")); context.op().Output("Out"));
} }
template <ActBwdOpFwdDeps kDepValue>
inline void ExtractActivationGradTensor( inline void ExtractActivationGradTensor(
const framework::ExecutionContext& context, const framework::Tensor** X, const framework::ExecutionContext& context, const framework::Tensor** X,
const framework::Tensor** Out, const framework::Tensor** dOut, const framework::Tensor** Out, const framework::Tensor** dOut,
framework::Tensor** dX) { framework::Tensor** dX) {
auto out_var = context.InputVar("Out");
auto out_grad_var = context.InputVar(framework::GradVarName("Out")); auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto x_grad_var = context.OutputVar(framework::GradVarName("X")); auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
const framework::Variable* out_var = nullptr;
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
out_var = context.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr, PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s", "Cannot get input Variable Out, variable name = %s",
context.op().Input("Out")); context.op().Input("Out"));
}
PADDLE_ENFORCE(out_grad_var != nullptr, PADDLE_ENFORCE(out_grad_var != nullptr,
"Cannot get input Variable %s, variable name = %s", "Cannot get input Variable %s, variable name = %s",
framework::GradVarName("Out"), framework::GradVarName("Out"),
...@@ -105,23 +119,36 @@ inline void ExtractActivationGradTensor( ...@@ -105,23 +119,36 @@ inline void ExtractActivationGradTensor(
context.op().Output(framework::GradVarName("X"))); context.op().Output(framework::GradVarName("X")));
if (CanBeUsedBySelectedRows.count(context.op().Type())) { if (CanBeUsedBySelectedRows.count(context.op().Type())) {
*Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
*dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar( *dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
*out_grad_var); *out_grad_var);
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar( *dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
x_grad_var); x_grad_var);
if (out_var) {
*Out =
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
} else {
*Out = *dOut; // fake out
}
} else { } else {
*Out = context.Input<framework::Tensor>("Out"); *Out = context.Input<framework::Tensor>("Out");
*dOut = context.Input<framework::Tensor>(framework::GradVarName("Out")); *dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
*dX = context.Output<framework::Tensor>(framework::GradVarName("X")); *dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
if (out_var) {
*Out = &(out_var->Get<framework::LoDTensor>());
} else {
*Out = *dOut; // fake out
} }
}
PADDLE_ENFORCE(*dX != nullptr, PADDLE_ENFORCE(*dX != nullptr,
"Cannot get output tensor %s, variable name = %s", "Cannot get output tensor %s, variable name = %s",
framework::GradVarName("X"), framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X"))); context.op().Output(framework::GradVarName("X")));
bool inplace = IsInplace(context.op().Type()); if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (!inplace) {
auto x_var = context.InputVar("X"); auto x_var = context.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr, PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s", "Cannot get input tensor X, variable name = %s",
...@@ -172,7 +199,8 @@ class ActivationGradKernel ...@@ -172,7 +199,8 @@ class ActivationGradKernel
const framework::Tensor *X, *Out, *dOut; const framework::Tensor *X, *Out, *dOut;
framework::Tensor* dX = nullptr; framework::Tensor* dX = nullptr;
X = Out = dOut = nullptr; X = Out = dOut = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
...@@ -222,6 +250,8 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -222,6 +250,8 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out * (static_cast<T>(1) - out); dx.device(d) = dout * out * (static_cast<T>(1) - out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// Originally: logsigmoid(x) = -log (1 + exp(-x)) // Originally: logsigmoid(x) = -log (1 + exp(-x))
...@@ -258,6 +288,8 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -258,6 +288,8 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dx.device(d) =
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// exp(x) = e^x // exp(x) = e^x
...@@ -276,6 +308,8 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> { ...@@ -276,6 +308,8 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out; dx.device(d) = dout * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
...@@ -294,6 +328,8 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -294,6 +328,8 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>(); dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) // gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
...@@ -338,6 +374,8 @@ struct GeluGradFunctor : BaseActivationFunctor<T> { ...@@ -338,6 +374,8 @@ struct GeluGradFunctor : BaseActivationFunctor<T> {
(-static_cast<T>(0.5) * x.square()).exp(); (-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second); dx.device(d) = dout * (first + second);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
...@@ -356,6 +394,8 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -356,6 +394,8 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) - out * out); dx.device(d) = dout * (static_cast<T>(1) - out * out);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// tanhshrink(x) = x - tanh(x) // tanhshrink(x) = x - tanh(x)
...@@ -375,6 +415,8 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -375,6 +415,8 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh()); dx.device(d) = dout * (x.tanh() * x.tanh());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// tanhshrink(x) = x - tanh(x) // tanhshrink(x) = x - tanh(x)
...@@ -409,6 +451,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -409,6 +451,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval(); auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
...@@ -443,6 +487,8 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -443,6 +487,8 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x < -lambdaT).template cast<T>().eval(); auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
...@@ -461,6 +507,8 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -461,6 +507,8 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dout / out; dx.device(d) = static_cast<T>(0.5) * dout / out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// ceil(x) = ceiling(x) // ceil(x) = ceiling(x)
...@@ -479,6 +527,8 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -479,6 +527,8 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / out; dx.device(d) = static_cast<T>(0) / out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
}; };
// floor(x) = flooring(x) // floor(x) = flooring(x)
...@@ -522,6 +572,8 @@ struct CosGradFunctor : public BaseActivationFunctor<T> { ...@@ -522,6 +572,8 @@ struct CosGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = -dout * x.unaryExpr(Sine<T>()); dx.device(d) = -dout * x.unaryExpr(Sine<T>());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// cosine(x) = cos(x) // cosine(x) = cos(x)
...@@ -541,6 +593,8 @@ struct SinGradFunctor : public BaseActivationFunctor<T> { ...@@ -541,6 +593,8 @@ struct SinGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosine<T>()); dx.device(d) = dout * x.unaryExpr(Cosine<T>());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// sine(x) = sin(x) // sine(x) = sin(x)
...@@ -582,6 +636,8 @@ struct AcosGradFunctor : public BaseActivationFunctor<T> { ...@@ -582,6 +636,8 @@ struct AcosGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dx.device(d) =
-dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt(); -dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -614,6 +670,8 @@ struct AsinGradFunctor : public BaseActivationFunctor<T> { ...@@ -614,6 +670,8 @@ struct AsinGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dx.device(d) =
dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt(); dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -645,6 +703,8 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> { ...@@ -645,6 +703,8 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square()); dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// round(x) = [x] // round(x) = [x]
...@@ -672,6 +732,8 @@ struct AbsGradFunctor : public BaseActivationFunctor<T> { ...@@ -672,6 +732,8 @@ struct AbsGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign(); dx.device(d) = dout * x.sign();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; }
}; };
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
...@@ -690,6 +752,8 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> { ...@@ -690,6 +752,8 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out; dx.device(d) = dout * static_cast<T>(-1) * out * out;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// log(x) = natural logarithm of x // log(x) = natural logarithm of x
...@@ -708,6 +772,8 @@ struct LogGradFunctor : public BaseActivationFunctor<T> { ...@@ -708,6 +772,8 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x); dx.device(d) = dout * (static_cast<T>(1) / x);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// square(x) = x^2 // square(x) = x^2
...@@ -726,6 +792,8 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> { ...@@ -726,6 +792,8 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x; dx.device(d) = dout * static_cast<T>(2) * x;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -760,6 +828,8 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -760,6 +828,8 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max))) ((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>(); .template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// relu6(x) = min(max(0, x), 6) // relu6(x) = min(max(0, x), 6)
...@@ -792,6 +862,8 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -792,6 +862,8 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
((out > static_cast<T>(0)) * (out < static_cast<T>(threshold))) ((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
.template cast<T>(); .template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
// softplus(x) = log(1 + exp(x)) // softplus(x) = log(1 + exp(x))
...@@ -821,6 +893,8 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> { ...@@ -821,6 +893,8 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// softsign(x) = x / (1 + |x|) // softsign(x) = x / (1 + |x|)
...@@ -842,6 +916,8 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -842,6 +916,8 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square()); dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -872,6 +948,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -872,6 +948,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval(); auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp; dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -901,6 +979,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -901,6 +979,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval(); auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>(); dx.device(d) = dout * (temp1 + temp2).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -928,9 +1008,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -928,9 +1008,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() + dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * (out + static_cast<T>(alpha)) * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>(); (x < static_cast<T>(0)).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198 // FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
...@@ -958,6 +1040,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> { ...@@ -958,6 +1040,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(factor) * dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1)); x.pow(static_cast<T>(factor) - static_cast<T>(1));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -991,6 +1075,8 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -991,6 +1075,8 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
auto temp = (a * x).tanh() * (a * x).tanh(); auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp); dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1020,6 +1106,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1020,6 +1106,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
auto th = static_cast<T>(threshold); auto th = static_cast<T>(threshold);
dx.device(d) = dout * (x > th).template cast<T>(); dx.device(d) = dout * (x > th).template cast<T>();
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
template <typename T> template <typename T>
...@@ -1053,6 +1141,8 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -1053,6 +1141,8 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>() * .template cast<T>() *
static_cast<T>(slope); static_cast<T>(slope);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -1077,49 +1167,54 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -1077,49 +1167,54 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) / auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp()); (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto out = x * temp1;
auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out)); auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2); dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(relu, Relu, ReluFunctor, ReluGradFunctor); \
__macro(gelu, GeluFunctor, GeluGradFunctor); \ __macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, AtanFunctor, AtanGradFunctor); \ __macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \ __macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \ __macro(abs, Abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, CeilFunctor, ZeroGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, FloorFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, CosFunctor, CosGradFunctor); \ __macro(cos, Cos, CosFunctor, CosGradFunctor); \
__macro(acos, AcosFunctor, AcosGradFunctor); \ __macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, SinFunctor, SinGradFunctor); \ __macro(sin, Sin, SinFunctor, SinGradFunctor); \
__macro(asin, AsinFunctor, AsinGradFunctor); \ __macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \ __macro(log, Log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \ __macro(square, Square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \ __macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \ __macro(pow, Pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor); \ __macro(elu, ELU, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
__macro(swish, SwishFunctor, SwishGradFunctor); \ HardSigmoidGradFunctor); \
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
ThresholdedReluGradFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册