From 4ad504e7c7f9bcdf4c634ff8f927d9b4988e3d05 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 21 Aug 2020 09:43:34 +0800 Subject: [PATCH] hardshrink: support threshold < 0 (#26403) --- paddle/fluid/operators/activation_op.h | 12 ++++++------ .../fluid/tests/unittests/test_activation_op.py | 15 ++++++++++++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 94cbaf76e9..018f515d81 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -388,9 +388,9 @@ struct HardShrinkFunctor : public BaseActivationFunctor { } template void operator()(Device d, X x, Out out) const { - auto temp1 = (x < static_cast(threshold * -1)).template cast(); - auto temp2 = (x > static_cast(threshold)).template cast(); - out.device(d) = x * (temp1 + temp2); + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + out.device(d) = x * (temp1 + temp2 > 0).template cast(); } }; @@ -405,9 +405,9 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp1 = (x < static_cast(threshold * -1)).template cast(); - auto temp2 = (x > static_cast(threshold)).template cast(); - dx.device(d) = dout * (temp1 + temp2).template cast(); + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + dx.device(d) = dout * (temp1 + temp2 > 0).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index db32976f04..9704eb432c 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -453,20 +453,29 @@ class TestHardShrink(TestActivation): self.op_type = "hard_shrink" self.init_dtype() - threshold = 0.5 + self.threshold = 0.5 + self.set_attrs() x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) * 10 - out = ref_hardshrink(x, threshold) + out = ref_hardshrink(x, self.threshold) - self.attrs = {'threshold': threshold} + self.attrs = {'threshold': self.threshold} self.inputs = {'X': x} self.outputs = {'Out': out} + def set_attrs(self): + pass + def test_check_grad(self): if self.dtype == np.float16: return self.check_grad(['X'], 'Out') +class TestHardShrink_threshold_negative(TestHardShrink): + def set_attrs(self): + self.threshold = -0.1 + + class TestHardShrinkAPI(unittest.TestCase): # test paddle.nn.Hardshrink, paddle.nn.functional.hardshrink def setUp(self): -- GitLab