From 6c89ca2157ec9af94a3dd76e793694ad82e2a098 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 26 May 2020 12:52:14 +0800 Subject: [PATCH] Add output threshold for ops that have several output activations, test=develop (#24726) --- .../slim/quantization/quantization_pass.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index a6ab2aa86d0..44a872bd762 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -1156,14 +1156,13 @@ class OutScaleForTrainingPass(object): assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' self._is_test = graph.is_test() - ops = graph.all_op_nodes() - for op_node in ops: - name = op_node.name() - if name in self._teller_set: - if len(op_node.output_arg_names()) != 1: - continue - in_node = graph._find_node_by_name( - op_node.outputs, op_node.output_arg_names()[0]) + target_ops = [] + for op in graph.all_op_nodes(): + if op.name() in self._teller_set: + target_ops.append(op) + for op in target_ops: + for output_var_name in _get_op_output_var_names(op): + in_node = graph._find_node_by_name(op.outputs, output_var_name) out_node = graph.create_var_node_from_desc(in_node.var()) scale_node = graph.create_persistable_node( name=self._scale_name(in_node.name()), @@ -1263,13 +1262,13 @@ class OutScaleForInferencePass(object): """ assert isinstance(graph, IrGraph), 'graph must be the instance of IrGraph.' - ops = graph.all_op_nodes() - for op_node in ops: - name = op_node.name() - if name in self._teller_set: - if len(op_node.output_arg_names()) != 1: - continue - scale_name = self._scale_name(op_node.output_arg_names()[0]) + op_nodes = graph.all_op_nodes() + for op_node in op_nodes: + if op_node.name() in self._teller_set: + output_var_name = _get_op_output_var_names(op_node) + assert len(output_var_name) == 1, "Only support collecting " \ + "output for op that only has an activation output for now." + scale_name = self._scale_name(output_var_name[0]) scale_v = np.array( self._scope.find_var(scale_name).get_tensor())[0] op_node.op()._set_attr("out_threshold", float(scale_v)) -- GitLab