提交 0d017d91 编写于 作者: Z zhouxiao-coder 提交者: GitHub

Merge pull request #4395 from zhouxiao-coder/elu-activation

ELU activation
...@@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
template <typename AttrType>
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(Tensor) The input of ELU operator, it shouldn't be empty. Input "
"is flattened and treated as a 1D array.");
AddOutput("Y",
"(Tensor) The output of ELU operator. It has the same shape as "
"the input.");
AddAttr<AttrType>(
"alpha", "(float, default 1.0) Alpha value in the elu formulation.")
.SetDefault(static_cast<AttrType>(1.));
AddComment(R"DOC(
ELU activation operator. It applies this element-wise computation on
the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)).
Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC");
}
};
template <typename AttrType> template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>, ...@@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad); soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker<float>, elu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad, REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
......
...@@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
x.cwiseMax(static_cast<T>(0)) +
(alpha * (x.exp() - static_cast<T>(1))).cwiseMin(static_cast<T>(0));
}
};
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + alpha) * (x < static_cast<T>(0)).template cast<T>();
}
};
template <typename T> template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> { struct PowFunctor : public BaseActivationFunctor<T> {
float factor; float factor;
...@@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \ #define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \ __macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \ __macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \ __macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \ __macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \ __macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \ __macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \ __macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \ __macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \ __macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \ __macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor) __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
...@@ -181,6 +181,26 @@ class TestSoftRelu(OpTest): ...@@ -181,6 +181,26 @@ class TestSoftRelu(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.02) self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestELU(OpTest):
def setUp(self):
self.op_type = "elu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
alpha = 1.
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self.inputs = {'X': x}
self.attrs = {'alpha': alpha}
self.outputs = {
'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestReciprocal(OpTest): class TestReciprocal(OpTest):
def setUp(self): def setUp(self):
self.op_type = "reciprocal" self.op_type = "reciprocal"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册