未验证 提交 89f1cdab 编写于 作者: C ceci3 提交者: GitHub

update shortcut (#1173)

上级 3ae057ac
...@@ -97,6 +97,12 @@ def _find_next_target_op(op, graph, target_op_idx, sc_path): ...@@ -97,6 +97,12 @@ def _find_next_target_op(op, graph, target_op_idx, sc_path):
return False return False
def _is_identity_op(op):
if op.type() == 'scale' and op.attr('scale') == 1:
return True
return False
def is_shortcut(op, graph, sc_path, shortcut_start_op): def is_shortcut(op, graph, sc_path, shortcut_start_op):
""" """
op /```````````````````\\ add op /```````````````````\\ add
...@@ -105,12 +111,16 @@ def is_shortcut(op, graph, sc_path, shortcut_start_op): ...@@ -105,12 +111,16 @@ def is_shortcut(op, graph, sc_path, shortcut_start_op):
inps = op.all_inputs() inps = op.all_inputs()
pre_ops = graph.pre_ops(op) pre_ops = graph.pre_ops(op)
for p_op in pre_ops: for p_op in pre_ops:
if _is_identity_op(p_op):
p_op = graph.pre_ops(p_op)[0]
n_ops = graph.next_ops(p_op) n_ops = graph.next_ops(p_op)
if len(n_ops) == 1: if len(n_ops) == 1:
continue continue
### note: only support one branch donnot have op ### note: only support one branch donnot have op or has one scale op
has_sc = False has_sc = False
for n_op in n_ops: for n_op in n_ops:
if _is_identity_op(n_op):
n_op = graph.next_ops(n_op)[0]
if n_op.idx() == op.idx(): if n_op.idx() == op.idx():
shortcut_start_op.append(p_op) shortcut_start_op.append(p_op)
has_sc = True has_sc = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册