未验证 提交 7f0ba045 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Fix none var added with op error (#55133)

* fix prim add fill_any_like bug

* polish code
上级 4d1b9f04
......@@ -2500,7 +2500,21 @@ def calc_gradient_helper(
for op in reversed(block.ops):
if op.type == "fill_any_like":
continue
# Some outputs of composite op are not needed and will be removed.
# Thus, those vars should not be added with another op.
keep_var_list = []
if op.type in core.ops_contain_none.keys():
values = core.ops_contain_none[op.type]
if isinstance(values, list):
none_vars = values
else:
none_vars = values(op)
for none_var_name in none_vars:
keep_var_list.append(op.output(none_var_name)[0])
for var_name in op.desc.output_arg_names():
if keep_var_list and (var_name in keep_var_list):
continue
grad_var_name = _append_grad_suffix_(var_name)
if grad_var_name not in grad_name_set:
op_desc = _create_op_desc_(
......
......@@ -465,6 +465,28 @@ def _test_use_sync(value):
prim_config = {"forward_blacklist": set(), "composite_ops_record": set()}
def _get_batch_norm_none_var(op):
"""Some outputs of batch_norm's replaced composite rule are not needed and will be removed."""
use_run_stat = (
op.attr("is_test") and (not op.attr("trainable_statistics"))
) or op.attr("use_global_stats")
if use_run_stat:
return ["ReserveSpace", "SavedMean", "SavedVariance"]
else:
return ["ReserveSpace"]
# In some case, inputs and outputs of composite op or its replaced composite rule might be None.
# It means such arg will be no longer required in processed program by composite mechanism.
# Therefore, such special ops should be recorded in advance and be released in args check.
ops_contain_none = {
"batch_norm": _get_batch_norm_none_var,
"flatten_contiguous_range": ["XShape"],
"squeeze2": ["XShape"],
"unsqueeze2": ["XShape"],
}
def _set_prim_forward_blacklist(ops=None):
if ops is None:
prim_config["forward_blacklist"] = []
......
......@@ -18,7 +18,7 @@ from collections import OrderedDict
import paddle
from paddle.fluid import framework
from paddle.fluid.core import prim_config
from paddle.fluid.core import ops_contain_none, prim_config
from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors
......@@ -549,17 +549,6 @@ def _lower(block, reverse, blacklist):
block._sync_with_cpp()
# In some case, inputs and outputs of composite op or its replaced composite rule might be None.
# It means such arg will be no longer required in processed program by composite mechanism.
# Therefore, such special ops should be recorded in advance and be released in args check.
ops_contain_none = (
"batch_norm",
"flatten_contiguous_range",
"squeeze2",
"unsqueeze2",
)
def _lower_composite(
block,
filter_: typing.Callable[[framework.Operator], bool] = lambda x: True,
......
......@@ -489,7 +489,7 @@ class TestPrimEvalBranch(unittest.TestCase):
def train(self, use_prim):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = BatchNorm(2, act="relu", is_test=True)
net = BatchNorm(2, is_test=True)
net = apply_to_static(net, False)
out = net(self.x)
loss = paddle.mean(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册