未验证 提交 fbd83812 编写于 作者: Z Zhang Ting 提交者: GitHub

Fix HardSwish inf (#35386)

* fix hard_swish inf

* skip_check_grad for mkldnn op

* 'fix code style'

* fix unittest
上级 212b51ef
...@@ -3363,7 +3363,8 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -3363,7 +3363,8 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> { struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType zero = static_cast<MPType>(0.0f);
float threshold; float threshold;
float scale; float scale;
float offset; float offset;
...@@ -3377,19 +3378,19 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> { ...@@ -3377,19 +3378,19 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
// x * (x + offset) / scale, otherwise // x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default // threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const { __device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold); const MPType x_t = static_cast<MPType>(x);
T temp = x + static_cast<T>(offset); const MPType temp_max = std::max(x_t + static_cast<MPType>(offset), zero);
T temp_max = temp > zero ? temp : zero; const MPType temp_min = std::min(temp_max, static_cast<MPType>(threshold));
T temp_min = temp_max < t ? temp_max : t; return static_cast<T>(temp_min * x_t / static_cast<MPType>(scale));
return temp_min * x / static_cast<T>(scale);
} }
}; };
template <typename T> template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> { struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
T one = static_cast<T>(1.0f); const MPType zero = static_cast<MPType>(0.0f);
T two = static_cast<T>(2.0f); const MPType one = static_cast<MPType>(1.0f);
const MPType two = static_cast<MPType>(2.0f);
float threshold; float threshold;
float scale; float scale;
float offset; float offset;
...@@ -3403,11 +3404,17 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -3403,11 +3404,17 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
// dout * (2 * x / scale + offset / scale), otherwise // dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default // threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T dout, const T x) const { __device__ __forceinline__ T operator()(const T dout, const T x) const {
T o = static_cast<T>(offset); const MPType dout_t = static_cast<MPType>(dout);
T s = static_cast<T>(scale); const MPType x_t = static_cast<MPType>(x);
T temp1 = static_cast<T>(x + o > zero); const MPType offset_t = static_cast<MPType>(offset);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold)); const MPType scale_t = static_cast<MPType>(scale);
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2); const MPType temp1 = static_cast<MPType>(x_t + offset_t > zero);
const MPType temp2 =
static_cast<MPType>(x_t + offset_t < static_cast<MPType>(threshold));
return static_cast<T>(
dout_t *
(temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2));
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import numpy as np import numpy as np
from scipy.special import expit, erf from scipy.special import expit, erf
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_activation_op import TestActivation, TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish, TestHardSwish, TestRelu6, TestSigmoid from paddle.fluid.tests.unittests.test_activation_op import TestActivation, TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish, TestHardSwish, TestRelu6, TestSigmoid
from paddle.fluid.tests.unittests.test_gelu_op import gelu from paddle.fluid.tests.unittests.test_gelu_op import gelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
...@@ -128,6 +128,7 @@ class TestMKLDNNSwishDim2(TestSwish): ...@@ -128,6 +128,7 @@ class TestMKLDNNSwishDim2(TestSwish):
self.dtype = np.float32 self.dtype = np.float32
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim2(TestHardSwish): class TestMKLDNNHardSwishDim2(TestHardSwish):
def setUp(self): def setUp(self):
...@@ -138,6 +139,9 @@ class TestMKLDNNHardSwishDim2(TestHardSwish): ...@@ -138,6 +139,9 @@ class TestMKLDNNHardSwishDim2(TestHardSwish):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNSigmoidDim2(TestSigmoid): class TestMKLDNNSigmoidDim2(TestSigmoid):
...@@ -317,6 +321,7 @@ def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): ...@@ -317,6 +321,7 @@ def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
scale).astype(x.dtype) scale).astype(x.dtype)
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim4(TestHardSwish): class TestMKLDNNHardSwishDim4(TestHardSwish):
def setUp(self): def setUp(self):
...@@ -338,6 +343,9 @@ class TestMKLDNNHardSwishDim4(TestHardSwish): ...@@ -338,6 +343,9 @@ class TestMKLDNNHardSwishDim4(TestHardSwish):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNMish(TestActivation): class TestMKLDNNMish(TestActivation):
......
...@@ -1815,8 +1815,12 @@ class TestRelu6API(unittest.TestCase): ...@@ -1815,8 +1815,12 @@ class TestRelu6API(unittest.TestCase):
def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0): def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
x_dtype = x.dtype
if x_dtype == 'float16':
x_dtype = 'float16'
x = x.astype('float32')
return (x * np.minimum(np.maximum(x + offset, 0.), threshold) / return (x * np.minimum(np.maximum(x + offset, 0.), threshold) /
scale).astype(x.dtype) scale).astype(x_dtype)
class TestHardSwish(TestActivation): class TestHardSwish(TestActivation):
...@@ -1825,7 +1829,6 @@ class TestHardSwish(TestActivation): ...@@ -1825,7 +1829,6 @@ class TestHardSwish(TestActivation):
self.op_type = 'hard_swish' self.op_type = 'hard_swish'
self.init_dtype() self.init_dtype()
self.python_api = paddle.nn.functional.hardswish self.python_api = paddle.nn.functional.hardswish
skip_check_grad_ci(reason="not implemented yet")
np.random.seed(1024) np.random.seed(1024)
x = np.random.uniform(-6, 6, [10, 12]).astype(self.dtype) x = np.random.uniform(-6, 6, [10, 12]).astype(self.dtype)
...@@ -1842,10 +1845,6 @@ class TestHardSwish(TestActivation): ...@@ -1842,10 +1845,6 @@ class TestHardSwish(TestActivation):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16:
return
return # not implemented yet
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True)
def test_check_output(self): def test_check_output(self):
...@@ -1873,11 +1872,11 @@ class TestHardswishAPI(unittest.TestCase): ...@@ -1873,11 +1872,11 @@ class TestHardswishAPI(unittest.TestCase):
def test_dygraph_api(self): def test_dygraph_api(self):
paddle.disable_static(self.place) paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np) x = paddle.to_tensor([11648., 11448.])
out1 = F.hardswish(x) out1 = F.hardswish(x)
m = paddle.nn.Hardswish() m = paddle.nn.Hardswish()
out2 = m(x) out2 = m(x)
out_ref = ref_hardswish(self.x_np) out_ref = [11648., 11448.]
for r in [out1, out2]: for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05) np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static() paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册