diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index ced14a8923140ec6b08e3e6725a5780b61033daf..cba57ba57f5e03c7861897e177cc09aa513e5395 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -321,6 +321,23 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ThresholdedReluOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of ThresholdedRelu operator"); + AddOutput("Y", "Output of ThresholdedRelu operator"); + AddComment( + "ThresholdedRelu activation operator, " + "thresholded_relu = x for x > threshold, " + "thresholded_relu = 0 otherwise."); + AddAttr("threshold", "The threshold location of activation") + .SetDefault(static_cast(1.0)); + } +}; + } // namespace operators } // namespace paddle @@ -392,6 +409,10 @@ REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, hard_shrink_grad, ops::ActivationOpGrad); +REGISTER_OP(thresholded_relu, ops::ActivationOp, + ops::ThresholdedReluOpMaker, thresholded_relu_grad, + ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index f88c9c48eb9fcb779de5a99a45a832e582d76ab0..502c33be103c465c14f128be38ac62d029f1bfb9 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -590,6 +590,32 @@ struct STanhGradFunctor : public BaseActivationFunctor { } }; +template +struct ThresholdedReluFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y) const { + y.device(d) = (x > static_cast(threshold)).template cast() * x; + } +}; + +template +struct ThresholdedReluGradFunctor : public BaseActivationFunctor { + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + dx.device(d) = dy * (x > static_cast(threshold)).template cast(); + } +}; + } // namespace operators } // namespace paddle @@ -615,4 +641,5 @@ struct STanhGradFunctor : public BaseActivationFunctor { __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ __macro(elu, ELUFunctor, ELUGradFunctor); \ - __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor) + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \ + __macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor); diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index a28c4431e1ae9230750247c0ed16c9aff37364fa..3acd00e35213981fce60504876af1861961ebe12 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -363,5 +363,26 @@ class TestSoftsign(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestThresholdedRelu(OpTest): + def setUp(self): + self.op_type = "thresholded_relu" + threshold = 0.25 + self.relative_error = 0.005 + X = np.random.uniform(-1, 1, [11, 17]).astype("float32") + + # Same reason as TestAbs + X[np.abs(X - threshold) < self.relative_error] = threshold + 0.2 + + self.inputs = {'X': X} + self.attrs = {'threshold': threshold} + self.outputs = {'Y': (X > threshold) * X} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=self.relative_error) + + if __name__ == "__main__": unittest.main()