未验证 提交 f7f5044b 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16489 from wzzju/fix_slim_quant_bugs

Clean codes and fix some bugs.
......@@ -26,6 +26,17 @@ __all__ = [
]
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.'
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
tensor = scope.var(var_node.name()).get_tensor()
tensor.set(value, place)
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
......@@ -88,14 +99,14 @@ class QuantizationTransformPass(object):
assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'."
if activation_quantize_type not in quant_type:
raise ValueError(
"Unknown activation_quantize_type : '%s'. It can only be ",
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(activation_quantize_type))
"Unknown activation_quantize_type : '%s'. It can only be "
"'abs_max' or 'range_abs_max' or 'moving_average_abs_max'." %
(str(activation_quantize_type)))
if weight_quantize_type not in quant_type:
raise ValueError(
"Unknown weight_quantize_type: '%s'. It can only be ",
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.",
str(weight_quantize_type))
"Unknown weight_quantize_type: '%s'. It can only be "
"'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'."
% (str(weight_quantize_type)))
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
......@@ -121,8 +132,6 @@ class QuantizationTransformPass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
#sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
......@@ -203,9 +212,12 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self._init_var_node(
global_step_in, np.zeros(
[1], dtype='int64'))
_init_var_node(
global_step_in,
np.zeros(
[1], dtype='int64'),
self._scope,
self._place)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
......@@ -284,7 +296,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
......@@ -299,9 +316,13 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(
scales_node, np.zeros(
[self._window_size], dtype=data_type))
_init_var_node(
scales_node,
np.zeros(
[self._window_size], dtype=data_type),
self._scope,
self._place)
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
......@@ -343,7 +364,12 @@ class QuantizationTransformPass(object):
var_dtype=var_node.dtype())
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
_init_var_node(
scale_in_node,
np.array(
[0.001], dtype=data_type),
self._scope,
self._place)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
......@@ -356,13 +382,23 @@ class QuantizationTransformPass(object):
shape=[1])
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
_init_var_node(
scale_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
_init_var_node(
accum_in_node,
np.ones(
[1], dtype=data_type),
self._scope,
self._place)
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
......@@ -482,16 +518,6 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
......@@ -594,8 +620,8 @@ class QuantizationFreezePass(object):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = self._to_node(op_node.outputs,
op_node.output('OutScale')[0])
scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes()
......@@ -627,8 +653,8 @@ class QuantizationFreezePass(object):
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = self._to_node(op_node.outputs, op_node.output('Out')[0])
v = self._to_node(op_node.inputs, op_node.input('X')[0])
k = graph._find_node_by_name(op_node.outputs, op_node.output('Out')[0])
v = graph._find_node_by_name(op_node.inputs, op_node.input('X')[0])
if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k.node] = v
else:
......@@ -663,8 +689,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
......@@ -672,7 +698,9 @@ class QuantizationFreezePass(object):
var_dtype=output_var_node.dtype())
data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
_init_var_node(weight_scale_node,
channel_scale.astype(data_type), self._scope,
self._place)
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -724,8 +752,8 @@ class QuantizationFreezePass(object):
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
output_var_node = graph._find_node_by_name(
op_node.outputs, op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -746,24 +774,6 @@ class QuantizationFreezePass(object):
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _to_node(self, nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
......
......@@ -45,13 +45,14 @@ class QuantizationStrategy(Strategy):
activation_bits=8,
weight_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
save_in_nodes=None,
save_out_nodes=None):
"""
Args:
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
float_model_save_path(str): The path to save model with float weights.
float_model_save_path(str): The path to save model with float weights.
None means it doesn't save float model. defalut: None.
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
None means it doesn't save mobile model. defalut: None.
......@@ -66,9 +67,11 @@ class QuantizationStrategy(Strategy):
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
save_in_nodes(list<str>): A list of variable names used to prune graph
weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained.
save_in_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
save_out_nodes(list<str>): A list of variable names used to prune graph
save_out_nodes(list<str>): A list of variable names used to prune graph
for saving inference model.
"""
......@@ -81,6 +84,7 @@ class QuantizationStrategy(Strategy):
self.activation_bits = activation_bits
self.weight_bits = weight_bits
self.activation_quantize_type = activation_quantize_type
self.weight_quantize_type = weight_quantize_type
self.save_out_nodes = save_out_nodes
self.save_in_nodes = save_in_nodes
......@@ -100,7 +104,8 @@ class QuantizationStrategy(Strategy):
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits,
activation_quantize_type=self.activation_quantize_type)
activation_quantize_type=self.activation_quantize_type,
weight_quantize_type=self.weight_quantize_type)
transform_pass.apply(train_ir_graph)
transform_pass.apply(test_ir_graph)
......@@ -134,7 +139,8 @@ class QuantizationStrategy(Strategy):
scope=context.scope,
place=context.place,
weight_bits=self.weight_bits,
activation_bits=self.activation_bits)
activation_bits=self.activation_bits,
weight_quantize_type=self.weight_quantize_type)
freeze_pass.apply(test_ir_graph)
# for other strategies
......
......@@ -35,6 +35,8 @@ strategies:
start_epoch: 0
end_epoch: 0
float_model_save_path: './output/float'
mobile_model_save_path: './output/mobile'
int8_model_save_path: './output/int8'
weight_bits: 8
activation_bits: 8
weight_quantize_type: 'abs_max'
......
......@@ -256,8 +256,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
#transform_pass = QuantizationTransformPass(
# scope=scope, place=place, activation_quantize_type=activation_quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
......@@ -315,7 +313,6 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(
scope=scope, place=place, weight_quantize_type=weight_quant_type)
#freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
if not for_ci:
marked_nodes = set()
......
......@@ -2347,40 +2347,6 @@ class IrGraph(object):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def _find_var_node(self, key):
"""
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name or id.
"""
target_var_node = None
var_nodes = self.all_var_nodes()
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
......@@ -2525,14 +2491,6 @@ class IrGraph(object):
core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
......@@ -2540,16 +2498,17 @@ class IrGraph(object):
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
self._find_node_by_name(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
self._find_node_by_name(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self._find_node_by_name(node.outputs,
each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self):
......@@ -2662,6 +2621,17 @@ class IrGraph(object):
program = Program._construct_from_desc(desc)
return program
def _find_node_by_name(self, nodes, node_name):
"""
Find a node in the giving nodes set by the name.
"""
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _update_desc_attr(self, desc, name, val):
"""
Update the value of desc's attribute by attribute's name.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册