未验证 提交 8785537c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Remove unnecessary operations of GroupNorm in eager mode (#47791)

上级 1c6013dd
...@@ -35,7 +35,6 @@ from ...framework import get_default_dtype ...@@ -35,7 +35,6 @@ from ...framework import get_default_dtype
from ..initializer import Constant from ..initializer import Constant
from ...framework import ParamAttr from ...framework import ParamAttr
from ...fluid.data_feeder import check_variable_and_dtype from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import dygraph_utils
from ..functional import batch_norm, layer_norm, instance_norm from ..functional import batch_norm, layer_norm, instance_norm
...@@ -413,15 +412,8 @@ class GroupNorm(Layer): ...@@ -413,15 +412,8 @@ class GroupNorm(Layer):
) )
def forward(self, input): def forward(self, input):
mean_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
variance_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
if in_dygraph_mode(): if in_dygraph_mode():
pre_act = _C_ops.group_norm( return _C_ops.group_norm(
input, input,
self.weight, self.weight,
self.bias, self.bias,
...@@ -430,11 +422,14 @@ class GroupNorm(Layer): ...@@ -430,11 +422,14 @@ class GroupNorm(Layer):
self._data_format, self._data_format,
) )
return dygraph_utils._append_activation_in_dygraph( mean_out = self._helper.create_variable_for_type_inference(
pre_act, act=None dtype=input.dtype, stop_gradient=True
)
variance_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
) )
elif _in_legacy_dygraph(): if _in_legacy_dygraph():
pre_act, _, _ = _legacy_C_ops.group_norm( pre_act, _, _ = _legacy_C_ops.group_norm(
input, input,
self.weight, self.weight,
...@@ -446,9 +441,7 @@ class GroupNorm(Layer): ...@@ -446,9 +441,7 @@ class GroupNorm(Layer):
'groups', 'groups',
self._num_groups, self._num_groups,
) )
return dygraph_utils._append_activation_in_dygraph( return pre_act
pre_act, act=None
)
inputs = {'X': input} inputs = {'X': input}
if self.bias is not None: if self.bias is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册