未验证 提交 7c98abd9 编写于 作者: Q qizhaoaoe 提交者: GitHub

【AMP OP&Test】instance_norm fp16 and bf16 support. (#52241)

* add fp16 and bf16 support for instance_norm

* fix /= operator which not support bf16

* fix instance_norm_grad kernel and unittests.

* fix fp32 unittests.

* fix instance_norm_kernel and unittests.

* fix instance_norm_grad_kernel and unittest threshold.

* add fp16/bf16 for instance_norm_grad_grad op.

* add bf16 dtype check.

* fix conflicts.

* fix cpu support for fp32 op and fix type in instance_norm_grad_kernel.

* fix type in instance_norm_kernel.

* fix bf16 outputs in unittests and refine codes.

* fix dx computation.

* delete unuseful params and head including.

* add fp16/bf16 for static graph.

* fix device condiction for instance_norm op.

* fix instance_norm_grad_grad and bf16 op tests.

* fix op_test to support grad of bf16 can be compared with fp32.

* remove updates.

* add self-defined grad.
上级 de44b3ac
......@@ -33,6 +33,7 @@ void InstanceNormKernel(const Context &dev_ctx,
DenseTensor *y,
DenseTensor *saved_mean,
DenseTensor *saved_variance) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
double epsilon = static_cast<double>(epsilon_f);
auto &x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
......@@ -113,10 +114,10 @@ void InstanceNormKernel(const Context &dev_ctx,
DenseTensor scale_tmp;
scale_tmp.Resize({NxC});
dev_ctx.template Alloc<T>(&scale_tmp);
dev_ctx.template Alloc<AccT>(&scale_tmp);
DenseTensor bias_tmp;
bias_tmp.Resize({NxC});
dev_ctx.template Alloc<T>(&bias_tmp);
dev_ctx.template Alloc<AccT>(&bias_tmp);
const int n = x.numel();
const int block = 512;
......@@ -124,24 +125,25 @@ void InstanceNormKernel(const Context &dev_ctx,
const int max_blocks = std::max(max_threads / block, 1);
const int grid = std::min((NxC + block - 1) / block, max_blocks);
phi::funcs::SetConstant<GPUContext, T> set_constant;
phi::funcs::SetConstant<GPUContext, AccT> set_constant;
if (scale_ptr) {
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
scale_ptr->data<T>(), scale_tmp.data<T>(), N, C);
repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
scale_ptr->data<AccT>(), scale_tmp.data<AccT>(), N, C);
} else {
set_constant(dev_ctx, &scale_tmp, static_cast<T>(1));
set_constant(dev_ctx, &scale_tmp, static_cast<AccT>(1));
}
if (bias_ptr) {
repeat_param<T><<<grid, block, 0, dev_ctx.stream()>>>(
bias_ptr->data<T>(), bias_tmp.data<T>(), N, C);
repeat_param<AccT><<<grid, block, 0, dev_ctx.stream()>>>(
bias_ptr->data<AccT>(), bias_tmp.data<AccT>(), N, C);
} else {
set_constant(dev_ctx, &bias_tmp, static_cast<T>(0));
set_constant(dev_ctx, &bias_tmp, static_cast<AccT>(0));
}
auto handle = dev_ctx.cudnn_handle();
DenseTensor saved_mean_tmp, saved_variance_tmp;
phi::funcs::SetConstant<GPUContext, BatchNormParamType<T>> functor;
if (saved_mean) {
dev_ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
......@@ -156,7 +158,6 @@ void InstanceNormKernel(const Context &dev_ctx,
saved_variance_tmp = phi::Full<BatchNormParamType<T>>(
dev_ctx, {NxC}, static_cast<BatchNormParamType<T>>(0));
}
auto *saved_mean_data = saved_mean
? saved_mean->data<BatchNormParamType<T>>()
: saved_mean_tmp.data<BatchNormParamType<T>>();
......@@ -225,9 +226,27 @@ void InstanceNormKernel(const Context &dev_ctx,
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
PD_REGISTER_KERNEL(
instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {}
PD_REGISTER_KERNEL(instance_norm,
GPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
phi::dtype::float16) {}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(instance_norm,
GPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(
instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float, double) {}
PD_REGISTER_KERNEL(instance_norm,
GPU,
ALL_LAYOUT,
phi::InstanceNormKernel,
float,
double,
phi::dtype::float16) {}
#endif
......@@ -27,6 +27,7 @@ namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace phi {
......@@ -51,22 +52,23 @@ static __global__ void add_param(const T *input,
T *output,
const int repeat_num,
const int C) {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
typedef cub::BlockReduce<MPType, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ou_storage;
for (int i = blockIdx.x; i < C; i += gridDim.x) {
T ou = static_cast<T>(0);
MPType ou = static_cast<MPType>(0);
for (int j = threadIdx.x; j < repeat_num; j += blockDim.x) {
const int index = j * C + i;
ou += static_cast<T>(input[index]);
ou = ou + static_cast<MPType>(input[index]);
}
ou = BlockReduce(ou_storage).Reduce(ou, cub::Sum());
if (threadIdx.x == 0) {
output[i] = ou;
output[i] = static_cast<T>(ou);
}
__syncthreads();
if (AVG) {
output[i] /= repeat_num;
output[i] = static_cast<T>(static_cast<MPType>(output[i]) / repeat_num);
}
}
}
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
......@@ -121,5 +122,202 @@ class TestInstanceNorm(unittest.TestCase):
np.testing.assert_allclose(y1, y2, rtol=1e-05)
def instance_norm_warpper(
input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW'
):
if data_format == "AnyLayout":
data_format = "NCDHW"
return paddle._C_ops.instance_norm(
input, weight, bias, epsilon, momentum, data_format
)
def _reference_instance_norm(x, scale, bias, epsilon):
N, C, H, W = x.shape
mean = np.mean(x, axis=(2, 3), keepdims=True)
variance = np.var(x, axis=(2, 3), keepdims=True)
std = np.sqrt(variance) + epsilon
x_norm = (x - mean) / std
scale = scale.reshape([1, C, 1, 1])
bias = bias.reshape([1, C, 1, 1])
x_norm = scale * x_norm + bias
return x_norm, mean.reshape(N * C), std.reshape(N * C)
def _reference_instance_norm_grad(x, scale, mean, var):
n, c, h, w = x.shape
d_y = np.ones(x.shape) / (np.prod(x.shape))
d_bias = np.ones((c,)) / c
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = (
scale_tile
* var_inv
* (
d_y
- np.mean(d_y, axis=(2, 3), keepdims=True)
- (x - mean_tile)
* var_inv
* np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True
)
)
)
return d_x, d_scale, d_bias
class TestInstanceNormFP32OP(OpTest):
def setUp(self):
'''Test instance_norm op with default value'''
self.op_type = "instance_norm"
self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper
self.data_format = "NCHW"
self.eps = 1e-5
self.init_dtype()
self.init_shape()
self.init_value()
self.set_err_thre()
self.inputs = {'X': self.value, 'Scale': self.scale, 'Bias': self.bias}
self.attrs = {
'epsilon': self.eps,
'momentum': 0.9,
'data_format': self.data_format,
}
y, mean, variance = _reference_instance_norm(
self.value, self.scale, self.bias, self.eps
)
self.python_out_sig = ['Y']
self.outputs = {
'Y': y,
'SavedMean': mean,
'SavedVariance': 1.0 / variance,
}
def test_check_output(self):
self.check_output(atol=self.atol)
def test_check_grad(self):
self.check_grad(
['X', 'Scale', 'Bias'],
'Y',
)
def init_dtype(self):
self.dtype = np.float32
def init_shape(self):
self.shape = [4, 100, 4, 4]
def init_value(self):
np.random.seed(0)
self.value = np.random.random(self.shape).astype(self.dtype)
self.scale = np.random.random([self.shape[1]]).astype(np.float32)
self.bias = np.random.random([self.shape[1]]).astype(np.float32)
def set_err_thre(self):
self.atol = 1e-3
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_float16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the float16",
)
class TestInstanceNormFP16OP(TestInstanceNormFP32OP):
def init_dtype(self):
self.dtype = np.float16
def set_err_thre(self):
self.atol = 0.03125
self.max_relative_error = 8e-3
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=self.atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Scale', 'Bias'],
'Y',
max_relative_error=self.max_relative_error,
)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestInstanceNormBF16OP(OpTest):
def setUp(self):
self.op_type = "instance_norm"
self.__class__.op_type = self.op_type
self.python_api = instance_norm_warpper
self.eps = 1e-5
self.data_format = "NCHW"
self.dtype = np.uint16
self.init_shape()
self.init_value()
y, mean, variance = _reference_instance_norm(
self.value, self.scale, self.bias, self.eps
)
var_inv = 1.0 / variance
self.user_defined_grads = _reference_instance_norm_grad(
self.value, self.scale, mean, var_inv
)
self.python_out_sig = ['Y']
self.outputs = {
'Y': convert_float_to_uint16(y),
'SavedMean': mean,
'SavedVariance': var_inv,
}
self.inputs = {
'X': convert_float_to_uint16(self.value),
'Scale': self.scale,
'Bias': self.bias,
}
self.attrs = {
'epsilon': self.eps,
'momentum': 0.9,
'data_format': self.data_format,
}
def init_value(self):
np.random.seed(0)
self.value = np.random.random(self.shape).astype(np.float32)
self.scale = np.random.random([self.shape[1]]).astype(np.float32)
self.bias = np.random.random([self.shape[1]]).astype(np.float32)
def init_shape(self):
self.shape = [4, 100, 4, 4]
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
['X', 'Scale', 'Bias'],
'Y',
user_defined_grads=self.user_defined_grads,
)
if __name__ == '__main__':
unittest.main()
......@@ -14,6 +14,7 @@
# For op in NO_FP64_CHECK_GRAD_OP_LIST, the op test requires check_grad with fp64 precision
NO_FP64_CHECK_GRAD_OP_LIST = [
'instance_norm',
'affine_grid',
'clip',
'conv2d',
......
......@@ -426,7 +426,10 @@ def instance_norm(
return out
else:
check_variable_and_dtype(
x, 'input', ['float32', 'float64'], "InstanceNorm"
x,
'input',
['float32', 'float64', 'float16', 'uint16'],
"InstanceNorm",
)
attrs = {
......
......@@ -306,7 +306,10 @@ def instance_norm(
hidden2 = paddle.static.nn.instance_norm(hidden1)
"""
check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'instance_norm'
input,
'input',
['uint16', 'float16', 'float32', 'float64'],
'instance_norm',
)
if param_attr is False:
assert (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册