提交 2ccbfd5e 编写于 作者: Z Zhen Wang

Fix some bugs for quantization passes.

上级 ec11135d
......@@ -14,15 +14,10 @@
import collections
import numpy as np
import six
from ..... import compat as cpt
from .... import core
from .... import Executor
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Program
from ....initializer import Constant
from ....initializer import NumpyArrayInitializer
from .... import unique_name
__all__ = [
......@@ -31,6 +26,35 @@ __all__ = [
]
def _resolve_hazard(graph):
def _to_node(nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n.node
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
ordered_nodes = graph.topology_sort()
var_nodes = dict()
for node in ordered_nodes:
if node.is_op() and node.op() is not None:
for each_var_name in node.op().input_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.inputs, each_var_name)
]
for each_var_name in node.op().output_arg_names():
if each_var_name not in var_nodes:
var_nodes[each_var_name] = [
_to_node(node.outputs, each_var_name)
]
else:
var_nodes[each_var_name].append(
_to_node(node.outputs, each_var_name))
graph.graph.resolve_hazard(var_nodes)
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
......@@ -107,7 +131,6 @@ class QuantizationTransformPass(object):
self._window_size = window_size
self._moving_rate = moving_rate
self._need_initialized = collections.OrderedDict()
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [
......@@ -127,7 +150,8 @@ class QuantizationTransformPass(object):
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear()
sequential_execution = core.get_pass('sequential_execution_pass')
sequential_execution.apply(graph.graph)
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
......@@ -135,6 +159,8 @@ class QuantizationTransformPass(object):
def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
......@@ -168,6 +194,8 @@ class QuantizationTransformPass(object):
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
graph.update_input_link(var_node, dequant_var_node, op)
......@@ -188,25 +216,7 @@ class QuantizationTransformPass(object):
for op in ops:
if op.name() in self._quantizable_grad_ops:
_transform_backward(graph, op)
if len(self._need_initialized) > 0:
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
_resolve_hazard(graph)
return graph
def _create_global_step(self, graph):
......@@ -222,8 +232,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self._need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
self._init_var_node(
global_step_in, np.zeros(
[1], dtype='int64'))
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
......@@ -300,7 +311,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
......@@ -313,7 +326,11 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self._window_size],
var_dtype=var_node.dtype())
self._need_initialized[scales_node.var()] = Constant(value=0)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(
scales_node, np.zeros(
[self._window_size], dtype=data_type))
inputs['Iter'] = self._global_step
outputs['OutScales'] = scales_node
attrs = {
......@@ -353,7 +370,9 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.dtype())
self._need_initialized[scale_in_node.var()] = Constant(value=0.001)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.array([0.001], dtype=data_type))
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
ins = {'X': var_node, 'InScale': scale_in_node}
......@@ -364,13 +383,15 @@ class QuantizationTransformPass(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[state_in_node.var()] = Constant(value=1)
data_type = 'float64' if var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(scale_in_node, np.ones([1], dtype=data_type))
accum_in_node = graph.create_persistable_node(
name=unique_name.generate('accum'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
var_dtype=var_node.dtype(),
shape=[1])
self._need_initialized[accum_in_node.var()] = Constant(value=1)
self._init_var_node(accum_in_node, np.ones([1], dtype=data_type))
state_out_node = graph.create_var_node_from_desc(state_in_node.var(
))
accum_out_node = graph.create_var_node_from_desc(accum_in_node.var(
......@@ -490,6 +511,16 @@ class QuantizationTransformPass(object):
graph.link_to(dequant_op_node, dequant_var_node)
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _quantized_var_name(self, var_name):
"""
Return quantized variable name for the input `var_name`.
......@@ -592,7 +623,8 @@ class QuantizationFreezePass(object):
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
else:
scale_v = graph.var_node(op_node.output('OutScale')[0])
scale_v = self._to_node(op_node.outputs,
op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
ops = graph.all_op_nodes()
......@@ -613,10 +645,9 @@ class QuantizationFreezePass(object):
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_output_rename_map:
old_in = graph.var_node(name)
new_in = self._op_output_rename_map[name]
if var_node.node in self._op_output_rename_map:
old_in = var_node
new_in = self._op_output_rename_map[var_node.node]
graph.update_input_link(old_in, new_in, op_node)
# remove the unused var node in the graph
......@@ -624,21 +655,24 @@ class QuantizationFreezePass(object):
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.output('Out')[0]
v = op_node.input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
k = self._to_node(op_node.outputs, op_node.output('Out')[0])
v = self._to_node(op_node.inputs, op_node.input('X')[0])
if v.node not in self._op_input_rename_map:
self._op_input_rename_map[k.node] = v
else:
self._op_input_rename_map[k] = self._op_input_rename_map[v]
self._op_input_rename_map[k.node] = self._op_input_rename_map[
v.node]
graph.safe_remove_nodes(op_node)
def _insert_post_channel_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
if name not in op_node.input_arg_names():
continue
if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
......@@ -653,28 +687,20 @@ class QuantizationFreezePass(object):
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0]
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
weight_scale_node = graph.create_persistable_node(
name=unique_name.generate('channel_scale'),
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[channel_scale.shape[0]],
var_dtype=output_var_node.dtype())
init_program = Program()
weight_scale_var = init_program.global_block().create_var(
name=weight_scale_node.name(),
shape=weight_scale_node.shape(),
dtype=weight_scale_node.dtype(),
type=weight_scale_node.type(),
lod_level=weight_scale_node.var().lod_level(),
persistable=weight_scale_node.persistable())
initializer = NumpyArrayInitializer(value=channel_scale)
initializer(weight_scale_var, init_program.global_block())
exe = Executor(self._place)
exe.run(program=init_program, scope=self._scope)
data_type = 'float64' if output_var_node.dtype(
) == core.VarDesc.VarType.FP64 else 'float32'
self._init_var_node(weight_scale_node, channel_scale.astype(data_type))
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -695,16 +721,18 @@ class QuantizationFreezePass(object):
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(weight_scale_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node
def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
if name not in op_node.input_arg_names():
continue
if var_node.node in self._op_input_rename_map:
old_in = var_node
new_in = self._op_input_rename_map[var_node.node]
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
......@@ -720,11 +748,12 @@ class QuantizationFreezePass(object):
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0]
output_var_node = self._to_node(op_node.outputs,
op_node.output_arg_names()[0])
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.type(),
......@@ -742,9 +771,27 @@ class QuantizationFreezePass(object):
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
self._op_output_rename_map[output_var_node.node] = dequant_var_node
return dequant_var_node
def _init_var_node(self, var_node, value):
assert isinstance(
value, np.ndarray), 'The type of value should be numpy array.'
assert self._scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self._place is not None, \
'The place cannot be set None when activation_quantize_type equals to range_abs_max.'
tensor = self._scope.var(var_node.name()).get_tensor()
tensor.set(value, self._place)
def _to_node(self, nodes, node_name):
target_node = None
for n in nodes:
if n.name() == node_name:
target_node = n
assert target_node is not None, "Cannot find the target node in the giving set."
return target_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
......
......@@ -1964,6 +1964,28 @@ class IrOpNode(IrNode):
else:
desc._set_attr(name, val)
def input_arg_names(self):
"""
Return input arguments' names of this op node.
Returns:
list(str): input arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().input_arg_names()
def output_arg_names(self):
"""
Return output arguments' names of this op node.
Returns:
list(str): output arguments' names of this op node.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
return self.node.op().output_arg_names()
@property
def inputs(self):
"""
......@@ -2054,31 +2076,38 @@ class IrGraph(object):
"""
return {IrOpNode(node) for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
def _find_var_node(self, key):
"""
Get a variable node by name from the graph.
Get a variable node by the `key` from this graph. The key
can be a node name or a node id.
WARNS:
There are some nodes may have the same name. So, be
cautious about using this method when you find the
target var node by its name.
Args:
name(str): the name of the variable node.
key(str|int): The str type denotes that the target variable node's name.
And the int type denotes that the target variable node's id.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
ValueError: If this graph doesn't have a variable with the giving name or id.
Returns:
IrVarNode: the variable node with the giving name.
IrVarNode: the variable node with the giving name or id.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_var_nodes()
if isinstance(key, six.string_types):
for var_node in var_nodes:
if var_node.name() == key:
target_var_node = var_node
elif isinstance(key, int):
for var_node in var_nodes:
if var_node.name() == name:
if var_node.id() == key:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
raise ValueError("var_node %s not in this graph" % key)
return target_var_node
def create_persistable_node(self, name, var_type, shape, var_dtype):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册