提交 7c8f7df2 编写于 作者: Z Zhen Wang

add some op_des funs to IrOpNode and add some var_des funs to IrVarNode. test=develop

上级 33f99d61
...@@ -231,14 +231,14 @@ class QuantizationTransformPass(object): ...@@ -231,14 +231,14 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()), name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node( scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max', op_type='fake_quantize_abs_max',
attrs={ attrs={
...@@ -261,15 +261,15 @@ class QuantizationTransformPass(object): ...@@ -261,15 +261,15 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node( quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()), name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node( scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()), name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001) self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
...@@ -282,7 +282,7 @@ class QuantizationTransformPass(object): ...@@ -282,7 +282,7 @@ class QuantizationTransformPass(object):
name=unique_name.generate('scales'), name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size], shape=[self._window_size],
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0) self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
...@@ -317,9 +317,9 @@ class QuantizationTransformPass(object): ...@@ -317,9 +317,9 @@ class QuantizationTransformPass(object):
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()), name=self._dequantized_var_name(var_node.name()),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=var_node.var().dtype()) var_dtype=var_node.dtype())
max_range = (1 << (quant_bits - 1)) - 1 max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
...@@ -408,17 +408,17 @@ class QuantizationFreezePass(object): ...@@ -408,17 +408,17 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_quant_op_names: if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0] input_arg_name = op_node.input('X')[0]
if input_arg_name in persistable_vars: if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max': if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name) param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param)) scale_v = np.max(np.abs(param))
else: else:
scale_v = self._load_var(op_node.op().output('OutScale') scale_v = self._load_var(
[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
else: else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0]) scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars: if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
...@@ -454,8 +454,8 @@ class QuantizationFreezePass(object): ...@@ -454,8 +454,8 @@ class QuantizationFreezePass(object):
return graph return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node): def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0] k = op_node.output('Out')[0]
v = op_node.op().input('X')[0] v = op_node.input('X')[0]
if v not in self._op_input_rename_map: if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v self._op_input_rename_map[k] = v
else: else:
...@@ -493,9 +493,9 @@ class QuantizationFreezePass(object): ...@@ -493,9 +493,9 @@ class QuantizationFreezePass(object):
output_var_node = op_node.outputs[0] output_var_node = op_node.outputs[0]
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(), var_type=output_var_node.type(),
shape=output_var_node.var().shape(), shape=output_var_node.shape(),
var_dtype=output_var_node.var().dtype()) var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node( dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs', op_type='fake_dequantize_max_abs',
attrs={ attrs={
...@@ -615,8 +615,8 @@ class ConvertToInt8Pass(object): ...@@ -615,8 +615,8 @@ class ConvertToInt8Pass(object):
int8_var_node_name = var_node.name() + ".int8" int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_persistable_node( int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name), name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(), var_type=var_node.type(),
shape=var_node.var().shape(), shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.INT8) var_dtype=core.VarDesc.VarType.INT8)
array = self._load_var(var_node.name()) array = self._load_var(var_node.name())
self._scope.var(int8_var_node_name) self._scope.var(int8_var_node_name)
...@@ -672,7 +672,7 @@ class TransformForMobilePass(object): ...@@ -672,7 +672,7 @@ class TransformForMobilePass(object):
for op_node in ops: for op_node in ops:
name = op_node.name() name = op_node.name()
if name in self._fake_quant_op_names: if name in self._fake_quant_op_names:
op_node.op().set_type('quantize') op_node.set_type('quantize')
quant_node = graph.create_op_node_from_desc(op_node.op()) quant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs: for input_node in op_node.inputs:
graph.link_to(input_node, quant_node) graph.link_to(input_node, quant_node)
...@@ -680,7 +680,7 @@ class TransformForMobilePass(object): ...@@ -680,7 +680,7 @@ class TransformForMobilePass(object):
graph.link_to(quant_node, output_node) graph.link_to(quant_node, output_node)
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
if name in self._fake_dequant_op_names: if name in self._fake_dequant_op_names:
op_node.op().set_type('dequantize') op_node.set_type('dequantize')
dequant_node = graph.create_op_node_from_desc(op_node.op()) dequant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs: for input_node in op_node.inputs:
graph.link_to(input_node, dequant_node) graph.link_to(input_node, dequant_node)
......
...@@ -1754,6 +1754,39 @@ class IrVarNode(IrNode): ...@@ -1754,6 +1754,39 @@ class IrVarNode(IrNode):
"The node variable description cannot be None." "The node variable description cannot be None."
return self.node.var().persistable() return self.node.var().persistable()
def type(self):
"""
Return the variable type.
Returns:
core.VarDesc.VarType: the variable type.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().type()
def dtype(self):
"""
Return the variable data type.
Returns:
core.VarDesc.VarType: the variable data type.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().dtype()
def shape(self):
"""
Return the variable shape.
Returns:
list: the variable shape.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().shape()
@property @property
def inputs(self): def inputs(self):
""" """
...@@ -1804,6 +1837,45 @@ class IrOpNode(IrNode): ...@@ -1804,6 +1837,45 @@ class IrOpNode(IrNode):
"The node operator description cannot be None." "The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name) self.node.op()._rename_input(old_input_name, new_input_name)
def input(self, name):
"""
Get the argument name list by the parameter name for input.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input(name)
def output(self, name):
"""
Get the argument name list by the parameter name for output.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output(name)
def set_type(self, new_type):
"""
Change the operator type into new type.
Args:
new_type(str): new operator type to be set.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().set_type(new_type)
@property @property
def inputs(self): def inputs(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册