未验证 提交 79e75bc5 编写于 作者: L Li Min 提交者: GitHub

remove the initialization of saved_mean and saved_variance for batch_norm op (#33851)

上级 24783c84
...@@ -382,8 +382,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -382,8 +382,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and see if we need to // obtain running mean and running inv var, and there is no need
// initialize them. // to initialize them.
auto *mean_out = ctx.Output<Tensor>("MeanOut"); auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut"); auto *variance_out = ctx.Output<Tensor>("VarianceOut");
...@@ -394,10 +394,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -394,10 +394,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto *saved_variance = ctx.Output<Tensor>("SavedVariance"); auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
functor;
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
// Only 1 element in normalization dimension, // Only 1 element in normalization dimension,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册