diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index 9e183abf0287fa595c6905f1a44665eef62ead4a..cc7f71ff3646d7fbabf7bb1aff7b897acb856761 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx, float offset, DenseTensor* dx) { HardSwishOneDNNGradFunctor functor; - functor(dev_ctx, x, dout, threshold, 0, dx); + functor(dev_ctx, x, dout, 0, 0, dx); } template diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index ac880b02c4557e6615f2a22ad640a5502495a73e..b2e58514dca78bd811ef6bb7e4f21f952c7ffba3 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -128,19 +128,11 @@ class TestMKLDNNSwishDim2(TestSwish): self.dtype = np.float32 -@skip_check_grad_ci(reason="not implemented yet") class TestMKLDNNHardSwishDim2(TestHardSwish): def setUp(self): super(TestMKLDNNHardSwishDim2, self).setUp() - - self.attrs["use_mkldnn"] = True - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_grad(self): - pass + self.attrs = {"use_mkldnn": True} class TestMKLDNNSigmoidDim2(TestSigmoid): @@ -317,11 +309,14 @@ class TestMKLDNNSwishDim4(TestSwish): def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): + x_dtype = x.dtype + if x_dtype == 'float16': + x_dtype = 'float16' + x = x.astype('float32') return (x * np.minimum(np.maximum(x + offset, 0.), threshold) / - scale).astype(x.dtype) + scale).astype(x_dtype) -@skip_check_grad_ci(reason="not implemented yet") class TestMKLDNNHardSwishDim4(TestHardSwish): def setUp(self): @@ -343,9 +338,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): def init_dtype(self): self.dtype = np.float32 - def test_check_grad(self): - pass - class TestMKLDNNMish(TestActivation):