提交 648cb508 编写于 作者: C cyber-pioneer

polish log

上级 d83fe716
...@@ -604,17 +604,28 @@ def _lower_composite(block, blacklist=[]): ...@@ -604,17 +604,28 @@ def _lower_composite(block, blacklist=[]):
ops_to_remove.append(op_idx) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist: if lookup_fn(op.type) is not None and op.type not in blacklist:
change = True change = True
op_name = op.type
input_args = prepare_python_api_arguments(op) input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table) bind(input_args, to_bind, value_table)
orig_outs = expand_nested_list(
get_output_vars_from_comosite(op)
)
new_outs = expand_nested_list(
as_tensors(lower_fn(op, *input_args))
)
assert len(orig_outs) == len(new_outs), (
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
)
for orig_out, new_out in zip( for orig_out, new_out in zip(
expand_nested_list(get_output_vars_from_comosite(op)), orig_outs,
expand_nested_list(as_tensors(lower_fn(op, *input_args))), new_outs,
): ):
if new_out is not None: if new_out is not None:
if orig_out.shape and new_out.shape: if orig_out.shape and new_out.shape:
assert orig_out.shape == new_out.shape, ( assert orig_out.shape == new_out.shape, (
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, ' f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}' f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
) )
assert not (orig_out is None) ^ ( assert not (orig_out is None) ^ (
......
...@@ -239,6 +239,7 @@ def get_output_vars_from_comosite(op): ...@@ -239,6 +239,7 @@ def get_output_vars_from_comosite(op):
for item in op_map[name]["outputs"].keys(): for item in op_map[name]["outputs"].keys():
origin_output_name = op_map[name]["outputs"][item] origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names: if origin_output_name not in origin_output_names:
# in some cases, some output of origin op is optional, so op name may not be in origin_output_names
continue continue
origin_output_var = get_var_block( origin_output_var = get_var_block(
op.block, op.output(origin_output_name) op.block, op.output(origin_output_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册