未验证 提交 2c6bd4ad 编写于 作者: S Sławomir Siwek 提交者: GitHub

hard_swish grad (#46857)

上级 2bcbf8b0
...@@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx, ...@@ -243,7 +243,7 @@ void HardSwishGradKernel(const Context& dev_ctx,
float offset, float offset,
DenseTensor* dx) { DenseTensor* dx) {
HardSwishOneDNNGradFunctor<T> functor; HardSwishOneDNNGradFunctor<T> functor;
functor(dev_ctx, x, dout, threshold, 0, dx); functor(dev_ctx, x, dout, 0, 0, dx);
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -128,19 +128,11 @@ class TestMKLDNNSwishDim2(TestSwish): ...@@ -128,19 +128,11 @@ class TestMKLDNNSwishDim2(TestSwish):
self.dtype = np.float32 self.dtype = np.float32
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim2(TestHardSwish): class TestMKLDNNHardSwishDim2(TestHardSwish):
def setUp(self): def setUp(self):
super(TestMKLDNNHardSwishDim2, self).setUp() super(TestMKLDNNHardSwishDim2, self).setUp()
self.attrs = {"use_mkldnn": True}
self.attrs["use_mkldnn"] = True
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNSigmoidDim2(TestSigmoid): class TestMKLDNNSigmoidDim2(TestSigmoid):
...@@ -317,11 +309,14 @@ class TestMKLDNNSwishDim4(TestSwish): ...@@ -317,11 +309,14 @@ class TestMKLDNNSwishDim4(TestSwish):
def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
x_dtype = x.dtype
if x_dtype == 'float16':
x_dtype = 'float16'
x = x.astype('float32')
return (x * np.minimum(np.maximum(x + offset, 0.), threshold) / return (x * np.minimum(np.maximum(x + offset, 0.), threshold) /
scale).astype(x.dtype) scale).astype(x_dtype)
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim4(TestHardSwish): class TestMKLDNNHardSwishDim4(TestHardSwish):
def setUp(self): def setUp(self):
...@@ -343,9 +338,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): ...@@ -343,9 +338,6 @@ class TestMKLDNNHardSwishDim4(TestHardSwish):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNMish(TestActivation): class TestMKLDNNMish(TestActivation):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册