未验证 提交 619a241b 编写于 作者: L lvmengsi 提交者: GitHub

Fix OpTest of bn (#19062)

* fix bn
上级 5920d69d
......@@ -496,6 +496,21 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
int scale_coefff = use_global_stats ? 1 : N * sample_size;
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / scale_coefff;
Tensor dy_sum;
dy_sum.Resize({C});
dy_sum.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dy_sum_arr(dy_sum.mutable_data<T>(ctx.GetPlace()),
C);
Tensor dy_mul_x_sub_mean_mul_invstd_sum;
dy_mul_x_sub_mean_mul_invstd_sum.Resize({C});
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace());
EigenVectorArrayMap<T> dy_mul_x_sub_mean_mul_invstd_sum_arr(
dy_mul_x_sub_mean_mul_invstd_sum.mutable_data<T>(ctx.GetPlace()), C);
dy_sum_arr.setZero();
dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero();
switch (data_layout) {
case DataLayout::kNCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
......@@ -504,23 +519,27 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
sample_size, N * C);
d_x_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
dy_sum_arr(c) += d_y_arr.col(nc).sum();
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) +=
((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc))
.sum();
}
if (d_scale && d_bias) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_bias_arr(c) += d_y_arr.col(nc).sum();
d_scale_arr(c) += ((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) *
d_y_arr.col(nc))
.sum();
}
d_bias_arr = dy_sum_arr;
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
}
if (!use_global_stats) {
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) +=
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) *
inv_var_arr(c));
(d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) -
(x_arr.col(nc) - mean_arr[c]) *
dy_mul_x_sub_mean_mul_invstd_sum_arr(c) * inv_var_arr(c));
}
} else {
for (int nc = 0; nc < N * C; ++nc) {
......@@ -537,27 +556,24 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
N * sample_size);
d_x_arr.setZero();
const auto d_y_row_sum = d_y_arr.rowwise().sum();
const auto x_minus_mean = x_arr.colwise() - mean_arr;
const auto d_y_mul_x_minus_mean_row_sum =
(d_y_arr * x_minus_mean).rowwise().sum();
const auto inv_var_sqr = inv_var_arr * inv_var_arr;
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
dy_sum_arr += d_y_arr.col(nhw);
dy_mul_x_sub_mean_mul_invstd_sum_arr +=
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
}
if (d_scale && d_bias) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_bias_arr += d_y_arr.col(nhw);
d_scale_arr +=
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
}
d_bias_arr = dy_sum_arr;
d_scale_arr = dy_mul_x_sub_mean_mul_invstd_sum_arr;
}
if (!use_global_stats) {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_x_arr.col(nhw) +=
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum -
x_minus_mean.col(nhw) * inv_var_sqr *
d_y_mul_x_minus_mean_row_sum);
(d_y_arr.col(nhw) * N * sample_size - dy_sum_arr -
(x_arr.col(nhw) - mean_arr) *
dy_mul_x_sub_mean_mul_invstd_sum_arr * inv_var_arr);
}
} else {
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
......
......@@ -234,6 +234,63 @@ static __global__ void KeBNBackwardData(const T *dy,
}
}
template <typename T, int BlockDim, framework::DataLayout layout>
static __global__ void BNBackwardData(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean,
const T *x,
const BatchNormParamType<T> *variance,
const int C, const int N, const int HxW,
T *dx) {
const int outer_size = C;
const int inner_size = N * HxW;
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
__shared__ BatchNormParamType<T> dy_sum_val;
__shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;
for (int i = blockIdx.x; i < outer_size; i += gridDim.x) {
BatchNormParamType<T> inv_var_i = variance[i];
BatchNormParamType<T> mean_i = mean[i];
BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> dy_x_sub_mean_sum =
static_cast<BatchNormParamType<T>>(0);
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
BatchNormParamType<T> dy_i =
static_cast<BatchNormParamType<T>>(dy[index]);
dy_sum += dy_i;
dy_x_sub_mean_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[index]) - mean_i);
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
dy_x_sub_mean_sum = BlockReduce(dy_x_sub_mean_storage)
.Reduce(dy_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
}
__syncthreads();
for (int j = threadIdx.x; j < inner_size; j += blockDim.x) {
const int index = layout == framework::DataLayout::kNCHW
? (j / HxW * C + i) * HxW + j % HxW
: j * outer_size + i;
dx[index] =
(static_cast<BatchNormParamType<T>>(dy[index]) -
dy_sum_val / static_cast<BatchNormParamType<T>>(inner_size) -
(static_cast<BatchNormParamType<T>>(x[index]) - mean_i) *
dy_x_sub_mean_sum_val * inv_var_i * inv_var_i / inner_size) *
scale[i] * inv_var_i;
}
}
}
template <typename T>
class BatchNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -282,6 +339,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
const int num = x->numel();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (!use_global_stats) {
if ((N * H * W * D) == 1) {
framework::TensorCopy(*d_y, ctx.GetPlace(), d_x);
......@@ -325,21 +389,43 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
const auto *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
const auto *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
if (d_scale && d_bias) {
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, d_y->template data<T>(), data_desc_,
d_x->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
} else {
if (data_layout == framework::DataLayout::kNCHW) {
if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>());
}
} else {
if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
grid2, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, N, H * W * D,
d_x->data<T>());
}
}
}
// clean when exit.
CUDNN_ENFORCE(
......@@ -355,13 +441,6 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *running_var_data =
running_var->template data<BatchNormParamType<T>>();
const int num = x->numel();
const int block = 512;
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(max_threads / block, 1);
int grid1 = (num + block - 1) / block;
int grid2 = std::min(C, max_blocks);
if (data_layout == framework::DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, framework::DataLayout::kNCHW><<<
......
......@@ -170,6 +170,58 @@ class InstanceNormKernel<platform::CUDADeviceContext, T>
}
};
template <typename T, int BlockDim>
static __global__ void GradComputeDX(const T *dy,
const BatchNormParamType<T> *scale,
const BatchNormParamType<T> *mean,
const T *x,
const BatchNormParamType<T> *variance,
const int C, const int sample_size,
T *dx) {
int beg_idx = blockIdx.x * sample_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * sample_size;
int ncid = blockIdx.x;
int c = ncid % C;
BatchNormParamType<T> mean_val = mean[ncid];
BatchNormParamType<T> inv_var_val = variance[ncid];
typedef cub::BlockReduce<BatchNormParamType<T>, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage dy_storage;
__shared__ typename BlockReduce::TempStorage dy_x_sub_mean_storage;
__shared__ BatchNormParamType<T> dy_sum_val;
__shared__ BatchNormParamType<T> dy_x_sub_mean_sum_val;
BatchNormParamType<T> dy_sum = static_cast<BatchNormParamType<T>>(0);
BatchNormParamType<T> dy_x_sub_mean_sum =
static_cast<BatchNormParamType<T>>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
BatchNormParamType<T> dy_i = static_cast<BatchNormParamType<T>>(dy[i]);
dy_sum += dy_i;
dy_x_sub_mean_sum +=
dy_i * (static_cast<BatchNormParamType<T>>(x[i]) - mean_val);
}
dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum());
dy_x_sub_mean_sum =
BlockReduce(dy_x_sub_mean_storage).Reduce(dy_x_sub_mean_sum, cub::Sum());
if (threadIdx.x == 0) {
dy_sum_val = dy_sum;
dy_x_sub_mean_sum_val = dy_x_sub_mean_sum;
}
__syncthreads();
for (int i = beg_idx; i < end_idx; i += BlockDim) {
dx[i] =
(static_cast<BatchNormParamType<T>>(dy[i]) -
dy_sum_val / static_cast<BatchNormParamType<T>>(sample_size) -
(static_cast<BatchNormParamType<T>>(x[i]) - mean_val) *
dy_x_sub_mean_sum_val * inv_var_val * inv_var_val / sample_size) *
scale[c] * inv_var_val;
}
}
template <typename T>
class InstanceNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -258,21 +310,31 @@ class InstanceNormGradKernel<platform::CUDADeviceContext, T>
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
const void *saved_mean_data =
const auto *saved_mean_data =
saved_mean->template data<BatchNormParamType<T>>();
const void *saved_var_data =
const auto *saved_var_data =
saved_var->template data<BatchNormParamType<T>>();
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x_tmp.template data<T>(), data_desc_, d_y_tmp.template data<T>(),
data_desc_, d_x->template mutable_data<T>(ctx.GetPlace()),
in_param_desc_, scale_tmp.template data<BatchNormParamType<T>>(),
d_scale_tmp.template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias_tmp.template mutable_data<BatchNormParamType<T>>(ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
if (d_scale && d_bias) {
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
dev_ctx.cudnn_handle(), CUDNN_BATCHNORM_SPATIAL,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), data_desc_,
x_tmp.template data<T>(), data_desc_, d_y_tmp.template data<T>(),
data_desc_, d_x->template mutable_data<T>(ctx.GetPlace()),
in_param_desc_, scale_tmp.template data<BatchNormParamType<T>>(),
d_scale_tmp.template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
d_bias_tmp.template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, saved_mean_data, saved_var_data));
} else {
if (d_x) {
GradComputeDX<T, block><<<grid, block, 0, dev_ctx.stream()>>>(
d_y->data<T>(), scale->data<BatchNormParamType<T>>(),
saved_mean_data, x->data<T>(), saved_var_data, C, H * W * D,
d_x->data<T>());
}
}
if (d_scale && d_bias) {
add_param<T, block, false><<<grid1, block, 0, dev_ctx.stream()>>>(
......
......@@ -338,14 +338,14 @@ class TestBatchNormOpTraining(unittest.TestCase):
return y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad
def set_mean_variance(self, scale_shape, x, data_layout):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
mean, variance = _cal_mean_variance(x, self.epsilon, data_layout)
mean_pre = np.zeros(scale_shape).astype(np.float32)
variance_pre = np.ones(scale_shape).astype(np.float32)
# computing global mean/variance for one step
if self.use_global_stats:
mom = self.momentum
x_mean, x_var = _cal_mean_variance(x, self.epsilon, data_layout)
mean = x_mean * (1. - mom) + mom * mean
variance = x_var * (1. - mom) + mom * variance
mean = mean * (1. - mom) + mom * mean_pre
variance = variance * (1. - mom) + mom * variance_pre
return mean, variance
def test_forward_backward(self):
......@@ -442,6 +442,10 @@ class TestBatchNormOpTraining(unittest.TestCase):
fetch_list=self.fetch_list)
for id, name in enumerate(self.fetch_list):
if name == 'variance':
self.__assert_close(
var_dict[name], out[id], name, atol=1e-3)
continue
self.__assert_close(var_dict[name], out[id], name)
print("op test forward passed: ", str(place), data_layout)
......@@ -458,6 +462,13 @@ class TestBatchNormOpTraining(unittest.TestCase):
pass
class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = True
......
......@@ -184,5 +184,12 @@ class TestInstanceNormOpTraining(unittest.TestCase):
test_with_place(place, [2, 3, 4, 5])
class TestInstanceNormOpTrainingCase1(TestInstanceNormOpTraining):
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set(['scale@GRAD', 'bias@GRAD'])
self.fetch_list = ['y', 'saved_mean', 'saved_variance', 'x@GRAD']
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册