diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 147e7fca3ff19dcdd56cfbf6ccf288df2811e14c..b0e0fe323437d0e29583c37caa3f43e58bf399d0 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -564,19 +564,25 @@ class _BatchNormBase(Layer): self._use_global_stats = use_global_stats if get_default_dtype() == 'float16': - set_default_dtype('float32') + self._dtype = 'float32' + else: + self._dtype = get_default_dtype() param_shape = [num_features] # create parameter if weight_attr == False: self.weight = self.create_parameter( - attr=None, shape=param_shape, default_initializer=Constant(1.0)) + attr=None, + shape=param_shape, + dtype=self._dtype, + default_initializer=Constant(1.0)) self.weight.stop_gradient = True else: self.weight = self.create_parameter( attr=self._weight_attr, shape=param_shape, + dtype=self._dtype, default_initializer=Constant(1.0)) self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. @@ -584,12 +590,16 @@ class _BatchNormBase(Layer): self.bias = self.create_parameter( attr=None, shape=param_shape, + dtype=self._dtype, default_initializer=Constant(0.0), is_bias=True) self.bias.stop_gradient = True else: self.bias = self.create_parameter( - attr=self._bias_attr, shape=param_shape, is_bias=True) + attr=self._bias_attr, + shape=param_shape, + dtype=self._dtype, + is_bias=True) self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0. moving_mean_name = None @@ -600,6 +610,7 @@ class _BatchNormBase(Layer): moving_variance_name = name + "_variance" self._mean = self.create_parameter( + dtype=self._dtype, attr=ParamAttr( name=moving_mean_name, initializer=Constant(0.0), @@ -609,6 +620,7 @@ class _BatchNormBase(Layer): self._mean.stop_gradient = True self._variance = self.create_parameter( + dtype=self._dtype, attr=ParamAttr( name=moving_variance_name, initializer=Constant(1.0),