未验证 提交 0fe72469 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the max-pool2d quantization support and the partial quantization support. (#19310)

* add pool2d quantization support, only for max-pooling.

* add the partial quantization support.
上级 d49c2bad
...@@ -26,6 +26,23 @@ __all__ = [ ...@@ -26,6 +26,23 @@ __all__ = [
'AddQuantDequantPass' 'AddQuantDequantPass'
] ]
_quantizable_op_list = ['conv2d', 'depthwise_conv2d', 'mul', 'pool2d']
_fake_quant_op_list = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_moving_average_abs_max', 'fake_channel_wise_quantize_abs_max'
]
_fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
_out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
]
def _init_var_node(var_node, value, scope, place): def _init_var_node(var_node, value, scope, place):
assert isinstance(value, assert isinstance(value,
...@@ -47,7 +64,8 @@ class QuantizationTransformPass(object): ...@@ -47,7 +64,8 @@ class QuantizationTransformPass(object):
activation_quantize_type='abs_max', activation_quantize_type='abs_max',
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
window_size=10000, window_size=10000,
moving_rate=0.9): moving_rate=0.9,
skip_pattern='skip_quant'):
""" """
Convert and rewrite the IrGraph according to weight and Convert and rewrite the IrGraph according to weight and
activation quantization type. activation quantization type.
...@@ -92,6 +110,7 @@ class QuantizationTransformPass(object): ...@@ -92,6 +110,7 @@ class QuantizationTransformPass(object):
self._place = place self._place = place
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._skip_pattern = skip_pattern
quant_type = [ quant_type = [
'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
...@@ -114,7 +133,7 @@ class QuantizationTransformPass(object): ...@@ -114,7 +133,7 @@ class QuantizationTransformPass(object):
self._window_size = window_size self._window_size = window_size
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = _quantizable_op_list
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
...@@ -138,6 +157,16 @@ class QuantizationTransformPass(object): ...@@ -138,6 +157,16 @@ class QuantizationTransformPass(object):
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
def _quant_preprocess(op_node):
pool_skipped = op_node.op().has_attr("pooling_type") and \
op_node.op().attr("pooling_type") == 'avg'
user_skipped = isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
if pool_skipped or user_skipped:
op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op): def _transform_forward(graph, op):
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
...@@ -188,14 +217,28 @@ class QuantizationTransformPass(object): ...@@ -188,14 +217,28 @@ class QuantizationTransformPass(object):
if not self._is_test: if not self._is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
# Do the preproccess of quantization, such as skipping some ops
# for not being quantized.
for op in ops:
if op.name() in self._quantizable_ops or \
op.name() in self._quantizable_grad_ops:
_quant_preprocess(op)
# The process of _transform_forward and _transform_backward is needed in two for loops. # The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
if op.name() in self._quantizable_grad_ops: if op.name() in self._quantizable_grad_ops:
skipped = op.op().has_attr("skip_quant") and \
op.op().attr("skip_quant")
if skipped:
continue
_transform_backward(graph, op) _transform_backward(graph, op)
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
...@@ -571,16 +614,10 @@ class QuantizationFreezePass(object): ...@@ -571,16 +614,10 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = _quantizable_op_list
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = [ self._fake_quant_op_names = _fake_quant_op_list
'fake_quantize_abs_max', 'fake_quantize_range_abs_max', self._fake_dequant_op_names = _fake_dequant_op_list
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict()
...@@ -635,6 +672,10 @@ class QuantizationFreezePass(object): ...@@ -635,6 +672,10 @@ class QuantizationFreezePass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
continue
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node) self._insert_post_channel_dequant_op(graph, op_node)
else: else:
...@@ -727,6 +768,13 @@ class QuantizationFreezePass(object): ...@@ -727,6 +768,13 @@ class QuantizationFreezePass(object):
def _insert_post_dequant_op(self, graph, op_node): def _insert_post_dequant_op(self, graph, op_node):
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
if len(op_node.input_arg_names()) >= 2 and len(persistable_vars) == 0:
raise ValueError("The op %s has more than one inputs "
"and all of them are not persistable. "
"Now, it is not supported!" % (op_node.name()))
max_range = 1
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name not in op_node.input_arg_names(): if name not in op_node.input_arg_names():
...@@ -739,13 +787,12 @@ class QuantizationFreezePass(object): ...@@ -739,13 +787,12 @@ class QuantizationFreezePass(object):
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name] scale_v = self._var_scale_map[original_var_name]
if original_var_name in persistable_vars: if original_var_name in persistable_vars:
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
assert self._is_float( assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % ( scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name) original_var_name)
max_range = param_range * act_range / scale_v max_range *= param_range / scale_v
else: else:
max_range *= act_range
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._var_scale_map[original_var_name]
...@@ -850,7 +897,7 @@ class ConvertToInt8Pass(object): ...@@ -850,7 +897,7 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.' 'The place cannot be set None.'
self._scope = scope self._scope = scope
self._place = place self._place = place
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = _quantizable_op_list
def apply(self, graph): def apply(self, graph):
""" """
...@@ -866,6 +913,10 @@ class ConvertToInt8Pass(object): ...@@ -866,6 +913,10 @@ class ConvertToInt8Pass(object):
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._quantizable_ops: if op_name in self._quantizable_ops:
skipped = op_node.op().has_attr("skip_quant") and \
op_node.op().attr("skip_quant")
if skipped:
continue
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in persistable_vars: if name in persistable_vars:
...@@ -924,14 +975,8 @@ class TransformForMobilePass(object): ...@@ -924,14 +975,8 @@ class TransformForMobilePass(object):
""" """
def __init__(self): def __init__(self):
self._fake_quant_op_names = [ self._fake_quant_op_names = _fake_quant_op_list
'fake_quantize_abs_max', 'fake_quantize_range_abs_max', self._fake_dequant_op_names = _fake_dequant_op_list
'fake_quantize_moving_average_abs_max',
'fake_channel_wise_quantize_abs_max'
]
self._fake_dequant_op_names = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
def apply(self, graph): def apply(self, graph):
""" """
...@@ -980,12 +1025,7 @@ class ScaleForTrainingPass(object): ...@@ -980,12 +1025,7 @@ class ScaleForTrainingPass(object):
self._place = place self._place = place
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._is_test = None self._is_test = None
self._teller_set = [ self._teller_set = _out_scale_op_list
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
"conv2d_transpose", "leaky_relu"
]
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1087,12 +1127,7 @@ class ScaleForInferencePass(object): ...@@ -1087,12 +1127,7 @@ class ScaleForInferencePass(object):
scope(fluid.Scope): The scope is used to initialize these new parameters. scope(fluid.Scope): The scope is used to initialize these new parameters.
""" """
self._scope = scope self._scope = scope
self._teller_set = [ self._teller_set = _out_scale_op_list
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid",
"depthwise_conv2d", "batch_norm", "concat", "tanh", "pad",
"elementwise_add", "elementwise_mul", "dropout", "split", "prelu",
"conv2d_transpose", "leaky_relu"
]
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1135,7 +1170,7 @@ class AddQuantDequantPass(object): ...@@ -1135,7 +1170,7 @@ class AddQuantDequantPass(object):
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = None
self._target_ops = ["elementwise_add", "pool2d"] self._target_ops = ["elementwise_add"]
def apply(self, graph): def apply(self, graph):
""" """
......
...@@ -123,7 +123,7 @@ class TestGraph(unittest.TestCase): ...@@ -123,7 +123,7 @@ class TestGraph(unittest.TestCase):
for op in backup_graph.all_op_nodes(): for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op) backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes) backup_graph.draw('./origin', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
......
...@@ -72,13 +72,14 @@ def residual_block(num): ...@@ -72,13 +72,14 @@ def residual_block(num):
return loss return loss
def conv_net(img, label): def conv_net(img, label, quant_skip_pattern):
conv_pool_1 = fluid.nets.simple_img_conv_pool( conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img, input=img,
filter_size=5, filter_size=5,
num_filters=20, num_filters=20,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
pool_type='max',
act="relu") act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool( conv_pool_2 = fluid.nets.simple_img_conv_pool(
...@@ -87,8 +88,11 @@ def conv_net(img, label): ...@@ -87,8 +88,11 @@ def conv_net(img, label):
num_filters=50, num_filters=50,
pool_size=2, pool_size=2,
pool_stride=2, pool_stride=2,
pool_type='avg',
act="relu") act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') hidden = fluid.layers.fc(input=conv_pool_2, size=100, act='relu')
with fluid.name_scope(quant_skip_pattern):
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label) loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
return avg_loss return avg_loss
...@@ -107,7 +111,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -107,7 +111,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
'mul_grad': ['X', 'Y'] 'mul_grad': ['X', 'Y']
} }
def check_program(self, transform_pass, program): def check_program(self, program):
quantized_ops = set() quantized_ops = set()
for block in program.blocks: for block in program.blocks:
for op in block.ops: for op in block.ops:
...@@ -127,7 +131,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -127,7 +131,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
arg_name.endswith('.quantized.dequantized')) arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops) self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, activation_quant_type, for_ci=False): def linear_fc_quant(self, activation_quant_type, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -135,7 +139,6 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -135,7 +139,6 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
...@@ -150,7 +153,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -150,7 +153,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_fc_' + activation_quant_type, graph.draw('.', 'quantize_fc_' + activation_quant_type,
marked_nodes) marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci: if not for_ci:
val_marked_nodes = set() val_marked_nodes = set()
...@@ -169,7 +172,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -169,7 +172,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
def test_linear_fc_quant_moving_average_abs_max(self): def test_linear_fc_quant_moving_average_abs_max(self):
self.linear_fc_quant('moving_average_abs_max', for_ci=True) self.linear_fc_quant('moving_average_abs_max', for_ci=True)
def residual_block_quant(self, activation_quant_type, for_ci=False): def residual_block_quant(self, activation_quant_type, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -177,7 +180,6 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -177,7 +180,6 @@ class TestQuantizationTransformPass(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), scope=fluid.global_scope(),
...@@ -192,7 +194,7 @@ class TestQuantizationTransformPass(unittest.TestCase): ...@@ -192,7 +194,7 @@ class TestQuantizationTransformPass(unittest.TestCase):
graph.draw('.', 'quantize_residual_' + activation_quant_type, graph.draw('.', 'quantize_residual_' + activation_quant_type,
marked_nodes) marked_nodes)
program = graph.to_program() program = graph.to_program()
self.check_program(transform_pass, program) self.check_program(program)
val_graph = IrGraph(core.Graph(program.desc), for_test=False) val_graph = IrGraph(core.Graph(program.desc), for_test=False)
if not for_ci: if not for_ci:
val_marked_nodes = set() val_marked_nodes = set()
...@@ -218,7 +220,8 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -218,7 +220,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
seed, seed,
activation_quant_type, activation_quant_type,
weight_quant_type='abs_max', weight_quant_type='abs_max',
for_ci=False): for_ci=True,
quant_skip_pattern='skip_quant'):
def build_program(main, startup, is_test): def build_program(main, startup, is_test):
main.random_seed = seed main.random_seed = seed
startup.random_seed = seed startup.random_seed = seed
...@@ -228,7 +231,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -228,7 +231,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
name='image', shape=[1, 28, 28], dtype='float32') name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data( label = fluid.layers.data(
name='label', shape=[1], dtype='int64') name='label', shape=[1], dtype='int64')
loss = conv_net(img, label) loss = conv_net(img, label, quant_skip_pattern)
if not is_test: if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
...@@ -255,7 +258,8 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -255,7 +258,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
scope=scope, scope=scope,
place=place, place=place,
activation_quantize_type=activation_quant_type, activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type) weight_quantize_type=weight_quant_type,
skip_pattern=quant_skip_pattern)
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
transform_pass.apply(test_graph) transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_' dev_name = '_gpu_' if use_cuda else '_cpu_'
......
...@@ -2728,6 +2728,8 @@ class IrGraph(object): ...@@ -2728,6 +2728,8 @@ class IrGraph(object):
if self.graph.has('__graphviz__marked_node__'): if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__') self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes) self.graph.set('__graphviz__marked_node__', marked_nodes)
if not os.path.exists(save_path):
os.makedirs(save_path)
viz_dot_path = os.path.join(save_path, name) + '.dot' viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass') viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set('graph_viz_path', viz_dot_path) viz_pass.set('graph_viz_path', viz_dot_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册