未验证 提交 a7936b3f 编写于 作者: C ceci3 提交者: GitHub

fix auto compress bugs (#1293)

* fix demo

* add unittest

* update unittest
上级 ae630a78
......@@ -47,7 +47,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
loader = paddle.io.DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
num_workers=1,
num_workers=0,
return_list=True, )
total_iters = len(loader)
......@@ -137,7 +137,7 @@ def main(args):
train_dataset,
places=[place],
batch_sampler=batch_sampler,
num_workers=2,
num_workers=0,
return_list=True,
worker_init_fn=worker_init_fn)
train_dataloader = reader_wrapper(train_loader)
......
......@@ -143,6 +143,8 @@ class AutoCompression:
self.train_config = TrainConfig(**config.pop('TrainConfig'))
else:
self.train_config = None
else:
self.train_config = None
self.strategy_config = extract_strategy_config(config)
# prepare dataloader
......@@ -596,7 +598,6 @@ class AutoCompression:
_logger.info(
"==> The ACT compression has been completed and the final model is saved in `{}`".
format(final_model_path))
os._exit(0)
def single_strategy_compress(self, strategy, config, strategy_idx,
train_config):
......
......@@ -144,12 +144,7 @@ def standardization(data):
"""standardization numpy array"""
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
if isinstance(sigma, list) or isinstance(sigma, np.ndarray):
for idx, sig in enumerate(sigma):
if sig == 0.:
sigma[idx] = 1e-13
else:
sigma = 1e-13 if sigma == 0. else sigma
sigma = 1e-13 if sigma == 0. else sigma
return (data - mu) / sigma
......@@ -246,15 +241,19 @@ def eval_quant_model():
if have_invalid_num(out_float) or have_invalid_num(out_quant):
continue
out_float_list.append(list(out_float))
out_quant_list.append(list(out_quant))
try:
if len(out_float) > 3:
out_float = standardization(out_float)
out_quant = standardization(out_quant)
except:
continue
out_float_list.append(out_float)
out_quant_list.append(out_quant)
valid_data_num += 1
if valid_data_num >= max_eval_data_num:
break
out_float_list = standardization(out_float_list)
out_quant_list = standardization(out_quant_list)
emd_sum = cal_emd_lose(out_float_list, out_quant_list,
out_len_sum / float(valid_data_num))
_logger.info("output diff: {}".format(emd_sum))
......
import os
import sys
import unittest
sys.path.append("../../")
import paddle
from PIL import Image
from paddle.vision.datasets import DatasetFolder
from paddle.vision.transforms import transforms
from paddleslim.auto_compression import AutoCompression
paddle.enable_static()
class ImageNetDataset(DatasetFolder):
def __init__(self, path, image_size=224):
super(ImageNetDataset, self).__init__(path)
normalize = transforms.Normalize(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
self.transform = transforms.Compose([
transforms.Resize(256), transforms.CenterCrop(image_size),
transforms.Transpose(), normalize
])
def __getitem__(self, idx):
img_path, _ = self.samples[idx]
return self.transform(Image.open(img_path).convert('RGB'))
def __len__(self):
return len(self.samples)
class ACTDemo(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(ACTDemo, self).__init__(*args, **kwargs)
os.system(
'wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/MobileNetV1_infer.tar'
)
os.system('tar -xf MobileNetV1_infer.tar')
os.system(
'wget https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
def test_demo(self):
train_dataset = ImageNetDataset(
"./ILSVRC2012_data_demo/ILSVRC2012/train/")
image = paddle.static.data(
name='inputs', shape=[None] + [3, 224, 224], dtype='float32')
train_loader = paddle.io.DataLoader(
train_dataset, feed_list=[image], batch_size=32, return_list=False)
ac = AutoCompression(
model_dir="./MobileNetV1_infer",
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
save_dir="MobileNetV1_quant",
config={
'Quantization': {},
"HyperParameterOptimization": {
'ptq_algo': ['avg'],
'max_quant_count': 3
}
},
train_dataloader=train_loader,
eval_dataloader=train_loader)
ac.compress()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册