diff --git a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu index 7ffb36d1129975f5c28bd41f12b044838ac83f17..a121f9fb95b0636324da841f2af4fc599fd9148f 100644 --- a/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/instance_norm_grad_kernel.cu @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/phi/kernels/instance_norm_grad_kernel.h" - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/core/kernel_registry.h" @@ -62,12 +61,12 @@ static __global__ void GradComputeDX(const T *dy, } __syncthreads(); for (int i = beg_idx; i < end_idx; i += BlockDim) { - dx[i] = + dx[i] = static_cast( (static_cast>(dy[i]) - dy_sum_val / static_cast>(sample_size) - (static_cast>(x[i]) - mean_val) * dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) * - scale[c] * inv_var_val; + scale[c] * inv_var_val); } } @@ -78,14 +77,14 @@ static __device__ __forceinline__ double real_sqrt(double x) { return 1. / sqrt(x); } -template +template __global__ void DoubleGradComputeDX(const T *x, - const T *mean, - const T *variance, + const AccT *mean, + const AccT *variance, const T *ddx, const T *dy, - const T *scale, - const T *ddscale, + const AccT *scale, + const AccT *ddscale, int C, int sample_size, const double epsilon, @@ -95,30 +94,30 @@ __global__ void DoubleGradComputeDX(const T *x, int ncid = blockIdx.x; int c = ncid % C; - T mean_val = mean[ncid]; - T var_val = variance[ncid]; + AccT mean_val = mean[ncid]; + AccT var_val = variance[ncid]; - typedef cub::BlockReduce BlockReduce; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage dy_storage; __shared__ typename BlockReduce::TempStorage ddx_storage; __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage; __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; - __shared__ T dy_sum_val; - __shared__ T ddx_sum_val; - __shared__ T dy_mul_ddx_sum_val; - __shared__ T dy_mul_x_sub_mean_sum_val; - __shared__ T ddx_mul_x_sub_mean_sum_val; - - T dy_sum = 0; - T ddx_sum = 0; - T dy_mul_ddx_sum = 0; - T dy_mul_x_sub_mean_sum = 0; - T ddx_mul_x_sub_mean_sum = 0; + __shared__ AccT dy_sum_val; + __shared__ AccT ddx_sum_val; + __shared__ AccT dy_mul_ddx_sum_val; + __shared__ AccT dy_mul_x_sub_mean_sum_val; + __shared__ AccT ddx_mul_x_sub_mean_sum_val; + + AccT dy_sum = 0; + AccT ddx_sum = 0; + AccT dy_mul_ddx_sum = 0; + AccT dy_mul_x_sub_mean_sum = 0; + AccT ddx_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { - T ddx_i = ddx[i]; - T dy_i = dy[i]; - T tmp = x[i] - mean_val; + AccT ddx_i = static_cast(ddx[i]); + AccT dy_i = static_cast(dy[i]); + AccT tmp = static_cast(x[i]) - mean_val; dy_sum += dy_i; ddx_sum += ddx_i; @@ -148,37 +147,44 @@ __global__ void DoubleGradComputeDX(const T *x, if (ddx != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { - dx[i] += - ((x[i] - mean_val) * var_val * var_val * var_val / sample_size * + AccT tmp = static_cast(dx[i]); + tmp += + ((static_cast(x[i]) - mean_val) * var_val * var_val * var_val / + sample_size * (ddx_sum_val * dy_sum_val / sample_size - dy_mul_ddx_sum_val + 3. * dy_mul_x_sub_mean_sum_val * var_val * ddx_mul_x_sub_mean_sum_val * var_val / sample_size) + ddx_mul_x_sub_mean_sum_val * var_val / sample_size * var_val * - var_val * (dy_sum_val / sample_size - dy[i]) + + var_val * (dy_sum_val / sample_size - static_cast(dy[i])) + dy_mul_x_sub_mean_sum_val * var_val / sample_size * var_val * - var_val * (ddx_sum_val / sample_size - ddx[i])) * + var_val * + (ddx_sum_val / sample_size - static_cast(ddx[i]))) * scale[c]; + dx[i] = static_cast(tmp); } } __syncthreads(); if (ddscale != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { - dx[i] += (dy[i] * var_val - dy_sum_val / sample_size * var_val - - (x[i] - mean_val) * var_val * dy_mul_x_sub_mean_sum_val * - var_val / sample_size) * - ddscale[c]; + AccT tmp = static_cast(dx[i]); + tmp += (static_cast(dy[i]) * var_val - + dy_sum_val / sample_size * var_val - + (static_cast(x[i]) - mean_val) * var_val * + dy_mul_x_sub_mean_sum_val * var_val / sample_size) * + ddscale[c]; + dx[i] = static_cast(tmp); } } } -template +template __global__ void DoubleGradComputeDDY(const T *x, - const T *mean, - const T *variance, - const T *ddscale, - const T *ddbias, + const AccT *mean, + const AccT *variance, + const AccT *ddscale, + const AccT *ddbias, const T *ddx, - const T *scale, + const AccT *scale, int C, int sample_size, const double epsilon, @@ -187,20 +193,20 @@ __global__ void DoubleGradComputeDDY(const T *x, int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; - T mean_val = mean[ncid]; - T var_val = variance[ncid]; - typedef cub::BlockReduce BlockReduce; + AccT mean_val = mean[ncid]; + AccT var_val = variance[ncid]; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage ddx_storage; __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; - __shared__ T ddx_sum_val; - __shared__ T ddx_mul_x_sub_mean_sum_val; + __shared__ AccT ddx_sum_val; + __shared__ AccT ddx_mul_x_sub_mean_sum_val; - T ddx_sum = 0; - T ddx_mul_x_sub_mean_sum = 0; + AccT ddx_sum = 0; + AccT ddx_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { - T ddx_i = ddx[i]; + AccT ddx_i = static_cast(ddx[i]); ddx_sum += ddx_i; - ddx_mul_x_sub_mean_sum += (ddx_i * (x[i] - mean_val)); + ddx_mul_x_sub_mean_sum += (ddx_i * (static_cast(x[i]) - mean_val)); } ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum()); ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage) @@ -212,55 +218,59 @@ __global__ void DoubleGradComputeDDY(const T *x, __syncthreads(); if (ddx != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { - ddy[i] += scale[c] * var_val * - (ddx[i] - ddx_sum_val / sample_size - - (x[i] - mean_val) * var_val * ddx_mul_x_sub_mean_sum_val * - var_val / sample_size); + AccT tmp = static_cast(ddy[i]); + tmp += scale[c] * var_val * + (static_cast(ddx[i]) - ddx_sum_val / sample_size - + (static_cast(x[i]) - mean_val) * var_val * + ddx_mul_x_sub_mean_sum_val * var_val / sample_size); + ddy[i] = static_cast(tmp); } } __syncthreads(); if (ddscale != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { - ddy[i] += (x[i] - mean_val) * var_val * ddscale[c]; + AccT tmp = static_cast(ddy[i]); + tmp += (static_cast(x[i]) - mean_val) * var_val * ddscale[c]; + ddy[i] = static_cast(tmp); } } __syncthreads(); if (ddbias != nullptr) { for (int i = beg_idx; i < end_idx; i += BlockDim) { - ddy[i] += ddbias[c]; + ddy[i] = static_cast(static_cast(ddy[i]) + ddbias[c]); } } } -template +template __global__ void DoubleGradComputeDScale(const T *x, - const T *mean, - const T *variance, + const AccT *mean, + const AccT *variance, const T *ddx, const T *dy, int C, int sample_size, const double epsilon, - T *dscale) { + AccT *dscale) { int beg_idx = blockIdx.x * sample_size + threadIdx.x; int end_idx = (blockIdx.x + 1) * sample_size; int ncid = blockIdx.x; int c = ncid % C; - T mean_val = mean[ncid]; - T var_val = variance[ncid]; - typedef cub::BlockReduce BlockReduce; + AccT mean_val = mean[ncid]; + AccT var_val = variance[ncid]; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage dy_storage; __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; __shared__ typename BlockReduce::TempStorage dscale_tmp_storage; - __shared__ T dy_sum_val; - __shared__ T dy_mul_x_sub_mean_sum_val; + __shared__ AccT dy_sum_val; + __shared__ AccT dy_mul_x_sub_mean_sum_val; - T dy_sum = 0; - T dy_mul_x_sub_mean_sum = 0; + AccT dy_sum = 0; + AccT dy_mul_x_sub_mean_sum = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { - T dy_i = dy[i]; + AccT dy_i = static_cast(dy[i]); dy_sum += dy_i; - dy_mul_x_sub_mean_sum += (dy_i * (x[i] - mean_val)); + dy_mul_x_sub_mean_sum += (dy_i * (static_cast(x[i]) - mean_val)); } dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage) @@ -272,12 +282,13 @@ __global__ void DoubleGradComputeDScale(const T *x, } __syncthreads(); if (ddx != nullptr) { - T dscale_tmp = 0; + AccT dscale_tmp = 0; for (int i = beg_idx; i < end_idx; i += BlockDim) { - dscale_tmp += ddx[i] * var_val * - (dy[i] - dy_sum_val / sample_size - - dy_mul_x_sub_mean_sum_val * (x[i] - mean_val) * var_val * - var_val / sample_size); + dscale_tmp += + static_cast(ddx[i]) * var_val * + (static_cast(dy[i]) - dy_sum_val / sample_size - + dy_mul_x_sub_mean_sum_val * (static_cast(x[i]) - mean_val) * + var_val * var_val / sample_size); } dscale_tmp = BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum()); if (threadIdx.x == 0) { @@ -298,6 +309,7 @@ void InstanceNormGradKernel(const Context &dev_ctx, DenseTensor *d_x, DenseTensor *d_scale, DenseTensor *d_bias) { + using AccT = typename phi::dtype::MPTypeTrait::Type; double epsilon = static_cast(epsilon_f); const auto *scale_ptr = scale.get_ptr(); @@ -313,8 +325,8 @@ void InstanceNormGradKernel(const Context &dev_ctx, dev_ctx.template Alloc(d_x); if (d_scale && d_bias) { - dev_ctx.template Alloc(d_scale); - dev_ctx.template Alloc(d_bias); + dev_ctx.template Alloc(d_scale); + dev_ctx.template Alloc(d_bias); } if (scale_ptr) { PADDLE_ENFORCE_EQ( @@ -339,7 +351,7 @@ void InstanceNormGradKernel(const Context &dev_ctx, scale_ptr->dims())); } - phi::funcs::SetConstant set_constant; + phi::funcs::SetConstant set_constant; const int n = x.numel(); const int block = 512; @@ -350,23 +362,21 @@ void InstanceNormGradKernel(const Context &dev_ctx, DenseTensor scale_tmp; scale_tmp.Resize({NxC}); - dev_ctx.template Alloc(&scale_tmp); + dev_ctx.template Alloc(&scale_tmp); DenseTensor d_scale_tmp; d_scale_tmp.Resize({NxC}); - dev_ctx.template Alloc(&d_scale_tmp); + dev_ctx.template Alloc(&d_scale_tmp); DenseTensor d_bias_tmp; d_bias_tmp.Resize({NxC}); - dev_ctx.template Alloc(&d_bias_tmp); - + dev_ctx.template Alloc(&d_bias_tmp); if (scale_ptr) { - repeat_param<<>>( - scale_ptr->data(), scale_tmp.data(), N, C); + repeat_param<<>>( + scale_ptr->data(), scale_tmp.data(), N, C); } else { - set_constant(dev_ctx, &scale_tmp, static_cast(1)); + set_constant(dev_ctx, &scale_tmp, static_cast(1)); } - std::vector dims; std::vector strides; dims = {1, NxC, H, W, D}; @@ -424,11 +434,11 @@ void InstanceNormGradKernel(const Context &dev_ctx, PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnDeriveBNTensorDescriptor( in_param_desc_, data_desc_, CUDNN_BATCHNORM_SPATIAL)); #endif - const auto *saved_mean_data = saved_mean.template data>(); const auto *saved_var_data = saved_variance.template data>(); + if (d_scale && d_bias) { #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenBatchNormalizationBackward( @@ -486,12 +496,11 @@ void InstanceNormGradKernel(const Context &dev_ctx, d_x->data()); } } - if (d_scale && d_bias) { - add_param<<>>( - d_scale_tmp.data(), d_scale->data(), N, C); - add_param<<>>( - d_bias_tmp.data(), d_bias->data(), N, C); + add_param<<>>( + d_scale_tmp.data(), d_scale->data(), N, C); + add_param<<>>( + d_bias_tmp.data(), d_bias->data(), N, C); } #ifdef PADDLE_WITH_HIP @@ -521,6 +530,7 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, DenseTensor *dx, DenseTensor *dscale, DenseTensor *ddy) { + using AccT = typename phi::dtype::MPTypeTrait::Type; const auto *Scale = scale.get_ptr(); const auto *ddX = ddx.get_ptr(); const auto *ddScale = ddscale.get_ptr(); @@ -529,11 +539,15 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, const T *x_data = x.data(); const T *dy_data = dy.data(); const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data()); - const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data()); - const T *ddbias_data = (ddScale == nullptr ? nullptr : ddBias->data()); - const T *mean_data = saved_mean.data(); - const T *variance_data = saved_variance.data(); + const AccT *ddscale_data = + (ddScale == nullptr ? nullptr : ddScale->data()); + const AccT *ddbias_data = + (ddScale == nullptr ? nullptr : ddBias->data()); + const AccT *mean_data = saved_mean.data(); + const AccT *variance_data = saved_variance.data(); phi::funcs::SetConstant set_zero; + phi::funcs::SetConstant set_zero_AccT; + auto &x_dims = x.dims(); int N, C, H, W, D; funcs::ExtractNCWHD(x_dims, DataLayout::kNCHW, &N, &C, &H, &W, &D); @@ -544,10 +558,10 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, DenseTensor scale_tmp; if (!Scale) { scale_tmp.Resize({C}); - dev_ctx.template Alloc(&scale_tmp); - set_zero(dev_ctx, &scale_tmp, static_cast(1)); + dev_ctx.template Alloc(&scale_tmp); + set_zero_AccT(dev_ctx, &scale_tmp, static_cast(1)); } - const T *scale_data = Scale ? Scale->data() : scale_tmp.data(); + const AccT *scale_data = Scale ? Scale->data() : scale_tmp.data(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); @@ -557,7 +571,7 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, if (dx) { T *dx_data = dev_ctx.template Alloc(dx); set_zero(dev_ctx, dx, static_cast(0)); - DoubleGradComputeDX + DoubleGradComputeDX <<>>(x_data, mean_data, variance_data, @@ -573,13 +587,13 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, if (dscale) { DenseTensor dscale_tmp; dscale_tmp.Resize({NxC}); - dev_ctx.template Alloc(&dscale_tmp); - set_zero(dev_ctx, &dscale_tmp, static_cast(0)); - T *dscale_tmp_data = dscale_tmp.data(); + dev_ctx.template Alloc(&dscale_tmp); + set_zero_AccT(dev_ctx, &dscale_tmp, static_cast(0)); + AccT *dscale_tmp_data = dscale_tmp.data(); - T *dscale_data = dev_ctx.template Alloc(dscale); - set_zero(dev_ctx, dscale, static_cast(0)); - DoubleGradComputeDScale + AccT *dscale_data = dev_ctx.template Alloc(dscale); + set_zero_AccT(dev_ctx, dscale, static_cast(0)); + DoubleGradComputeDScale <<>>(x_data, mean_data, variance_data, @@ -589,13 +603,13 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, sample_size, epsilon, dscale_tmp_data); - add_param<<>>( - dscale_tmp.data(), dscale->data(), N, C); + add_param<<>>( + dscale_tmp.data(), dscale->data(), N, C); } if (ddy) { T *ddy_data = dev_ctx.template Alloc(ddy); set_zero(dev_ctx, ddy, static_cast(0)); - DoubleGradComputeDDY + DoubleGradComputeDDY <<>>(x_data, mean_data, variance_data, @@ -613,24 +627,48 @@ void InstanceNormDoubleGradKernel(const Context &dev_ctx, #ifdef PADDLE_WITH_HIP // MIOPEN do not support double -PD_REGISTER_KERNEL( - instance_norm_grad, GPU, ALL_LAYOUT, phi::InstanceNormGradKernel, float) {} +PD_REGISTER_KERNEL(instance_norm_grad, + GPU, + ALL_LAYOUT, + phi::InstanceNormGradKernel, + float, + phi::dtype::float16) {} PD_REGISTER_KERNEL(instance_norm_double_grad, GPU, ALL_LAYOUT, phi::InstanceNormDoubleGradKernel, - float) {} + float, + phi::dtype::float16) {} +#elif CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(instance_norm_grad, + GPU, + ALL_LAYOUT, + phi::InstanceNormGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(instance_norm_double_grad, + GPU, + ALL_LAYOUT, + phi::InstanceNormDoubleGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} #else PD_REGISTER_KERNEL(instance_norm_grad, GPU, ALL_LAYOUT, phi::InstanceNormGradKernel, float, - double) {} + double, + phi::dtype::float16) {} PD_REGISTER_KERNEL(instance_norm_double_grad, GPU, ALL_LAYOUT, phi::InstanceNormDoubleGradKernel, float, - double) {} + double, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/gpu/instance_norm_kernel.cu b/paddle/phi/kernels/gpu/instance_norm_kernel.cu index b842ce61dc3eb55ccc1c0981f574375400fda8a8..d4f421e62ddb92004506809b5a6533a71d97957f 100644 --- a/paddle/phi/kernels/gpu/instance_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/instance_norm_kernel.cu @@ -33,6 +33,7 @@ void InstanceNormKernel(const Context &dev_ctx, DenseTensor *y, DenseTensor *saved_mean, DenseTensor *saved_variance) { + using AccT = typename phi::dtype::MPTypeTrait::Type; double epsilon = static_cast(epsilon_f); auto &x_dims = x.dims(); PADDLE_ENFORCE_GE(x_dims.size(), @@ -113,10 +114,10 @@ void InstanceNormKernel(const Context &dev_ctx, DenseTensor scale_tmp; scale_tmp.Resize({NxC}); - dev_ctx.template Alloc(&scale_tmp); + dev_ctx.template Alloc(&scale_tmp); DenseTensor bias_tmp; bias_tmp.Resize({NxC}); - dev_ctx.template Alloc(&bias_tmp); + dev_ctx.template Alloc(&bias_tmp); const int n = x.numel(); const int block = 512; @@ -124,24 +125,25 @@ void InstanceNormKernel(const Context &dev_ctx, const int max_blocks = std::max(max_threads / block, 1); const int grid = std::min((NxC + block - 1) / block, max_blocks); - phi::funcs::SetConstant set_constant; + phi::funcs::SetConstant set_constant; if (scale_ptr) { - repeat_param<<>>( - scale_ptr->data(), scale_tmp.data(), N, C); + repeat_param<<>>( + scale_ptr->data(), scale_tmp.data(), N, C); } else { - set_constant(dev_ctx, &scale_tmp, static_cast(1)); + set_constant(dev_ctx, &scale_tmp, static_cast(1)); } if (bias_ptr) { - repeat_param<<>>( - bias_ptr->data(), bias_tmp.data(), N, C); + repeat_param<<>>( + bias_ptr->data(), bias_tmp.data(), N, C); } else { - set_constant(dev_ctx, &bias_tmp, static_cast(0)); + set_constant(dev_ctx, &bias_tmp, static_cast(0)); } auto handle = dev_ctx.cudnn_handle(); DenseTensor saved_mean_tmp, saved_variance_tmp; phi::funcs::SetConstant> functor; + if (saved_mean) { dev_ctx.template Alloc>(saved_mean); functor(dev_ctx, saved_mean, static_cast>(0)); @@ -156,7 +158,6 @@ void InstanceNormKernel(const Context &dev_ctx, saved_variance_tmp = phi::Full>( dev_ctx, {NxC}, static_cast>(0)); } - auto *saved_mean_data = saved_mean ? saved_mean->data>() : saved_mean_tmp.data>(); @@ -225,9 +226,27 @@ void InstanceNormKernel(const Context &dev_ctx, #ifdef PADDLE_WITH_HIP // MIOPEN do not support double -PD_REGISTER_KERNEL( - instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float) {} +PD_REGISTER_KERNEL(instance_norm, + GPU, + ALL_LAYOUT, + phi::InstanceNormKernel, + float, + phi::dtype::float16) {} +#elif CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(instance_norm, + GPU, + ALL_LAYOUT, + phi::InstanceNormKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} #else -PD_REGISTER_KERNEL( - instance_norm, GPU, ALL_LAYOUT, phi::InstanceNormKernel, float, double) {} +PD_REGISTER_KERNEL(instance_norm, + GPU, + ALL_LAYOUT, + phi::InstanceNormKernel, + float, + double, + phi::dtype::float16) {} #endif diff --git a/paddle/phi/kernels/gpu/instance_norm_utils.h b/paddle/phi/kernels/gpu/instance_norm_utils.h index e52fe868c39ec5eeafa0e2493c3d85313f2d0898..865ab91da7b1b319df349e20170654915ab5ae47 100644 --- a/paddle/phi/kernels/gpu/instance_norm_utils.h +++ b/paddle/phi/kernels/gpu/instance_norm_utils.h @@ -27,6 +27,7 @@ namespace cub = hipcub; #endif #include "paddle/phi/backends/gpu/gpu_dnn.h" +#include "paddle/phi/common/amp_type_traits.h" namespace phi { @@ -51,22 +52,23 @@ static __global__ void add_param(const T *input, T *output, const int repeat_num, const int C) { - typedef cub::BlockReduce BlockReduce; + using MPType = typename phi::dtype::MPTypeTrait::Type; + typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage ou_storage; for (int i = blockIdx.x; i < C; i += gridDim.x) { - T ou = static_cast(0); + MPType ou = static_cast(0); for (int j = threadIdx.x; j < repeat_num; j += blockDim.x) { const int index = j * C + i; - ou += static_cast(input[index]); + ou = ou + static_cast(input[index]); } ou = BlockReduce(ou_storage).Reduce(ou, cub::Sum()); if (threadIdx.x == 0) { - output[i] = ou; + output[i] = static_cast(ou); } __syncthreads(); if (AVG) { - output[i] /= repeat_num; + output[i] = static_cast(static_cast(output[i]) / repeat_num); } } } diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py index 6dd462fda43fbd367829fd8a6291a481a6ac124f..d214965b2dd6e6a2f726c007ffa565c83ae86722 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op_v2.py @@ -15,6 +15,7 @@ import unittest import numpy as np +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid @@ -121,5 +122,202 @@ class TestInstanceNorm(unittest.TestCase): np.testing.assert_allclose(y1, y2, rtol=1e-05) +def instance_norm_warpper( + input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW' +): + if data_format == "AnyLayout": + data_format = "NCDHW" + return paddle._C_ops.instance_norm( + input, weight, bias, epsilon, momentum, data_format + ) + + +def _reference_instance_norm(x, scale, bias, epsilon): + N, C, H, W = x.shape + mean = np.mean(x, axis=(2, 3), keepdims=True) + variance = np.var(x, axis=(2, 3), keepdims=True) + std = np.sqrt(variance) + epsilon + x_norm = (x - mean) / std + scale = scale.reshape([1, C, 1, 1]) + bias = bias.reshape([1, C, 1, 1]) + x_norm = scale * x_norm + bias + return x_norm, mean.reshape(N * C), std.reshape(N * C) + + +def _reference_instance_norm_grad(x, scale, mean, var): + n, c, h, w = x.shape + d_y = np.ones(x.shape) / (np.prod(x.shape)) + d_bias = np.ones((c,)) / c + + mean_tile = np.reshape(mean, (n, c, 1, 1)) + mean_tile = np.tile(mean_tile, (1, 1, h, w)) + var_tile = np.reshape(var, (n, c, 1, 1)) + var_tile = np.tile(var_tile, (1, 1, h, w)) + + d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3)) + var_inv = var_tile + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + + d_x = ( + scale_tile + * var_inv + * ( + d_y + - np.mean(d_y, axis=(2, 3), keepdims=True) + - (x - mean_tile) + * var_inv + * np.mean( + d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True + ) + ) + ) + + return d_x, d_scale, d_bias + + +class TestInstanceNormFP32OP(OpTest): + def setUp(self): + '''Test instance_norm op with default value''' + self.op_type = "instance_norm" + self.__class__.op_type = self.op_type + self.python_api = instance_norm_warpper + self.data_format = "NCHW" + self.eps = 1e-5 + self.init_dtype() + self.init_shape() + self.init_value() + self.set_err_thre() + self.inputs = {'X': self.value, 'Scale': self.scale, 'Bias': self.bias} + self.attrs = { + 'epsilon': self.eps, + 'momentum': 0.9, + 'data_format': self.data_format, + } + y, mean, variance = _reference_instance_norm( + self.value, self.scale, self.bias, self.eps + ) + self.python_out_sig = ['Y'] + self.outputs = { + 'Y': y, + 'SavedMean': mean, + 'SavedVariance': 1.0 / variance, + } + + def test_check_output(self): + self.check_output(atol=self.atol) + + def test_check_grad(self): + self.check_grad( + ['X', 'Scale', 'Bias'], + 'Y', + ) + + def init_dtype(self): + self.dtype = np.float32 + + def init_shape(self): + self.shape = [4, 100, 4, 4] + + def init_value(self): + np.random.seed(0) + self.value = np.random.random(self.shape).astype(self.dtype) + self.scale = np.random.random([self.shape[1]]).astype(np.float32) + self.bias = np.random.random([self.shape[1]]).astype(np.float32) + + def set_err_thre(self): + self.atol = 1e-3 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_float16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the float16", +) +class TestInstanceNormFP16OP(TestInstanceNormFP32OP): + def init_dtype(self): + self.dtype = np.float16 + + def set_err_thre(self): + self.atol = 0.03125 + self.max_relative_error = 8e-3 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, atol=self.atol) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Scale', 'Bias'], + 'Y', + max_relative_error=self.max_relative_error, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestInstanceNormBF16OP(OpTest): + def setUp(self): + self.op_type = "instance_norm" + self.__class__.op_type = self.op_type + self.python_api = instance_norm_warpper + self.eps = 1e-5 + self.data_format = "NCHW" + self.dtype = np.uint16 + self.init_shape() + self.init_value() + + y, mean, variance = _reference_instance_norm( + self.value, self.scale, self.bias, self.eps + ) + var_inv = 1.0 / variance + self.user_defined_grads = _reference_instance_norm_grad( + self.value, self.scale, mean, var_inv + ) + self.python_out_sig = ['Y'] + self.outputs = { + 'Y': convert_float_to_uint16(y), + 'SavedMean': mean, + 'SavedVariance': var_inv, + } + self.inputs = { + 'X': convert_float_to_uint16(self.value), + 'Scale': self.scale, + 'Bias': self.bias, + } + self.attrs = { + 'epsilon': self.eps, + 'momentum': 0.9, + 'data_format': self.data_format, + } + + def init_value(self): + np.random.seed(0) + self.value = np.random.random(self.shape).astype(np.float32) + self.scale = np.random.random([self.shape[1]]).astype(np.float32) + self.bias = np.random.random([self.shape[1]]).astype(np.float32) + + def init_shape(self): + self.shape = [4, 100, 4, 4] + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Scale', 'Bias'], + 'Y', + user_defined_grads=self.user_defined_grads, + ) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py index ced30722cf2792259265c79f9982a6ff3ac0fc8e..d7613f7b284e81eefd762d4ca57429bd7cc08599 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_accuracy_white_list.py @@ -14,6 +14,7 @@ # For op in NO_FP64_CHECK_GRAD_OP_LIST, the op test requires check_grad with fp64 precision NO_FP64_CHECK_GRAD_OP_LIST = [ + 'instance_norm', 'affine_grid', 'clip', 'conv2d', diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 2f71f137f43b272353494f72b0e53abff4a92cfc..95e1ca2504cd976354725f51da6456cbf3a7a48a 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -426,7 +426,10 @@ def instance_norm( return out else: check_variable_and_dtype( - x, 'input', ['float32', 'float64'], "InstanceNorm" + x, + 'input', + ['float32', 'float64', 'float16', 'uint16'], + "InstanceNorm", ) attrs = { diff --git a/python/paddle/static/nn/common.py b/python/paddle/static/nn/common.py index 2e0d3b2289c29082fb69f6b7aae07c38d540e593..a3ae723c6273c78077de30be8a0bc097af0435c0 100644 --- a/python/paddle/static/nn/common.py +++ b/python/paddle/static/nn/common.py @@ -306,7 +306,10 @@ def instance_norm( hidden2 = paddle.static.nn.instance_norm(hidden1) """ check_variable_and_dtype( - input, 'input', ['float32', 'float64'], 'instance_norm' + input, + 'input', + ['uint16', 'float16', 'float32', 'float64'], + 'instance_norm', ) if param_attr is False: assert (