From 6cf74c257cd0184177fe9c8840ca25858e3939a6 Mon Sep 17 00:00:00 2001 From: zhang wenhui Date: Mon, 7 Sep 2020 15:35:45 +0800 Subject: [PATCH] fix batchnorm ,test=develop (#27036) --- python/paddle/nn/functional/norm.py | 7 +++---- python/paddle/nn/layer/norm.py | 6 +++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index f63fc33525..9e8f365f6d 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -176,14 +176,13 @@ def batch_norm(x, mean_out = running_mean variance_out = running_var - true_data_format = ['NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW'] + true_data_format = ['NC', 'NCL', 'NCHW', 'NCDHW'] if data_format not in true_data_format: raise ValueError( - "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCWH', 'NCDHW', but receive {}". + "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', but receive {}". format(data_format)) - if data_format != 'NCWH': - data_format = 'NCHW' + data_format = 'NCHW' if in_dygraph_mode(): # for dygraph need tuple diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 8bdb09c769..d13bf66ba5 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -811,7 +811,7 @@ class BatchNorm2d(_BatchNormBase): If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr as bias_attr. If it is set to Fasle, the weight is not learnable. If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. - data_format(str, optional): Specify the input data format, the data format can be "NCHW" or "NHWC". Default: NCHW. + data_format(str, optional): Specify the input data format, the data format can be "NCHW". Default: NCHW. track_running_stats(bool, optional): Whether to use global mean and variance. In train period, True will track global mean and variance used for inference. When inference, track_running_stats must be True. Default: True. @@ -844,10 +844,10 @@ class BatchNorm2d(_BatchNormBase): """ def _check_data_format(self, input): - if input == 'NCHW' or input == 'NCWH': + if input == 'NCHW': self._data_format = input else: - raise ValueError('expected NCHW or NCWH for data_format input') + raise ValueError('expected NCHW for data_format input') def _check_input_dim(self, input): if len(input.shape) != 4: -- GitLab