提交 648cb508 编写于 作者: C cyber-pioneer

polish log

上级 d83fe716
......@@ -604,17 +604,28 @@ def _lower_composite(block, blacklist=[]):
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
change = True
op_name = op.type
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
orig_outs = expand_nested_list(
get_output_vars_from_comosite(op)
)
new_outs = expand_nested_list(
as_tensors(lower_fn(op, *input_args))
)
assert len(orig_outs) == len(new_outs), (
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
)
for orig_out, new_out in zip(
expand_nested_list(get_output_vars_from_comosite(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
orig_outs,
new_outs,
):
if new_out is not None:
if orig_out.shape and new_out.shape:
assert orig_out.shape == new_out.shape, (
f'when replace origin op with composite rule, origin out shape should be equal to new out shape, '
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
......
......@@ -239,6 +239,7 @@ def get_output_vars_from_comosite(op):
for item in op_map[name]["outputs"].keys():
origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names:
# in some cases, some output of origin op is optional, so op name may not be in origin_output_names
continue
origin_output_var = get_var_block(
op.block, op.output(origin_output_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册