未验证 提交 fbd83812 编写于 作者: Z Zhang Ting 提交者: GitHub

Fix HardSwish inf (#35386)

* fix hard_swish inf

* skip_check_grad for mkldnn op

* 'fix code style'

* fix unittest
上级 212b51ef
......@@ -3363,7 +3363,8 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType zero = static_cast<MPType>(0.0f);
float threshold;
float scale;
float offset;
......@@ -3377,19 +3378,19 @@ struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < t ? temp_max : t;
return temp_min * x / static_cast<T>(scale);
const MPType x_t = static_cast<MPType>(x);
const MPType temp_max = std::max(x_t + static_cast<MPType>(offset), zero);
const MPType temp_min = std::min(temp_max, static_cast<MPType>(threshold));
return static_cast<T>(temp_min * x_t / static_cast<MPType>(scale));
}
};
template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
T two = static_cast<T>(2.0f);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
const MPType zero = static_cast<MPType>(0.0f);
const MPType one = static_cast<MPType>(1.0f);
const MPType two = static_cast<MPType>(2.0f);
float threshold;
float scale;
float offset;
......@@ -3403,11 +3404,17 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T o = static_cast<T>(offset);
T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
const MPType dout_t = static_cast<MPType>(dout);
const MPType x_t = static_cast<MPType>(x);
const MPType offset_t = static_cast<MPType>(offset);
const MPType scale_t = static_cast<MPType>(scale);
const MPType temp1 = static_cast<MPType>(x_t + offset_t > zero);
const MPType temp2 =
static_cast<MPType>(x_t + offset_t < static_cast<MPType>(threshold));
return static_cast<T>(
dout_t *
(temp1 * temp2 * (two * x_t + offset_t) / scale_t + one - temp2));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
......
......@@ -18,7 +18,7 @@ import unittest
import numpy as np
from scipy.special import expit, erf
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16, skip_check_grad_ci
from paddle.fluid.tests.unittests.test_activation_op import TestActivation, TestRelu, TestTanh, TestSqrt, TestAbs, TestLeakyRelu, TestSwish, TestHardSwish, TestRelu6, TestSigmoid
from paddle.fluid.tests.unittests.test_gelu_op import gelu
from mkldnn_op_test import check_if_mkldnn_primitives_exist_in_bwd
......@@ -128,6 +128,7 @@ class TestMKLDNNSwishDim2(TestSwish):
self.dtype = np.float32
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim2(TestHardSwish):
def setUp(self):
......@@ -138,6 +139,9 @@ class TestMKLDNNHardSwishDim2(TestHardSwish):
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNSigmoidDim2(TestSigmoid):
......@@ -317,6 +321,7 @@ def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
scale).astype(x.dtype)
@skip_check_grad_ci(reason="not implemented yet")
class TestMKLDNNHardSwishDim4(TestHardSwish):
def setUp(self):
......@@ -338,6 +343,9 @@ class TestMKLDNNHardSwishDim4(TestHardSwish):
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
pass
class TestMKLDNNMish(TestActivation):
......
......@@ -1815,8 +1815,12 @@ class TestRelu6API(unittest.TestCase):
def ref_hardswish(x, threshold=6.0, scale=6.0, offset=3.0):
x_dtype = x.dtype
if x_dtype == 'float16':
x_dtype = 'float16'
x = x.astype('float32')
return (x * np.minimum(np.maximum(x + offset, 0.), threshold) /
scale).astype(x.dtype)
scale).astype(x_dtype)
class TestHardSwish(TestActivation):
......@@ -1825,7 +1829,6 @@ class TestHardSwish(TestActivation):
self.op_type = 'hard_swish'
self.init_dtype()
self.python_api = paddle.nn.functional.hardswish
skip_check_grad_ci(reason="not implemented yet")
np.random.seed(1024)
x = np.random.uniform(-6, 6, [10, 12]).astype(self.dtype)
......@@ -1842,10 +1845,6 @@ class TestHardSwish(TestActivation):
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
return # not implemented yet
self.check_grad(['X'], 'Out', check_eager=True)
def test_check_output(self):
......@@ -1873,11 +1872,11 @@ class TestHardswishAPI(unittest.TestCase):
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
x = paddle.to_tensor([11648., 11448.])
out1 = F.hardswish(x)
m = paddle.nn.Hardswish()
out2 = m(x)
out_ref = ref_hardswish(self.x_np)
out_ref = [11648., 11448.]
for r in [out1, out2]:
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册