提交 7c8f7df2 编写于 作者: Z Zhen Wang

add some op_des funs to IrOpNode and add some var_des funs to IrVarNode. test=develop

上级 33f99d61
......@@ -231,14 +231,14 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_var_node = graph.create_var_node(
name=self._quantized_scale_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={
......@@ -261,15 +261,15 @@ class QuantizationTransformPass(object):
quant_var_node = graph.create_var_node(
name=self._quantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
scale_in_node = graph.create_persistable_node(
name=self._quantized_scale_name(var_node.name()),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.var().dtype())
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
......@@ -282,7 +282,7 @@ class QuantizationTransformPass(object):
name=unique_name.generate('scales'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
var_dtype=var_node.var().dtype())
var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
......@@ -317,9 +317,9 @@ class QuantizationTransformPass(object):
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(var_node.name()),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=var_node.var().dtype())
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=var_node.dtype())
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
......@@ -408,17 +408,17 @@ class QuantizationFreezePass(object):
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0]
input_arg_name = op_node.input('X')[0]
if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param))
else:
scale_v = self._load_var(op_node.op().output('OutScale')
[0])[0]
scale_v = self._load_var(
op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0])
scale_v = graph.var_node(op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node)
......@@ -454,8 +454,8 @@ class QuantizationFreezePass(object):
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0]
v = op_node.op().input('X')[0]
k = op_node.output('Out')[0]
v = op_node.input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
else:
......@@ -493,9 +493,9 @@ class QuantizationFreezePass(object):
output_var_node = op_node.outputs[0]
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(),
shape=output_var_node.var().shape(),
var_dtype=output_var_node.var().dtype())
var_type=output_var_node.type(),
shape=output_var_node.shape(),
var_dtype=output_var_node.dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={
......@@ -615,8 +615,8 @@ class ConvertToInt8Pass(object):
int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_persistable_node(
name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_type=var_node.type(),
shape=var_node.shape(),
var_dtype=core.VarDesc.VarType.INT8)
array = self._load_var(var_node.name())
self._scope.var(int8_var_node_name)
......@@ -672,7 +672,7 @@ class TransformForMobilePass(object):
for op_node in ops:
name = op_node.name()
if name in self._fake_quant_op_names:
op_node.op().set_type('quantize')
op_node.set_type('quantize')
quant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, quant_node)
......@@ -680,7 +680,7 @@ class TransformForMobilePass(object):
graph.link_to(quant_node, output_node)
graph.safe_remove_nodes(op_node)
if name in self._fake_dequant_op_names:
op_node.op().set_type('dequantize')
op_node.set_type('dequantize')
dequant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, dequant_node)
......
......@@ -1754,6 +1754,39 @@ class IrVarNode(IrNode):
"The node variable description cannot be None."
return self.node.var().persistable()
def type(self):
"""
Return the variable type.
Returns:
core.VarDesc.VarType: the variable type.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().type()
def dtype(self):
"""
Return the variable data type.
Returns:
core.VarDesc.VarType: the variable data type.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().dtype()
def shape(self):
"""
Return the variable shape.
Returns:
list: the variable shape.
"""
assert self.node.var() is not None, \
"The node variable description cannot be None."
return self.node.var().shape()
@property
def inputs(self):
"""
......@@ -1804,6 +1837,45 @@ class IrOpNode(IrNode):
"The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name)
def input(self, name):
"""
Get the argument name list by the parameter name for input.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input(name)
def output(self, name):
"""
Get the argument name list by the parameter name for output.
Args:
name(str): the parameter name.
Returns:
list(str): the argument name list.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output(name)
def set_type(self, new_type):
"""
Change the operator type into new type.
Args:
new_type(str): new operator type to be set.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().set_type(new_type)
@property
def inputs(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册