未验证 提交 ff0171e4 编写于 作者: S Sławomir Siwek 提交者: GitHub

Enable hard_swish_grad unit test (#46621)

* enable hard_swish_grad unit test

* remove unused argument
上级 078e8c78
...@@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx, ...@@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx,
float offset, float offset,
DenseTensor* dx) { DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor; HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx); functor(dev_ctx, x, dout, 0, 0, dx);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -126,19 +126,11 @@ class TestMKLDNNSwishDim2(TestSwish): ...@@ -126,19 +126,11 @@ class TestMKLDNNSwishDim2(TestSwish):
self.dtype = np.float32 self.dtype = np.float32
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim2(TestHardSwish): class TestMKLDNNHardSwishDim2(TestHardSwish):
def setUp(self): def setUp(self):
super(TestMKLDNNHardSwishDim2, self).setUp() super(TestMKLDNNHardSwishDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
self.attrs["use_mkldnn"] = True
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNSigmoidDim2(TestSigmoid): class TestMKLDNNSigmoidDim2(TestSigmoid):
...@@ -315,11 +307,14 @@ class TestMKLDNNSwishDim4(TestSwish): ...@@ -315,11 +307,14 @@ class TestMKLDNNSwishDim4(TestSwish):
def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
x_dtype = x.dtype
if x_dtype == 'float16':
x_dtype = 'float16'
x = x.astype('float32')
return (x * np.minimum(np.maximum(x + offset, 0.), threshold) / return (x * np.minimum(np.maximum(x + offset, 0.), threshold) /
scale).astype(x.dtype) scale).astype(x_dtype)
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim4(TestHardSwish): class TestMKLDNNHardSwishDim4(TestHardSwish):
def setUp(self): def setUp(self):
...@@ -341,9 +336,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): ...@@ -341,9 +336,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNMish(TestActivation): class TestMKLDNNMish(TestActivation):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册