未验证 提交 666efc23 编写于 作者: A AshburnLee 提交者: GitHub

Call new cudnn batch norm API regardless of data type and data layout (#30157)

上级 5c8455d6
...@@ -114,7 +114,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -114,7 +114,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
} }
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else { } else {
...@@ -122,7 +122,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -122,7 +122,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
} }
#else #else
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif // CUDNN_VERSION_MIN(7, 0, 1)
VLOG(3) << "Setting descriptors."; VLOG(3) << "Setting descriptors.";
std::vector<int> dims; std::vector<int> dims;
...@@ -151,7 +151,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -151,7 +151,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
if (test_mode || use_global_stats) { // It is training mode when it's not reference AND not using pre-trained
// model.
bool training = !test_mode && !use_global_stats;
if (!training) {
// only when test we use input to do computation. // only when test we use input to do computation.
const auto *est_mean = ctx.Input<Tensor>("Mean"); const auto *est_mean = ctx.Input<Tensor>("Mean");
const auto *est_var = ctx.Input<Tensor>("Variance"); const auto *est_var = ctx.Input<Tensor>("Variance");
...@@ -234,72 +237,70 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -234,72 +237,70 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
bool called = false; bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1) #if CUDNN_VERSION_MIN(7, 4, 1)
if (compute_format == DataLayout::kNHWC) { called = true;
called = true; size_t workspace_size = 0;
size_t workspace_size = 0; size_t reserve_space_size = 0;
size_t reserve_space_size = 0; void *reserve_space_ptr = nullptr;
void *reserve_space_ptr = nullptr; void *workspace_ptr = nullptr;
void *workspace_ptr = nullptr; Tensor workspace_tensor;
Tensor workspace_tensor; // Create reserve space and workspace for batch norm.
// Create reserve space and workspace for batch norm. // Create tensor for each batchnorm op, it will be used in the
// Create tensor for each batchnorm op, it will be used in the // backward. Thus this tensor shouldn't be temp.
// backward. Thus this tensor shouldn't be temp. auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
auto *reserve_space = ctx.Output<Tensor>("ReserveSpace"); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL( reserve_space,
reserve_space, platform::errors::NotFound(
platform::errors::NotFound( "The argument ReserveSpace of batch_norm op is not found."));
"The argument ReserveSpace of batch_norm op is not found."));
// --------------- cudnn batchnorm workspace ---------------
// --------------- cudnn batchnorm workspace --------------- PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::
platform::dynload:: cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( /*handle=*/handle,
/*handle=*/handle, /*mode=*/mode_,
/*mode=*/mode_, /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN, /*xDesc=*/data_desc_,
/*xDesc=*/data_desc_, /*zDesc=*/nullptr,
/*zDesc=*/nullptr, /*yDesc=*/data_desc_,
/*yDesc=*/data_desc_, /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_, /*activationDesc=*/nullptr,
/*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size));
/*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space --------------
// -------------- cudnn batchnorm reserve space -------------- PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::
platform::dynload:: cudnnGetBatchNormalizationTrainingExReserveSpaceSize(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize( /*handle=*/handle,
/*handle=*/handle, /*mode=*/mode_,
/*mode=*/mode_, /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN, /*activationDesc=*/nullptr,
/*activationDesc=*/nullptr, /*xDesc=*/data_desc_,
/*xDesc=*/data_desc_, /*sizeInBytes=*/&reserve_space_size));
/*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data(
reserve_space_ptr = reserve_space->mutable_data( ctx.GetPlace(), transformed_x.type(), reserve_space_size);
ctx.GetPlace(), transformed_x.type(), reserve_space_size); workspace_ptr = workspace_tensor.mutable_data(
workspace_ptr = workspace_tensor.mutable_data( ctx.GetPlace(), transformed_x.type(), workspace_size);
ctx.GetPlace(), transformed_x.type(), workspace_size); PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
platform::dynload::cudnnBatchNormalizationForwardTrainingEx( handle, mode_, CUDNN_BATCHNORM_OPS_BN, CudnnDataType<T>::kOne(),
handle, mode_, CUDNN_BATCHNORM_OPS_BN, CudnnDataType<T>::kZero(), data_desc_,
CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), transformed_x.template data<T>(), nullptr, nullptr, data_desc_,
data_desc_, transformed_x.template data<T>(), nullptr, transformed_y.template data<T>(), bn_param_desc_,
nullptr, data_desc_, transformed_y.template data<T>(), scale->template data<BatchNormParamType<T>>(),
bn_param_desc_, scale->template data<BatchNormParamType<T>>(), bias->template data<BatchNormParamType<T>>(), this_factor,
bias->template data<BatchNormParamType<T>>(), this_factor, mean_out->template mutable_data<BatchNormParamType<T>>(
mean_out->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), variance_out->template mutable_data<BatchNormParamType<T>>(
variance_out->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), epsilon,
epsilon, saved_mean->template mutable_data<BatchNormParamType<T>>(
saved_mean->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), saved_variance->template mutable_data<BatchNormParamType<T>>(
saved_variance->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), nullptr, workspace_ptr, workspace_size, reserve_space_ptr,
nullptr, workspace_ptr, workspace_size, reserve_space_ptr, reserve_space_size));
reserve_space_size)); #endif // CUDNN_VERSION_MIN(7, 4, 1)
}
#endif
if (!called) { if (!called) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationForwardTraining( platform::dynload::cudnnBatchNormalizationForwardTraining(
...@@ -640,7 +641,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -640,7 +641,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
<< "CUDNN_BN_MIN_EPSILON instead."; << "CUDNN_BN_MIN_EPSILON instead.";
} }
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON); epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
#if CUDNN_VERSION_MIN(7, 0, 0) #if CUDNN_VERSION_MIN(7, 0, 1)
if (FLAGS_cudnn_batchnorm_spatial_persistent) { if (FLAGS_cudnn_batchnorm_spatial_persistent) {
mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
} else { } else {
...@@ -648,7 +649,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -648,7 +649,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
} }
#else #else
mode_ = CUDNN_BATCHNORM_SPATIAL; mode_ = CUDNN_BATCHNORM_SPATIAL;
#endif #endif // CUDNN_VERSION_MIN(7, 0, 1)
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
...@@ -672,74 +673,73 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -672,74 +673,73 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
num, transformed_x.data<T>(), grid2, block, stream); num, transformed_x.data<T>(), grid2, block, stream);
} }
// This branch calls CUDNN APIs
if (d_scale && d_bias) { if (d_scale && d_bias) {
bool called = false; bool called = false;
#if CUDNN_VERSION_MIN(7, 4, 1) #if CUDNN_VERSION_MIN(7, 4, 1)
if (compute_format == DataLayout::kNHWC) { called = true;
called = true; size_t workspace_size = 0;
size_t workspace_size = 0; void *workspace_ptr = nullptr;
void *workspace_ptr = nullptr; Tensor workspace_tensor;
Tensor workspace_tensor; auto reserve_space_size = reserve_space->memory_size();
auto reserve_space_size = reserve_space->memory_size(); // --------------- cudnn batchnorm workspace ---------------
// --------------- cudnn batchnorm workspace --------------- PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::
platform::dynload:: cudnnGetBatchNormalizationBackwardExWorkspaceSize(
cudnnGetBatchNormalizationBackwardExWorkspaceSize( /*handle=*/dev_ctx.cudnn_handle(),
/*handle=*/dev_ctx.cudnn_handle(), /*mode=*/mode_,
/*mode=*/mode_, /*bnIps=*/CUDNN_BATCHNORM_OPS_BN,
/*bnIps=*/CUDNN_BATCHNORM_OPS_BN, /*xDesc=*/data_desc_,
/*xDesc=*/data_desc_, /*yDesc=*/data_desc_,
/*yDesc=*/data_desc_, /*dyDesc=*/data_desc_,
/*dyDesc=*/data_desc_, /*dzDesc=*/nullptr,
/*dzDesc=*/nullptr, /*dxDesc=*/data_desc_,
/*dxDesc=*/data_desc_, /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_, /*activationDesc=*/nullptr,
/*activationDesc=*/nullptr, /*sizeInBytes=*/&workspace_size));
/*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data(
workspace_ptr = workspace_tensor.mutable_data( ctx.GetPlace(), transformed_x.type(), workspace_size);
ctx.GetPlace(), transformed_x.type(), workspace_size);
PADDLE_ENFORCE_CUDA_SUCCESS(
PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cudnnBatchNormalizationBackwardEx(
platform::dynload::cudnnBatchNormalizationBackwardEx( /*handle=*/dev_ctx.cudnn_handle(),
/*handle=*/dev_ctx.cudnn_handle(), /*mode=*/mode_,
/*mode=*/mode_, /*bnOps=*/CUDNN_BATCHNORM_OPS_BN,
/*bnOps=*/CUDNN_BATCHNORM_OPS_BN, /*alphaDataDiff=*/CudnnDataType<T>::kOne(),
/*alphaDataDiff=*/CudnnDataType<T>::kOne(), /*betaDataDiff=*/CudnnDataType<T>::kZero(),
/*betaDataDiff=*/CudnnDataType<T>::kZero(), /*alphaParamDiff=*/CudnnDataType<T>::kOne(),
/*alphaParamDiff=*/CudnnDataType<T>::kOne(), /*betaParamDiff=*/CudnnDataType<T>::kZero(),
/*betaParamDiff=*/CudnnDataType<T>::kZero(), /*xDesc=*/data_desc_,
/*xDesc=*/data_desc_, /*xData=*/transformed_x.template data<T>(),
/*xData=*/transformed_x.template data<T>(), /*yDesc=*/nullptr,
/*yDesc=*/nullptr, /*yData=*/nullptr,
/*yData=*/nullptr, /*dyDesc=*/data_desc_,
/*dyDesc=*/data_desc_, /*dyData=*/transformed_d_y.template data<T>(),
/*dyData=*/transformed_d_y.template data<T>(), /*dzDesc=*/nullptr,
/*dzDesc=*/nullptr, /*dzData=*/nullptr,
/*dzData=*/nullptr, /*dxDesc=*/data_desc_,
/*dxDesc=*/data_desc_, /*dxData=*/transformed_d_x.template mutable_data<T>(
/*dxData=*/transformed_d_x.template mutable_data<T>( ctx.GetPlace()),
ctx.GetPlace()), /*dBnScaleBiasDesc=*/bn_param_desc_,
/*dBnScaleBiasDesc=*/bn_param_desc_, /*bnScaleData=*/scale->template data<BatchNormParamType<T>>(),
/*bnScaleData=*/scale->template data<BatchNormParamType<T>>(), /*bnBiasData=*/nullptr,
/*bnBiasData=*/nullptr, /*dBnScaleData=*/d_scale
/*dBnScaleData=*/d_scale ->template mutable_data<BatchNormParamType<T>>(
->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), /*dBnBiasData=*/d_bias
/*dBnBiasData=*/d_bias ->template mutable_data<BatchNormParamType<T>>(
->template mutable_data<BatchNormParamType<T>>( ctx.GetPlace()),
ctx.GetPlace()), /*epsilon=*/epsilon,
/*epsilon=*/epsilon, /*savedMean=*/saved_mean_data,
/*savedMean=*/saved_mean_data, /*savedInvVariance=*/saved_var_data,
/*savedInvVariance=*/saved_var_data, /*activationDesc=*/nullptr,
/*activationDesc=*/nullptr, /*workspace=*/workspace_ptr,
/*workspace=*/workspace_ptr, /*workSpaceSizeInBytes=*/workspace_size,
/*workSpaceSizeInBytes=*/workspace_size, /*reserveSpace=*/const_cast<T *>(
/*reserveSpace=*/const_cast<T *>( reserve_space->template data<T>()),
reserve_space->template data<T>()), /*reserveSpaceSizeInBytes=*/reserve_space_size));
/*reserveSpaceSizeInBytes=*/reserve_space_size)); #endif // CUDNN_VERSION_MIN(7, 4, 1)
}
#endif
if (!called) { if (!called) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnBatchNormalizationBackward( platform::dynload::cudnnBatchNormalizationBackward(
...@@ -764,6 +764,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -764,6 +764,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
ctx, &transformed_d_x, d_x); ctx, &transformed_d_x, d_x);
} }
} else { } else {
// This branch call CUDA kernels
if (compute_format == DataLayout::kNCHW) { if (compute_format == DataLayout::kNCHW) {
if (d_x) { if (d_x) {
BNBackwardData<T, block, framework::DataLayout::kNCHW><<< BNBackwardData<T, block, framework::DataLayout::kNCHW><<<
......
...@@ -178,6 +178,9 @@ class InplaceABNOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -178,6 +178,9 @@ class InplaceABNOpGradMaker : public framework::SingleGradOpMaker<T> {
op->SetInput("Bias", this->Input("Bias")); op->SetInput("Bias", this->Input("Bias"));
op->SetInput("SavedMean", this->Output("SavedMean")); op->SetInput("SavedMean", this->Output("SavedMean"));
op->SetInput("SavedVariance", this->Output("SavedVariance")); op->SetInput("SavedVariance", this->Output("SavedVariance"));
if (this->HasOutput("ReserveSpace")) {
op->SetInput("ReserveSpace", this->Output("ReserveSpace"));
}
// used when setting use_global_stats True during training // used when setting use_global_stats True during training
if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) {
......
...@@ -1309,12 +1309,6 @@ class BatchNorm(layers.Layer): ...@@ -1309,12 +1309,6 @@ class BatchNorm(layers.Layer):
dtype=self._dtype) dtype=self._dtype)
self._variance.stop_gradient = True self._variance.stop_gradient = True
self._has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
self._has_reserve_space = True
self._in_place = in_place self._in_place = in_place
self._data_layout = data_layout self._data_layout = data_layout
self._momentum = momentum self._momentum = momentum
...@@ -1341,7 +1335,6 @@ class BatchNorm(layers.Layer): ...@@ -1341,7 +1335,6 @@ class BatchNorm(layers.Layer):
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance, input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs) mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn) batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn)
...@@ -1371,11 +1364,8 @@ class BatchNorm(layers.Layer): ...@@ -1371,11 +1364,8 @@ class BatchNorm(layers.Layer):
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
saved_variance = self._helper.create_variable_for_type_inference( saved_variance = self._helper.create_variable_for_type_inference(
dtype=self._dtype, stop_gradient=True) dtype=self._dtype, stop_gradient=True)
reserve_space = self._helper.create_variable_for_type_inference(
reserve_space = None dtype=self._helper.input_dtype(input), stop_gradient=True)
if self._has_reserve_space:
reserve_space = self._helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference( batch_norm_out = input if self._in_place else self._helper.create_variable_for_type_inference(
self._dtype) self._dtype)
...@@ -1388,7 +1378,7 @@ class BatchNorm(layers.Layer): ...@@ -1388,7 +1378,7 @@ class BatchNorm(layers.Layer):
"SavedVariance": [saved_variance] "SavedVariance": [saved_variance]
} }
if reserve_space is not None: if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space outputs["ReserveSpace"] = [reserve_space]
self._helper.append_op( self._helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
......
...@@ -2792,12 +2792,6 @@ def batch_norm(input, ...@@ -2792,12 +2792,6 @@ def batch_norm(input,
'batch_norm') 'batch_norm')
dtype = helper.input_dtype() dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
# use fp32 for bn parameter # use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16: if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32 dtype = core.VarDesc.VarType.FP32
...@@ -2845,17 +2839,16 @@ def batch_norm(input, ...@@ -2845,17 +2839,16 @@ def batch_norm(input,
# create output # create output
# mean and mean_out share the same memory # mean and mean_out share the same memory
mean_out = mean mean_out = mean
# variance and variance out share the same memory # variance and variance_out share the same memory
variance_out = variance variance_out = variance
saved_mean = helper.create_variable_for_type_inference( saved_mean = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference( saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
reserve_space = None reserve_space = None
if has_reserve_space: if not is_test:
reserve_space = helper.create_variable_for_type_inference( reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True) dtype=helper.input_dtype(), stop_gradient=True)
batch_norm_out = input if in_place else \ batch_norm_out = input if in_place else \
helper.create_variable_for_type_inference(dtype) helper.create_variable_for_type_inference(dtype)
...@@ -2998,12 +2991,6 @@ def inplace_abn(input, ...@@ -2998,12 +2991,6 @@ def inplace_abn(input,
'inplace_abn') 'inplace_abn')
dtype = helper.input_dtype() dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
input_shape = input.shape input_shape = input.shape
if data_layout == 'NCHW': if data_layout == 'NCHW':
channel_num = input_shape[1] channel_num = input_shape[1]
...@@ -3053,12 +3040,8 @@ def inplace_abn(input, ...@@ -3053,12 +3040,8 @@ def inplace_abn(input,
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference( saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
reserve_space = helper.create_variable_for_type_inference(
reserve_space = None dtype=dtype, stop_gradient=True)
if has_reserve_space:
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input batch_norm_out = input
inputs = { inputs = {
...@@ -3082,7 +3065,6 @@ def inplace_abn(input, ...@@ -3082,7 +3065,6 @@ def inplace_abn(input,
inputs['MomemtumTensor'] = momentum inputs['MomemtumTensor'] = momentum
else: else:
attrs['momentum'] = momentum attrs['momentum'] = momentum
outputs = { outputs = {
"Y": batch_norm_out, "Y": batch_norm_out,
"MeanOut": mean_out, "MeanOut": mean_out,
......
...@@ -440,16 +440,8 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -440,16 +440,8 @@ class TestBatchNormOpTraining(unittest.TestCase):
"SavedMean": block.var('saved_mean'), "SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance') "SavedVariance": block.var('saved_variance')
} }
has_reserve_space = False block.create_var(name="reserve_space", dtype='float32')
if data_format == 'NHWC': outputs["ReserveSpace"] = block.var('reserve_space')
flag = os.environ.get(
'FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
if has_reserve_space:
block.create_var(name="reserve_space", dtype='float16')
outputs["ReserveSpace"] = block.var('reserve_space')
del os.environ['FLAGS_cudnn_batchnorm_spatial_persistent']
bn_op = block.append_op( bn_op = block.append_op(
type="batch_norm", type="batch_norm",
inputs=inputs, inputs=inputs,
......
...@@ -122,7 +122,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -122,7 +122,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
if not only_forward: if not only_forward:
others = [ others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD' 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD'
] ]
fetch_names += others fetch_names += others
bn_fetches = exe.run(program=main, bn_fetches = exe.run(program=main,
...@@ -142,7 +142,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase): ...@@ -142,7 +142,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
if not only_forward: if not only_forward:
others = [ others = [
'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD', 'batch_norm_0.tmp_0', 'batch_norm_0.tmp_1', 'bn_scale@GRAD',
'bn_bias@GRAD', 'batch_norm_0.tmp_2@GRAD', 'conv2d_0.tmp_0@GRAD' 'bn_bias@GRAD', 'batch_norm_0.tmp_3@GRAD', 'conv2d_0.tmp_0@GRAD'
] ]
fetch_names += others fetch_names += others
for nm in fetch_names: for nm in fetch_names:
......
...@@ -166,7 +166,6 @@ def batch_norm(x, ...@@ -166,7 +166,6 @@ def batch_norm(x,
batch_norm_out = paddle.nn.functional.batch_norm(x, rm, rv, w, b) batch_norm_out = paddle.nn.functional.batch_norm(x, rm, rv, w, b)
print(batch_norm_out) print(batch_norm_out)
""" """
assert len(x.shape) >= 2, "input dim must be larger than 1" assert len(x.shape) >= 2, "input dim must be larger than 1"
# input ad out must share the memory # input ad out must share the memory
...@@ -196,7 +195,6 @@ def batch_norm(x, ...@@ -196,7 +195,6 @@ def batch_norm(x,
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm( batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out, x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs) *attrs)
return dygraph_utils._append_activation_in_dygraph( return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=None) batch_norm_out, act=None)
...@@ -230,13 +228,16 @@ def batch_norm(x, ...@@ -230,13 +228,16 @@ def batch_norm(x,
saved_variance = helper.create_variable_for_type_inference( saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
batch_norm_out = helper.create_variable_for_type_inference(dtype) batch_norm_out = helper.create_variable_for_type_inference(dtype)
reserve_space = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
outputs = { outputs = {
"Y": [batch_norm_out], "Y": [batch_norm_out],
"MeanOut": [running_mean], "MeanOut": [running_mean],
"VarianceOut": [running_var], "VarianceOut": [running_var],
"SavedMean": [saved_mean], "SavedMean": [saved_mean],
"SavedVariance": [saved_variance] "SavedVariance": [saved_variance],
"ReserveSpace": [reserve_space]
} }
helper.append_op( helper.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册