diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index bb76904bff0154772958496b6608f9230ff918fc..9a1d724c73962e37f71102afd65c49bbc14088cb 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -496,6 +496,21 @@ class BatchNormGradKernel int scale_coefff = use_global_stats ? 1 : N * sample_size; const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff; + Tensor dy_sum; + dy_sum.Resize({C}); + dy_sum.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap dy_sum_arr(dy_sum.mutable_data(ctx.GetPlace()), + C); + + Tensor dy_mul_x_sub_mean_mul_invstd_sum; + dy_mul_x_sub_mean_mul_invstd_sum.Resize({C}); + dy_mul_x_sub_mean_mul_invstd_sum.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap dy_mul_x_sub_mean_mul_invstd_sum_arr( + dy_mul_x_sub_mean_mul_invstd_sum.mutable_data(ctx.GetPlace()), C); + + dy_sum_arr.setZero(); + dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero(); + switch (data_layout) { case DataLayout::kNCHW: { ConstEigenArrayMap x_arr(x->data(), sample_size, N * C); @@ -504,23 +519,27 @@ class BatchNormGradKernel sample_size, N * C); d_x_arr.setZero(); + for (int nc = 0; nc < N * C; ++nc) { + int c = nc % C; + dy_sum_arr(c) += d_y_arr.col(nc).sum(); + dy_mul_x_sub_mean_mul_invstd_sum_arr(c) += + ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc)) + .sum(); + } + if (d_scale && d_bias) { - for (int nc = 0; nc < N * C; ++nc) { - int c = nc % C; - d_bias_arr(c) += d_y_arr.col(nc).sum(); - d_scale_arr(c) += ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * - d_y_arr.col(nc)) - .sum(); - } + d_bias_arr = dy_sum_arr; + d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; } + if (!use_global_stats) { for (int nc = 0; nc < N * C; ++nc) { int c = nc % C; d_x_arr.col(nc) += scale_inv_var_nhw(c) * - (d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) - - (x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * - inv_var_arr(c)); + (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - + (x_arr.col(nc) - mean_arr[c]) * + dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c)); } } else { for (int nc = 0; nc < N * C; ++nc) { @@ -537,27 +556,24 @@ class BatchNormGradKernel N * sample_size); d_x_arr.setZero(); - const auto d_y_row_sum = d_y_arr.rowwise().sum(); - const auto x_minus_mean = x_arr.colwise() - mean_arr; - const auto d_y_mul_x_minus_mean_row_sum = - (d_y_arr * x_minus_mean).rowwise().sum(); - const auto inv_var_sqr = inv_var_arr * inv_var_arr; + for (int nhw = 0; nhw < N * sample_size; ++nhw) { + dy_sum_arr += d_y_arr.col(nhw); + dy_mul_x_sub_mean_mul_invstd_sum_arr += + (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw); + } if (d_scale && d_bias) { - for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_bias_arr += d_y_arr.col(nhw); - d_scale_arr += - (x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw); - } + d_bias_arr = dy_sum_arr; + d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr; } if (!use_global_stats) { for (int nhw = 0; nhw < N * sample_size; ++nhw) { d_x_arr.col(nhw) += scale_inv_var_nhw * - (d_y_arr.col(nhw) * N * sample_size - d_y_row_sum - - x_minus_mean.col(nhw) * inv_var_sqr * - d_y_mul_x_minus_mean_row_sum); + (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr - + (x_arr.col(nhw) - mean_arr) * + dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr); } } else { for (int nhw = 0; nhw < N * sample_size; ++nhw) { diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 49ff7069ba075fa156fa2f875684d0786af8e82b..95d7f23b2c0ac6e46cf85bef4340eb4180dc3dba 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -234,6 +234,63 @@ static __global__ void KeBNBackwardData(const T *dy, } } +template +static __global__ void BNBackwardData(const T *dy, + const BatchNormParamType *scale, + const BatchNormParamType *mean, + const T *x, + const BatchNormParamType *variance, + const int C, const int N, const int HxW, + T *dx) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage dy_storage; + __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage; + __shared__ BatchNormParamType dy_sum_val; + __shared__ BatchNormParamType dy_x_sub_mean_sum_val; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType inv_var_i = variance[i]; + BatchNormParamType mean_i = mean[i]; + BatchNormParamType dy_sum = static_cast>(0); + BatchNormParamType dy_x_sub_mean_sum = + static_cast>(0); + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == framework::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + BatchNormParamType dy_i = + static_cast>(dy[index]); + dy_sum += dy_i; + dy_x_sub_mean_sum += + dy_i * (static_cast>(x[index]) - mean_i); + } + + dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); + dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage) + .Reduce(dy_x_sub_mean_sum, cub::Sum()); + + if (threadIdx.x == 0) { + dy_sum_val = dy_sum; + dy_x_sub_mean_sum_val = dy_x_sub_mean_sum; + } + __syncthreads(); + + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == framework::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + dx[index] = + (static_cast>(dy[index]) - + dy_sum_val / static_cast>(inner_size) - + (static_cast>(x[index]) - mean_i) * + dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) * + scale[i] * inv_var_i; + } + } +} + template class BatchNormGradKernel : public framework::OpKernel { @@ -282,6 +339,13 @@ class BatchNormGradKernel } auto &dev_ctx = ctx.template device_context(); + const int num = x->numel(); + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid1 = (num + block - 1) / block; + int grid2 = std::min(C, max_blocks); + if (!use_global_stats) { if ((N * H * W * D) == 1) { framework::TensorCopy(*d_y, ctx.GetPlace(), d_x); @@ -325,21 +389,43 @@ class BatchNormGradKernel const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); - const void *saved_mean_data = + const auto *saved_mean_data = saved_mean->template data>(); - const void *saved_var_data = + const auto *saved_var_data = saved_var->template data>(); - CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( - dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), - CudnnDataType::kZero(), CudnnDataType::kOne(), - CudnnDataType::kZero(), data_desc_, x->template data(), - data_desc_, d_y->template data(), data_desc_, - d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - d_scale->template mutable_data>(ctx.GetPlace()), - d_bias->template mutable_data>(ctx.GetPlace()), - epsilon, saved_mean_data, saved_var_data)); + if (d_scale && d_bias) { + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), + CudnnDataType::kZero(), CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, x->template data(), + data_desc_, d_y->template data(), data_desc_, + d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, + scale->template data>(), + d_scale->template mutable_data>( + ctx.GetPlace()), + d_bias->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); + } else { + if (data_layout == framework::DataLayout::kNCHW) { + if (d_x) { + BNBackwardData<<< + grid2, block, 0, dev_ctx.stream()>>>( + d_y->data(), scale->data>(), + saved_mean_data, x->data(), saved_var_data, C, N, H * W * D, + d_x->data()); + } + } else { + if (d_x) { + BNBackwardData<<< + grid2, block, 0, dev_ctx.stream()>>>( + d_y->data(), scale->data>(), + saved_mean_data, x->data(), saved_var_data, C, N, H * W * D, + d_x->data()); + } + } + } // clean when exit. CUDNN_ENFORCE( @@ -355,13 +441,6 @@ class BatchNormGradKernel const auto *running_var_data = running_var->template data>(); - const int num = x->numel(); - const int block = 512; - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - int grid1 = (num + block - 1) / block; - int grid2 = std::min(C, max_blocks); - if (data_layout == framework::DataLayout::kNCHW) { if (d_x) { KeBNBackwardData<<< diff --git a/paddle/fluid/operators/instance_norm_op.cu b/paddle/fluid/operators/instance_norm_op.cu index 20954342371fa6ecace76fdfc5726638ab9ce78e..3f0799fbdbd40c29a6098a4ffcd93b4bd31fb70f 100644 --- a/paddle/fluid/operators/instance_norm_op.cu +++ b/paddle/fluid/operators/instance_norm_op.cu @@ -170,6 +170,58 @@ class InstanceNormKernel } }; +template +static __global__ void GradComputeDX(const T *dy, + const BatchNormParamType *scale, + const BatchNormParamType *mean, + const T *x, + const BatchNormParamType *variance, + const int C, const int sample_size, + T *dx) { + int beg_idx = blockIdx.x * sample_size + threadIdx.x; + int end_idx = (blockIdx.x + 1) * sample_size; + int ncid = blockIdx.x; + int c = ncid % C; + + BatchNormParamType mean_val = mean[ncid]; + BatchNormParamType inv_var_val = variance[ncid]; + + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage dy_storage; + __shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage; + __shared__ BatchNormParamType dy_sum_val; + __shared__ BatchNormParamType dy_x_sub_mean_sum_val; + + BatchNormParamType dy_sum = static_cast>(0); + BatchNormParamType dy_x_sub_mean_sum = + static_cast>(0); + + for (int i = beg_idx; i < end_idx; i += BlockDim) { + BatchNormParamType dy_i = static_cast>(dy[i]); + dy_sum += dy_i; + dy_x_sub_mean_sum += + dy_i * (static_cast>(x[i]) - mean_val); + } + dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); + dy_x_sub_mean_sum = + BlockReduce(dy_x_sub_mean_storage).Reduce(dy_x_sub_mean_sum, cub::Sum()); + + if (threadIdx.x == 0) { + dy_sum_val = dy_sum; + dy_x_sub_mean_sum_val = dy_x_sub_mean_sum; + } + __syncthreads(); + + for (int i = beg_idx; i < end_idx; i += BlockDim) { + dx[i] = + (static_cast>(dy[i]) - + dy_sum_val / static_cast>(sample_size) - + (static_cast>(x[i]) - mean_val) * + dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) * + scale[c] * inv_var_val; + } +} + template class InstanceNormGradKernel : public framework::OpKernel { @@ -258,21 +310,31 @@ class InstanceNormGradKernel const auto *saved_mean = ctx.Input("SavedMean"); const auto *saved_var = ctx.Input("SavedVariance"); - const void *saved_mean_data = + const auto *saved_mean_data = saved_mean->template data>(); - const void *saved_var_data = + const auto *saved_var_data = saved_var->template data>(); - CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( - dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, - CudnnDataType::kOne(), CudnnDataType::kZero(), - CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, - x_tmp.template data(), data_desc_, d_y_tmp.template data(), - data_desc_, d_x->template mutable_data(ctx.GetPlace()), - in_param_desc_, scale_tmp.template data>(), - d_scale_tmp.template mutable_data>( - ctx.GetPlace()), - d_bias_tmp.template mutable_data>(ctx.GetPlace()), - epsilon, saved_mean_data, saved_var_data)); + if (d_scale && d_bias) { + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL, + CudnnDataType::kOne(), CudnnDataType::kZero(), + CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, + x_tmp.template data(), data_desc_, d_y_tmp.template data(), + data_desc_, d_x->template mutable_data(ctx.GetPlace()), + in_param_desc_, scale_tmp.template data>(), + d_scale_tmp.template mutable_data>( + ctx.GetPlace()), + d_bias_tmp.template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); + } else { + if (d_x) { + GradComputeDX<<>>( + d_y->data(), scale->data>(), + saved_mean_data, x->data(), saved_var_data, C, H * W * D, + d_x->data()); + } + } if (d_scale && d_bias) { add_param<<>>( diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 2869a6ba53bfb9120ae68d67d10eb5080be5f07b..ec96e5f79ca39998dad8d2222cecd573d477ce5b 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -338,14 +338,14 @@ class TestBatchNormOpTraining(unittest.TestCase): return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad def set_mean_variance(self, scale_shape, x, data_layout): - mean = np.zeros(scale_shape).astype(np.float32) - variance = np.ones(scale_shape).astype(np.float32) + mean, variance = _cal_mean_variance(x, self.epsilon, data_layout) + mean_pre = np.zeros(scale_shape).astype(np.float32) + variance_pre = np.ones(scale_shape).astype(np.float32) # computing global mean/variance for one step if self.use_global_stats: mom = self.momentum - x_mean, x_var = _cal_mean_variance(x, self.epsilon, data_layout) - mean = x_mean * (1. - mom) + mom * mean - variance = x_var * (1. - mom) + mom * variance + mean = mean * (1. - mom) + mom * mean_pre + variance = variance * (1. - mom) + mom * variance_pre return mean, variance def test_forward_backward(self): @@ -442,6 +442,10 @@ class TestBatchNormOpTraining(unittest.TestCase): fetch_list=self.fetch_list) for id, name in enumerate(self.fetch_list): + if name == 'variance': + self.__assert_close( + var_dict[name], out[id], name, atol=1e-3) + continue self.__assert_close(var_dict[name], out[id], name) print("op test forward passed: ", str(place), data_layout) @@ -458,6 +462,13 @@ class TestBatchNormOpTraining(unittest.TestCase): pass +class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining): + def init_test_case(self): + self.use_global_stats = False + self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) + self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD'] + + class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): def init_test_case(self): self.use_global_stats = True diff --git a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py index 7c2a4d212faad42b3bb41490732a7ad1d082e302..ccdf12849c7ee61fad5916aea0403b760a0302db 100644 --- a/python/paddle/fluid/tests/unittests/test_instance_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_instance_norm_op.py @@ -184,5 +184,12 @@ class TestInstanceNormOpTraining(unittest.TestCase): test_with_place(place, [2, 3, 4, 5]) +class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining): + def init_test_case(self): + self.use_global_stats = False + self.no_grad_set = set(['scale@GRAD', 'bias@GRAD']) + self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD'] + + if __name__ == '__main__': unittest.main()