未验证 提交 dc1202b2 编写于 作者: P PuQing 提交者: GitHub

[Bug Fix] fix parameter not passed in InstanceNorm (#53900)

* fix parameter not passed

* fix repr
上级 20512dee
...@@ -70,10 +70,12 @@ class _InstanceNormBase(Layer): ...@@ -70,10 +70,12 @@ class _InstanceNormBase(Layer):
assert ( assert (
weight_attr == bias_attr weight_attr == bias_attr
), "weight_attr and bias_attr must be set to False at the same time in InstanceNorm" ), "weight_attr and bias_attr must be set to False at the same time in InstanceNorm"
self._momentum = momentum
self._epsilon = epsilon self._epsilon = epsilon
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._bias_attr = bias_attr self._bias_attr = bias_attr
self._num_features = num_features self._num_features = num_features
self._data_format = data_format
if weight_attr is not False and bias_attr is not False: if weight_attr is not False and bias_attr is not False:
self.scale = self.create_parameter( self.scale = self.create_parameter(
...@@ -99,7 +101,12 @@ class _InstanceNormBase(Layer): ...@@ -99,7 +101,12 @@ class _InstanceNormBase(Layer):
self._check_input_dim(input) self._check_input_dim(input)
return instance_norm( return instance_norm(
input, weight=self.scale, bias=self.bias, eps=self._epsilon input,
weight=self.scale,
bias=self.bias,
momentum=self._momentum,
eps=self._epsilon,
data_format=self._data_format,
) )
def extra_repr(self): def extra_repr(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册