未验证 提交 50ec7275 编写于 作者: C ceci3 提交者: GitHub

add transformer prune unittest (#1362)

* add act unittest
上级 c3c6ef19
...@@ -640,6 +640,11 @@ class AutoCompression: ...@@ -640,6 +640,11 @@ class AutoCompression:
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
executor=self._exe) executor=self._exe)
if self.eval_function is None:
# If eval function is None, ptq_hpo will use emd distance to eval the quantized model, so need the dataloader without label
eval_dataloader = self.train_dataloader
else:
eval_dataloader = self.eval_dataloader
post_quant_hpo.quant_post_hpo( post_quant_hpo.quant_post_hpo(
self._exe, self._exe,
self._places, self._places,
...@@ -647,7 +652,7 @@ class AutoCompression: ...@@ -647,7 +652,7 @@ class AutoCompression:
quantize_model_path=os.path.join( quantize_model_path=os.path.join(
self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))), self.tmp_dir, 'strategy_{}'.format(str(strategy_idx + 1))),
train_dataloader=self.train_dataloader, train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader, eval_dataloader=eval_dataloader,
eval_function=self.eval_function, eval_function=self.eval_function,
model_filename=self.model_filename, model_filename=self.model_filename,
params_filename=self.params_filename, params_filename=self.params_filename,
......
...@@ -47,125 +47,127 @@ def post_quant_fake(executor, ...@@ -47,125 +47,127 @@ def post_quant_fake(executor,
set(_weight_supported_quantizable_op_type + set(_weight_supported_quantizable_op_type +
_act_supported_quantizable_op_type + _dynamic_quantize_op_type)) _act_supported_quantizable_op_type + _dynamic_quantize_op_type))
_place = executor.place _place = executor.place
_scope = paddle.static.global_scope() _scope = paddle.static.Scope()
if is_full_quantize:
_quantizable_op_type = _support_quantize_op_type with paddle.static.scope_guard(_scope):
else: if is_full_quantize:
_quantizable_op_type = quantizable_op_type _quantizable_op_type = _support_quantize_op_type
for op_type in _quantizable_op_type: else:
assert op_type in _support_quantize_op_type, \ _quantizable_op_type = quantizable_op_type
op_type + " is not supported for quantization." for op_type in _quantizable_op_type:
_program, _feed_list, _fetch_list = load_inference_model( assert op_type in _support_quantize_op_type, \
model_dir, op_type + " is not supported for quantization."
executor, _program, _feed_list, _fetch_list = load_inference_model(
model_filename=model_filename, model_dir,
params_filename=params_filename) executor,
model_filename=model_filename,
graph = IrGraph(core.Graph(_program.desc), for_test=True) params_filename=params_filename)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op graph = IrGraph(core.Graph(_program.desc), for_test=True)
major_quantizable_op_types = []
for op_type in _weight_supported_quantizable_op_type: # use QuantizationTransformPass to insert fake_quant/fake_dequantize op
if op_type in _quantizable_op_type: major_quantizable_op_types = []
major_quantizable_op_types.append(op_type) for op_type in _weight_supported_quantizable_op_type:
if onnx_format: if op_type in _quantizable_op_type:
transform_pass = QuantizationTransformPassV2( major_quantizable_op_types.append(op_type)
scope=_scope, if onnx_format:
place=_place, transform_pass = QuantizationTransformPassV2(
weight_bits=weight_bits, scope=_scope,
activation_bits=activation_bits, place=_place,
activation_quantize_type=activation_quantize_type, weight_bits=weight_bits,
weight_quantize_type=weight_quantize_type, activation_bits=activation_bits,
quantizable_op_type=major_quantizable_op_types) activation_quantize_type=activation_quantize_type,
else: weight_quantize_type=weight_quantize_type,
transform_pass = QuantizationTransformPass( quantizable_op_type=major_quantizable_op_types)
scope=_scope, else:
place=_place, transform_pass = QuantizationTransformPass(
weight_bits=weight_bits, scope=_scope,
activation_bits=activation_bits, place=_place,
activation_quantize_type=activation_quantize_type, weight_bits=weight_bits,
weight_quantize_type=weight_quantize_type, activation_bits=activation_bits,
quantizable_op_type=major_quantizable_op_types) activation_quantize_type=activation_quantize_type,
weight_quantize_type=weight_quantize_type,
for sub_graph in graph.all_sub_graphs(): quantizable_op_type=major_quantizable_op_types)
# Insert fake_quant/fake_dequantize op must in test graph, so
# set per graph's _for_test is True.
sub_graph._for_test = True
transform_pass.apply(sub_graph)
# use AddQuantDequantPass to insert fake_quant_dequant op
minor_quantizable_op_types = []
for op_type in _act_supported_quantizable_op_type:
if op_type in _quantizable_op_type:
minor_quantizable_op_types.append(op_type)
if onnx_format:
add_quant_dequant_pass = AddQuantDequantPassV2(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=is_full_quantize)
else:
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True
add_quant_dequant_pass.apply(sub_graph)
# apply QuantizationFreezePass, and obtain the final quant model
if onnx_format:
quant_weight_pass = QuantWeightPass(_scope, _place)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
# Insert fake_quant/fake_dequantize op must in test graph, so
# set per graph's _for_test is True.
sub_graph._for_test = True sub_graph._for_test = True
quant_weight_pass.apply(sub_graph) transform_pass.apply(sub_graph)
else:
freeze_pass = QuantizationFreezePass( # use AddQuantDequantPass to insert fake_quant_dequant op
scope=_scope, minor_quantizable_op_types = []
place=_place, for op_type in _act_supported_quantizable_op_type:
weight_bits=weight_bits, if op_type in _quantizable_op_type:
activation_bits=activation_bits, minor_quantizable_op_types.append(op_type)
weight_quantize_type=weight_quantize_type, if onnx_format:
quantizable_op_type=major_quantizable_op_types) add_quant_dequant_pass = AddQuantDequantPassV2(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types,
is_full_quantized=is_full_quantize)
else:
add_quant_dequant_pass = AddQuantDequantPass(
scope=_scope,
place=_place,
quantizable_op_type=minor_quantizable_op_types)
for sub_graph in graph.all_sub_graphs(): for sub_graph in graph.all_sub_graphs():
sub_graph._for_test = True sub_graph._for_test = True
freeze_pass.apply(sub_graph) add_quant_dequant_pass.apply(sub_graph)
_program = graph.to_program() # apply QuantizationFreezePass, and obtain the final quant model
if onnx_format:
def save_info(op_node, out_var_name, out_info_name, quantized_type): quant_weight_pass = QuantWeightPass(_scope, _place)
op_node._set_attr(out_info_name, 0.001) for sub_graph in graph.all_sub_graphs():
op_node._set_attr("with_quant_attr", True) sub_graph._for_test = True
if op_node.type in _quantizable_op_type: quant_weight_pass.apply(sub_graph)
op._set_attr("quantization_type", quantized_type) else:
freeze_pass = QuantizationFreezePass(
def analysis_and_save_info(op_node, out_var_name): scope=_scope,
argname_index = utils._get_output_name_index(op_node, out_var_name) place=_place,
assert argname_index is not None, \ weight_bits=weight_bits,
out_var_name + " is not the output of the op" activation_bits=activation_bits,
weight_quantize_type=weight_quantize_type,
save_info(op_node, out_var_name, "out_threshold", "post_avg") quantizable_op_type=major_quantizable_op_types)
save_info(op_node, out_var_name,
argname_index[0] + str(argname_index[1]) + "_threshold", for sub_graph in graph.all_sub_graphs():
"post_avg") sub_graph._for_test = True
freeze_pass.apply(sub_graph)
for block_id in range(len(_program.blocks)):
for op in _program.blocks[block_id].ops: _program = graph.to_program()
if op.type in (_quantizable_op_type + utils._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op) def save_info(op_node, out_var_name, out_info_name, quantized_type):
for var_name in out_var_names: op_node._set_attr(out_info_name, 0.001)
analysis_and_save_info(op, var_name) op_node._set_attr("with_quant_attr", True)
if op_node.type in _quantizable_op_type:
feed_vars = [_program.global_block().var(name) for name in _feed_list] op._set_attr("quantization_type", quantized_type)
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model' def analysis_and_save_info(op_node, out_var_name):
save_model_path = os.path.join(save_model_path, model_name) argname_index = utils._get_output_name_index(op_node, out_var_name)
paddle.static.save_inference_model( assert argname_index is not None, \
path_prefix=save_model_path, out_var_name + " is not the output of the op"
feed_vars=feed_vars,
fetch_vars=_fetch_list, save_info(op_node, out_var_name, "out_threshold", "post_avg")
executor=executor, save_info(op_node, out_var_name,
program=_program) argname_index[0] + str(argname_index[1]) + "_threshold",
print("The quantized model is saved in: " + save_model_path) "post_avg")
for block_id in range(len(_program.blocks)):
for op in _program.blocks[block_id].ops:
if op.type in (_quantizable_op_type + utils._out_scale_op_list):
out_var_names = utils._get_op_output_var_names(op)
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model'
save_model_path = os.path.join(save_model_path, model_name)
paddle.static.save_inference_model(
path_prefix=save_model_path,
feed_vars=feed_vars,
fetch_vars=_fetch_list,
executor=executor,
program=_program)
print("The quantized model is saved in: " + save_model_path)
...@@ -10,7 +10,7 @@ Distillation: ...@@ -10,7 +10,7 @@ Distillation:
TrainConfig: TrainConfig:
epochs: 1 epochs: 1
eval_iter: 1070 eval_iter: 1
learning_rate: 2.0e-5 learning_rate: 2.0e-5
optimizer_builder: optimizer_builder:
optimizer: optimizer:
......
...@@ -125,7 +125,7 @@ class TestLoadONNXModel(ACTBase): ...@@ -125,7 +125,7 @@ class TestLoadONNXModel(ACTBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(TestLoadONNXModel, self).__init__(*args, **kwargs) super(TestLoadONNXModel, self).__init__(*args, **kwargs)
os.system( os.system(
'wget https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx') 'wget -q https://paddle-slim-models.bj.bcebos.com/act/yolov5s.onnx')
self.model_dir = 'yolov5s.onnx' self.model_dir = 'yolov5s.onnx'
def test_compress(self): def test_compress(self):
......
import os
import sys
import unittest
sys.path.append("../../")
import numpy as np
import paddle
from paddle.io import Dataset
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
class RandomEvalDataset(Dataset):
def __init__(self, num_samples, image_shape=[3, 398, 224], class_num=10):
self.num_samples = num_samples
self.image_shape = image_shape
self.class_num = class_num
def __getitem__(self, idx):
image = np.random.random(self.image_shape).astype('float32')
return image
def __len__(self):
return self.num_samples
class ACTSparse(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTSparse, self).__init__(*args, **kwargs)
if not os.path.exists('ppseg_lite_portrait_398x224_with_softmax'):
os.system(
"wget -q https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz"
)
os.system(
'tar -xzvf ppseg_lite_portrait_398x224_with_softmax.tar.gz')
self.create_dataloader()
self.get_train_config()
def create_dataloader(self):
# define a random dataset
self.eval_dataset = RandomEvalDataset(32)
def get_train_config(self):
self.train_config = {
'TrainConfig': {
'epochs': 1,
'eval_iter': 1,
'learning_rate': 5.0e-03,
'optimizer_builder': {
'optimizer': {
'type': 'SGD'
},
"weight_decay": 0.0005,
}
}
}
def test_demo(self):
image = paddle.static.data(
name='x', shape=[-1, 3, 398, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
self.eval_dataset, feed_list=[image], batch_size=4)
ac = AutoCompression(
model_dir="./ppseg_lite_portrait_398x224_with_softmax",
model_filename="model.pdmodel",
params_filename="model.pdiparams",
input_shapes=[1, 3, 398, 224],
config=self.train_config,
save_dir="ppliteseg_output",
train_dataloader=train_loader,
deploy_hardware='SD710')
ac.compress()
os.system('rm -rf ppliteseg_output')
if __name__ == '__main__':
unittest.main()
...@@ -31,14 +31,16 @@ class ImageNetDataset(DatasetFolder): ...@@ -31,14 +31,16 @@ class ImageNetDataset(DatasetFolder):
class ACTDemo(unittest.TestCase): class ACTDemo(unittest.TestCase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ACTDemo, self).__init__(*args, **kwargs) super(ACTDemo, self).__init__(*args, **kwargs)
os.system( if not os.path.exists('MobileNetV1_infer'):
'wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar' os.system(
) 'wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
os.system('tar -xf MobileNetV1_infer.tar') )
os.system( os.system('tar -xf MobileNetV1_infer.tar')
'wget https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz' if not os.path.exists('ILSVRC2012_data_demo'):
) os.system(
os.system('tar -xf ILSVRC2012_data_demo.tar.gz') 'wget https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self): def test_demo(self):
train_dataset = ImageNetDataset( train_dataset = ImageNetDataset(
......
import os
import sys
import unittest
import numpy as np
sys.path.append("../../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, data_dir, image_size=224, mode='train'):
super(ImageNetDataset, self).__init__(data_dir)
self.data_dir = data_dir
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
self.mode = mode
train_file_list = os.path.join(data_dir, 'train_list.txt')
val_file_list = os.path.join(data_dir, 'val_list.txt')
self.mode = mode
if mode == 'train':
with open(train_file_list) as flist:
full_lines = [line.strip() for line in flist]
np.random.shuffle(full_lines)
lines = full_lines
self.samples = [line.split() for line in lines]
else:
with open(val_file_list) as flist:
lines = [line.strip() for line in flist]
self.samples = [line.split() for line in lines]
def __getitem__(self, idx):
img_path, label = self.samples[idx]
if self.mode == 'train':
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB'))
else:
return self.transform(
Image.open(os.path.join(self.data_dir, img_path)).convert(
'RGB')), np.array([label]).astype('int64')
def __len__(self):
return len(self.samples)
class ACTEvalFunction(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTEvalFunction, self).__init__(*args, **kwargs)
if not os.path.exists('MobileNetV1_infer'):
os.system(
'wget -q https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset("./ILSVRC2012_data_demo/ILSVRC2012/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
label = paddle.static.data(
name='labels', shape=[None] + [1], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=32, return_list=False)
def reader_wrapper(reader, input_name):
def gen():
for i, (imgs, label) in enumerate(reader()):
yield {input_name: imgs}
return gen
def eval_reader(data_dir,
batch_size,
crop_size,
resize_size,
place=None):
val_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/", mode='val')
val_loader = paddle.io.DataLoader(
val_dataset,
feed_list=[image, label],
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=0,
return_list=False)
return val_loader
def eval_function(exe, compiled_test_program, test_feed_names,
test_fetch_list):
val_loader = eval_reader(
'./ILSVRC2012_data_demo/ILSVRC2012/',
batch_size=32,
crop_size=224,
resize_size=256)
results = []
print('Evaluating...')
for batch_id, data in enumerate(val_loader):
image = data[0]['inputs']
label = data[0]['labels']
# top1_acc, top5_acc
if len(test_feed_names) == 1:
image = np.array(image)
label = np.array(label).astype('int64')
pred = exe.run(compiled_test_program,
feed={test_feed_names[0]: image},
fetch_list=test_fetch_list)
pred = np.array(pred[0])
label = np.array(label)
sort_array = pred.argsort(axis=1)
top_1_pred = sort_array[:, -1:][:, ::-1]
top_1 = np.mean(label == top_1_pred)
top_5_pred = sort_array[:, -5:][:, ::-1]
acc_num = 0
for i in range(len(label)):
if label[i][0] in top_5_pred[i]:
acc_num += 1
top_5 = float(acc_num) / len(label)
results.append([top_1, top_5])
else:
# eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
image = np.array(image)
label = np.array(label).astype('int64')
result = exe.run(compiled_test_program,
feed={
test_feed_names[0]: image,
test_feed_names[1]: label
},
fetch_list=test_fetch_list)
result = [np.mean(r) for r in result]
results.append(result)
if batch_id % 100 == 0:
print('Eval iter: ', batch_id)
result = np.mean(np.array(results), axis=0)
return result[0]
ac = AutoCompression(
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_eval_quant",
config='./qat_dist_train.yaml',
train_dataloader=train_loader,
eval_callback=eval_function)
ac.compress()
os.system('rm -rf MobileNetV1_eval_quant')
if __name__ == '__main__':
unittest.main()
import os
import sys
sys.path.append("../../")
import numpy as np
import unittest
import paddle
from paddle.io import Dataset
from paddleslim.common import load_config
from paddleslim.auto_compression.compressor import AutoCompression
class RandomDataset(Dataset):
def __init__(self, num_samples, sample_shape=[128]):
self.num_samples = num_samples
self.sample_shape = sample_shape
def __getitem__(self, idx):
input_ids = np.random.random(self.sample_shape).astype('int64')
token_type_ids = np.random.random(self.sample_shape).astype('int64')
return input_ids, token_type_ids
def __len__(self):
return self.num_samples
class RandomEvalDataset(Dataset):
def __init__(self, num_samples, sample_shape=[128]):
self.num_samples = num_samples
self.sample_shape = sample_shape
def __getitem__(self, idx):
input_ids = np.random.random(self.sample_shape).astype('int64')
token_type_ids = np.random.random(self.sample_shape).astype('int64')
labels = np.ones(([1])).astype('int64')
return input_ids, token_type_ids, labels
def __len__(self):
return self.num_samples
### select transformer_prune and qat
class NLPAutoCompress(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(NLPAutoCompress, self).__init__(*args, **kwargs)
paddle.enable_static()
if not os.path.exists('afqmc'):
os.system(
'wget -q https://bj.bcebos.com/v1/paddle-slim-models/act/afqmc.tar'
)
os.system('tar -xf afqmc.tar')
self.create_dataset()
self.get_train_config()
def create_dataset(self):
self.fake_dataset = RandomDataset(32)
self.fake_eval_dataset = RandomEvalDataset(32)
def get_train_config(self):
self.train_config = {
'TrainConfig': {
'epochs': 1,
'eval_iter': 1,
'learning_rate': 2.0e-5,
'optimizer_builder': {
'optimizer': {
'type': 'AdamW'
},
'weight_decay': 0.01
},
}
}
def test_nlp(self):
input_ids = paddle.static.data(
name='input_ids', shape=[-1, -1], dtype='int64')
token_type_ids = paddle.static.data(
name='token_type_ids', shape=[-1, -1], dtype='int64')
labels = paddle.static.data(name='labels', shape=[-1], dtype='int64')
train_loader = paddle.io.DataLoader(
self.fake_dataset,
feed_list=[input_ids, token_type_ids],
batch_size=32,
return_list=False)
eval_loader = paddle.io.DataLoader(
self.fake_eval_dataset,
feed_list=[input_ids, token_type_ids, labels],
batch_size=32,
return_list=False)
ac = AutoCompression(
model_dir='afqmc',
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
config=self.train_config,
save_dir="nlp_ac_output",
train_dataloader=train_loader,
eval_dataloader=eval_loader)
ac.compress()
os.system("rm -rf nlp_ac_output")
os.system("rm -rf afqmc*")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册