提交 1397e17f 编写于 作者: K kavyasrinet 提交者: GitHub

Implemented the hardShrink activation (#4653)

* Implemented the hardShrink activation

* Fixing the unit test
上级 6604d7cd
...@@ -137,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -137,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
template <typename AttrType>
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
HardShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddComment(
"HardShrink activation operator, "
"hard_shrink(x) = x if x > lambda"
"hard_shrink(x) = x if x < -lambda"
"hard_shrink(x) = 0 otherwise");
AddAttr<AttrType>("threshold", "The value of threshold for HardShrink")
.SetDefault(static_cast<AttrType>(0.5));
}
};
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
...@@ -357,6 +375,9 @@ REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad, ...@@ -357,6 +375,9 @@ REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>,
hard_shrink_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, \ act_type, \
......
...@@ -199,6 +199,39 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> { ...@@ -199,6 +199,39 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2);
}
};
template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
auto temp2 = (x > threshold).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
}
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0
// otherwise // otherwise
template <typename T> template <typename T>
...@@ -351,8 +384,6 @@ template <typename T> ...@@ -351,8 +384,6 @@ template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> { struct Relu6Functor : public BaseActivationFunctor<T> {
float threshold; float threshold;
// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
// not polymorphism for speed.
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}}; return {{"threshold", &threshold}};
} }
...@@ -555,4 +586,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -555,4 +586,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor) __macro(elu, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
...@@ -78,6 +78,26 @@ class TestTanhShrink(OpTest): ...@@ -78,6 +78,26 @@ class TestTanhShrink(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.008) self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestHardShrink(OpTest):
def setUp(self):
self.op_type = "hard_shrink"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
threshold = 0.5
self.inputs = {'X': x}
self.attrs = {'lambda': threshold}
t = np.copy(x)
t[(t >= -threshold) & (t <= threshold)] = 0
self.outputs = {'Y': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.005)
class TestSoftShrink(OpTest): class TestSoftShrink(OpTest):
def setUp(self): def setUp(self):
self.op_type = "softshrink" self.op_type = "softshrink"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册