未验证 提交 a77e2d68 编写于 作者: H handiz 提交者: GitHub

add new function ptq first then initialize qat scale with ptq scale (#1394)

* add new function ptq first then initialize qat scale with ptq scale
上级 58797143
......@@ -53,7 +53,6 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op):
skip_quant_tensor_list.extend(op._op.input('X'))
matmul_num += 1
if matmul_num == 2:
return True
......@@ -88,6 +87,8 @@ def get_patterns(program, only_final_node=True):
block_num = 0
model_type = None
for op in graph.ops():
if len(op.all_inputs()) == 0 or op.all_inputs()[0] is None:
continue
belonged_teacher = False
for inp in op.all_inputs():
if 'teacher' in inp._var.name:
......@@ -106,8 +107,9 @@ def get_patterns(program, only_final_node=True):
out_var_name = op.all_outputs()[0]._var.name
shortcut_start_op = shortcut_start_op[0]
next_op = graph.next_ops(op)
pattern_ops, pattern_ops_type = traversal_ops(
shortcut_start_op, graph, op.idx())
shortcut_start_op, graph, next_op[0].idx())
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
))
......
......@@ -51,6 +51,7 @@ def get_weight(op, return_name=True):
return inp.name()
else:
return inp
return None
def is_dynamic_weight_op(op):
......
......@@ -357,7 +357,8 @@ class GraphWrapper(object):
ops = []
for p in self.ops():
for out_var in op.all_outputs():
if out_var in p.all_inputs():
if len(p.all_inputs()) > 0 and p.all_inputs()[
0] is not None and out_var in p.all_inputs():
if p.idx() != op.idx():
ops.append(p)
return sorted(ops)
......
......@@ -33,12 +33,16 @@ from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from ..common import get_logger
from ..common.patterns import get_patterns
from ..common.patterns_common import is_dynamic_weight_op, get_weight
from ..core.graph_wrapper import GraphWrapper
_logger = get_logger(__name__, level=logging.INFO)
try:
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import QuantWeightPass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram
except:
_logger.warning(
"Some functions fail to import, please update PaddlePaddle version to 2.3+"
......@@ -98,6 +102,10 @@ _quant_config_default = {
'is_full_quantize': False,
# if True, use onnx format to quant.
'onnx_format': False,
# quant post to get initial scale for quant_aware
'quant_post_first': False,
# whether scale can be train
'scale_trainable': True
}
......@@ -254,7 +262,12 @@ def quant_aware(program,
optimizer_func=None,
executor=None,
return_program=False,
draw_graph=False):
calib_config={},
draw_graph=False,
return_scale_dict=False,
scale_dict=None,
model_type=None,
pattern_ops=None):
"""Add quantization and dequantization operators to "program"
for quantization training or testing.
......@@ -271,7 +284,7 @@ def quant_aware(program,
Default: ``None``.
for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``.
Otherwise, set to ``False``.Default: False
weight_quantize_func(function): Function that defines how to quantize weight. Using this
weight_quantize_func(function): Function that defines how to quantize weight. Using this
can quickly test if user's quantization method works or not. In this function, user should
both define quantization function and dequantization function, that is, the function's input
is non-quantized weight and function returns dequantized weight. If None, will use
......@@ -301,6 +314,12 @@ def quant_aware(program,
Default is False.
draw_graph(bool): whether to draw graph when quantization is initialized. In order to prevent cycle,
the ERNIE model needs to be set to True. Default is False.
return_scale_dict(bool): If user want to return scale dict, model_type and pattern_ops, this argument should be set True.
Default is False.
scale_dict(dict): Use scale dict to initialize scales in program. Default is None.
model_type(str): Model type can be 'transformer' or 'non-transformer'. If model type is transformer, patterns will be analyzed.
Default is None.
pattern_ops(dict): Pattern_ops contain pattern name and corresponding ops. Default is None.
Returns:
paddle.static.CompiledProgram | paddle.static.Program: Program with quantization and dequantization ``operators``
"""
......@@ -313,52 +332,164 @@ def quant_aware(program,
config = _parse_configs(config)
_logger.info("quant_aware config {}".format(config))
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = []
quant_dequant_ops = []
for op_type in config['quantize_op_types']:
if op_type in TRANSFORM_PASS_OP_TYPES:
transform_pass_ops.append(op_type)
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0:
trannsform_func = 'QuantizationTransformPassV2' if config[
'onnx_format'] else 'QuantizationTransformPass'
transform_pass = eval(trannsform_func)(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern'],
weight_quantize_func=weight_quantize_func,
act_quantize_func=act_quantize_func,
weight_preprocess_func=weight_preprocess_func,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor)
transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
qdq_func = 'AddQuantDequantPassV2' if config[
'onnx_format'] else 'AddQuantDequantPass'
quant_dequant_pass = eval(qdq_func)(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
def find_next_ops(program, var_name):
"""
Find all followed ops for the input variable.
"""
block = program.global_block()
res_ops = []
for op in block.ops:
if var_name in op.input_arg_names:
res_ops.append(op)
return res_ops
def find_pre_ops(program, var_name):
"""
Find all followed ops for the input variable.
"""
block = program.global_block()
res_ops = []
for op in block.ops:
if var_name in op.output_arg_names:
res_ops.append(op)
return res_ops
def _is_skip_layernorm(program, op):
if get_weight(op) is not None:
return False
output_names = op._op.output_arg_names
for output_name in output_names:
for next_op in find_next_ops(program, output_name):
if next_op.type == 'layer_norm':
return True
return False
skip_tensor_list = []
same_scale_tensor_list = []
if model_type == 'transformer' and pattern_ops is None:
pattern_ops, _, model_type = get_patterns(program)
if model_type != 'transformer':
_logger.info(
'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.'
)
if model_type == 'transformer':
not_skip_quant_list = []
for part_name, ops in pattern_ops.items():
if 'MHA' in part_name:
qkv_weight_tensor = []
qkv_output_tensor = []
### get qkv
output_names = ops[0]._op.output_arg_names
for output_name in output_names:
for next_op in find_next_ops(program, output_name):
if next_op.type in ['mul', 'matmul_v2']:
qkv_weight_tensor.append(next_op.input('Y')[0])
same_scale_tensor_list.append(qkv_weight_tensor)
for op in ops:
if op._op.type in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)):
input_names = op._op.input_arg_names
for input_name in input_names:
pre_op = find_pre_ops(program, input_name)[0]
if pre_op.type == 'softmax' or pre_op.type == 'dropout':
continue
elif pre_op.type == 'scale':
qkv_output_tensor.append(
input_name + '#/#{}'.format(
pre_op.attr('scale')))
else:
qkv_output_tensor.append(input_name)
elif op._op.type == 'elementwise_add':
if _is_skip_layernorm(program, op):
not_skip_quant_list.append(op)
same_scale_tensor_list.append(qkv_output_tensor)
elif 'FFN' in part_name:
for op in ops:
if op._op.type == 'elementwise_add':
if _is_skip_layernorm(program, op):
not_skip_quant_list.append(op)
tmp_graph = GraphWrapper(program)
for op in tmp_graph.ops():
### find elementwise_add in skip layernorm
if op._op.type == 'elementwise_add' and op not in not_skip_quant_list:
op._op._set_attr("op_namescope", "skip_quant")
is_test = True if for_test else not config['scale_trainable']
if config['quant_post_first'] and for_test:
if 'quantizable_op_type' not in calib_config:
calib_config['quantizable_op_type'] = config['quantize_op_types']
exe = paddle.static.Executor() if executor is None else executor
post_training_quantization = PostTrainingQuantizationProgram(
exe,
program,
freeze_model=False,
skip_tensor_list=skip_tensor_list,
same_scale_tensor_list=same_scale_tensor_list,
scale_trainable=config['scale_trainable'],
batch_nums=10,
scale_dict=scale_dict,
return_graph=True,
**calib_config)
main_graph = post_training_quantization.quantize()
scale_dict = post_training_quantization._scale_dict
else:
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = []
quant_dequant_ops = []
for op_type in config['quantize_op_types']:
if op_type in TRANSFORM_PASS_OP_TYPES:
transform_pass_ops.append(op_type)
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0:
trannsform_func = 'QuantizationTransformPassV2' if config[
'onnx_format'] else 'QuantizationTransformPass'
transform_pass = eval(trannsform_func)(
scope=scope,
place=place,
weight_bits=config['weight_bits'],
activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern'],
weight_quantize_func=weight_quantize_func,
act_quantize_func=act_quantize_func,
weight_preprocess_func=weight_preprocess_func,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
is_test=is_test)
transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
qdq_func = 'AddQuantDequantPassV2' if config[
'onnx_format'] else 'AddQuantDequantPass'
quant_dequant_pass = eval(qdq_func)(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops,
is_test=is_test,
scale_dict=scale_dict)
quant_dequant_pass.apply(main_graph)
out_scale_training_pass = OutScaleForTrainingPass(
scope=scope, place=place, moving_rate=config['moving_rate'])
scope=scope,
place=place,
moving_rate=config['moving_rate'],
is_test=is_test,
scale_dict=scale_dict)
out_scale_training_pass.apply(main_graph)
if (weight_preprocess_func is not None or
......@@ -378,7 +509,11 @@ def quant_aware(program,
quant_program = main_graph.to_program()
else:
quant_program = paddle.static.CompiledProgram(main_graph.graph)
return quant_program
if return_scale_dict:
return quant_program, scale_dict, model_type, pattern_ops
else:
return quant_program
def quant_post_static(
......
import sys
import random
sys.path.append("../")
import unittest
import paddle
import paddle.nn as nn
from paddle.io import Dataset
from paddleslim.quant import quant_aware, convert
from paddle.nn import TransformerEncoderLayer, TransformerEncoder, Linear
from paddleslim.quant import quant_aware, convert
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
import numpy as np
np.random.seed(0)
random.seed(0)
paddle.seed(0)
class RandomDataset(Dataset):
def __init__(self, num_samples):
self.num_samples = num_samples
def __getitem__(self, idx):
enc_input = np.random.random([4, 128]).astype('float32')
attn_mask = np.random.random([2, 4, 4]).astype('float32')
label = np.random.randint(0, 2, (1, )).astype('int64')
return enc_input, attn_mask, label
def __len__(self):
return self.num_samples
class TestQuantPostQuantAwareCase1(StaticCase):
def test_accuracy(self):
def simple_transformer(enc_input, attn_mask):
encoder_layer = nn.TransformerEncoderLayer(128, 2, 512)
encoder = TransformerEncoder(encoder_layer, 2)
encoder_output = encoder(enc_input, attn_mask)
first_token = encoder_output[:, 0]
bias = paddle.full(shape=[1, 128], fill_value=1e-6)
linear = Linear(128, 2)
logits = linear(first_token + bias)
return logits
enc_input = paddle.static.data(
name='enc_input', shape=[None, 4, 128], dtype='float32')
attn_mask = paddle.static.data(
name='attn_mask', shape=[None, 2, 4, 4], dtype='float32')
label = paddle.static.data(name='label', shape=[None, 1], dtype='int64')
out = simple_transformer(enc_input, attn_mask)
cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label)
avg_cost = paddle.mean(x=cost)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
optimizer = paddle.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
weight_decay=paddle.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = paddle.static.default_main_program()
val_prog = main_prog.clone(for_test=True)
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(paddle.static.default_startup_program())
train_dataset = RandomDataset(100)
test_dataset = RandomDataset(50)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[enc_input, attn_mask, label],
drop_last=True,
return_list=False,
batch_size=10)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
feed_list=[enc_input, attn_mask, label],
batch_size=10,
return_list=False)
def train(program):
iter = 0
for data in train_loader():
cost, top1 = exe.run(program,
feed=data,
fetch_list=[avg_cost, acc_top1])
iter += 1
if iter % 100 == 0:
print('train iter={}, avg loss {}, acc_top1 {}'.format(
iter, cost, top1))
def test(program):
iter = 0
result = [[], []]
for data in valid_loader():
cost, top1 = exe.run(program,
feed=data,
fetch_list=[avg_cost, acc_top1])
iter += 1
if iter % 100 == 0:
print('eval iter={}, avg loss {}, acc_top1 {}'.format(
iter, cost, top1))
result[0].append(cost)
result[1].append(top1)
print(' avg loss {}, acc_top1 {}'.format(
np.mean(result[0]), np.mean(result[1])))
return np.mean(result[1])
train(main_prog)
top1_1 = test(main_prog)
config = {
'weight_quantize_type': 'channel_wise_abs_max',
'activation_quantize_type': 'moving_average_abs_max',
'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'elementwise_add'],
'quant_post_first': True,
'scale_trainable': True
}
calib_config = {
'data_loader': valid_loader,
'algo': 'abs_max',
'feed_list': ['enc_input', 'attn_mask', 'label'],
'fetch_list': [avg_cost, acc_top1]
}
quant_eval_prog, scale_dict, _, _ = quant_aware(
val_prog,
place,
config,
for_test=True,
calib_config=calib_config,
model_type='transformer',
return_scale_dict=True)
quant_train_prog = quant_aware(
main_prog,
place,
config,
for_test=False,
calib_config=calib_config,
return_program=True,
scale_dict=scale_dict,
model_type='transformer')
train(quant_train_prog)
quant_eval_prog, int8_prog = convert(
quant_eval_prog, place, config, save_int8=True)
top1_2 = test(quant_eval_prog)
# values before quantization and after quantization should be close
print("before quantization: top1: {}".format(top1_1))
print("after quantization: top1: {}".format(top1_2))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册