未验证 提交 417b22d2 编写于 作者: G Guoxia Wang 提交者: GitHub

fix BatchNorm for fp16 (#36376) (#36691)

* fix BatchNorm for fp16
上级 64643d50
......@@ -564,19 +564,25 @@ class _BatchNormBase(Layer):
self._use_global_stats = use_global_stats
if get_default_dtype() == 'float16':
set_default_dtype('float32')
self._dtype = 'float32'
else:
self._dtype = get_default_dtype()
param_shape = [num_features]
# create parameter
if weight_attr == False:
self.weight = self.create_parameter(
attr=None, shape=param_shape, default_initializer=Constant(1.0))
attr=None,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(1.0))
self.weight.stop_gradient = self._weight_attr != None and self._weight_attr.learning_rate == 0.
......@@ -584,12 +590,16 @@ class _BatchNormBase(Layer):
self.bias = self.create_parameter(
attr=None,
shape=param_shape,
dtype=self._dtype,
default_initializer=Constant(0.0),
is_bias=True)
self.bias.stop_gradient = True
else:
self.bias = self.create_parameter(
attr=self._bias_attr, shape=param_shape, is_bias=True)
attr=self._bias_attr,
shape=param_shape,
dtype=self._dtype,
is_bias=True)
self.bias.stop_gradient = self._bias_attr != None and self._bias_attr.learning_rate == 0.
moving_mean_name = None
......@@ -600,6 +610,7 @@ class _BatchNormBase(Layer):
moving_variance_name = name + "_variance"
self._mean = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr(
name=moving_mean_name,
initializer=Constant(0.0),
......@@ -609,6 +620,7 @@ class _BatchNormBase(Layer):
self._mean.stop_gradient = True
self._variance = self.create_parameter(
dtype=self._dtype,
attr=ParamAttr(
name=moving_variance_name,
initializer=Constant(1.0),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册