未验证 提交 8fd1b6ad 编写于 作者: G Guoxia Wang 提交者: GitHub

fix BatchNorm for fp16 (#36376)

* fix BatchNorm for fp16
上级 d7064f04
...@@ -564,19 +564,25 @@ class _BatchNormBase(Layer): ...@@ -564,19 +564,25 @@ class _BatchNormBase(Layer):
self._use_global_stats = use_global_stats self._use_global_stats = use_global_stats
if get_default_dtype() == 'float16': if get_default_dtype() == 'float16':
set_default_dtype('float32') self._dtype = 'float32'
else:
self._dtype = get_default_dtype()
param_shape = [num_features] param_shape = [num_features]
# create parameter # create parameter
if weight_attr == False: if weight_attr == False:
self.weight = self.create_parameter( self.weight = self.create_parameter(
attr=None, shape=param_shape, default_initializer=Constant(1.0)) attr=None,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = True self.weight.stop_gradient = True
else: else:
self.weight = self.create_parameter( self.weight = self.create_parameter(
attr=self._weight_attr, attr=self._weight_attr,
shape=param_shape, shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0)) default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0. self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
...@@ -584,12 +590,16 @@ class _BatchNormBase(Layer): ...@@ -584,12 +590,16 @@ class _BatchNormBase(Layer):
self.bias = self.create_parameter( self.bias = self.create_parameter(
attr=None, attr=None,
shape=param_shape, shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(0.0), default_initializer=Constant(0.0),
is_bias=True) is_bias=True)
self.bias.stop_gradient = True self.bias.stop_gradient = True
else: else:
self.bias = self.create_parameter( self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True) attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0. self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
moving_mean_name = None moving_mean_name = None
...@@ -600,6 +610,7 @@ class _BatchNormBase(Layer): ...@@ -600,6 +610,7 @@ class _BatchNormBase(Layer):
moving_variance_name = name + "_variance" moving_variance_name = name + "_variance"
self._mean = self.create_parameter( self._mean = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr( attr=ParamAttr(
name=moving_mean_name, name=moving_mean_name,
initializer=Constant(0.0), initializer=Constant(0.0),
...@@ -609,6 +620,7 @@ class _BatchNormBase(Layer): ...@@ -609,6 +620,7 @@ class _BatchNormBase(Layer):
self._mean.stop_gradient = True self._mean.stop_gradient = True
self._variance = self.create_parameter( self._variance = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr( attr=ParamAttr(
name=moving_variance_name, name=moving_variance_name,
initializer=Constant(1.0), initializer=Constant(1.0),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册