未验证 提交 6c89ca21 编写于 作者: C cc 提交者: GitHub

Add output threshold for ops that have several output activations, test=develop (#24726)

上级 dbe24977
......@@ -1156,14 +1156,13 @@ class OutScaleForTrainingPass(object):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._teller_set:
if len(op_node.output_arg_names()) != 1:
continue
in_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
target_ops = []
for op in graph.all_op_nodes():
if op.name() in self._teller_set:
target_ops.append(op)
for op in target_ops:
for output_var_name in _get_op_output_var_names(op):
in_node = graph._find_node_by_name(op.outputs, output_var_name)
out_node = graph.create_var_node_from_desc(in_node.var())
scale_node = graph.create_persistable_node(
name=self._scale_name(in_node.name()),
......@@ -1263,13 +1262,13 @@ class OutScaleForInferencePass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
ops = graph.all_op_nodes()
for op_node in ops:
name = op_node.name()
if name in self._teller_set:
if len(op_node.output_arg_names()) != 1:
continue
scale_name = self._scale_name(op_node.output_arg_names()[0])
op_nodes = graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in self._teller_set:
output_var_name = _get_op_output_var_names(op_node)
assert len(output_var_name) == 1, "Only support collecting " \
"output for op that only has an activation output for now."
scale_name = self._scale_name(output_var_name[0])
scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_threshold", float(scale_v))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册