diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 5f2ecc2673a5bf87d07e99faf60d585fb02621d0..66e9d2c40138c26975f07cb544e54de6f00d6b09 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -97,6 +97,17 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TanhShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of TanhShrink operator"); + AddOutput("Y", "Output of TanhShrink operator"); + AddComment("TanhShrink activation operator, tanhshrink(x) = x - tanh(x)"); + } +}; + class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { public: SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -235,6 +246,9 @@ REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, ops::ActivationOpGrad); +REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, + tanh_shrink_grad, ops::ActivationOpGrad); + REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, ops::ActivationOpGrad); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index dae66cc77d9103a6a2b13e69b3f014f8de313209..245060174224c5e24f75adf4ddc9a6db29101d74 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -146,6 +146,24 @@ struct TanhGradFunctor : public BaseActivationFunctor { } }; +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct TanhShrinkFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) const { + y.device(d) = x - x.tanh(); + } +}; + +template +struct TanhShrinkGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = dy * (x.tanh() * x.tanh()); + } +}; + // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { @@ -407,4 +425,5 @@ struct STanhGradFunctor : public BaseActivationFunctor { __macro(pow, PowFunctor, PowGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \ __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ - __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor) + __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ + __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor) diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index f232996a55da86399140ea2aad27428926c1a055..701e1a1aeec2746643fbd5432dadfd6bc46f358f 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -48,6 +48,21 @@ class TestTanh(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestTanhShrink(OpTest): + def setUp(self): + self.op_type = "tanh_shrink" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32") + } + self.outputs = {'Y': self.inputs['X'] - np.tanh(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.008) + + class TestSqrt(OpTest): def setUp(self): self.op_type = "sqrt"