提交 4b8b4c71 编写于 作者: C cyber-pioneer

map output from composite rule to origin op

上级 37b33973
......@@ -98,6 +98,6 @@ def composite_batchnorm(
run_mean_ = assign(run_mean)
run_var_ = assign(run_var)
if trainable_statistics or not is_test:
return run_mean_, None, batch_mean_, batch_var_, run_var_, y
return y, run_mean_, run_var_, batch_mean_, batch_var_, None
else:
return run_mean_, batch_mean_, batch_var_, run_var_, y
return y, run_mean_, run_var_, batch_mean_, batch_var_
......@@ -36,6 +36,7 @@ from .utils import (
flatten_and_remove_none,
get_input_var_list,
get_output_var_list,
get_output_vars_from_comosite,
prepare_python_api_arguments,
)
......@@ -605,14 +606,15 @@ def _lower_composite(block, blacklist=[]):
bind(input_args, to_bind, value_table)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(get_output_vars_from_comosite(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
if new_out is not None:
assert orig_out.shape == new_out.shape, (
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
if orig_out.shape and new_out.shape:
assert orig_out.shape == new_out.shape, (
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
......
......@@ -219,6 +219,30 @@ def get_output_var_list(op):
]
def get_output_vars_from_comosite(op):
"""origin op outputs must be mapped into outputs of composite rule."""
origin_output_names = op.output_names
if origin_output_names is None:
return []
else:
name = op.type
res = []
if op_map[name].get("outputs"):
for item in op_map[name]["outputs"].keys():
origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names:
continue
origin_output_var = get_var_block(op.block, op.output(origin_output_name))
res.append(origin_output_var)
elif len(origin_output_names) == 1:
# When origin output num is 1, map info is not needed.
origin_output_var = get_var_block(op.block, op.output(origin_output_names[0]))
res.append(origin_output_var)
else:
raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.")
return res
def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册