未验证 提交 4032c2e4 编写于 作者: C cc 提交者: GitHub

Refine QuantizeTranspilerV2 to support distributed training (#33781)

* refine the old code

* support moving_average_abs_max and per_channel_abs_max

* Add moving_average_abs_max_scale op

* Convert the test program
上级 d5bdfaf1
...@@ -17,6 +17,7 @@ import logging ...@@ -17,6 +17,7 @@ import logging
import numpy as np import numpy as np
from .... import core from .... import core
from ....framework import Program, Operator, Variable, program_guard from ....framework import Program, Operator, Variable, program_guard
from ....executor import global_scope
from .... import unique_name from .... import unique_name
from ....layer_helper import LayerHelper from ....layer_helper import LayerHelper
from ....param_attr import ParamAttr from ....param_attr import ParamAttr
...@@ -27,26 +28,49 @@ _logger = get_logger( ...@@ -27,26 +28,49 @@ _logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
def find_next_ops(block, var_name):
"""
Find all followed ops for the input variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops
def load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
class QuantizeTranspilerV2(object): class QuantizeTranspilerV2(object):
def __init__(self, def __init__(self,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
activation_quantize_type='abs_max', activation_quantize_type='moving_average_abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'], quantizable_op_type=[
'conv2d',
'depthwise_conv2d',
'mul',
],
skip_pattern=['skip_quant']): skip_pattern=['skip_quant']):
""" """
Add quant_dequant op before the quantized op to quantize the fluid Program. Apply fake quant for the quantized ops.
It is a patch for distributed quantization, we will support others module for
distributed quantization.
Args: Args:
weight_bits(int): the bit of quantized weight. weight_bits(int): the bit of quantized weight.
activation_bits(int): the bit of quantized activation. activation_bits(int): the bit of quantized activation.
weight_quantize_type(str): the quantization type for weight. weight_quantize_type(str): the quantization type for weight.
Only support to be 'abs_max' for now. Only support to be 'abs_max' and 'channel_wise_abs_max'.
activation_quantize_type(str): the quantization type for activation. activation_quantize_type(str): the quantization type for activation.
Only support to be 'abs_max' for now. Only support to be 'abs_max' and 'moving_average_abs_max'.
quantizable_op_type(str): set the op type for quantization. quantizable_op_type(str): set the op type for quantization.
skip_pattern(str|list): The user-defined quantization skip pattern, which skip_pattern(str|list): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is will be presented in the name scope of an op. When the skip pattern is
...@@ -55,28 +79,37 @@ class QuantizeTranspilerV2(object): ...@@ -55,28 +79,37 @@ class QuantizeTranspilerV2(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
assert activation_quantize_type == "abs_max", \ assert activation_quantize_type in \
"activation_quantize_type should be abs_max for now." ["abs_max", "moving_average_abs_max"], \
assert weight_quantize_type == "abs_max", \ "activation_quantize_type should be abs_max " \
"weight_quantize_type should be abs_max for now." "or moving_average_abs_max for now."
assert weight_quantize_type in ["abs_max", "channel_wise_abs_max"], \
"weight_quantize_type should be abs_max or channel_wise_abs_max."
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
for op_type in quantizable_op_type:
assert op_type in ['conv2d', 'depthwise_conv2d', 'mul'], \
"Quantize op should be ['conv2d', 'depthwise_conv2d', 'mul']"
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops '%s_grad' % (op) for op in self._quantizable_ops
] ]
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
self.helper = LayerHelper(self.__class__.__name__) self._helper = LayerHelper(self.__class__.__name__)
def apply(self, program, startup_program): self._moving_rate = 0.9
self._out_ch_axis1_ops = ['conv2d_transpose', 'mul', 'matmul']
def apply(self, program, startup_program, is_test=False):
""" """
Apply quantization to fluid Program. Apply quantization to fluid Program.
Args: Args:
program(Program): the train or test program to be quantized. program(Program): the train or test program to be quantized.
startup_program(Program): the corresponding startup_program. startup_program(Program): the corresponding startup_program.
is_test(bool): Whethe the program is used for test.
Returns: Returns:
None None
""" """
...@@ -85,7 +118,7 @@ class QuantizeTranspilerV2(object): ...@@ -85,7 +118,7 @@ class QuantizeTranspilerV2(object):
assert isinstance(startup_program, Program), \ assert isinstance(startup_program, Program), \
"startup_program must be the instance of Program" "startup_program must be the instance of Program"
quant_dequant_vars = [ var_rename_map = [
collections.OrderedDict() for _ in range(len(program.blocks)) collections.OrderedDict() for _ in range(len(program.blocks))
] ]
with program_guard(program, startup_program): with program_guard(program, startup_program):
...@@ -94,13 +127,104 @@ class QuantizeTranspilerV2(object): ...@@ -94,13 +127,104 @@ class QuantizeTranspilerV2(object):
for op in ops: for op in ops:
if op.type in self._quantizable_ops and \ if op.type in self._quantizable_ops and \
(not self._is_skip_quant(op)): (not self._is_skip_quant(op)):
self._transform_forward(block, op, quant_dequant_vars) self._transform_forward(block, op, var_rename_map,
is_test)
for block in program.blocks: for block in program.blocks:
ops = list(block.ops) ops = list(block.ops)
for op in ops: for op in ops:
if op.type in self._quantizable_grad_ops and \ if op.type in self._quantizable_grad_ops and \
(not self._is_skip_quant(op)): (not self._is_skip_quant(op)):
self._transform_backward(block, op, quant_dequant_vars) self._transform_backward(block, op, var_rename_map)
def convert(self, test_program, scope=None):
"""
Convert the test program.
Get the out scale from the moving_average_abs_max_scale op and save the
out scale into the quantized op.
Args:
test_program(Program): the test program to be converted.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
"""
scope = global_scope() if scope == None else scope
for block in test_program.blocks:
for op in block.ops:
if op.has_attr("quantization_type") \
and op.attr("quantization_type") == "qat_with_weight":
# quant op -> var1 -> fake op -> var2
assert len(op.output_arg_names) == 1
var1_name = op.output_arg_names[0]
fake_ops = find_next_ops(block, var1_name)
assert len(fake_ops) == 1
fake_op = fake_ops[0]
assert fake_op.type == "moving_average_abs_max_scale"
out_scale_name = fake_op.output("OutScale")
out_threshold = load_variable_data(scope, out_scale_name[0])
op._set_attr("out_threshold", float(out_threshold))
var2_name = fake_op.output("Out")[0]
op._rename_output(var1_name, var2_name)
fake_op._rename_output(var2_name, var1_name)
def _transform_forward(self, block, op, var_rename_map, is_test):
"""
Insert fake quant op before the target ops.
"""
op._set_attr("quantization_type", "qat_with_weight")
# insert fake quant op before the quantized op
for in_name in op.input_arg_names:
block_id = block.idx
idx = block.ops.index(op)
if in_name in var_rename_map[block_id]:
new_in_name = var_rename_map[block_id][in_name]
else:
in_var = block.var(in_name)
if in_var.dtype != core.VarDesc.VarType.FP32:
continue
quant_bits = self._weight_bits if in_var.persistable \
else self._activation_bits
quant_type = self._weight_quantize_type if in_var.persistable \
else self._activation_quantize_type
if quant_type == "abs_max":
new_var = self._insert_abs_max_fq_op(block, idx, in_var,
quant_bits)
elif quant_type == "moving_average_abs_max":
new_var = self._insert_ma_abs_max_fq_op(block, idx, in_var,
quant_bits, is_test)
elif quant_type == "channel_wise_abs_max":
ch_axis = 1 if op.type in self._out_ch_axis1_ops else 0
new_var = self._insert_pc_abs_max_fq_op(block, idx, in_var,
quant_bits, ch_axis)
else:
_logger.error("Don't support the quant_type: %s" %
quant_type)
continue
new_in_name = new_var.name
var_rename_map[block_id][in_name] = new_in_name
op._rename_input(in_name, new_in_name)
# insert out scale op followed the quantized op
for out_name in op.output_arg_names:
next_ops = find_next_ops(block, out_name)
idx = block.ops.index(op)
out_var = block.var(out_name)
new_out_var = self._insert_ma_abs_max_scale_op(
block, idx + 1, out_var, is_test, True)
for next_op in next_ops:
if "_grad" not in next_op.type:
next_op._rename_input(out_name, new_out_var.name)
def _is_skip_quant(self, op): def _is_skip_quant(self, op):
""" """
...@@ -117,49 +241,35 @@ class QuantizeTranspilerV2(object): ...@@ -117,49 +241,35 @@ class QuantizeTranspilerV2(object):
self._skip_pattern) != -1 self._skip_pattern) != -1
return user_skipped return user_skipped
def _transform_forward(self, block, op, quant_dequant_vars): def _transform_backward(self, block, op, var_rename_map):
op._set_attr("quantization_type", "qat_with_weight") """
idx = block.ops.index(op) Update the backword of the target ops.
block_id = block.idx Note: for the grad ops, only rename the input, skip rename the output.
for in_name in op.input_arg_names: """
if in_name in quant_dequant_vars[block_id]:
quant_dequant_var = quant_dequant_vars[block_id][in_name]
else:
in_var = block.var(in_name)
quant_bits = self._weight_bits if in_var.persistable \
else self._activation_bits
quant_type = self._weight_quantize_type if in_var.persistable \
else self._activation_quantize_type
if quant_type == "abs_max":
quant_dequant_var = self._insert_quant_dequant_abs_max_op(
block, idx, in_var, quant_bits)
else:
_logger.error("Quant_type only supported to be abs_max")
quant_dequant_vars[block_id][in_name] = quant_dequant_var
op._rename_input(in_name, quant_dequant_var.name)
def _transform_backward(self, block, op, quant_dequant_vars):
block_id = block.idx block_id = block.idx
no_dequanted_input_vars = True no_dequanted_input_vars = True
for name in op.input_arg_names: for name in op.input_arg_names:
if name in quant_dequant_vars[block_id]: if name in var_rename_map[block_id]:
dequant_var = quant_dequant_vars[block_id][name] new_var_name = var_rename_map[block_id][name]
op._rename_input(name, dequant_var.name) op._rename_input(name, new_var_name)
no_dequanted_input_vars = False no_dequanted_input_vars = False
if no_dequanted_input_vars: if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." % raise ValueError("There is no dequanted inputs for op %s." %
(op.type)) (op.type))
def _insert_quant_dequant_abs_max_op(self, block, idx, in_var, quant_bits): def _insert_abs_max_fq_op(self, block, idx, in_var, quant_bits):
"""
Inset abs max fake quant op.
"""
quant_dequant_var = block.create_var( quant_dequant_var = block.create_var(
type=in_var.type, type=in_var.type,
name="{}.quant_dequant".format(in_var.name), name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape, shape=in_var.shape,
dtype=in_var.dtype) dtype=in_var.dtype)
scale_var = self.helper.create_parameter( scale_var = self._helper.create_parameter(
attr=ParamAttr( attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name), name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.001), initializer=Constant(0.),
trainable=False), trainable=False),
shape=[1], shape=[1],
dtype=in_var.dtype) dtype=in_var.dtype)
...@@ -175,3 +285,157 @@ class QuantizeTranspilerV2(object): ...@@ -175,3 +285,157 @@ class QuantizeTranspilerV2(object):
inputs=inputs, inputs=inputs,
outputs=outputs) outputs=outputs)
return quant_dequant_var return quant_dequant_var
def _insert_ma_abs_max_fq_op(self, block, idx, in_var, quant_bits, is_test):
"""
Insert moving average abs max fake quant op.
"""
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype)
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.),
trainable=False),
shape=[1],
dtype=in_var.dtype)
scale_var.stop_gradient = True
if not is_test:
state_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.state".format(in_var.name),
initializer=Constant(0),
trainable=False),
shape=[1],
dtype=in_var.dtype)
state_var.stop_gradient = True
accum_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.accum".format(in_var.name),
initializer=Constant(0),
trainable=False),
shape=[1],
dtype=in_var.dtype)
accum_var.stop_gradient = True
attrs = {
'moving_rate': self._moving_rate,
'bit_length': quant_bits,
'is_test': is_test
}
inputs = {'X': in_var, 'InScale': scale_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
if not is_test:
inputs['InState'] = state_var
inputs['InAccum'] = accum_var
outputs['OutState'] = state_var
outputs['OutAccum'] = accum_var
block._insert_op(
idx,
type='fake_quantize_dequantize_moving_average_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs)
return quant_dequant_var
def _insert_pc_abs_max_fq_op(self, block, idx, in_var, quant_bits, ch_axis):
"""
Insert per channel abs max fake quant op.
"""
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype)
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.),
trainable=False),
shape=[in_var.shape[ch_axis]],
dtype=in_var.dtype)
scale_var.stop_gradient = True
inputs = {'X': in_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
attrs = {'bit_length': quant_bits, 'quant_axis': ch_axis}
block._insert_op(
idx,
type='fake_channel_wise_quantize_dequantize_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs)
return quant_dequant_var
def _insert_ma_abs_max_scale_op(self,
block,
idx,
in_var,
is_test,
has_out_var=False):
"""
Insert moving average abs max scale op.
"""
scale_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.scale".format(in_var.name),
initializer=Constant(0.),
trainable=False),
shape=[1],
dtype=in_var.dtype)
scale_var.stop_gradient = True
attrs = {'moving_rate': self._moving_rate, 'is_test': is_test}
inputs = {'X': in_var}
outputs = {'OutScale': scale_var}
if not is_test:
state_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.state".format(in_var.name),
initializer=Constant(0),
trainable=False),
shape=[1],
dtype=in_var.dtype)
state_var.stop_gradient = True
accum_var = self._helper.create_parameter(
attr=ParamAttr(
name="{}.outscale.accum".format(in_var.name),
initializer=Constant(0),
trainable=False),
shape=[1],
dtype=in_var.dtype)
accum_var.stop_gradient = True
inputs['InState'] = state_var
inputs['InAccum'] = accum_var
outputs['OutState'] = state_var
outputs['OutAccum'] = accum_var
if has_out_var:
out_var = block.create_var(
type=in_var.type,
name="{}.tmp".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype)
outputs['Out'] = out_var
block._insert_op(
idx,
type='moving_average_abs_max_scale',
attrs=attrs,
inputs=inputs,
outputs=outputs)
if has_out_var:
return out_var
...@@ -79,6 +79,7 @@ class TestQuantizeProgramPass(unittest.TestCase): ...@@ -79,6 +79,7 @@ class TestQuantizeProgramPass(unittest.TestCase):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
# 1 Define program
train_program = fluid.Program() train_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
test_program = fluid.Program() test_program = fluid.Program()
...@@ -93,15 +94,14 @@ class TestQuantizeProgramPass(unittest.TestCase): ...@@ -93,15 +94,14 @@ class TestQuantizeProgramPass(unittest.TestCase):
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
test_graph.draw('.', 'test_program_1') test_graph.draw('.', 'test_program_1')
# 2 Apply quantization
qt = QuantizeTranspilerV2( qt = QuantizeTranspilerV2(
activation_quantize_type=activation_quant_type, activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type, weight_quantize_type=weight_quant_type)
quantizable_op_type=[ qt.apply(train_program, startup_program, is_test=False)
'conv2d', 'depthwise_conv2d', 'mul', 'pool2d' qt.apply(test_program, startup_program, is_test=True)
])
qt.apply(train_program, startup_program)
qt.apply(test_program, startup_program)
# 3 Train
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
scope = fluid.Scope() scope = fluid.Scope()
...@@ -120,28 +120,32 @@ class TestQuantizeProgramPass(unittest.TestCase): ...@@ -120,28 +120,32 @@ class TestQuantizeProgramPass(unittest.TestCase):
build_strategy.fuse_all_reduce_ops = False build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(train_program).with_data_parallel( binary = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy) loss_name=loss.name, build_strategy=build_strategy)
iters = 2 iters = 5
batch_size = 8 batch_size = 8
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size) paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place) feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
for _ in range(iters): for idx in range(iters):
data = next(train_reader()) data = next(train_reader())
loss_v = exe.run(binary, loss_v = exe.run(binary,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss]) fetch_list=[loss])
if not for_ci: if not for_ci and idx % 20 == 0:
print('{}: {}'.format('loss', loss_v)) print('{}: {}'.format('loss', np.mean(loss_v)))
print('{}: {}'.format('loss', np.mean(loss_v)))
# 4 Convert
qt.convert(test_program, scope)
if not for_ci: if not for_ci:
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
fluid.io.save_inference_model('./infer_model', fluid.io.save_inference_model('./infer_model',
['image', 'label'], [loss], exe, ['image', 'label'], [loss], exe,
test_program) test_program)
def test_quantize_program_gpu(self): def test_gpu_1(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
self.quantize_program( self.quantize_program(
use_cuda=True, use_cuda=True,
...@@ -150,7 +154,16 @@ class TestQuantizeProgramPass(unittest.TestCase): ...@@ -150,7 +154,16 @@ class TestQuantizeProgramPass(unittest.TestCase):
weight_quant_type='abs_max', weight_quant_type='abs_max',
for_ci=True) for_ci=True)
def test_quantize_program_cpu(self): def test_gpu_2(self):
if fluid.core.is_compiled_with_cuda():
self.quantize_program(
use_cuda=True,
seed=1,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_cpu_1(self):
self.quantize_program( self.quantize_program(
use_cuda=False, use_cuda=False,
seed=2, seed=2,
...@@ -158,6 +171,14 @@ class TestQuantizeProgramPass(unittest.TestCase): ...@@ -158,6 +171,14 @@ class TestQuantizeProgramPass(unittest.TestCase):
weight_quant_type='abs_max', weight_quant_type='abs_max',
for_ci=True) for_ci=True)
def test_cpu_2(self):
self.quantize_program(
use_cuda=False,
seed=2,
activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max',
for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册