diff --git a/paddle/phi/kernels/onednn/activation_grad_kernel.cc b/paddle/phi/kernels/onednn/activation_grad_kernel.cc index 3e59a2a0df2e0cf5a299f5026de1101a96dcf17f..1ebe9f20c63ea2afde7786c04479e0f6476b04c2 100644 --- a/paddle/phi/kernels/onednn/activation_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/activation_grad_kernel.cc @@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx, float offset, DenseTensor* dx) { HardSwishOneDNNGradFunctor functor; - functor(dev_ctx, x, dout, threshold, 0, dx); + functor(dev_ctx, x, dout, 0, 0, dx); } template diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py index 35195ea1a532d027eca5519ae694d9606190183f..2aaba4521dd67739e1ced02e9d2c6f52d536ad7b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_activation_mkldnn_op.py @@ -126,19 +126,11 @@ class TestMKLDNNSwishDim2(TestSwish): self.dtype = np.float32 -@skip_check_grad_ci(reason="not implemented yet") class TestMKLDNNHardSwishDim2(TestHardSwish): def setUp(self): super(TestMKLDNNHardSwishDim2, self).setUp() - - self.attrs["use_mkldnn"] = True - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_grad(self): - pass + self.attrs = {"use_mkldnn": True} class TestMKLDNNSigmoidDim2(TestSigmoid): @@ -315,11 +307,14 @@ class TestMKLDNNSwishDim4(TestSwish): def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): + x_dtype = x.dtype + if x_dtype == 'float16': + x_dtype = 'float16' + x = x.astype('float32') return (x * np.minimum(np.maximum(x + offset, 0.), threshold) / - scale).astype(x.dtype) + scale).astype(x_dtype) -@skip_check_grad_ci(reason="not implemented yet") class TestMKLDNNHardSwishDim4(TestHardSwish): def setUp(self): @@ -341,9 +336,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): def init_dtype(self): self.dtype = np.float32 - def test_check_grad(self): - pass - class TestMKLDNNMish(TestActivation):