未验证 提交 6bd5fd75 编写于 作者: V Vvsmile 提交者: GitHub

[AMP OP&Test] Add fp16 and bf16 test to activation (#52521)

* adjust defalut tolerance of output and grad

* fix a bug in the grad of OpTest

* fix the type of setting defalut value in optest, both forward and
backward

* add defalut

* fix test_sum_op

* adjust tolerance

* fix the tolerance of eager

* add bf16 and fp16 to the activation tests

* remove some fixs

* fix activation

* fix fp16

* fix gelu

* fix the activation tests

* add bfloat16 specialization to singrad and cosgrad

* fix bugs

* fix bugs

* add unittest

* add skip

* add fp/bf to rrelu/rrelu_grad

* git add rrelu

* fix bugs
上级 70eaf9de
...@@ -70,6 +70,13 @@ struct Sine<dtype::float16> { ...@@ -70,6 +70,13 @@ struct Sine<dtype::float16> {
} }
}; };
template <>
struct Sine<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(sin(static_cast<float>(val)));
}
};
template <typename T> template <typename T>
struct Cosine { struct Cosine {
HOSTDEVICE T operator()(const T& val) const { return cos(val); } HOSTDEVICE T operator()(const T& val) const { return cos(val); }
...@@ -82,6 +89,13 @@ struct Cosine<dtype::float16> { ...@@ -82,6 +89,13 @@ struct Cosine<dtype::float16> {
} }
}; };
template <>
struct Cosine<dtype::bfloat16> {
HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16& val) const {
return dtype::bfloat16(cos(static_cast<float>(val)));
}
};
// sine'(x) = cos(x) // sine'(x) = cos(x)
template <typename T> template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> { struct SinGradFunctor : public BaseActivationFunctor<T> {
...@@ -2664,10 +2678,12 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> { ...@@ -2664,10 +2678,12 @@ struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// reciprocal(x) = 1 / x __device__ __forceinline__ T operator()(const T x) const {
__device__ __forceinline__ T operator()(const T x) const { return one / x; } return static_cast<T>(one / static_cast<MPType>(x));
}
}; };
template <typename T> template <typename T>
......
...@@ -425,7 +425,8 @@ PD_REGISTER_KERNEL(sin_double_grad, ...@@ -425,7 +425,8 @@ PD_REGISTER_KERNEL(sin_double_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(sin_triple_grad, PD_REGISTER_KERNEL(sin_triple_grad,
GPU, GPU,
...@@ -435,7 +436,8 @@ PD_REGISTER_KERNEL(sin_triple_grad, ...@@ -435,7 +436,8 @@ PD_REGISTER_KERNEL(sin_triple_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(cos_double_grad, PD_REGISTER_KERNEL(cos_double_grad,
GPU, GPU,
...@@ -445,7 +447,8 @@ PD_REGISTER_KERNEL(cos_double_grad, ...@@ -445,7 +447,8 @@ PD_REGISTER_KERNEL(cos_double_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(cos_triple_grad, PD_REGISTER_KERNEL(cos_triple_grad,
GPU, GPU,
...@@ -455,7 +458,8 @@ PD_REGISTER_KERNEL(cos_triple_grad, ...@@ -455,7 +458,8 @@ PD_REGISTER_KERNEL(cos_triple_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
...@@ -473,7 +477,8 @@ PD_REGISTER_KERNEL(log_double_grad, ...@@ -473,7 +477,8 @@ PD_REGISTER_KERNEL(log_double_grad,
phi::LogDoubleGradKernel, phi::LogDoubleGradKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardswish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
......
...@@ -83,4 +83,5 @@ PD_REGISTER_KERNEL(rrelu_grad, ...@@ -83,4 +83,5 @@ PD_REGISTER_KERNEL(rrelu_grad,
phi::RReluGradKernel, phi::RReluGradKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
double) {} double) {}
...@@ -110,4 +110,5 @@ PD_REGISTER_KERNEL(rrelu, ...@@ -110,4 +110,5 @@ PD_REGISTER_KERNEL(rrelu,
phi::RReluKernel, phi::RReluKernel,
float, float,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
double) {} double) {}
...@@ -885,7 +885,9 @@ class OpTest(unittest.TestCase): ...@@ -885,7 +885,9 @@ class OpTest(unittest.TestCase):
np_dyg, np_dyg,
rtol=1e-05, rtol=1e-05,
equal_nan=False, equal_nan=False,
err_msg='Output (' err_msg='Operator ('
+ self.op_type
+ ') Output ('
+ name + name
+ ') has diff at ' + ') has diff at '
+ str(place) + str(place)
...@@ -1137,7 +1139,9 @@ class OpTest(unittest.TestCase): ...@@ -1137,7 +1139,9 @@ class OpTest(unittest.TestCase):
actual_out, actual_out,
rtol=1e-05, rtol=1e-05,
atol=inplace_atol, atol=inplace_atol,
err_msg='Output (' err_msg='Operator ('
+ self.op_type
+ ') Output ('
+ name + name
+ ') has diff at ' + ') has diff at '
+ str(place) + str(place)
...@@ -1626,7 +1630,9 @@ class OpTest(unittest.TestCase): ...@@ -1626,7 +1630,9 @@ class OpTest(unittest.TestCase):
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
err_msg=( err_msg=(
"Output (" "Operator ("
+ self.op_type
+ ") Output ("
+ name + name
+ ") has diff at " + ") has diff at "
+ str(place) + str(place)
...@@ -1643,7 +1649,9 @@ class OpTest(unittest.TestCase): ...@@ -1643,7 +1649,9 @@ class OpTest(unittest.TestCase):
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
), ),
"Output (" "Operator ("
+ self.op_type
+ ") Output ("
+ name + name
+ ") has diff at " + ") has diff at "
+ str(place) + str(place)
...@@ -1815,7 +1823,9 @@ class OpTest(unittest.TestCase): ...@@ -1815,7 +1823,9 @@ class OpTest(unittest.TestCase):
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
err_msg=( err_msg=(
"Output (" "Operator ("
+ self.op_type
+ ") Output ("
+ name + name
+ ") has diff at " + ") has diff at "
+ str(place) + str(place)
...@@ -1832,7 +1842,9 @@ class OpTest(unittest.TestCase): ...@@ -1832,7 +1842,9 @@ class OpTest(unittest.TestCase):
rtol=self.rtol if hasattr(self, 'rtol') else rtol, rtol=self.rtol if hasattr(self, 'rtol') else rtol,
equal_nan=equal_nan, equal_nan=equal_nan,
), ),
"Output (" "Operator ("
+ self.op_type
+ ") Output ("
+ name + name
+ ") has diff at " + ") has diff at "
+ str(place) + str(place)
...@@ -1882,7 +1894,9 @@ class OpTest(unittest.TestCase): ...@@ -1882,7 +1894,9 @@ class OpTest(unittest.TestCase):
.get_tensor() .get_tensor()
.recursive_sequence_lengths(), .recursive_sequence_lengths(),
expect[1], expect[1],
"Output (" "Operator ("
+ self.op_type
+ ") Output ("
+ name + name
+ ") has different lod at " + ") has different lod at "
+ str(place) + str(place)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
...@@ -327,7 +327,7 @@ class RReluTest(OpTest): ...@@ -327,7 +327,7 @@ class RReluTest(OpTest):
] # python out sig is customized output signature. ] # python out sig is customized output signature.
def init_params(self): def init_params(self):
self.dtype = "float64" self.init_dtype()
self.x_shape = [2, 3, 4, 5] self.x_shape = [2, 3, 4, 5]
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype) x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
...@@ -337,12 +337,19 @@ class RReluTest(OpTest): ...@@ -337,12 +337,19 @@ class RReluTest(OpTest):
self.inputs = {'X': x_np} self.inputs = {'X': x_np}
self.outputs = {'Out': out_np, 'Noise': noise_np} self.outputs = {'Out': out_np, 'Noise': noise_np}
self.convert_input_output()
self.attrs = { self.attrs = {
'lower': self.lower, 'lower': self.lower,
"upper": self.upper, "upper": self.upper,
"is_test": self.is_test, "is_test": self.is_test,
} }
def init_dtype(self):
self.dtype = "float64"
def convert_input_output(self):
pass
def test_check_output(self): def test_check_output(self):
self.check_output(no_check_set=['Noise']) self.check_output(no_check_set=['Noise'])
...@@ -363,5 +370,67 @@ class RReluTrainingTest(RReluTest): ...@@ -363,5 +370,67 @@ class RReluTrainingTest(RReluTest):
] # python out sig is customized output signature. ] # python out sig is customized output signature.
class RReluTestFP16OP(RReluTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class RReluTestBF16OP(RReluTest):
def init_dtype(self):
self.dtype = np.float32
def convert_input_output(self):
self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])}
self.outputs = {
'Out': convert_float_to_uint16(self.outputs['Out']),
'Noise': convert_float_to_uint16(self.outputs['Noise']),
}
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, no_check_set=['Noise'])
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
class RReluTrainingTestFP16OP(RReluTrainingTest):
def init_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA and do not support bfloat16",
)
class RReluTrainingTestBF16OP(RReluTrainingTest):
def init_dtype(self):
self.dtype = np.float32
def convert_input_output(self):
self.inputs = {'X': convert_float_to_uint16(self.inputs['X'])}
self.outputs = {
'Out': convert_float_to_uint16(self.outputs['Out']),
'Noise': convert_float_to_uint16(self.outputs['Noise']),
}
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, no_check_set=['Noise'])
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册