未验证 提交 8785537c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Remove unnecessary operations of GroupNorm in eager mode (#47791)

上级 1c6013dd
......@@ -35,7 +35,6 @@ from ...framework import get_default_dtype
from ..initializer import Constant
from ...framework import ParamAttr
from ...fluid.data_feeder import check_variable_and_dtype
from ...fluid import dygraph_utils
from ..functional import batch_norm, layer_norm, instance_norm
......@@ -413,15 +412,8 @@ class GroupNorm(Layer):
)
def forward(self, input):
mean_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
variance_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
if in_dygraph_mode():
pre_act = _C_ops.group_norm(
return _C_ops.group_norm(
input,
self.weight,
self.bias,
......@@ -430,11 +422,14 @@ class GroupNorm(Layer):
self._data_format,
)
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=None
)
mean_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
variance_out = self._helper.create_variable_for_type_inference(
dtype=input.dtype, stop_gradient=True
)
elif _in_legacy_dygraph():
if _in_legacy_dygraph():
pre_act, _, _ = _legacy_C_ops.group_norm(
input,
self.weight,
......@@ -446,9 +441,7 @@ class GroupNorm(Layer):
'groups',
self._num_groups,
)
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=None
)
return pre_act
inputs = {'X': input}
if self.bias is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册