提交 4b8b4c71 编写于 作者: C cyber-pioneer

map output from composite rule to origin op

上级 37b33973
...@@ -98,6 +98,6 @@ def composite_batchnorm( ...@@ -98,6 +98,6 @@ def composite_batchnorm(
run_mean_ = assign(run_mean) run_mean_ = assign(run_mean)
run_var_ = assign(run_var) run_var_ = assign(run_var)
if trainable_statistics or not is_test: if trainable_statistics or not is_test:
return run_mean_, None, batch_mean_, batch_var_, run_var_, y return y, run_mean_, run_var_, batch_mean_, batch_var_, None
else: else:
return run_mean_, batch_mean_, batch_var_, run_var_, y return y, run_mean_, run_var_, batch_mean_, batch_var_
...@@ -36,6 +36,7 @@ from .utils import ( ...@@ -36,6 +36,7 @@ from .utils import (
flatten_and_remove_none, flatten_and_remove_none,
get_input_var_list, get_input_var_list,
get_output_var_list, get_output_var_list,
get_output_vars_from_comosite,
prepare_python_api_arguments, prepare_python_api_arguments,
) )
...@@ -605,14 +606,15 @@ def _lower_composite(block, blacklist=[]): ...@@ -605,14 +606,15 @@ def _lower_composite(block, blacklist=[]):
bind(input_args, to_bind, value_table) bind(input_args, to_bind, value_table)
for orig_out, new_out in zip( for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)), expand_nested_list(get_output_vars_from_comosite(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))), expand_nested_list(as_tensors(lower_fn(op, *input_args))),
): ):
if new_out is not None: if new_out is not None:
assert orig_out.shape == new_out.shape, ( if orig_out.shape and new_out.shape:
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' assert orig_out.shape == new_out.shape, (
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' f'when replace origin op with composite rule, origin out shape should be equal to new out shape, '
) f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
assert not (orig_out is None) ^ ( assert not (orig_out is None) ^ (
new_out is None new_out is None
), "orig_out and new_out should match." ), "orig_out and new_out should match."
......
...@@ -219,6 +219,30 @@ def get_output_var_list(op): ...@@ -219,6 +219,30 @@ def get_output_var_list(op):
] ]
def get_output_vars_from_comosite(op):
"""origin op outputs must be mapped into outputs of composite rule."""
origin_output_names = op.output_names
if origin_output_names is None:
return []
else:
name = op.type
res = []
if op_map[name].get("outputs"):
for item in op_map[name]["outputs"].keys():
origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names:
continue
origin_output_var = get_var_block(op.block, op.output(origin_output_name))
res.append(origin_output_var)
elif len(origin_output_names) == 1:
# When origin output num is 1, map info is not needed.
origin_output_var = get_var_block(op.block, op.output(origin_output_names[0]))
res.append(origin_output_var)
else:
raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.")
return res
def flatten(inp): def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable): if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp] return [inp]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册