From b504a2346ce29ff3f63a185cc9c45c32cd03bf7b Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Wed, 11 Oct 2017 10:41:45 -0700 Subject: [PATCH] Adding the Thresholded Relu Op (#4685) * Adding thresholded_relu op * Adding test for thresholded relu op --- paddle/operators/activation_op.cc | 21 ++++++++++++++ paddle/operators/activation_op.h | 29 ++++++++++++++++++- .../v2/framework/tests/test_activation_op.py | 21 ++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index ced14a89231..cba57ba57f5 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -321,6 +321,23 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ThresholdedReluOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of ThresholdedRelu operator"); + AddOutput("Y", "Output of ThresholdedRelu operator"); + AddComment( + "ThresholdedRelu activation operator, " + "thresholded_relu = x for x > threshold, " + "thresholded_relu = 0 otherwise."); + AddAttr("threshold", "The threshold location of activation") + .SetDefault(static_cast(1.0)); + } +}; + } // namespace operators } // namespace paddle @@ -392,6 +409,10 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, hard_shrink_grad, ops::ActivationOpGrad); +REGISTER_OP(thresholded_relu, ops::ActivationOp, + ops::ThresholdedReluOpMaker, thresholded_relu_grad, + ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index f88c9c48eb9..502c33be103 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -590,6 +590,32 @@ struct STanhGradFunctor : public BaseActivationFunctor { } }; +template +struct ThresholdedReluFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = (x > static_cast(threshold)).template cast() * x; + } +}; + +template +struct ThresholdedReluGradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = dy * (x > static_cast(threshold)).template cast(); + } +}; + } // namespace operators } // namespace paddle @@ -615,4 +641,5 @@ struct STanhGradFunctor : public BaseActivationFunctor { __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(elu, ELUFunctor, ELUGradFunctor); \ - __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor) + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ + __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index a28c4431e1a..3acd00e3521 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -363,5 +363,26 @@ class TestSoftsign(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestThresholdedRelu(OpTest): + def setUp(self): + self.op_type = "thresholded_relu" + threshold = 0.25 + self.relative_error = 0.005 + X = np.random.uniform(-1, 1, [11, 17]).astype("float32") + + # Same reason as TestAbs + X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2 + + self.inputs = {'X': X} + self.attrs = {'threshold': threshold} + self.outputs = {'Y': (X > threshold) * X} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=self.relative_error) + + if __name__ == "__main__": unittest.main() -- GitLab