未验证 提交 376dbb82 编写于 作者: Z zhiboniu 提交者: GitHub

【AMP OP&Test】add fp16 and bf16 test (#51286)

* add fp16 and bf16 test

* update
上级 93867e20
......@@ -51,8 +51,8 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
if (x_scale != static_cast<T>(0)) x_scale_inv = static_cast<T>(1.0) / x_scale;
AccT d_mean_data = static_cast<AccT>(0);
AccT d_var_data = static_cast<AccT>(0);
T d_scale_data = static_cast<T>(0);
T d_bias_data = static_cast<T>(0);
AccT d_scale_data = static_cast<AccT>(0);
AccT d_bias_data = static_cast<AccT>(0);
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
AccT val, dval;
......@@ -67,8 +67,8 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
d_mean_data += dval * static_cast<AccT>(x_scale);
val = val * static_cast<AccT>(x_scale_inv);
d_bias_data += static_cast<T>(dval);
d_scale_data += static_cast<T>(val * dval);
d_bias_data += dval;
d_scale_data += val * dval;
}
CudaAtomicAddWithWarp(&(d_mean[bid * groups + gid]),
static_cast<AccT>(d_mean_data));
......@@ -77,16 +77,16 @@ __global__ void GroupNormBackwardGetMeanAndVar(const T* x,
if (flags & kHasScale) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_scale[ccid]), d_scale_data);
phi::CudaAtomicAdd(&(d_scale[ccid]), static_cast<T>(d_scale_data));
#else
CudaAtomicAddWithWarp(&(d_scale[ccid]), d_scale_data);
CudaAtomicAddWithWarp(&(d_scale[ccid]), static_cast<T>(d_scale_data));
#endif
}
if (flags & kHasBias) {
#if CUDA_VERSION >= 11070
phi::CudaAtomicAdd(&(d_bias[ccid]), d_bias_data);
phi::CudaAtomicAdd(&(d_bias[ccid]), static_cast<T>(d_bias_data));
#else
CudaAtomicAddWithWarp(&(d_bias[ccid]), d_bias_data);
CudaAtomicAddWithWarp(&(d_bias[ccid]), static_cast<T>(d_bias_data));
#endif
}
}
......@@ -128,7 +128,7 @@ __global__ void GroupNormBackward(const T* x,
: static_cast<AccT>(1);
AccT x_bias =
(flags & kHasBias) ? static_cast<AccT>(bias[ccid]) : static_cast<AccT>(0);
AccT x_scale_inv = static_cast<T>(0);
AccT x_scale_inv = static_cast<AccT>(0);
if (x_scale != static_cast<AccT>(0))
x_scale_inv = static_cast<AccT>(1.0) / x_scale;
......@@ -220,7 +220,7 @@ __global__ void GetBackwardParamsCUDAKernel(int imsize,
sum1 += static_cast<AccT>(ds[index]) * scale_v;
sum2 += static_cast<AccT>(db[index]) * scale_v;
const AccT scale_c =
scale == nullptr ? static_cast<AccT>(0) : static_cast<T>(scale[c]);
scale == nullptr ? static_cast<AccT>(0) : static_cast<AccT>(scale[c]);
p1[index] = static_cast<AccT>(scale_c) * var_inv;
}
......@@ -402,7 +402,7 @@ void GroupNormGradKernel(const Context& dev_ctx,
p1_data,
p2_data,
p3_data);
GetXGradientCUDAKernel<T>
GetXGradientCUDAKernel<T, AccT>
<<<grid, threads, 0, dev_ctx.stream()>>>(imsize,
C,
group_size,
......@@ -424,7 +424,7 @@ void GroupNormGradKernel(const Context& dev_ctx,
DenseTensor temp_var;
temp_var.Resize(var.dims());
dev_ctx.template Alloc<T>(&temp_var);
dev_ctx.template Alloc<AccT>(&temp_var);
set_zero_AccT(dev_ctx, &temp_var, static_cast<AccT>(0));
auto* temp_var_data = temp_var.data<AccT>();
......@@ -483,4 +483,5 @@ PD_REGISTER_KERNEL(group_norm_grad,
phi::GroupNormGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -20,6 +20,10 @@
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/gpu/group_norm_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename AccT>
......@@ -124,7 +128,7 @@ void GroupNormKernel(const Context& dev_ctx,
DenseTensor* y,
DenseTensor* mean,
DenseTensor* var) {
using AccT = typename kps::details::MPTypeTrait<T>::Type;
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
const DataLayout data_layout = phi::StringToDataLayout(data_layout_str);
const auto scale_ptr = scale.get_ptr();
const auto bias_ptr = bias.get_ptr();
......@@ -342,4 +346,5 @@ PD_REGISTER_KERNEL(group_norm,
phi::GroupNormKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
from testsuite import create_op
import paddle
......@@ -94,8 +94,8 @@ class TestGroupNormOp(OpTest):
self.attrs['data_layout'] = self.data_format
def test_check_output(self):
atol = 0.0
inplace_atol = 0.0
atol = 0
inplace_atol = 0
place = core.CPUPlace()
self.check_output_with_place(place, atol=atol)
......@@ -161,16 +161,133 @@ class TestGroupNormOp(OpTest):
pass
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestGroupNormFP16OP(TestGroupNormOp):
def test_check_output(self):
atol = 1e-3
inplace_atol = 1e-3
place = core.CUDAPlace(0)
# group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
# Set to inplace_atol to 0, which means the absolute error is 0, and the
# relative error is 1e-05 in numpy.allclose by default.
# Reference: https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html
self.check_output_with_place(place)
def test_check_grad(self):
if self.compare_between_place:
return
place = core.CUDAPlace(0)
self.check_grad_with_place(place, set(['X', 'Scale', 'Bias']), 'Y')
def init_test_case(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestGroupNormBF16Op(OpTest):
def setUp(self):
self.op_type = "group_norm"
self.data_format = "NCHW"
self.dtype = np.uint16
self.shape = (2, 100, 3, 5)
self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"}
self.compare_between_place = False
self.init_test_case()
input = np.random.random(self.shape).astype(np.float32)
if self.data_format == "NHWC":
input = np.transpose(input, (0, 2, 3, 1))
scale = np.random.random([self.shape[1]]).astype(np.float32)
bias = np.random.random([self.shape[1]]).astype(np.float32)
output, mean, var = group_norm_naive(
input,
scale,
bias,
self.attrs['epsilon'],
self.attrs['groups'],
self.data_format,
)
self.inputs = {
'X': convert_float_to_uint16(input),
'Scale': convert_float_to_uint16(scale),
'Bias': convert_float_to_uint16(bias),
}
self.outputs = {'Y': output, 'Mean': mean, 'Variance': var}
self.attrs['data_layout'] = self.data_format
def test_check_output(self):
atol = 1e-2
inplace_atol = 1e-2
place = core.CUDAPlace(0)
# group_norm uses AtomicAdd on CUDAPlace, which do not ensure
# computation order when multiple threads write the same address. So the
# result of group_norm is non-deterministic when datatype is float.
# When inplace_atol is not None, the inplace check uses numpy.allclose
# to check inplace result instead of numpy.array_equal.
# Set to inplace_atol to 0, which means the absolute error is 0, and the
# relative error is 1e-05 in numpy.allclose by default.
# Reference: https://docs.scipy.org/doc/numpy/reference/generated/numpy.allclose.html
self.check_output_with_place(place)
def test_check_grad(self):
if self.compare_between_place:
return
place = core.CUDAPlace(0)
self.check_grad_with_place(place, set(['X', 'Scale', 'Bias']), 'Y')
def init_test_case(self):
pass
class TestGroupNormOp1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
class TestGroupNormFP16Op1(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 1
self.dtype = np.float16
class TestGroupNormBF16Op1(TestGroupNormBF16Op):
def init_test_case(self):
self.attrs['groups'] = 1
class TestGroupNormOp2(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 4
class TestGroupNormFP16Op2(TestGroupNormFP16OP):
def init_test_case(self):
self.attrs['groups'] = 4
self.dtype = np.float16
class TestGroupNormBF16Op2(TestGroupNormBF16Op):
def init_test_case(self):
self.attrs['groups'] = 4
class TestGroupNormOpBigEps1(TestGroupNormOp):
def init_test_case(self):
self.attrs['groups'] = 1
......@@ -244,6 +361,8 @@ class TestGroupNormOpLargeData_With_NHWC(TestGroupNormOp):
class TestGroupNormAPI_With_NHWC(unittest.TestCase):
paddle.enable_static()
def test_case1(self):
data1 = fluid.data(name='data1', shape=[None, 3, 3, 4], dtype='float64')
out1 = paddle.static.nn.group_norm(
......
......@@ -692,7 +692,10 @@ def group_norm(
helper = LayerHelper('group_norm', **locals())
dtype = helper.input_dtype()
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'group_norm'
input,
'input',
['float16', 'uint16', 'float32', 'float64'],
'group_norm',
)
# create intput and parameters
inputs = {'X': input}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册