未验证 提交 0fe72469 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the max-pool2d quantization support and the partial quantization support. (#19310)

* add pool2d quantization support, only for max-pooling.

* add the partial quantization support.
上级 d49c2bad
......@@ -26,6 +26,23 @@ __all__ = [
'AddQuantDequantPass'
]
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d']
_fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max'
]
_fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
_out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
]
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
......@@ -47,7 +64,8 @@ class QuantizationTransformPass(object):
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000,
moving_rate=0.9):
moving_rate=0.9,
skip_pattern='skip_quant'):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
......@@ -92,6 +110,7 @@ class QuantizationTransformPass(object):
self._place = place
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._skip_pattern = skip_pattern
quant_type = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max',
......@@ -114,7 +133,7 @@ class QuantizationTransformPass(object):
self._window_size = window_size
self._moving_rate = moving_rate
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._quantizable_ops = _quantizable_op_list
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
......@@ -138,6 +157,16 @@ class QuantizationTransformPass(object):
dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node):
pool_skipped = op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'avg'
user_skipped = isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if pool_skipped or user_skipped:
op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
......@@ -188,14 +217,28 @@ class QuantizationTransformPass(object):
if not self._is_test:
self._create_global_step(graph)
ops = graph.all_op_nodes()
# Do the preproccess of quantization, such as skipping some ops
# for not being quantized.
for op in ops:
if op.name() in self._quantizable_ops or \
op.name() in self._quantizable_grad_ops:
_quant_preprocess(op)
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self._quantizable_grad_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_backward(graph, op)
graph.resolve_hazard()
return graph
......@@ -571,16 +614,10 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._quantizable_ops = _quantizable_op_list
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict()
......@@ -635,6 +672,10 @@ class QuantizationFreezePass(object):
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
continue
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
......@@ -727,6 +768,13 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
raise ValueError("The op %s has more than one inputs "
"and all of them are not persistable. "
"Now, it is not supported!" % (op_node.name()))
max_range = 1
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
for var_node in op_node.inputs:
name = var_node.name()
if name not in op_node.input_arg_names():
......@@ -739,13 +787,12 @@ class QuantizationFreezePass(object):
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
if original_var_name in persistable_vars:
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name)
max_range = param_range * act_range / scale_v
max_range *= param_range / scale_v
else:
max_range *= act_range
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
......@@ -850,7 +897,7 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.'
self._scope = scope
self._place = place
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._quantizable_ops = _quantizable_op_list
def apply(self, graph):
"""
......@@ -866,6 +913,10 @@ class ConvertToInt8Pass(object):
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
continue
for var_node in op_node.inputs:
name = var_node.name()
if name in persistable_vars:
......@@ -924,14 +975,8 @@ class TransformForMobilePass(object):
"""
def __init__(self):
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
def apply(self, graph):
"""
......@@ -980,12 +1025,7 @@ class ScaleForTrainingPass(object):
self._place = place
self._moving_rate = moving_rate
self._is_test = None
self._teller_set = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
"conv2d_transpose", "leaky_relu"
]
self._teller_set = _out_scale_op_list
def apply(self, graph):
"""
......@@ -1087,12 +1127,7 @@ class ScaleForInferencePass(object):
scope(fluid.Scope): The scope is used to initialize these new parameters.
"""
self._scope = scope
self._teller_set = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
"conv2d_transpose", "leaky_relu"
]
self._teller_set = _out_scale_op_list
def apply(self, graph):
"""
......@@ -1135,7 +1170,7 @@ class AddQuantDequantPass(object):
self._moving_rate = moving_rate
self._quant_bits = quant_bits
self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"]
self._target_ops = ["elementwise_add"]
def apply(self, graph):
"""
......
......@@ -123,7 +123,7 @@ class TestGraph(unittest.TestCase):
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
backup_graph.draw('./origin', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
......
......@@ -72,13 +72,14 @@ def residual_block(num):
return loss
def conv_net(img, label):
def conv_net(img, label, quant_skip_pattern):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
......@@ -87,8 +88,11 @@ def conv_net(img, label):
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
with fluid.name_scope(quant_skip_pattern):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
......@@ -107,7 +111,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
'mul_grad': ['X', 'Y']
}
def check_program(self, transform_pass, program):
def check_program(self, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
......@@ -127,7 +131,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, activation_quant_type, for_ci=False):
def linear_fc_quant(self, activation_quant_type, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -135,7 +139,6 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
......@@ -150,7 +153,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_fc_' + activation_quant_type,
marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
self.check_program(program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci:
val_marked_nodes = set()
......@@ -169,7 +172,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True)
def residual_block_quant(self, activation_quant_type, for_ci=False):
def residual_block_quant(self, activation_quant_type, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -177,7 +180,6 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
......@@ -192,7 +194,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_residual_' + activation_quant_type,
marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
self.check_program(program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci:
val_marked_nodes = set()
......@@ -218,7 +220,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
seed,
activation_quant_type,
weight_quant_type='abs_max',
for_ci=False):
for_ci=True,
quant_skip_pattern='skip_quant'):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
......@@ -228,7 +231,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
loss = conv_net(img, label, quant_skip_pattern)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
......@@ -255,7 +258,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
scope=scope,
place=place,
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type)
weight_quantize_type=weight_quant_type,
skip_pattern=quant_skip_pattern)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
......
......@@ -2728,6 +2728,8 @@ class IrGraph(object):
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
if not os.path.exists(save_path):
os.makedirs(save_path)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册