From 7c8f7df2fe3922c0a492522d890e47fb5af34cb7 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 21 Feb 2019 15:02:52 +0800 Subject: [PATCH] add some op_des funs to IrOpNode and add some var_des funs to IrVarNode. test=develop --- .../slim/quantization/quantization_pass.py | 54 +++++++------- python/paddle/fluid/framework.py | 72 +++++++++++++++++++ 2 files changed, 99 insertions(+), 27 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 5764d9d94..622add484 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -231,14 +231,14 @@ class QuantizationTransformPass(object): quant_var_node = graph.create_var_node( name=self._quantized_var_name(var_node.name()), - var_type=var_node.var().type(), - shape=var_node.var().shape(), - var_dtype=var_node.var().dtype()) + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) scale_var_node = graph.create_var_node( name=self._quantized_scale_name(var_node.name()), - var_type=var_node.var().type(), - shape=var_node.var().shape(), - var_dtype=var_node.var().dtype()) + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', attrs={ @@ -261,15 +261,15 @@ class QuantizationTransformPass(object): quant_var_node = graph.create_var_node( name=self._quantized_var_name(var_node.name()), - var_type=var_node.var().type(), - shape=var_node.var().shape(), - var_dtype=var_node.var().dtype()) + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) scale_in_node = graph.create_persistable_node( name=self._quantized_scale_name(var_node.name()), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[1], - var_dtype=var_node.var().dtype()) + var_dtype=var_node.dtype()) self._need_initialized[scale_in_node.var()] = Constant(value=0.001) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) @@ -282,7 +282,7 @@ class QuantizationTransformPass(object): name=unique_name.generate('scales'), var_type=core.VarDesc.VarType.LOD_TENSOR, shape=[self._window_size], - var_dtype=var_node.var().dtype()) + var_dtype=var_node.dtype()) self._need_initialized[scales_node.var()] = Constant(value=0) inputs['Iter'] = self._global_step outputs['OutScales'] = scales_node @@ -317,9 +317,9 @@ class QuantizationTransformPass(object): dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(var_node.name()), - var_type=var_node.var().type(), - shape=var_node.var().shape(), - var_dtype=var_node.var().dtype()) + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) max_range = (1 << (quant_bits - 1)) - 1 dequant_op_node = graph.create_op_node( op_type='fake_dequantize_max_abs', @@ -408,17 +408,17 @@ class QuantizationFreezePass(object): for op_node in ops: op_name = op_node.name() if op_name in self._fake_quant_op_names: - input_arg_name = op_node.op().input('X')[0] + input_arg_name = op_node.input('X')[0] if input_arg_name in persistable_vars: if self._weight_quantize_type == 'abs_max': param = self._load_var(input_arg_name) scale_v = np.max(np.abs(param)) else: - scale_v = self._load_var(op_node.op().output('OutScale') - [0])[0] + scale_v = self._load_var( + op_node.output('OutScale')[0])[0] self._var_scale_map[input_arg_name] = scale_v else: - scale_v = graph.var_node(op_node.op().output('OutScale')[0]) + scale_v = graph.var_node(op_node.output('OutScale')[0]) self._var_scale_map[input_arg_name] = scale_v if input_arg_name in persistable_vars: self._remove_fake_quant_and_dequant_op(graph, op_node) @@ -454,8 +454,8 @@ class QuantizationFreezePass(object): return graph def _remove_fake_quant_and_dequant_op(self, graph, op_node): - k = op_node.op().output('Out')[0] - v = op_node.op().input('X')[0] + k = op_node.output('Out')[0] + v = op_node.input('X')[0] if v not in self._op_input_rename_map: self._op_input_rename_map[k] = v else: @@ -493,9 +493,9 @@ class QuantizationFreezePass(object): output_var_node = op_node.outputs[0] dequant_var_node = graph.create_var_node( name=self._dequantized_var_name(output_var_node.name()), - var_type=output_var_node.var().type(), - shape=output_var_node.var().shape(), - var_dtype=output_var_node.var().dtype()) + var_type=output_var_node.type(), + shape=output_var_node.shape(), + var_dtype=output_var_node.dtype()) dequant_op_node = graph.create_op_node( op_type='fake_dequantize_max_abs', attrs={ @@ -615,8 +615,8 @@ class ConvertToInt8Pass(object): int8_var_node_name = var_node.name() + ".int8" int8_var_node = graph.create_persistable_node( name=cpt.to_text(int8_var_node_name), - var_type=var_node.var().type(), - shape=var_node.var().shape(), + var_type=var_node.type(), + shape=var_node.shape(), var_dtype=core.VarDesc.VarType.INT8) array = self._load_var(var_node.name()) self._scope.var(int8_var_node_name) @@ -672,7 +672,7 @@ class TransformForMobilePass(object): for op_node in ops: name = op_node.name() if name in self._fake_quant_op_names: - op_node.op().set_type('quantize') + op_node.set_type('quantize') quant_node = graph.create_op_node_from_desc(op_node.op()) for input_node in op_node.inputs: graph.link_to(input_node, quant_node) @@ -680,7 +680,7 @@ class TransformForMobilePass(object): graph.link_to(quant_node, output_node) graph.safe_remove_nodes(op_node) if name in self._fake_dequant_op_names: - op_node.op().set_type('dequantize') + op_node.set_type('dequantize') dequant_node = graph.create_op_node_from_desc(op_node.op()) for input_node in op_node.inputs: graph.link_to(input_node, dequant_node) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 70c100d9e..8c62d2f28 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1754,6 +1754,39 @@ class IrVarNode(IrNode): "The node variable description cannot be None." return self.node.var().persistable() + def type(self): + """ + Return the variable type. + + Returns: + core.VarDesc.VarType: the variable type. + """ + assert self.node.var() is not None, \ + "The node variable description cannot be None." + return self.node.var().type() + + def dtype(self): + """ + Return the variable data type. + + Returns: + core.VarDesc.VarType: the variable data type. + """ + assert self.node.var() is not None, \ + "The node variable description cannot be None." + return self.node.var().dtype() + + def shape(self): + """ + Return the variable shape. + + Returns: + list: the variable shape. + """ + assert self.node.var() is not None, \ + "The node variable description cannot be None." + return self.node.var().shape() + @property def inputs(self): """ @@ -1804,6 +1837,45 @@ class IrOpNode(IrNode): "The node operator description cannot be None." self.node.op()._rename_input(old_input_name, new_input_name) + def input(self, name): + """ + Get the argument name list by the parameter name for input. + + Args: + name(str): the parameter name. + + Returns: + list(str): the argument name list. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + return self.node.op().input(name) + + def output(self, name): + """ + Get the argument name list by the parameter name for output. + + Args: + name(str): the parameter name. + + Returns: + list(str): the argument name list. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + return self.node.op().output(name) + + def set_type(self, new_type): + """ + Change the operator type into new type. + + Args: + new_type(str): new operator type to be set. + """ + assert self.node.op() is not None, \ + "The node operator description cannot be None." + return self.node.op().set_type(new_type) + @property def inputs(self): """ -- GitLab