提交 154a6ed2 编写于 作者: K Kavya Srinet

Implementing tanhshrink operator

上级 3e2be065
......@@ -97,6 +97,17 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhShrinkOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator");
AddComment("TanhShrink activation operator, tanhshrink(x) = x - tanh(x)");
}
};
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
......@@ -235,6 +246,9 @@ REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
......
......@@ -146,6 +146,24 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
}
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x - x.tanh();
}
};
template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x.tanh() * x.tanh());
}
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
......@@ -407,4 +425,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor)
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
......@@ -48,6 +48,21 @@ class TestTanh(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestTanhShrink(OpTest):
def setUp(self):
self.op_type = "tanh_shrink"
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32")
}
self.outputs = {'Y': self.inputs['X'] - np.tanh(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册