未验证 提交 8daccc9e 编写于 作者: C ceci3 提交者: GitHub

Fix batch norm double grad compute (#27549)

* fix bn double grad, test=develop

* update, test=develop
上级 c143326d
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/update_ps_heter github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chenwhql/saveload/remove_save_load_config github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/iclementine/rnn_fix github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/mapingshuo/doc_2.0 github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0925 github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wzzju/fix_err_info github/fork/xiemoyuan/op_error_message github/fork/yaoxuefeng6/fix_doc github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/is_compile_with_cuda github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhiqiu/dev/refine_initializer improve_sccache incubate/infrt inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/2.0 release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0
无相关合并请求
......@@ -839,6 +839,7 @@ void BatchNormDoubleGradMaker<T>::Apply(GradOpPtr<T> op) const {
op->SetInput("SavedMean", this->Input("SavedMean"));
op->SetInput("SavedVariance", this->Input("SavedVariance"));
if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) {
op->SetInput("Mean", this->Input("Mean"));
op->SetInput("Variance", this->Input("Variance"));
}
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
......@@ -868,14 +869,19 @@ void BatchNormDoubleGradOp::InferShape(
"BatchNormDoubleGrad");
}
OP_INOUT_CHECK(ctx->HasInput("DDX"), "Input", "DDX", "BatchNormDoubleGrad");
OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad");
// check output
OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad");
const auto x_dims = ctx->GetInputDim("X");
const int C = x_dims[1];
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
if (ctx->HasOutput("DX")) {
ctx->SetOutputDim("DX", x_dims);
}
......@@ -957,7 +963,9 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
Tensor inv_var_tensor;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_variance = ctx.Input<Tensor>("Variance");
mean_data = running_mean->data<T>();
inv_var_tensor.Resize({C});
T *running_inv_var_data = inv_var_tensor.mutable_data<T>(ctx.GetPlace());
......@@ -1077,12 +1085,12 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w))))
// axis=(n,h,w)))
if (ddX) {
dx_arr +=
......@@ -1176,7 +1184,8 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
C, sample_size);
ddy_arr.setZero();
if (use_global_stats) {
// math: ddy = r * ddx * inv_var
// math: ddy = r * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
if (ddX) {
ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data;
}
......@@ -1196,25 +1205,29 @@ class BatchNormDoubleGradKernel<platform::CPUDeviceContext, T>
.replicate(1, sample_size) /
sample_size);
}
if (ddScale && ddBias) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
}
if (ddScale) {
ConstEigenVectorArrayMap<T> ddscale_arr(ddScale->data<T>(), C);
Tensor ddscale_tile;
ddscale_tile.Resize({C, sample_size});
EigenArrayMap<T> ddscale_tile_data(
ddscale_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddscale_tile_data = ddscale_arr.replicate(1, sample_size);
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
}
ConstEigenVectorArrayMap<T> ddbias_arr(ddBias->data<T>(), C);
Tensor ddbias_tile;
ddbias_tile.Resize({C, sample_size});
EigenArrayMap<T> ddbias_tile_data(
ddbias_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
if (ddBias) {
ConstEigenVectorArrayMap<T> ddbias_arr(ddBias->data<T>(), C);
Tensor ddbias_tile;
ddbias_tile.Resize({C, sample_size});
EigenArrayMap<T> ddbias_tile_data(
ddbias_tile.mutable_data<T>(ctx.GetPlace()), C, sample_size);
ddbias_tile_data = ddbias_arr.replicate(1, sample_size);
ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data;
ddy_arr += ddbias_tile_data;
}
ddy_arr += ddbias_tile_data;
}
if (data_layout == DataLayout::kNCHW) {
VLOG(3) << "Transform batchnorm output from NHWC to NCHW";
TransToChannelFirst<paddle::platform::CPUDeviceContext, T>(
......
......@@ -520,11 +520,11 @@ class InstanceNormDoubleGradKernel<platform::CPUDeviceContext, T>
// (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW *
// np.sum(dy,
// axis=(h,w)) * (x - mean) *
// (np.mean(ddx, axis=(h,w)) - ddx) + ddr * (dy * inv_var - inv_var
// *
// (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var *
// np.mean(dy, axis=(h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(h,w))))
// axis=(h,w)))
Tensor x_sub_mean_mul_invstd;
x_sub_mean_mul_invstd.Resize({sample_size, NxC});
......
......@@ -40,12 +40,12 @@ using DataLayout = framework::DataLayout;
// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW *
// np.sum(dy,
// axis=(n,h,w)) * (x - mean) *
// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var -
// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var -
// inv_var
// *
// np.mean(dy, axis=(n,h,w)) -
// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean),
// axis=(n,h,w))))
// axis=(n,h,w)))
template <typename T, int BlockDim, framework::DataLayout layout>
__global__ void DoubleGradComputeDX(const T *x, const T *mean,
......@@ -138,7 +138,7 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean,
? (j / sample_size * C + i) * sample_size + j % sample_size
: j * outer_size + i;
dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val -
(x[index] - mean_val) * var_val *
(x[index] - mean_val) * var_val * var_val *
dy_mul_x_sub_mean_sum_val * var_val / inner_size) *
ddscale[i];
}
......@@ -326,19 +326,57 @@ __global__ void DoubleGradComputeDScaleWithGlobal(
}
// math: dx = ddscale * dy * inv_var
// math: ddy = scale * ddx * inv_var
template <typename T, framework::DataLayout layout>
__global__ void DoubleGradComputeDataWithGlobal(
const T *dy, const T *scale, const T *variance, const double epsilon,
const int C, const int sample_size, const int num, T *dx) {
__global__ void DoubleGradComputeDXWithGlobal(const T *dy, const T *ddscale,
const T *variance,
const double epsilon, const int C,
const int sample_size,
const int num, T *dx) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (scale != nullptr) {
if (ddscale != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
dx[i] = dy[i] * scale[c] * inv_var;
dx[i] = dy[i] * ddscale[c] * inv_var;
}
}
}
// math: ddy = scale * ddx * inv_var + ddbias +
// ddscale * (x - mean) * inv_var
template <typename T, framework::DataLayout layout>
__global__ void DoubleGradComputeDDYWithGlobal(
const T *ddx, const T *scale, const T *mean, const T *variance, const T *x,
const T *ddbias, const T *ddscale, const double epsilon, const int C,
const int sample_size, const int num, T *ddy) {
int gid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (ddx != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
ddy[i] += ddx[i] * scale[c] * inv_var;
}
}
__syncthreads();
if (ddscale != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
T inv_var = 1.0 / sqrt(variance[c] + epsilon);
ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c];
}
}
__syncthreads();
if (ddbias != nullptr) {
for (int i = gid; i < num; i += stride) {
const int c =
layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C;
ddy[i] += ddbias[c];
}
}
}
......@@ -383,8 +421,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
const T *mean_data, *variance_data;
if (use_global_stats) {
const auto *running_mean = ctx.Input<Tensor>("Mean");
const auto *running_var = ctx.Input<Tensor>("Variance");
const auto *running_mean_data = running_mean->template data<T>();
const auto *running_var_data = running_var->template data<T>();
mean_data = running_mean_data;
variance_data = running_var_data;
} else {
const T *smean_data = Saved_mean->data<T>();
......@@ -398,12 +439,12 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant(dev_ctx, dX, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDataWithGlobal<
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
} else {
DoubleGradComputeDataWithGlobal<
DoubleGradComputeDXWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num,
dx_data);
......@@ -456,15 +497,15 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx,
set_constant(dev_ctx, ddY, static_cast<T>(0));
if (use_global_stats) {
if (data_layout == DataLayout::kNHWC) {
DoubleGradComputeDataWithGlobal<
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNHWC><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
ddy_data);
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
} else {
DoubleGradComputeDataWithGlobal<
DoubleGradComputeDDYWithGlobal<
T, DataLayout::kNCHW><<<grid1, block, 0, dev_ctx.stream()>>>(
ddx_data, scale_data, variance_data, epsilon, C, sample_size, num,
ddy_data);
ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data,
ddscale_data, epsilon, C, sample_size, num, ddy_data);
}
} else {
if (data_layout == DataLayout::kNHWC) {
......
......@@ -130,5 +130,41 @@ class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck):
self.shape = [2, 2, 3, 4, 5]
class TestBatchNormDoubleGradCheckCase5(TestBatchNormDoubleGradCheck):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
dtype = "float32"
eps = 0.005
atol = 2e-4
chn = self.shape[1] if self.data_layout == 'NCHW' else self.shape[
-1]
x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x')
z = fluid.layers.batch_norm(
input=x,
data_layout=self.data_layout,
use_global_stats=self.use_global_stats)
x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype)
w, b = prog.global_block().all_parameters()[1:3]
w_arr = np.ones(chn).astype(dtype)
b_arr = np.zeros(chn).astype(dtype)
gradient_checker.double_grad_check(
[x, w, b],
z,
x_init=[x_arr, w_arr, b_arr],
atol=atol,
place=place,
eps=eps)
class TestBatchNormDoubleGradCheckCase6(TestBatchNormDoubleGradCheckCase5):
def init_test(self):
self.data_layout = 'NCHW'
self.use_global_stats = True
self.shape = [2, 3, 4, 5]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部