提交 f30a1f42 编写于 作者: K kavyasrinet 提交者: GitHub

Adding relu6 activation function (#4607)

上级 d9585f9a
......@@ -201,6 +201,19 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator");
AddComment("Relu6 activation operator, relu6 = min(max(0, x), 6)");
AddAttr<AttrType>("threshold", "The threshold value of Relu6")
.SetDefault(static_cast<AttrType>(6));
}
};
template <typename AttrType>
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -276,6 +289,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad,
ops::ActivationOpGrad);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad);
......
......@@ -280,6 +280,36 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
}
};
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
float threshold;
// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
// not polymorphism for speed.
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)).cwiseMin(threshold);
}
};
template <typename T>
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy * ((x > static_cast<T>(0)) * (x < threshold)).template cast<T>();
}
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
......@@ -425,5 +455,6 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
......@@ -137,21 +137,26 @@ class TestBRelu(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestLeakyRelu(OpTest):
class TestRelu6(OpTest):
def setUp(self):
self.op_type = "leaky_relu"
alpha = 0.02
self.attrs = {'alpha': alpha}
self.inputs = {'X': np.random.uniform(-3, 3, [4, 4]).astype("float32")}
self.op_type = "relu6"
x = np.random.uniform(-1, 1, [4, 10]).astype("float32")
threshold = 6.0
# The same with TestAbs
x[np.abs(x) < 0.005] = 0.02
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
self.inputs = {'X': x}
self.attrs = {'threshold': threshold}
self.outputs = {
'Y': np.maximum(self.inputs['X'], alpha * self.inputs['X'])
'Y': np.minimum(np.maximum(self.inputs['X'], 0), threshold)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestSoftRelu(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册