diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 94cbaf76e942aae3ec299f6d0186418527c7ce96..018f515d81e93fc6b97c28e4872269c9da33ddbb 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -388,9 +388,9 @@ struct HardShrinkFunctor : public BaseActivationFunctor { } template void operator()(Device d, X x, Out out) const { - auto temp1 = (x < static_cast(threshold * -1)).template cast(); - auto temp2 = (x > static_cast(threshold)).template cast(); - out.device(d) = x * (temp1 + temp2); + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + out.device(d) = x * (temp1 + temp2 > 0).template cast(); } }; @@ -405,9 +405,9 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { template void operator()(Device d, X x, Out out, dOut dout, dX dx) const { - auto temp1 = (x < static_cast(threshold * -1)).template cast(); - auto temp2 = (x > static_cast(threshold)).template cast(); - dx.device(d) = dout * (temp1 + temp2).template cast(); + auto temp1 = x < static_cast(threshold * -1.f); + auto temp2 = x > static_cast(threshold); + dx.device(d) = dout * (temp1 + temp2 > 0).template cast(); } static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; } diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index db32976f046d36557f201621b651e57904f5a974..9704eb432cae9970761c648a38f7158ce32ba974 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -453,20 +453,29 @@ class TestHardShrink(TestActivation): self.op_type = "hard_shrink" self.init_dtype() - threshold = 0.5 + self.threshold = 0.5 + self.set_attrs() x = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype) * 10 - out = ref_hardshrink(x, threshold) + out = ref_hardshrink(x, self.threshold) - self.attrs = {'threshold': threshold} + self.attrs = {'threshold': self.threshold} self.inputs = {'X': x} self.outputs = {'Out': out} + def set_attrs(self): + pass + def test_check_grad(self): if self.dtype == np.float16: return self.check_grad(['X'], 'Out') +class TestHardShrink_threshold_negative(TestHardShrink): + def set_attrs(self): + self.threshold = -0.1 + + class TestHardShrinkAPI(unittest.TestCase): # test paddle.nn.Hardshrink, paddle.nn.functional.hardshrink def setUp(self):