未验证 提交 9f7b027d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix activation grad op desc maker (#16715)

test=develop
上级 9bd44b94
......@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
}
};
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details
} // namespace framework
......
abs
acos
asin
atan
attention_lstm
brelu
conv_shift
cos
cos_sim
dequantize
elu
fc
flatten
fsp
......@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc
fusion_seqpool_concat
fusion_squared_mat_sub
gelu
gru
hard_shrink
hierarchical_sigmoid
leaky_relu
log
logsigmoid
lrn
lstm_unit
lstmp
......@@ -38,7 +26,6 @@ modified_huber_loss
nce
pool2d
pool3d
pow
prelu
quantize
rank_loss
......@@ -50,20 +37,10 @@ reduce_sum
requantize
reshape
rnn_memory_helper
round
sequence_softmax
sin
softplus
softshrink
softsign
spp
square
squeeze
stanh
swish
tanh_shrink
tensor_array_to_tensor
thresholded_relu
transpose
unpool
unsqueeze
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"
......@@ -82,6 +85,8 @@ template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -105,6 +112,8 @@ template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -116,6 +125,8 @@ template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename Functor>
......@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX);
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx);
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
......@@ -27,6 +29,25 @@ namespace operators {
using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
std::unique_ptr<std::unordered_set<std::string>> ret(
new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
return ret;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
......@@ -50,26 +71,32 @@ using paddle::framework::Tensor;
} \
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
class OP_NAME##GradMaker \
: public ::paddle::framework::SingleGradOpDescMaker { \
public: \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto* op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
\
op->SetAttrMap(Attrs()); \
\
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
op->SetInput("X", Input("X"));
}
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
}
return op;
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
......@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
auto out_grad_name = framework::GradVarName("Out");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
};
......@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \
__macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
REGISTER_OPERATOR( \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
......@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -15,7 +15,8 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
......@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
......@@ -12,6 +12,7 @@ limitations under the License. */
#pragma once
#include <glog/logging.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <utility>
......@@ -35,21 +36,29 @@ limitations under the License. */
namespace paddle {
namespace operators {
/* Use ugly global variable, for the using in python layer side
Please refer to the layer_helper.py and get the details.
*/
static std::unordered_set<std::string> InplaceOpSet = {
"sigmoid", "exp", "relu", "tanh", "sqrt", "ceil",
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid"};
enum ActBwdOpFwdDeps {
kNoDeps = 0x00, // Do not need any forward input/output
kDepX = 0x01, // Only need forward input X
kDepOut = 0x02, // Only need forward output Out
// Never add kDepXOut, because Out can be always calculated
// by forward input X in backward part.
// FIXME(zjl): but in MKLDNN abs, X and Out are all needed...
// Developers should not rely on this enum value!
kDepXOut = 0x03
};
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet();
static bool IsInplace(const std::string& op) {
bool inplace = InplaceOpSet.count(op);
static auto InplaceOpSet = GetInplaceOpSet();
bool inplace = InplaceOpSet->count(op);
// for op_grad
const int kGradSuffixLen = 4;
if (op.size() > kGradSuffixLen &&
op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) {
inplace =
InplaceOpSet.count(op.substr(0, op.size() - (kGradSuffixLen + 1)));
InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1)));
}
return inplace;
}
......@@ -85,16 +94,21 @@ inline void ExtractActivationTensor(const framework::ExecutionContext& context,
context.op().Output("Out"));
}
template <ActBwdOpFwdDeps kDepValue>
inline void ExtractActivationGradTensor(
const framework::ExecutionContext& context, const framework::Tensor** X,
const framework::Tensor** Out, const framework::Tensor** dOut,
framework::Tensor** dX) {
auto out_var = context.InputVar("Out");
auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
context.op().Input("Out"));
const framework::Variable* out_var = nullptr;
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
out_var = context.InputVar("Out");
PADDLE_ENFORCE(out_var != nullptr,
"Cannot get input Variable Out, variable name = %s",
context.op().Input("Out"));
}
PADDLE_ENFORCE(out_grad_var != nullptr,
"Cannot get input Variable %s, variable name = %s",
framework::GradVarName("Out"),
......@@ -105,23 +119,36 @@ inline void ExtractActivationGradTensor(
context.op().Output(framework::GradVarName("X")));
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
*Out = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
*dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
*out_grad_var);
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
x_grad_var);
if (out_var) {
*Out =
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
} else {
*Out = *dOut; // fake out
}
} else {
*Out = context.Input<framework::Tensor>("Out");
*dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
*dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
if (out_var) {
*Out = &(out_var->Get<framework::LoDTensor>());
} else {
*Out = *dOut; // fake out
}
}
PADDLE_ENFORCE(*dX != nullptr,
"Cannot get output tensor %s, variable name = %s",
framework::GradVarName("X"),
context.op().Output(framework::GradVarName("X")));
bool inplace = IsInplace(context.op().Type());
if (!inplace) {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
auto x_var = context.InputVar("X");
PADDLE_ENFORCE(x_var != nullptr,
"Cannot get input tensor X, variable name = %s",
......@@ -172,7 +199,8 @@ class ActivationGradKernel
const framework::Tensor *X, *Out, *dOut;
framework::Tensor* dX = nullptr;
X = Out = dOut = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX);
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
......@@ -222,6 +250,8 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out * (static_cast<T>(1) - out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// Originally: logsigmoid(x) = -log (1 + exp(-x))
......@@ -258,6 +288,8 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) =
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// exp(x) = e^x
......@@ -276,6 +308,8 @@ struct ExpGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// relu(x) = max(x, 0)
......@@ -294,6 +328,8 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
......@@ -338,6 +374,8 @@ struct GeluGradFunctor : BaseActivationFunctor<T> {
(-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
......@@ -356,6 +394,8 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) - out * out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// tanhshrink(x) = x - tanh(x)
......@@ -375,6 +415,8 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// tanhshrink(x) = x - tanh(x)
......@@ -409,6 +451,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
......@@ -443,6 +487,8 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sqrt(x) = x^(1/2)
......@@ -461,6 +507,8 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dout / out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// ceil(x) = ceiling(x)
......@@ -479,6 +527,8 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
};
// floor(x) = flooring(x)
......@@ -522,6 +572,8 @@ struct CosGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = -dout * x.unaryExpr(Sine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cosine(x) = cos(x)
......@@ -541,6 +593,8 @@ struct SinGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.unaryExpr(Cosine<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sine(x) = sin(x)
......@@ -582,6 +636,8 @@ struct AcosGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) =
-dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -614,6 +670,8 @@ struct AsinGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) =
dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -645,6 +703,8 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// round(x) = [x]
......@@ -672,6 +732,8 @@ struct AbsGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; }
};
// reciprocal(x) = 1 / x
......@@ -690,6 +752,8 @@ struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// log(x) = natural logarithm of x
......@@ -708,6 +772,8 @@ struct LogGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// square(x) = x^2
......@@ -726,6 +792,8 @@ struct SquareGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -760,6 +828,8 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// relu6(x) = min(max(0, x), 6)
......@@ -792,6 +862,8 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
.template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// softplus(x) = log(1 + exp(x))
......@@ -821,6 +893,8 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// softsign(x) = x / (1 + |x|)
......@@ -842,6 +916,8 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -872,6 +948,8 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -901,6 +979,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -928,9 +1008,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * (out + static_cast<T>(alpha)) *
dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
......@@ -958,6 +1040,8 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -991,6 +1075,8 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -1020,6 +1106,8 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
auto th = static_cast<T>(threshold);
dx.device(d) = dout * (x > th).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -1053,6 +1141,8 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
.template cast<T>() *
static_cast<T>(slope);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -1077,49 +1167,54 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto out = x * temp1;
auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
} // namespace operators
} // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, CosFunctor, CosGradFunctor); \
__macro(acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, SinFunctor, SinGradFunctor); \
__macro(asin, AsinFunctor, AsinGradFunctor); \
__macro(round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
__macro(swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, Relu, ReluFunctor, ReluGradFunctor); \
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, Abs, AbsFunctor, AbsGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(cos, Cos, CosFunctor, CosGradFunctor); \
__macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
__macro(sin, Sin, SinFunctor, SinGradFunctor); \
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, Log, LogFunctor, LogGradFunctor); \
__macro(square, Square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, Pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyRelu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELU, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
ThresholdedReluGradFunctor);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册