未验证 提交 113c026d 编写于 作者: A Abhinav Arora 提交者: GitHub

Swish activation operator (#6358)

上级 3a0a4586
...@@ -506,6 +506,22 @@ It is recommended to use the defaults for this activation. ...@@ -506,6 +506,22 @@ It is recommended to use the defaults for this activation.
} }
}; };
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SwishOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$y = \frac{x}{1 + e^{- \beta x}}$$
)DOC");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -592,6 +608,9 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker, ...@@ -592,6 +608,9 @@ REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker,
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker, REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
hard_sigmoid_grad, ops::ActivationOpGrad); hard_sigmoid_grad, ops::ActivationOpGrad);
REGISTER_OP(swish, ops::ActivationOp, ops::SwishOpMaker, swish_grad,
ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, \ act_type, \
......
...@@ -700,6 +700,35 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -700,6 +700,35 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
}
};
template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto temp2 = temp1 * (static_cast<T>(1) - (beta * y));
dx.device(d) = dy * ((beta * y) + temp2);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -730,4 +759,5 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -730,4 +759,5 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
__macro(elu, ELUFunctor, ELUGradFunctor); \ __macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \ __macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
__macro(swish, SwishFunctor, SwishGradFunctor); \
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from scipy.special import expit
class TestExp(OpTest): class TestExp(OpTest):
...@@ -455,5 +456,20 @@ class TestHardSigmoid(OpTest): ...@@ -455,5 +456,20 @@ class TestHardSigmoid(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.002) self.check_grad(['X'], 'Y', max_relative_error=0.002)
class TestSwish(OpTest):
def setUp(self):
self.op_type = "swish"
X = np.random.uniform(0.1, 1, [11, 17]).astype("float32")
self.inputs = {'X': X}
self.attrs = {'beta': 2.3}
self.outputs = {'Y': X * expit(self.attrs['beta'] * X)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册