未验证 提交 d68a02af 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16456 from wzzju/fix_quan_hang

Fix quantization hang bugs.
...@@ -14,15 +14,10 @@ ...@@ -14,15 +14,10 @@
import collections import collections
import numpy as np import numpy as np
import six
from ..... import compat as cpt from ..... import compat as cpt
from .... import core from .... import core
from .... import Executor
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
from ....framework import Program
from ....initializer import Constant
from ....initializer import NumpyArrayInitializer
from .... import unique_name from .... import unique_name
__all__ = [ __all__ = [
...@@ -107,7 +102,6 @@ class QuantizationTransformPass(object): ...@@ -107,7 +102,6 @@ class QuantizationTransformPass(object):
self._window_size = window_size self._window_size = window_size
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
...@@ -127,7 +121,8 @@ class QuantizationTransformPass(object): ...@@ -127,7 +121,8 @@ class QuantizationTransformPass(object):
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear() #sequential_execution = core.get_pass('sequential_execution_pass')
#sequential_execution.apply(graph.graph)
self._is_test = graph.is_test() self._is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
...@@ -135,6 +130,8 @@ class QuantizationTransformPass(object): ...@@ -135,6 +130,8 @@ class QuantizationTransformPass(object):
def _transform_forward(graph, op): def _transform_forward(graph, op):
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
else: else:
...@@ -168,6 +165,8 @@ class QuantizationTransformPass(object): ...@@ -168,6 +165,8 @@ class QuantizationTransformPass(object):
def _transform_backward(graph, op): def _transform_backward(graph, op):
no_dequanted_input_vars = True no_dequanted_input_vars = True
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars: if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()] dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op) graph.update_input_link(var_node, dequant_var_node, op)
...@@ -188,25 +187,7 @@ class QuantizationTransformPass(object): ...@@ -188,25 +187,7 @@ class QuantizationTransformPass(object):
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops: if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op) _transform_backward(graph, op)
graph.resolve_hazard()
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
return graph return graph
def _create_global_step(self, graph): def _create_global_step(self, graph):
...@@ -222,8 +203,9 @@ class QuantizationTransformPass(object): ...@@ -222,8 +203,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=core.VarDesc.VarType.INT64) var_dtype=core.VarDesc.VarType.INT64)
self._need_initialized[global_step_in.var()] = \ self._init_var_node(
Constant(value=0, force_cpu=True) global_step_in, np.zeros(
[1], dtype='int64'))
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor. # The attribute of `op_role` is needed by ParallelExecutor.
...@@ -300,7 +282,9 @@ class QuantizationTransformPass(object): ...@@ -300,7 +282,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001) data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
...@@ -313,7 +297,11 @@ class QuantizationTransformPass(object): ...@@ -313,7 +297,11 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size], shape=[self._window_size],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0) data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(
scales_node, np.zeros(
[self._window_size], dtype=data_type))
inputs['Iter'] = self._global_step inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
attrs = { attrs = {
...@@ -353,7 +341,9 @@ class QuantizationTransformPass(object): ...@@ -353,7 +341,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.dtype()) var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001) data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node} ins = {'X': var_node, 'InScale': scale_in_node}
...@@ -364,13 +354,15 @@ class QuantizationTransformPass(object): ...@@ -364,13 +354,15 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1]) shape=[1])
self._need_initialized[state_in_node.var()] = Constant(value=1) data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
accum_in_node = graph.create_persistable_node( accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'), name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(), var_dtype=var_node.dtype(),
shape=[1]) shape=[1])
self._need_initialized[accum_in_node.var()] = Constant(value=1) self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
state_out_node = graph.create_var_node_from_desc(state_in_node.var( state_out_node = graph.create_var_node_from_desc(state_in_node.var(
)) ))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var( accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
...@@ -490,6 +482,16 @@ class QuantizationTransformPass(object): ...@@ -490,6 +482,16 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
Return quantized variable name for the input `var_name`. Return quantized variable name for the input `var_name`.
...@@ -592,7 +594,8 @@ class QuantizationFreezePass(object): ...@@ -592,7 +594,8 @@ class QuantizationFreezePass(object):
self._weight_bits) self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
else: else:
scale_v = graph.var_node(op_node.output('OutScale')[0]) scale_v = self._to_node(op_node.outputs,
op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -613,32 +616,35 @@ class QuantizationFreezePass(object): ...@@ -613,32 +616,35 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops # insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() if var_node.node in self._op_output_rename_map:
if name in self._op_output_rename_map: old_in = var_node
old_in = graph.var_node(name) new_in = self._op_output_rename_map[var_node.node]
new_in = self._op_output_rename_map[name]
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
# remove the unused var node in the graph # remove the unused var node in the graph
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node): def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.output('Out')[0] k = self._to_node(op_node.outputs, op_node.output('Out')[0])
v = op_node.input('X')[0] v = self._to_node(op_node.inputs, op_node.input('X')[0])
if v not in self._op_input_rename_map: if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k] = v self._op_input_rename_map[k.node] = v
else: else:
self._op_input_rename_map[k] = self._op_input_rename_map[v] self._op_input_rename_map[k.node] = self._op_input_rename_map[
v.node]
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
def _insert_post_channel_dequant_op(self, graph, op_node): def _insert_post_channel_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in self._op_input_rename_map: if name not in op_node.input_arg_names():
old_in = graph.var_node(name) continue
new_in = graph.var_node(self._op_input_rename_map[name]) if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
...@@ -653,28 +659,20 @@ class QuantizationFreezePass(object): ...@@ -653,28 +659,20 @@ class QuantizationFreezePass(object):
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0] output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node( weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'), name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[channel_scale.shape[0]], shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype()) var_dtype=output_var_node.dtype())
init_program = Program() data_type = 'float64' if output_var_node.dtype(
weight_scale_var = init_program.global_block().create_var( ) == core.VarDesc.VarType.FP64 else 'float32'
name=weight_scale_node.name(), self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
shape=weight_scale_node.shape(),
dtype=weight_scale_node.dtype(),
type=weight_scale_node.type(),
lod_level=weight_scale_node.var().lod_level(),
persistable=weight_scale_node.persistable())
initializer = NumpyArrayInitializer(value=channel_scale)
initializer(weight_scale_var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -695,16 +693,18 @@ class QuantizationFreezePass(object): ...@@ -695,16 +693,18 @@ class QuantizationFreezePass(object):
graph.link_to(scale_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(weight_scale_node, dequant_op_node) graph.link_to(weight_scale_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node return dequant_var_node
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in self._op_input_rename_map: if name not in op_node.input_arg_names():
old_in = graph.var_node(name) continue
new_in = graph.var_node(self._op_input_rename_map[name]) if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
...@@ -720,11 +720,12 @@ class QuantizationFreezePass(object): ...@@ -720,11 +720,12 @@ class QuantizationFreezePass(object):
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name())) " more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0] output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node( dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()), name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(), var_type=output_var_node.type(),
...@@ -742,9 +743,27 @@ class QuantizationFreezePass(object): ...@@ -742,9 +743,27 @@ class QuantizationFreezePass(object):
graph.link_to(output_var_node, dequant_op_node) graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node) graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node) graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _to_node(self, nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _load_var(self, name): def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor()) return np.array(self._scope.find_var(name).get_tensor())
...@@ -848,6 +867,7 @@ class ConvertToInt8Pass(object): ...@@ -848,6 +867,7 @@ class ConvertToInt8Pass(object):
# remove the unused var node in the graph # remove the unused var node in the graph
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
graph.resolve_hazard()
return graph return graph
def _convert_to_int8(self, graph, var_node): def _convert_to_int8(self, graph, var_node):
...@@ -930,5 +950,5 @@ class TransformForMobilePass(object): ...@@ -930,5 +950,5 @@ class TransformForMobilePass(object):
for output_node in op_node.outputs: for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node) graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node) graph.safe_remove_nodes(op_node)
graph.resolve_hazard()
return graph return graph
...@@ -2052,6 +2052,28 @@ class IrOpNode(IrNode): ...@@ -2052,6 +2052,28 @@ class IrOpNode(IrNode):
else: else:
desc._set_attr(name, val) desc._set_attr(name, val)
def input_arg_names(self):
"""
Return input arguments' names of this op node.
Returns:
list(str): input arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input_arg_names()
def output_arg_names(self):
"""
Return output arguments' names of this op node.
Returns:
list(str): output arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output_arg_names()
@property @property
def inputs(self): def inputs(self):
""" """
...@@ -2142,31 +2164,38 @@ class IrGraph(object): ...@@ -2142,31 +2164,38 @@ class IrGraph(object):
""" """
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()} return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name): def _find_var_node(self, key):
""" """
Get a variable node by name from the graph. Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args: Args:
name(str): the name of the variable node. key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises: Raises:
ValueError: The If input's type is not str, or this graph ValueError: If this graph doesn't have a variable with the giving name or id.
doesn't have a variable with the giving name.
Returns: Returns:
IrVarNode: the variable node with the giving name. IrVarNode: the variable node with the giving name or id.
""" """
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None target_var_node = None
var_nodes = self.all_var_nodes() var_nodes = self.all_var_nodes()
for var_node in var_nodes: if isinstance(key, six.string_types):
if var_node.name() == name: for var_node in var_nodes:
target_var_node = var_node if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None: if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name) raise ValueError("var_node %s not in this graph" % key)
return target_var_node return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype): def create_persistable_node(self, name, var_type, shape, var_dtype):
...@@ -2312,6 +2341,34 @@ class IrGraph(object): ...@@ -2312,6 +2341,34 @@ class IrGraph(object):
original_nodes = {n.node for n in remove_nodes} original_nodes = {n.node for n in remove_nodes}
core.graph_safe_remove_nodes(self.graph, original_nodes) core.graph_safe_remove_nodes(self.graph, original_nodes)
def resolve_hazard(self):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = core.topology_sort(self.graph)
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
self.graph.resolve_hazard(var_nodes)
def has_circle(self): def has_circle(self):
""" """
Check if the graph has a circle. Check if the graph has a circle.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册