diff --git a/demo/imagenet_reader.py b/demo/imagenet_reader.py index dd102eb915ec98ff3e48cbd239e59637f44a37bd..609bfba168520610f4749d0546a69f3bf321fa55 100644 --- a/demo/imagenet_reader.py +++ b/demo/imagenet_reader.py @@ -11,6 +11,7 @@ random.seed(0) np.random.seed(0) DATA_DIM = 224 +RESIZE_DIM = 256 THREAD = 16 BUF_SIZE = 10240 @@ -34,8 +35,8 @@ def crop_image(img, target_size, center): width, height = img.size size = target_size if center == True: - w_start = (width - size) / 2 - h_start = (height - size) / 2 + w_start = (width - size) // 2 + h_start = (height - size) // 2 else: w_start = np.random.randint(0, width - size + 1) h_start = np.random.randint(0, height - size + 1) @@ -98,7 +99,7 @@ def distort_color(img): return img -def process_image(sample, mode, color_jitter, rotate): +def process_image(sample, mode, color_jitter, rotate, crop_size, resize_size): img_path = sample[0] try: @@ -108,10 +109,10 @@ def process_image(sample, mode, color_jitter, rotate): return None if mode == 'train': if rotate: img = rotate_image(img) - img = random_crop(img, DATA_DIM) + img = random_crop(img, crop_size) else: - img = resize_short(img, target_size=256) - img = crop_image(img, target_size=DATA_DIM, center=True) + img = resize_short(img, target_size=resize_size) + img = crop_image(img, target_size=crop_size, center=True) if mode == 'train': if color_jitter: img = distort_color(img) @@ -185,9 +186,15 @@ def test(data_dir=DATA_DIR): class ImageNetDataset(Dataset): - def __init__(self, data_dir=DATA_DIR, mode='train'): + def __init__(self, + data_dir=DATA_DIR, + mode='train', + crop_size=DATA_DIM, + resize_size=RESIZE_DIM): super(ImageNetDataset, self).__init__() self.data_dir = data_dir + self.crop_size = crop_size + self.resize_size = resize_size train_file_list = os.path.join(data_dir, 'train_list.txt') val_file_list = os.path.join(data_dir, 'val_list.txt') test_file_list = os.path.join(data_dir, 'test_list.txt') @@ -211,21 +218,27 @@ class ImageNetDataset(Dataset): [data_path, sample[1]], mode='train', color_jitter=False, - rotate=False) + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) return data, np.array([label]).astype('int64') elif self.mode == 'val': data, label = process_image( [data_path, sample[1]], mode='val', color_jitter=False, - rotate=False) + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) return data, np.array([label]).astype('int64') elif self.mode == 'test': data = process_image( [data_path, sample[1]], mode='test', color_jitter=False, - rotate=False) + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) return data def __len__(self): diff --git a/example/auto_compression/image_classification/README.md b/example/auto_compression/image_classification/README.md index a6a08a321f1e6d724303e0fb0322e17e02453700..438bd5016a4f4e69cf2573886a2cfcf8fe091cfe 100644 --- a/example/auto_compression/image_classification/README.md +++ b/example/auto_compression/image_classification/README.md @@ -24,25 +24,25 @@ | 模型 | 策略 | Top-1 Acc | GPU 耗时(ms) | ARM CPU 耗时(ms) | |:------:|:------:|:------:|:------:|:------:| | MobileNetV1 | Baseline | 70.90 | - | 33.15 | -| MobileNetV1 | 量化+蒸馏 | 70.49 | - | 13.64 | +| MobileNetV1 | 量化+蒸馏 | 70.57 | - | 13.64 | | ResNet50_vd | Baseline | 79.12 | 3.19 | - | -| ResNet50_vd | 量化+蒸馏 | 78.55 | 0.92 | - | +| ResNet50_vd | 量化+蒸馏 | 78.74 | 0.92 | - | | ShuffleNetV2_x1_0 | Baseline | 68.65 | - | 10.43 | -| ShuffleNetV2_x1_0 | 量化+蒸馏 | 67.78 | - | 5.51 | +| ShuffleNetV2_x1_0 | 量化+蒸馏 | 68.32 | - | 5.51 | | SqueezeNet1_0_infer | Baseline | 59.60 | - | 35.98 | -| SqueezeNet1_0_infer | 量化+蒸馏 | 59.13 | - | 16.96 | +| SqueezeNet1_0_infer | 量化+蒸馏 | 59.45 | - | 16.96 | | PPLCNetV2_base | Baseline | 76.86 | - | 36.50 | | PPLCNetV2_base | 量化+蒸馏 | 76.43 | - | 15.79 | | PPHGNet_tiny | Baseline | 79.59 | 2.82 | - | -| PPHGNet_tiny | 量化+蒸馏 | 79.19 | 0.98 | - | +| PPHGNet_tiny | 量化+蒸馏 | 79.20 | 0.98 | - | +| InceptionV3 | Baseline | 79.14 | 4.79 | - | +| InceptionV3 | 量化+蒸馏 | 78.32 | 1.47 | - | | EfficientNetB0 | Baseline | 77.02 | 1.95 | - | -| EfficientNetB0 | 量化+蒸馏 | 73.61 | 1.44 | - | +| EfficientNetB0 | 量化+蒸馏 | 75.39 | 1.44 | - | | GhostNet_x1_0 | Baseline | 74.02 | 2.93 | - | -| GhostNet_x1_0 | 量化+蒸馏 | 71.11 | 1.03 | - | -| InceptionV3 | Baseline | 79.14 | 4.79 | - | -| InceptionV3 | 量化+蒸馏 | 73.16 | 1.47 | - | +| GhostNet_x1_0 | 量化+蒸馏 | 72.62 | 1.03 | - | | MobileNetV3_large_x1_0 | Baseline | 75.32 | - | 16.62 | -| MobileNetV3_large_x1_0 | 量化+蒸馏 | 68.84 | - | 9.85 | +| MobileNetV3_large_x1_0 | 量化+蒸馏 | 70.93 | - | 9.85 | - ARM CPU 测试环境:`SDM865(4xA77+4xA55)` - Nvidia GPU 测试环境: @@ -119,7 +119,7 @@ python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' - 准备好inference模型后,使用以下命令进行预测: ```shell -python infer.py -c configs/infer.yaml +python infer.py --config_path="configs/infer.yaml" ``` 在配置文件```configs/infer.yaml```中有以下字段用于配置预测参数: @@ -134,7 +134,7 @@ python infer.py -c configs/infer.yaml - ```PostProcess.Topk.class_id_map_file```:数据集 label 的映射文件,默认为```./images/imagenet1k_label_list.txt```,该文件为 PaddleClas 所使用的 ImageNet 数据集 label 映射文件 注意: -- 请注意模型的输入数据尺寸,部分模型需要修改参数:```PreProcess.resize_short```, ```PreProcess.resize``` +- 请注意模型的输入数据尺寸,如InceptionV3输入尺寸为299,所以部分模型需要修改参数:```PreProcess.resize_short```, ```PreProcess.resize``` - 如果希望提升评测模型速度,使用 ```GPU``` 评测时,建议开启 ```TensorRT``` 加速预测,使用 ```CPU``` 评测时,建议开启 ```MKL-DNN``` 加速预测 - 若使用 TesorRT 预测引擎,需安装 ```WITH_TRT=ON``` 的Paddle,下载地址:[Python预测库](https://paddleinference.paddlepaddle.org.cn/master/user_guides/download_lib.html#python) diff --git a/example/auto_compression/image_classification/configs/EfficientNetB0/qat_dis.yaml b/example/auto_compression/image_classification/configs/EfficientNetB0/qat_dis.yaml index 608e091ce4259fa5429122981dd51e96a50da9f2..f93f9dcf54d6d4bd4862227a7689f01ed3ac2af9 100644 --- a/example/auto_compression/image_classification/configs/EfficientNetB0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/EfficientNetB0/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/GhostNet_x1_0/qat_dis.yaml b/example/auto_compression/image_classification/configs/GhostNet_x1_0/qat_dis.yaml index ef715027e481e84e149804b01af8550d17b311c0..1a13ce1e4006a4f639a3794386d9f5668c12c137 100644 --- a/example/auto_compression/image_classification/configs/GhostNet_x1_0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/GhostNet_x1_0/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant @@ -33,3 +33,4 @@ TrainConfig: optimizer: type: Momentum weight_decay: 0.00002 + origin_metric: 0.7402 diff --git a/example/auto_compression/image_classification/configs/InceptionV3/prune_dis.yaml b/example/auto_compression/image_classification/configs/InceptionV3/prune_dis.yaml index a7edd881e870acbee59ec9578d3e1ac1d02340fa..1b57873facd9fb3c110a56d257831a423bfc82d0 100644 --- a/example/auto_compression/image_classification/configs/InceptionV3/prune_dis.yaml +++ b/example/auto_compression/image_classification/configs/InceptionV3/prune_dis.yaml @@ -4,6 +4,8 @@ Global: model_filename: inference.pdmodel params_filename: inference.pdiparams batch_size: 32 + resize_size: 320 + crop_size: 299 data_dir: /ILSVRC2012 Distillation: diff --git a/example/auto_compression/image_classification/configs/InceptionV3/qat_dis.yaml b/example/auto_compression/image_classification/configs/InceptionV3/qat_dis.yaml index 61fb42bbb9047f2491534e8b3e0ae444bc62a0e5..e8b1630aaaf5f35996ed8a874f4bc776f7251a67 100644 --- a/example/auto_compression/image_classification/configs/InceptionV3/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/InceptionV3/qat_dis.yaml @@ -1,19 +1,21 @@ Global: input_name: x - model_dir: InceptionV3_infer + model_dir: save_quant_inception model_filename: inference.pdmodel params_filename: inference.pdiparams batch_size: 32 - data_dir: /ILSVRC2012 + resize_size: 320 + img_size: 299 + data_dir: /workspace/dataset/ILSVRC2012 Distillation: - alpha: 10.0 + alpha: 1.0 loss: l2 node: - softmax_1.tmp_0 Quantization: is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant @@ -21,6 +23,7 @@ Quantization: - conv2d - depthwise_conv2d weight_bits: 8 + TrainConfig: epochs: 1 eval_iter: 500 diff --git a/example/auto_compression/image_classification/configs/MobileNetV1/qat_dis.yaml b/example/auto_compression/image_classification/configs/MobileNetV1/qat_dis.yaml index e22a450b6f6bf4922fbc47c92ad5d5237cba1fcb..8b9aae31897489333a2cfb32264aa815705cdc6f 100644 --- a/example/auto_compression/image_classification/configs/MobileNetV1/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/MobileNetV1/qat_dis.yaml @@ -16,7 +16,7 @@ Quantization: activation_bits: 8 is_full_quantize: false activation_quantize_type: moving_average_abs_max - weight_quantize_type: abs_max + weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant quantize_op_types: diff --git a/example/auto_compression/image_classification/configs/MobileNetV3_large_x1_0/qat_dis.yaml b/example/auto_compression/image_classification/configs/MobileNetV3_large_x1_0/qat_dis.yaml index a890a63952736e7c0317f24e2f7ae63fd53c951b..a1fbef1bc02207d6934ee82d9b3ee3b41ec7a6f1 100644 --- a/example/auto_compression/image_classification/configs/MobileNetV3_large_x1_0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/MobileNetV3_large_x1_0/qat_dis.yaml @@ -12,10 +12,10 @@ Distillation: node: - softmax_0.tmp_0 Quantization: + use_pact: true activation_bits: 8 is_full_quantize: false - use_pact: true - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant @@ -25,10 +25,10 @@ Quantization: weight_bits: 8 TrainConfig: epochs: 1 - eval_iter: 2000 + eval_iter: 500 learning_rate: type: CosineAnnealingDecay - learning_rate: 0.0001 + learning_rate: 0.015 optimizer_builder: optimizer: type: Momentum diff --git a/example/auto_compression/image_classification/configs/PPHGNet_tiny/prune_dis.yaml b/example/auto_compression/image_classification/configs/PPHGNet_tiny/prune_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d886b6d9ac9695c5cab77d3680c5b84fc78fe41 --- /dev/null +++ b/example/auto_compression/image_classification/configs/PPHGNet_tiny/prune_dis.yaml @@ -0,0 +1,37 @@ +Global: + input_name: x + model_dir: PPHGNet_tiny_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 32 + data_dir: /ILSVRC2012 + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_1.tmp_0 +UnstructurePrune: + prune_strategy: gmp + prune_mode: ratio + ratio: 0.75 + gmp_config: + stable_iterations: 0 + pruning_iterations: 4500 + tunning_iterations: 4500 + resume_iteration: -1 + pruning_steps: 100 + initial_ratio: 0.15 + prune_params_type: conv1x1_only + local_sparsity: True +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7959 \ No newline at end of file diff --git a/example/auto_compression/image_classification/configs/PPHGNet_tiny/qat_dis.yaml b/example/auto_compression/image_classification/configs/PPHGNet_tiny/qat_dis.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6c2462f91b2b5b52f40f2c041c757abe2d7be860 --- /dev/null +++ b/example/auto_compression/image_classification/configs/PPHGNet_tiny/qat_dis.yaml @@ -0,0 +1,36 @@ +Global: + input_name: x + model_dir: PPHGNet_tiny_infer + model_filename: inference.pdmodel + params_filename: inference.pdiparams + batch_size: 32 + data_dir: /ILSVRC2012 + +Distillation: + alpha: 1.0 + loss: l2 + node: + - softmax_1.tmp_0 +Quantization: + use_pact: true + activation_bits: 8 + is_full_quantize: false + activation_quantize_type: moving_average_abs_max + weight_quantize_type: channel_wise_abs_max + not_quant_pattern: + - skip_quant + quantize_op_types: + - conv2d + - depthwise_conv2d + weight_bits: 8 +TrainConfig: + epochs: 1 + eval_iter: 500 + learning_rate: + type: CosineAnnealingDecay + learning_rate: 0.015 + optimizer_builder: + optimizer: + type: Momentum + weight_decay: 0.00002 + origin_metric: 0.7959 diff --git a/example/auto_compression/image_classification/configs/PPLCNetV2_base/qat_dis.yaml b/example/auto_compression/image_classification/configs/PPLCNetV2_base/qat_dis.yaml index e9097d26f6d9e20ef3267cbea93078ff1bf97f82..19fdd97aafcb18203f4babf4347790003576bdb7 100644 --- a/example/auto_compression/image_classification/configs/PPLCNetV2_base/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/PPLCNetV2_base/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/PPLCNet_x1_0/qat_dis.yaml b/example/auto_compression/image_classification/configs/PPLCNet_x1_0/qat_dis.yaml index 109dec68b56be46226e6e9070f3dd2eede1367d9..2754b5d8a5a6c0e05a00345adcdb92171280078a 100644 --- a/example/auto_compression/image_classification/configs/PPLCNet_x1_0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/PPLCNet_x1_0/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/ResNet50_vd/qat_dis.yaml b/example/auto_compression/image_classification/configs/ResNet50_vd/qat_dis.yaml index e99dddda110bbb907faf8e21627e3a8c9f5fdba4..05d51b7099f730ba947ad4755091ffb885f70a85 100644 --- a/example/auto_compression/image_classification/configs/ResNet50_vd/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/ResNet50_vd/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/ShuffleNetV2_x1_0/qat_dis.yaml b/example/auto_compression/image_classification/configs/ShuffleNetV2_x1_0/qat_dis.yaml index 1b1cd8b860eba13f7b9f0bf192ad57a52a36053d..815dccaed50675eb923ff44a22125466ed226944 100644 --- a/example/auto_compression/image_classification/configs/ShuffleNetV2_x1_0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/ShuffleNetV2_x1_0/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/SqueezeNet1_0/qat_dis.yaml b/example/auto_compression/image_classification/configs/SqueezeNet1_0/qat_dis.yaml index ed240ec0f6299efaf37b94cd61cecb8ab6d80722..33d7cc3f0dcf6b783672b6d0b98813c452fb491c 100644 --- a/example/auto_compression/image_classification/configs/SqueezeNet1_0/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/SqueezeNet1_0/qat_dis.yaml @@ -11,13 +11,10 @@ Distillation: loss: l2 node: - softmax_0.tmp_0 - teacher_model_dir: SqueezeNet1_0_infer - teacher_model_filename: inference.pdmodel - teacher_params_filename: inference.pdiparams Quantization: activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/SwinTransformer_base_patch4_window7_224/qat_dis.yaml b/example/auto_compression/image_classification/configs/SwinTransformer_base_patch4_window7_224/qat_dis.yaml index 4941ee92a4d27314fb340294dec89990a112b007..148085943fc77107f04cdd83fcec0cb31ffd427a 100644 --- a/example/auto_compression/image_classification/configs/SwinTransformer_base_patch4_window7_224/qat_dis.yaml +++ b/example/auto_compression/image_classification/configs/SwinTransformer_base_patch4_window7_224/qat_dis.yaml @@ -15,7 +15,7 @@ Quantization: use_pact: true activation_bits: 8 is_full_quantize: false - activation_quantize_type: range_abs_max + activation_quantize_type: moving_average_abs_max weight_quantize_type: channel_wise_abs_max not_quant_pattern: - skip_quant diff --git a/example/auto_compression/image_classification/configs/infer.yaml b/example/auto_compression/image_classification/configs/infer.yaml index 945575015dd403010645473a2ff08f38659397f6..c877526b46893d507205ebdd30aea49c96d69bd7 100644 --- a/example/auto_compression/image_classification/configs/infer.yaml +++ b/example/auto_compression/image_classification/configs/infer.yaml @@ -1,40 +1,17 @@ -Global: - infer_imgs: "./images/ILSVRC2012_val_00000010.jpeg" - inference_model_dir: "./MobileNetV1_infer" - model_filename: "inference.pdmodel" - params_filename: "inference.pdiparams" - batch_size: 1 - use_gpu: True - enable_mkldnn: True - cpu_num_threads: 10 - enable_benchmark: True - use_fp16: False - use_int8: False - ir_optim: True - use_tensorrt: True - gpu_mem: 8000 - enable_profile: False - benchmark: True - -PreProcess: - transform_ops: - - ResizeImage: - resize_short: 256 - - CropImage: - size: 224 - - NormalizeImage: - scale: 0.00392157 - mean: [0.485, 0.456, 0.406] - std: [0.229, 0.224, 0.225] - order: '' - channel_num: 3 - - ToCHWImage: - -PostProcess: - main_indicator: Topk - Topk: - topk: 5 - class_id_map_file: "./images/imagenet1k_label_list.txt" - SavePreLabel: - save_dir: ./pre_label/ +inference_model_dir: "./MobileNetV1_infer" +model_filename: "inference.pdmodel" +params_filename: "inference.pdiparams" +batch_size: 1 +image_size: 224 +use_gpu: True +enable_mkldnn: True +cpu_num_threads: 10 +enable_benchmark: True +use_fp16: False +use_int8: False +ir_optim: True +use_tensorrt: True +gpu_mem: 8000 +enable_profile: False +benchmark: True diff --git a/example/auto_compression/image_classification/eval.py b/example/auto_compression/image_classification/eval.py index 5d8a327aa1344354682cbe0ef59d5b1150e88008..ccb02fdb732cbd8a1a5684b076ec13ac1a659b90 100644 --- a/example/auto_compression/image_classification/eval.py +++ b/example/auto_compression/image_classification/eval.py @@ -16,6 +16,7 @@ import os import sys sys.path[0] = os.path.join( os.path.dirname("__file__"), os.path.pardir, os.path.pardir) +print(sys.path[0]) import argparse import functools from functools import partial @@ -23,8 +24,8 @@ from functools import partial import numpy as np import paddle import paddle.nn as nn -from paddle.io import Dataset, BatchSampler, DataLoader -import imagenet_reader as reader +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset from paddleslim.auto_compression.config_helpers import load_config as load_slim_config @@ -36,12 +37,22 @@ def argsparser(): default=None, help="path of compression strategy config.", required=True) + return parser -def eval_reader(data_dir, batch_size): - val_reader = paddle.batch( - reader.val(data_dir=data_dir), batch_size=batch_size) - return val_reader +def eval_reader(data_dir, batch_size, crop_size, resize_size): + val_reader = ImageNetDataset( + mode='val', + data_dir=data_dir, + crop_size=crop_size, + resize_size=resize_size) + val_loader = DataLoader( + val_reader, + batch_size=global_config['batch_size'], + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader def eval(): @@ -55,19 +66,16 @@ def eval(): params_filename=global_config["params_filename"]) print('Loaded model from: {}'.format(global_config["model_dir"])) - val_reader = eval_reader(data_dir, batch_size=global_config['batch_size']) - image = paddle.static.data( - name=global_config['input_name'], - shape=[None, 3, 224, 224], - dtype='float32') - label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + val_loader = eval_reader( + data_dir, + batch_size=global_config['batch_size'], + crop_size=img_size, + resize_size=resize_size) results = [] - print('Evaluating... It will take a while. Please wait...') - for batch_id, data in enumerate(val_reader()): - # top1_acc, top5_acc - image = np.array([[d[0]] for d in data]) - image = image.reshape((len(data), 3, 224, 224)) - label = [[d[1]] for d in data] + print('Evaluating...') + for batch_id, (image, label) in enumerate(val_loader): + image = np.array(image) + label = np.array(label).astype('int64') pred = exe.run(val_program, feed={feed_target_names[0]: image}, fetch_list=fetch_targets) @@ -92,8 +100,15 @@ def main(): all_config = load_slim_config(args.config_path) assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" global_config = all_config["Global"] + global data_dir data_dir = global_config['data_dir'] + + global img_size, resize_size + img_size = global_config['img_size'] if 'img_size' in global_config else 224 + resize_size = global_config[ + 'resize_size'] if 'resize_size' in global_config else 256 + result = eval() print('Eval Top1:', result) diff --git a/example/auto_compression/image_classification/imagenet_reader.py b/example/auto_compression/image_classification/imagenet_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..609bfba168520610f4749d0546a69f3bf321fa55 --- /dev/null +++ b/example/auto_compression/image_classification/imagenet_reader.py @@ -0,0 +1,245 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +from PIL import Image, ImageEnhance +from paddle.io import Dataset + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 +RESIZE_DIM = 256 + +THREAD = 16 +BUF_SIZE = 10240 + +DATA_DIR = 'data/ILSVRC2012/' +DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR) + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) // 2 + h_start = (height - size) // 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.size[0]) / img.size[1]) / (w**2), + (float(img.size[1]) / img.size[0]) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img.crop((i, j, i + w, j + h)) + img = img.resize((size, size), Image.LANCZOS) + return img + + +def rotate_image(img): + angle = np.random.randint(-10, 11) + img = img.rotate(angle) + return img + + +def distort_color(img): + def random_brightness(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Color(img).enhance(e) + + ops = [random_brightness, random_contrast, random_color] + np.random.shuffle(ops) + + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + + return img + + +def process_image(sample, mode, color_jitter, rotate, crop_size, resize_size): + img_path = sample[0] + + try: + img = Image.open(img_path) + except: + print(img_path, "not exists!") + return None + if mode == 'train': + if rotate: img = rotate_image(img) + img = random_crop(img, crop_size) + else: + img = resize_short(img, target_size=resize_size) + img = crop_image(img, target_size=crop_size, center=True) + if mode == 'train': + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return img, sample[1] + elif mode == 'test': + return [img] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR, + batch_size=1): + def reader(): + try: + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + lines = full_lines + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + except Exception as e: + print("Reader failed!\n{}".format(str(e))) + os._exit(1) + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') + return _reader_creator( + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=data_dir) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'test_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) + + +class ImageNetDataset(Dataset): + def __init__(self, + data_dir=DATA_DIR, + mode='train', + crop_size=DATA_DIM, + resize_size=RESIZE_DIM): + super(ImageNetDataset, self).__init__() + self.data_dir = data_dir + self.crop_size = crop_size + self.resize_size = resize_size + train_file_list = os.path.join(data_dir, 'train_list.txt') + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.mode = mode + if mode == 'train': + with open(train_file_list) as flist: + full_lines = [line.strip() for line in flist] + np.random.shuffle(full_lines) + lines = full_lines + self.data = [line.split() for line in lines] + else: + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + sample = self.data[index] + data_path = os.path.join(self.data_dir, sample[0]) + if self.mode == 'train': + data, label = process_image( + [data_path, sample[1]], + mode='train', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data, np.array([label]).astype('int64') + elif self.mode == 'val': + data, label = process_image( + [data_path, sample[1]], + mode='val', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data, np.array([label]).astype('int64') + elif self.mode == 'test': + data = process_image( + [data_path, sample[1]], + mode='test', + color_jitter=False, + rotate=False, + crop_size=self.crop_size, + resize_size=self.resize_size) + return data + + def __len__(self): + return len(self.data) diff --git a/example/auto_compression/image_classification/images/ILSVRC2012_val_00000010.jpeg b/example/auto_compression/image_classification/images/ILSVRC2012_val_00000010.jpeg deleted file mode 100644 index 6fcafb4d9dc86d3c10b2aca20b975da89d56e086..0000000000000000000000000000000000000000 Binary files a/example/auto_compression/image_classification/images/ILSVRC2012_val_00000010.jpeg and /dev/null differ diff --git a/example/auto_compression/image_classification/images/imagenet1k_label_list.txt b/example/auto_compression/image_classification/images/imagenet1k_label_list.txt deleted file mode 100644 index 376e18021d543bc45e33df771b5dc7acdd5f2e4f..0000000000000000000000000000000000000000 --- a/example/auto_compression/image_classification/images/imagenet1k_label_list.txt +++ /dev/null @@ -1,1000 +0,0 @@ -0 tench, Tinca tinca -1 goldfish, Carassius auratus -2 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias -3 tiger shark, Galeocerdo cuvieri -4 hammerhead, hammerhead shark -5 electric ray, crampfish, numbfish, torpedo -6 stingray -7 cock -8 hen -9 ostrich, Struthio camelus -10 brambling, Fringilla montifringilla -11 goldfinch, Carduelis carduelis -12 house finch, linnet, Carpodacus mexicanus -13 junco, snowbird -14 indigo bunting, indigo finch, indigo bird, Passerina cyanea -15 robin, American robin, Turdus migratorius -16 bulbul -17 jay -18 magpie -19 chickadee -20 water ouzel, dipper -21 kite -22 bald eagle, American eagle, Haliaeetus leucocephalus -23 vulture -24 great grey owl, great gray owl, Strix nebulosa -25 European fire salamander, Salamandra salamandra -26 common newt, Triturus vulgaris -27 eft -28 spotted salamander, Ambystoma maculatum -29 axolotl, mud puppy, Ambystoma mexicanum -30 bullfrog, Rana catesbeiana -31 tree frog, tree-frog -32 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui -33 loggerhead, loggerhead turtle, Caretta caretta -34 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea -35 mud turtle -36 terrapin -37 box turtle, box tortoise -38 banded gecko -39 common iguana, iguana, Iguana iguana -40 American chameleon, anole, Anolis carolinensis -41 whiptail, whiptail lizard -42 agama -43 frilled lizard, Chlamydosaurus kingi -44 alligator lizard -45 Gila monster, Heloderma suspectum -46 green lizard, Lacerta viridis -47 African chameleon, Chamaeleo chamaeleon -48 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis -49 African crocodile, Nile crocodile, Crocodylus niloticus -50 American alligator, Alligator mississipiensis -51 triceratops -52 thunder snake, worm snake, Carphophis amoenus -53 ringneck snake, ring-necked snake, ring snake -54 hognose snake, puff adder, sand viper -55 green snake, grass snake -56 king snake, kingsnake -57 garter snake, grass snake -58 water snake -59 vine snake -60 night snake, Hypsiglena torquata -61 boa constrictor, Constrictor constrictor -62 rock python, rock snake, Python sebae -63 Indian cobra, Naja naja -64 green mamba -65 sea snake -66 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus -67 diamondback, diamondback rattlesnake, Crotalus adamanteus -68 sidewinder, horned rattlesnake, Crotalus cerastes -69 trilobite -70 harvestman, daddy longlegs, Phalangium opilio -71 scorpion -72 black and gold garden spider, Argiope aurantia -73 barn spider, Araneus cavaticus -74 garden spider, Aranea diademata -75 black widow, Latrodectus mactans -76 tarantula -77 wolf spider, hunting spider -78 tick -79 centipede -80 black grouse -81 ptarmigan -82 ruffed grouse, partridge, Bonasa umbellus -83 prairie chicken, prairie grouse, prairie fowl -84 peacock -85 quail -86 partridge -87 African grey, African gray, Psittacus erithacus -88 macaw -89 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita -90 lorikeet -91 coucal -92 bee eater -93 hornbill -94 hummingbird -95 jacamar -96 toucan -97 drake -98 red-breasted merganser, Mergus serrator -99 goose -100 black swan, Cygnus atratus -101 tusker -102 echidna, spiny anteater, anteater -103 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus -104 wallaby, brush kangaroo -105 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus -106 wombat -107 jellyfish -108 sea anemone, anemone -109 brain coral -110 flatworm, platyhelminth -111 nematode, nematode worm, roundworm -112 conch -113 snail -114 slug -115 sea slug, nudibranch -116 chiton, coat-of-mail shell, sea cradle, polyplacophore -117 chambered nautilus, pearly nautilus, nautilus -118 Dungeness crab, Cancer magister -119 rock crab, Cancer irroratus -120 fiddler crab -121 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica -122 American lobster, Northern lobster, Maine lobster, Homarus americanus -123 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish -124 crayfish, crawfish, crawdad, crawdaddy -125 hermit crab -126 isopod -127 white stork, Ciconia ciconia -128 black stork, Ciconia nigra -129 spoonbill -130 flamingo -131 little blue heron, Egretta caerulea -132 American egret, great white heron, Egretta albus -133 bittern -134 crane -135 limpkin, Aramus pictus -136 European gallinule, Porphyrio porphyrio -137 American coot, marsh hen, mud hen, water hen, Fulica americana -138 bustard -139 ruddy turnstone, Arenaria interpres -140 red-backed sandpiper, dunlin, Erolia alpina -141 redshank, Tringa totanus -142 dowitcher -143 oystercatcher, oyster catcher -144 pelican -145 king penguin, Aptenodytes patagonica -146 albatross, mollymawk -147 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus -148 killer whale, killer, orca, grampus, sea wolf, Orcinus orca -149 dugong, Dugong dugon -150 sea lion -151 Chihuahua -152 Japanese spaniel -153 Maltese dog, Maltese terrier, Maltese -154 Pekinese, Pekingese, Peke -155 Shih-Tzu -156 Blenheim spaniel -157 papillon -158 toy terrier -159 Rhodesian ridgeback -160 Afghan hound, Afghan -161 basset, basset hound -162 beagle -163 bloodhound, sleuthhound -164 bluetick -165 black-and-tan coonhound -166 Walker hound, Walker foxhound -167 English foxhound -168 redbone -169 borzoi, Russian wolfhound -170 Irish wolfhound -171 Italian greyhound -172 whippet -173 Ibizan hound, Ibizan Podenco -174 Norwegian elkhound, elkhound -175 otterhound, otter hound -176 Saluki, gazelle hound -177 Scottish deerhound, deerhound -178 Weimaraner -179 Staffordshire bullterrier, Staffordshire bull terrier -180 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier -181 Bedlington terrier -182 Border terrier -183 Kerry blue terrier -184 Irish terrier -185 Norfolk terrier -186 Norwich terrier -187 Yorkshire terrier -188 wire-haired fox terrier -189 Lakeland terrier -190 Sealyham terrier, Sealyham -191 Airedale, Airedale terrier -192 cairn, cairn terrier -193 Australian terrier -194 Dandie Dinmont, Dandie Dinmont terrier -195 Boston bull, Boston terrier -196 miniature schnauzer -197 giant schnauzer -198 standard schnauzer -199 Scotch terrier, Scottish terrier, Scottie -200 Tibetan terrier, chrysanthemum dog -201 silky terrier, Sydney silky -202 soft-coated wheaten terrier -203 West Highland white terrier -204 Lhasa, Lhasa apso -205 flat-coated retriever -206 curly-coated retriever -207 golden retriever -208 Labrador retriever -209 Chesapeake Bay retriever -210 German short-haired pointer -211 vizsla, Hungarian pointer -212 English setter -213 Irish setter, red setter -214 Gordon setter -215 Brittany spaniel -216 clumber, clumber spaniel -217 English springer, English springer spaniel -218 Welsh springer spaniel -219 cocker spaniel, English cocker spaniel, cocker -220 Sussex spaniel -221 Irish water spaniel -222 kuvasz -223 schipperke -224 groenendael -225 malinois -226 briard -227 kelpie -228 komondor -229 Old English sheepdog, bobtail -230 Shetland sheepdog, Shetland sheep dog, Shetland -231 collie -232 Border collie -233 Bouvier des Flandres, Bouviers des Flandres -234 Rottweiler -235 German shepherd, German shepherd dog, German police dog, alsatian -236 Doberman, Doberman pinscher -237 miniature pinscher -238 Greater Swiss Mountain dog -239 Bernese mountain dog -240 Appenzeller -241 EntleBucher -242 boxer -243 bull mastiff -244 Tibetan mastiff -245 French bulldog -246 Great Dane -247 Saint Bernard, St Bernard -248 Eskimo dog, husky -249 malamute, malemute, Alaskan malamute -250 Siberian husky -251 dalmatian, coach dog, carriage dog -252 affenpinscher, monkey pinscher, monkey dog -253 basenji -254 pug, pug-dog -255 Leonberg -256 Newfoundland, Newfoundland dog -257 Great Pyrenees -258 Samoyed, Samoyede -259 Pomeranian -260 chow, chow chow -261 keeshond -262 Brabancon griffon -263 Pembroke, Pembroke Welsh corgi -264 Cardigan, Cardigan Welsh corgi -265 toy poodle -266 miniature poodle -267 standard poodle -268 Mexican hairless -269 timber wolf, grey wolf, gray wolf, Canis lupus -270 white wolf, Arctic wolf, Canis lupus tundrarum -271 red wolf, maned wolf, Canis rufus, Canis niger -272 coyote, prairie wolf, brush wolf, Canis latrans -273 dingo, warrigal, warragal, Canis dingo -274 dhole, Cuon alpinus -275 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus -276 hyena, hyaena -277 red fox, Vulpes vulpes -278 kit fox, Vulpes macrotis -279 Arctic fox, white fox, Alopex lagopus -280 grey fox, gray fox, Urocyon cinereoargenteus -281 tabby, tabby cat -282 tiger cat -283 Persian cat -284 Siamese cat, Siamese -285 Egyptian cat -286 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor -287 lynx, catamount -288 leopard, Panthera pardus -289 snow leopard, ounce, Panthera uncia -290 jaguar, panther, Panthera onca, Felis onca -291 lion, king of beasts, Panthera leo -292 tiger, Panthera tigris -293 cheetah, chetah, Acinonyx jubatus -294 brown bear, bruin, Ursus arctos -295 American black bear, black bear, Ursus americanus, Euarctos americanus -296 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus -297 sloth bear, Melursus ursinus, Ursus ursinus -298 mongoose -299 meerkat, mierkat -300 tiger beetle -301 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle -302 ground beetle, carabid beetle -303 long-horned beetle, longicorn, longicorn beetle -304 leaf beetle, chrysomelid -305 dung beetle -306 rhinoceros beetle -307 weevil -308 fly -309 bee -310 ant, emmet, pismire -311 grasshopper, hopper -312 cricket -313 walking stick, walkingstick, stick insect -314 cockroach, roach -315 mantis, mantid -316 cicada, cicala -317 leafhopper -318 lacewing, lacewing fly -319 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk -320 damselfly -321 admiral -322 ringlet, ringlet butterfly -323 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus -324 cabbage butterfly -325 sulphur butterfly, sulfur butterfly -326 lycaenid, lycaenid butterfly -327 starfish, sea star -328 sea urchin -329 sea cucumber, holothurian -330 wood rabbit, cottontail, cottontail rabbit -331 hare -332 Angora, Angora rabbit -333 hamster -334 porcupine, hedgehog -335 fox squirrel, eastern fox squirrel, Sciurus niger -336 marmot -337 beaver -338 guinea pig, Cavia cobaya -339 sorrel -340 zebra -341 hog, pig, grunter, squealer, Sus scrofa -342 wild boar, boar, Sus scrofa -343 warthog -344 hippopotamus, hippo, river horse, Hippopotamus amphibius -345 ox -346 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis -347 bison -348 ram, tup -349 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis -350 ibex, Capra ibex -351 hartebeest -352 impala, Aepyceros melampus -353 gazelle -354 Arabian camel, dromedary, Camelus dromedarius -355 llama -356 weasel -357 mink -358 polecat, fitch, foulmart, foumart, Mustela putorius -359 black-footed ferret, ferret, Mustela nigripes -360 otter -361 skunk, polecat, wood pussy -362 badger -363 armadillo -364 three-toed sloth, ai, Bradypus tridactylus -365 orangutan, orang, orangutang, Pongo pygmaeus -366 gorilla, Gorilla gorilla -367 chimpanzee, chimp, Pan troglodytes -368 gibbon, Hylobates lar -369 siamang, Hylobates syndactylus, Symphalangus syndactylus -370 guenon, guenon monkey -371 patas, hussar monkey, Erythrocebus patas -372 baboon -373 macaque -374 langur -375 colobus, colobus monkey -376 proboscis monkey, Nasalis larvatus -377 marmoset -378 capuchin, ringtail, Cebus capucinus -379 howler monkey, howler -380 titi, titi monkey -381 spider monkey, Ateles geoffroyi -382 squirrel monkey, Saimiri sciureus -383 Madagascar cat, ring-tailed lemur, Lemur catta -384 indri, indris, Indri indri, Indri brevicaudatus -385 Indian elephant, Elephas maximus -386 African elephant, Loxodonta africana -387 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens -388 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca -389 barracouta, snoek -390 eel -391 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch -392 rock beauty, Holocanthus tricolor -393 anemone fish -394 sturgeon -395 gar, garfish, garpike, billfish, Lepisosteus osseus -396 lionfish -397 puffer, pufferfish, blowfish, globefish -398 abacus -399 abaya -400 academic gown, academic robe, judge's robe -401 accordion, piano accordion, squeeze box -402 acoustic guitar -403 aircraft carrier, carrier, flattop, attack aircraft carrier -404 airliner -405 airship, dirigible -406 altar -407 ambulance -408 amphibian, amphibious vehicle -409 analog clock -410 apiary, bee house -411 apron -412 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin -413 assault rifle, assault gun -414 backpack, back pack, knapsack, packsack, rucksack, haversack -415 bakery, bakeshop, bakehouse -416 balance beam, beam -417 balloon -418 ballpoint, ballpoint pen, ballpen, Biro -419 Band Aid -420 banjo -421 bannister, banister, balustrade, balusters, handrail -422 barbell -423 barber chair -424 barbershop -425 barn -426 barometer -427 barrel, cask -428 barrow, garden cart, lawn cart, wheelbarrow -429 baseball -430 basketball -431 bassinet -432 bassoon -433 bathing cap, swimming cap -434 bath towel -435 bathtub, bathing tub, bath, tub -436 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon -437 beacon, lighthouse, beacon light, pharos -438 beaker -439 bearskin, busby, shako -440 beer bottle -441 beer glass -442 bell cote, bell cot -443 bib -444 bicycle-built-for-two, tandem bicycle, tandem -445 bikini, two-piece -446 binder, ring-binder -447 binoculars, field glasses, opera glasses -448 birdhouse -449 boathouse -450 bobsled, bobsleigh, bob -451 bolo tie, bolo, bola tie, bola -452 bonnet, poke bonnet -453 bookcase -454 bookshop, bookstore, bookstall -455 bottlecap -456 bow -457 bow tie, bow-tie, bowtie -458 brass, memorial tablet, plaque -459 brassiere, bra, bandeau -460 breakwater, groin, groyne, mole, bulwark, seawall, jetty -461 breastplate, aegis, egis -462 broom -463 bucket, pail -464 buckle -465 bulletproof vest -466 bullet train, bullet -467 butcher shop, meat market -468 cab, hack, taxi, taxicab -469 caldron, cauldron -470 candle, taper, wax light -471 cannon -472 canoe -473 can opener, tin opener -474 cardigan -475 car mirror -476 carousel, carrousel, merry-go-round, roundabout, whirligig -477 carpenter's kit, tool kit -478 carton -479 car wheel -480 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM -481 cassette -482 cassette player -483 castle -484 catamaran -485 CD player -486 cello, violoncello -487 cellular telephone, cellular phone, cellphone, cell, mobile phone -488 chain -489 chainlink fence -490 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour -491 chain saw, chainsaw -492 chest -493 chiffonier, commode -494 chime, bell, gong -495 china cabinet, china closet -496 Christmas stocking -497 church, church building -498 cinema, movie theater, movie theatre, movie house, picture palace -499 cleaver, meat cleaver, chopper -500 cliff dwelling -501 cloak -502 clog, geta, patten, sabot -503 cocktail shaker -504 coffee mug -505 coffeepot -506 coil, spiral, volute, whorl, helix -507 combination lock -508 computer keyboard, keypad -509 confectionery, confectionary, candy store -510 container ship, containership, container vessel -511 convertible -512 corkscrew, bottle screw -513 cornet, horn, trumpet, trump -514 cowboy boot -515 cowboy hat, ten-gallon hat -516 cradle -517 crane -518 crash helmet -519 crate -520 crib, cot -521 Crock Pot -522 croquet ball -523 crutch -524 cuirass -525 dam, dike, dyke -526 desk -527 desktop computer -528 dial telephone, dial phone -529 diaper, nappy, napkin -530 digital clock -531 digital watch -532 dining table, board -533 dishrag, dishcloth -534 dishwasher, dish washer, dishwashing machine -535 disk brake, disc brake -536 dock, dockage, docking facility -537 dogsled, dog sled, dog sleigh -538 dome -539 doormat, welcome mat -540 drilling platform, offshore rig -541 drum, membranophone, tympan -542 drumstick -543 dumbbell -544 Dutch oven -545 electric fan, blower -546 electric guitar -547 electric locomotive -548 entertainment center -549 envelope -550 espresso maker -551 face powder -552 feather boa, boa -553 file, file cabinet, filing cabinet -554 fireboat -555 fire engine, fire truck -556 fire screen, fireguard -557 flagpole, flagstaff -558 flute, transverse flute -559 folding chair -560 football helmet -561 forklift -562 fountain -563 fountain pen -564 four-poster -565 freight car -566 French horn, horn -567 frying pan, frypan, skillet -568 fur coat -569 garbage truck, dustcart -570 gasmask, respirator, gas helmet -571 gas pump, gasoline pump, petrol pump, island dispenser -572 goblet -573 go-kart -574 golf ball -575 golfcart, golf cart -576 gondola -577 gong, tam-tam -578 gown -579 grand piano, grand -580 greenhouse, nursery, glasshouse -581 grille, radiator grille -582 grocery store, grocery, food market, market -583 guillotine -584 hair slide -585 hair spray -586 half track -587 hammer -588 hamper -589 hand blower, blow dryer, blow drier, hair dryer, hair drier -590 hand-held computer, hand-held microcomputer -591 handkerchief, hankie, hanky, hankey -592 hard disc, hard disk, fixed disk -593 harmonica, mouth organ, harp, mouth harp -594 harp -595 harvester, reaper -596 hatchet -597 holster -598 home theater, home theatre -599 honeycomb -600 hook, claw -601 hoopskirt, crinoline -602 horizontal bar, high bar -603 horse cart, horse-cart -604 hourglass -605 iPod -606 iron, smoothing iron -607 jack-o'-lantern -608 jean, blue jean, denim -609 jeep, landrover -610 jersey, T-shirt, tee shirt -611 jigsaw puzzle -612 jinrikisha, ricksha, rickshaw -613 joystick -614 kimono -615 knee pad -616 knot -617 lab coat, laboratory coat -618 ladle -619 lampshade, lamp shade -620 laptop, laptop computer -621 lawn mower, mower -622 lens cap, lens cover -623 letter opener, paper knife, paperknife -624 library -625 lifeboat -626 lighter, light, igniter, ignitor -627 limousine, limo -628 liner, ocean liner -629 lipstick, lip rouge -630 Loafer -631 lotion -632 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system -633 loupe, jeweler's loupe -634 lumbermill, sawmill -635 magnetic compass -636 mailbag, postbag -637 mailbox, letter box -638 maillot -639 maillot, tank suit -640 manhole cover -641 maraca -642 marimba, xylophone -643 mask -644 matchstick -645 maypole -646 maze, labyrinth -647 measuring cup -648 medicine chest, medicine cabinet -649 megalith, megalithic structure -650 microphone, mike -651 microwave, microwave oven -652 military uniform -653 milk can -654 minibus -655 miniskirt, mini -656 minivan -657 missile -658 mitten -659 mixing bowl -660 mobile home, manufactured home -661 Model T -662 modem -663 monastery -664 monitor -665 moped -666 mortar -667 mortarboard -668 mosque -669 mosquito net -670 motor scooter, scooter -671 mountain bike, all-terrain bike, off-roader -672 mountain tent -673 mouse, computer mouse -674 mousetrap -675 moving van -676 muzzle -677 nail -678 neck brace -679 necklace -680 nipple -681 notebook, notebook computer -682 obelisk -683 oboe, hautboy, hautbois -684 ocarina, sweet potato -685 odometer, hodometer, mileometer, milometer -686 oil filter -687 organ, pipe organ -688 oscilloscope, scope, cathode-ray oscilloscope, CRO -689 overskirt -690 oxcart -691 oxygen mask -692 packet -693 paddle, boat paddle -694 paddlewheel, paddle wheel -695 padlock -696 paintbrush -697 pajama, pyjama, pj's, jammies -698 palace -699 panpipe, pandean pipe, syrinx -700 paper towel -701 parachute, chute -702 parallel bars, bars -703 park bench -704 parking meter -705 passenger car, coach, carriage -706 patio, terrace -707 pay-phone, pay-station -708 pedestal, plinth, footstall -709 pencil box, pencil case -710 pencil sharpener -711 perfume, essence -712 Petri dish -713 photocopier -714 pick, plectrum, plectron -715 pickelhaube -716 picket fence, paling -717 pickup, pickup truck -718 pier -719 piggy bank, penny bank -720 pill bottle -721 pillow -722 ping-pong ball -723 pinwheel -724 pirate, pirate ship -725 pitcher, ewer -726 plane, carpenter's plane, woodworking plane -727 planetarium -728 plastic bag -729 plate rack -730 plow, plough -731 plunger, plumber's helper -732 Polaroid camera, Polaroid Land camera -733 pole -734 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria -735 poncho -736 pool table, billiard table, snooker table -737 pop bottle, soda bottle -738 pot, flowerpot -739 potter's wheel -740 power drill -741 prayer rug, prayer mat -742 printer -743 prison, prison house -744 projectile, missile -745 projector -746 puck, hockey puck -747 punching bag, punch bag, punching ball, punchball -748 purse -749 quill, quill pen -750 quilt, comforter, comfort, puff -751 racer, race car, racing car -752 racket, racquet -753 radiator -754 radio, wireless -755 radio telescope, radio reflector -756 rain barrel -757 recreational vehicle, RV, R.V. -758 reel -759 reflex camera -760 refrigerator, icebox -761 remote control, remote -762 restaurant, eating house, eating place, eatery -763 revolver, six-gun, six-shooter -764 rifle -765 rocking chair, rocker -766 rotisserie -767 rubber eraser, rubber, pencil eraser -768 rugby ball -769 rule, ruler -770 running shoe -771 safe -772 safety pin -773 saltshaker, salt shaker -774 sandal -775 sarong -776 sax, saxophone -777 scabbard -778 scale, weighing machine -779 school bus -780 schooner -781 scoreboard -782 screen, CRT screen -783 screw -784 screwdriver -785 seat belt, seatbelt -786 sewing machine -787 shield, buckler -788 shoe shop, shoe-shop, shoe store -789 shoji -790 shopping basket -791 shopping cart -792 shovel -793 shower cap -794 shower curtain -795 ski -796 ski mask -797 sleeping bag -798 slide rule, slipstick -799 sliding door -800 slot, one-armed bandit -801 snorkel -802 snowmobile -803 snowplow, snowplough -804 soap dispenser -805 soccer ball -806 sock -807 solar dish, solar collector, solar furnace -808 sombrero -809 soup bowl -810 space bar -811 space heater -812 space shuttle -813 spatula -814 speedboat -815 spider web, spider's web -816 spindle -817 sports car, sport car -818 spotlight, spot -819 stage -820 steam locomotive -821 steel arch bridge -822 steel drum -823 stethoscope -824 stole -825 stone wall -826 stopwatch, stop watch -827 stove -828 strainer -829 streetcar, tram, tramcar, trolley, trolley car -830 stretcher -831 studio couch, day bed -832 stupa, tope -833 submarine, pigboat, sub, U-boat -834 suit, suit of clothes -835 sundial -836 sunglass -837 sunglasses, dark glasses, shades -838 sunscreen, sunblock, sun blocker -839 suspension bridge -840 swab, swob, mop -841 sweatshirt -842 swimming trunks, bathing trunks -843 swing -844 switch, electric switch, electrical switch -845 syringe -846 table lamp -847 tank, army tank, armored combat vehicle, armoured combat vehicle -848 tape player -849 teapot -850 teddy, teddy bear -851 television, television system -852 tennis ball -853 thatch, thatched roof -854 theater curtain, theatre curtain -855 thimble -856 thresher, thrasher, threshing machine -857 throne -858 tile roof -859 toaster -860 tobacco shop, tobacconist shop, tobacconist -861 toilet seat -862 torch -863 totem pole -864 tow truck, tow car, wrecker -865 toyshop -866 tractor -867 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi -868 tray -869 trench coat -870 tricycle, trike, velocipede -871 trimaran -872 tripod -873 triumphal arch -874 trolleybus, trolley coach, trackless trolley -875 trombone -876 tub, vat -877 turnstile -878 typewriter keyboard -879 umbrella -880 unicycle, monocycle -881 upright, upright piano -882 vacuum, vacuum cleaner -883 vase -884 vault -885 velvet -886 vending machine -887 vestment -888 viaduct -889 violin, fiddle -890 volleyball -891 waffle iron -892 wall clock -893 wallet, billfold, notecase, pocketbook -894 wardrobe, closet, press -895 warplane, military plane -896 washbasin, handbasin, washbowl, lavabo, wash-hand basin -897 washer, automatic washer, washing machine -898 water bottle -899 water jug -900 water tower -901 whiskey jug -902 whistle -903 wig -904 window screen -905 window shade -906 Windsor tie -907 wine bottle -908 wing -909 wok -910 wooden spoon -911 wool, woolen, woollen -912 worm fence, snake fence, snake-rail fence, Virginia fence -913 wreck -914 yawl -915 yurt -916 web site, website, internet site, site -917 comic book -918 crossword puzzle, crossword -919 street sign -920 traffic light, traffic signal, stoplight -921 book jacket, dust cover, dust jacket, dust wrapper -922 menu -923 plate -924 guacamole -925 consomme -926 hot pot, hotpot -927 trifle -928 ice cream, icecream -929 ice lolly, lolly, lollipop, popsicle -930 French loaf -931 bagel, beigel -932 pretzel -933 cheeseburger -934 hotdog, hot dog, red hot -935 mashed potato -936 head cabbage -937 broccoli -938 cauliflower -939 zucchini, courgette -940 spaghetti squash -941 acorn squash -942 butternut squash -943 cucumber, cuke -944 artichoke, globe artichoke -945 bell pepper -946 cardoon -947 mushroom -948 Granny Smith -949 strawberry -950 orange -951 lemon -952 fig -953 pineapple, ananas -954 banana -955 jackfruit, jak, jack -956 custard apple -957 pomegranate -958 hay -959 carbonara -960 chocolate sauce, chocolate syrup -961 dough -962 meat loaf, meatloaf -963 pizza, pizza pie -964 potpie -965 burrito -966 red wine -967 espresso -968 cup -969 eggnog -970 alp -971 bubble -972 cliff, drop, drop-off -973 coral reef -974 geyser -975 lakeside, lakeshore -976 promontory, headland, head, foreland -977 sandbar, sand bar -978 seashore, coast, seacoast, sea-coast -979 valley, vale -980 volcano -981 ballplayer, baseball player -982 groom, bridegroom -983 scuba diver -984 rapeseed -985 daisy -986 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum -987 corn -988 acorn -989 hip, rose hip, rosehip -990 buckeye, horse chestnut, conker -991 coral fungus -992 agaric -993 gyromitra -994 stinkhorn, carrion fungus -995 earthstar -996 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa -997 bolete -998 ear, spike, capitulum -999 toilet tissue, toilet paper, bathroom tissue diff --git a/example/auto_compression/image_classification/infer.py b/example/auto_compression/image_classification/infer.py index 88e4b82de2f13b448a6e9425174c2297421a1594..06dd90eb5687fd301f857194536d55b8ce530253 100644 --- a/example/auto_compression/image_classification/infer.py +++ b/example/auto_compression/image_classification/infer.py @@ -13,141 +13,76 @@ # limitations under the License. import os -import sys -import cv2 import numpy as np -import platform +import cv2 +import time +import sys import argparse -import base64 -import shutil +import yaml +from utils import preprocess, postprocess import paddle -from postprocess import build_postprocess -from preprocess import create_operators +from paddle.inference import create_predictor from paddleslim.auto_compression.config_helpers import load_config def argsparser(): parser = argparse.ArgumentParser(description=__doc__) parser.add_argument( - '-c', - '--config', + '--config_path', type=str, - default='configs/config.yaml', + default='configs/infer.yaml', help='config file path') return parser -def print_arguments(args): - print('----------- Running Arguments -----------') - for arg, value in args.items(): - print('%s: %s' % (arg, value)) - print('------------------------------------------') - - -def get_image_list(img_file): - imgs_lists = [] - if img_file is None or not os.path.exists(img_file): - raise Exception("not found any img file in {}".format(img_file)) - - img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] - if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: - imgs_lists.append(img_file) - elif os.path.isdir(img_file): - for single_file in os.listdir(img_file): - if single_file.split('.')[-1] in img_end: - imgs_lists.append(os.path.join(img_file, single_file)) - if len(imgs_lists) == 0: - raise Exception("not found any img file in {}".format(img_file)) - imgs_lists = sorted(imgs_lists) - return imgs_lists - - class Predictor(object): def __init__(self, config): - predict_args = config['Global'] - # HALF precission predict only work when using tensorrt - if predict_args['use_fp16'] is True: - assert predict_args.use_tensorrt is True - self.args = predict_args - if self.args.get("use_onnx", False): - self.predictor, self.config = self.create_onnx_predictor( - predict_args) - else: - self.predictor, self.config = self.create_paddle_predictor( - predict_args) - - self.preprocess_ops = [] - self.postprocess = None - if "PreProcess" in config: - if "transform_ops" in config["PreProcess"]: - self.preprocess_ops = create_operators(config["PreProcess"][ - "transform_ops"]) - if "PostProcess" in config: - self.postprocess = build_postprocess(config["PostProcess"]) - - # for whole_chain project to test each repo of paddle - self.benchmark = config["Global"].get("benchmark", False) - if self.benchmark: - import auto_log - import os - pid = os.getpid() - size = config["PreProcess"]["transform_ops"][1]["CropImage"]["size"] - if config["Global"].get("use_int8", False): - precision = "int8" - elif config["Global"].get("use_fp16", False): - precision = "fp16" - else: - precision = "fp32" - self.auto_logger = auto_log.AutoLogger( - model_name=config["Global"].get("model_name", "cls"), - model_precision=precision, - batch_size=config["Global"].get("batch_size", 1), - data_shape=[3, size, size], - save_path=config["Global"].get("save_log_path", - "./auto_log.log"), - inference_config=self.config, - pids=pid, - process_name=None, - gpu_ids=None, - time_keys=[ - 'preprocess_time', 'inference_time', 'postprocess_time' - ], - warmup=2) - - def create_paddle_predictor(self, args): - inference_model_dir = args['inference_model_dir'] - - params_file = os.path.join(inference_model_dir, args['params_filename']) - model_file = os.path.join(inference_model_dir, args['model_filename']) + # HALF precission predict only work when using tensorrt + if config['use_fp16'] is True: + assert config['use_tensorrt'] is True + self.config = config + + self.paddle_predictor = self.create_paddle_predictor() + input_names = self.paddle_predictor.get_input_names() + self.input_tensor = self.paddle_predictor.get_input_handle(input_names[ + 0]) + + output_names = self.paddle_predictor.get_output_names() + self.output_tensor = self.paddle_predictor.get_output_handle( + output_names[0]) + + def create_paddle_predictor(self): + inference_model_dir = self.config['inference_model_dir'] + model_file = os.path.join(inference_model_dir, + self.config['model_filename']) + params_file = os.path.join(inference_model_dir, + self.config['params_filename']) config = paddle.inference.Config(model_file, params_file) - - if args['use_gpu']: - config.enable_use_gpu(args['gpu_mem'], 0) + precision = paddle.inference.Config.Precision.Float32 + if self.config['use_int8']: + precision = paddle.inference.Config.Precision.Int8 + elif self.config['use_fp16']: + precision = paddle.inference.Config.Precision.Half + + if self.config['use_gpu']: + config.enable_use_gpu(self.config['gpu_mem'], 0) else: config.disable_gpu() - if args['enable_mkldnn']: - # there is no set_mkldnn_cache_capatity() on macOS - if platform.system() != "Darwin": - # cache 10 different shapes for mkldnn to avoid memory leak - config.set_mkldnn_cache_capacity(10) + if self.config['enable_mkldnn']: + # cache 10 different shapes for mkldnn to avoid memory leak + config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - config.set_cpu_math_library_num_threads(args['cpu_num_threads']) + config.set_cpu_math_library_num_threads(self.config['cpu_num_threads']) - if args['enable_profile']: + if self.config['enable_profile']: config.enable_profile() config.disable_glog_info() - config.switch_ir_optim(args['ir_optim']) # default true - if args['use_tensorrt']: - precision = paddle.inference.Config.Precision.Float32 - if args.get("use_int8", False): - precision = paddle.inference.Config.Precision.Int8 - elif args.get("use_fp16", False): - precision = paddle.inference.Config.Precision.Half - + config.switch_ir_optim(self.config['ir_optim']) # default true + if self.config['use_tensorrt']: config.enable_tensorrt_engine( precision_mode=precision, - max_batch_size=args['batch_size'], + max_batch_size=self.config['batch_size'], workspace_size=1 << 30, min_subgraph_size=30, use_calib_mode=False) @@ -155,112 +90,36 @@ class Predictor(object): config.enable_memory_optim() # use zero copy config.switch_use_feed_fetch_ops(False) - predictor = paddle.inference.create_predictor(config) - - return predictor, config - - def create_onnx_predictor(self, args): - import onnxruntime as ort - inference_model_dir = args['inference_model_dir'] - model_file = os.path.join(inference_model_dir, args['model_filename']) - config = ort.SessionOptions() - if args['use_gpu']: - raise ValueError( - "onnx inference now only supports cpu! please specify use_gpu false." - ) - else: - config.intra_op_num_threads = args['cpu_num_threads'] - if args['ir_optim']: - config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - predictor = ort.InferenceSession(model_file, sess_options=config) - return predictor, config - - def predict(self, images): - use_onnx = self.args.get("use_onnx", False) - if not use_onnx: - input_names = self.predictor.get_input_names() - input_tensor = self.predictor.get_input_handle(input_names[0]) - - output_names = self.predictor.get_output_names() - output_tensor = self.predictor.get_output_handle(output_names[0]) - else: - input_names = self.predictor.get_inputs()[0].name - output_names = self.predictor.get_outputs()[0].name - - if self.benchmark: - self.auto_logger.times.start() - if not isinstance(images, (list, )): - images = [images] - for idx in range(len(images)): - for ops in self.preprocess_ops: - images[idx] = ops(images[idx]) - image = np.array(images) - if self.benchmark: - self.auto_logger.times.stamp() - - if not use_onnx: - input_tensor.copy_from_cpu(image) - self.predictor.run() - batch_output = output_tensor.copy_to_cpu() - else: - batch_output = self.predictor.run( - output_names=[output_names], input_feed={input_names: image})[0] - - if self.benchmark: - self.auto_logger.times.stamp() - if self.postprocess is not None: - batch_output = self.postprocess(batch_output) - if self.benchmark: - self.auto_logger.times.end(stamp=True) - return batch_output - - -def main(config): - predictor = Predictor(config) - image_list = get_image_list(config["Global"]["infer_imgs"]) - image_list = image_list * 1000 - batch_imgs = [] - batch_names = [] - cnt = 0 - for idx, img_path in enumerate(image_list): - img = cv2.imread(img_path) - if img is None: - logger.warning( - "Image file failed to read and has been skipped. The path: {}". - format(img_path)) - else: - img = img[:, :, ::-1] - batch_imgs.append(img) - img_name = os.path.basename(img_path) - batch_names.append(img_name) - cnt += 1 - - if cnt % config["Global"]["batch_size"] == 0 or (idx + 1 - ) == len(image_list): - if len(batch_imgs) == 0: - continue - batch_results = predictor.predict(batch_imgs) - for number, result_dict in enumerate(batch_results): - if "PersonAttribute" in config[ - "PostProcess"] or "VehicleAttribute" in config[ - "PostProcess"]: - filename = batch_names[number] - else: - filename = batch_names[number] - clas_ids = result_dict["class_ids"] - scores_str = "[{}]".format(", ".join("{:.2f}".format( - r) for r in result_dict["scores"])) - label_names = result_dict["label_names"] - batch_imgs = [] - batch_names = [] - if predictor.benchmark: - predictor.auto_logger.report() - return + predictor = create_predictor(config) + + return predictor + + def predict(self): + test_num = 1000 + test_time = 0.0 + for i in range(0, test_num + 10): + inputs = np.random.rand(config['batch_size'], 3, + config['image_size'], + config['image_size']).astype(np.float32) + start_time = time.time() + self.input_tensor.copy_from_cpu(inputs) + self.paddle_predictor.run() + batch_output = self.output_tensor.copy_to_cpu().flatten() + if i >= 10: + test_time += time.time() - start_time + time.sleep(0.01) # sleep for T4 GPU + + fp_message = "FP16" if config['use_fp16'] else "FP32" + trt_msg = "using tensorrt" if config[ + 'use_tensorrt'] else "not using tensorrt" + print("{0}\t{1}\tbatch size: {2}\ttime(ms): {3}".format( + trt_msg, fp_message, config[ + 'batch_size'], 1000 * test_time / test_num)) if __name__ == "__main__": parser = argsparser() args = parser.parse_args() - config = load_config(args.config) - print_arguments(config['Global']) - main(config) + config = load_config(args.config_path) + predictor = Predictor(config) + predictor.predict() diff --git a/example/auto_compression/image_classification/postprocess.py b/example/auto_compression/image_classification/postprocess.py deleted file mode 100644 index 9b9b4afcb68374b867f148a8cfe283d268141820..0000000000000000000000000000000000000000 --- a/example/auto_compression/image_classification/postprocess.py +++ /dev/null @@ -1,131 +0,0 @@ -# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import copy -import shutil -from functools import partial -import importlib -import numpy as np -import paddle -import paddle.nn.functional as F - - -def build_postprocess(config): - if config is None: - return None - - mod = importlib.import_module(__name__) - config = copy.deepcopy(config) - - main_indicator = config.pop( - "main_indicator") if "main_indicator" in config else None - main_indicator = main_indicator if main_indicator else "" - - func_list = [] - for func in config: - func_list.append(getattr(mod, func)(**config[func])) - return PostProcesser(func_list, main_indicator) - - -class PostProcesser(object): - def __init__(self, func_list, main_indicator="Topk"): - self.func_list = func_list - self.main_indicator = main_indicator - - def __call__(self, x, image_file=None): - rtn = None - for func in self.func_list: - tmp = func(x, image_file) - if type(func).__name__ in self.main_indicator: - rtn = tmp - return rtn - - -class Topk(object): - def __init__(self, topk=1, class_id_map_file=None): - assert isinstance(topk, (int, )) - self.class_id_map = self.parse_class_id_map(class_id_map_file) - self.topk = topk - - def parse_class_id_map(self, class_id_map_file): - if class_id_map_file is None: - return None - - if not os.path.exists(class_id_map_file): - print( - "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" - ) - return None - - try: - class_id_map = {} - with open(class_id_map_file, "r") as fin: - lines = fin.readlines() - for line in lines: - partition = line.split("\n")[0].partition(" ") - class_id_map[int(partition[0])] = str(partition[-1]) - except Exception as ex: - print(ex) - class_id_map = None - return class_id_map - - def __call__(self, x, file_names=None, multilabel=False): - if file_names is not None: - assert x.shape[0] == len(file_names) - y = [] - for idx, probs in enumerate(x): - index = probs.argsort(axis=0)[-self.topk:][::-1].astype( - "int32") if not multilabel else np.where( - probs >= 0.5)[0].astype("int32") - clas_id_list = [] - score_list = [] - label_name_list = [] - for i in index: - clas_id_list.append(i.item()) - score_list.append(probs[i].item()) - if self.class_id_map is not None: - label_name_list.append(self.class_id_map[i.item()]) - result = { - "class_ids": clas_id_list, - "scores": np.around( - score_list, decimals=5).tolist(), - } - if file_names is not None: - result["file_name"] = file_names[idx] - if label_name_list is not None: - result["label_names"] = label_name_list - y.append(result) - return y - - -class SavePreLabel(object): - def __init__(self, save_dir): - if save_dir is None: - raise Exception( - "Please specify save_dir if SavePreLabel specified.") - self.save_dir = partial(os.path.join, save_dir) - - def __call__(self, x, file_names=None): - if file_names is None: - return - assert x.shape[0] == len(file_names) - for idx, probs in enumerate(x): - index = probs.argsort(axis=0)[-1].astype("int32") - self.save(index, file_names[idx]) - - def save(self, id, image_file): - output_dir = self.save_dir(str(id)) - os.makedirs(output_dir, exist_ok=True) - shutil.copy(image_file, output_dir) diff --git a/example/auto_compression/image_classification/preprocess.py b/example/auto_compression/image_classification/preprocess.py deleted file mode 100644 index 95561698b3e181d2d45f1c735211e01cfc3425af..0000000000000000000000000000000000000000 --- a/example/auto_compression/image_classification/preprocess.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -from functools import partial -import six -import math -import random -import cv2 -import numpy as np -import importlib -from PIL import Image - - -def create_operators(params): - """ - create operators based on the config - - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance(params, list), ('operator config should be a list') - mod = importlib.import_module(__name__) - ops = [] - for operator in params: - assert isinstance(operator, - dict) and len(operator) == 1, "yaml format error" - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - op = getattr(mod, op_name)(**param) - ops.append(op) - - return ops - - -class UnifiedResize(object): - def __init__(self, interpolation=None, backend="cv2"): - _cv2_interp_from_str = { - 'nearest': cv2.INTER_NEAREST, - 'bilinear': cv2.INTER_LINEAR, - 'area': cv2.INTER_AREA, - 'bicubic': cv2.INTER_CUBIC, - 'lanczos': cv2.INTER_LANCZOS4 - } - _pil_interp_from_str = { - 'nearest': Image.NEAREST, - 'bilinear': Image.BILINEAR, - 'bicubic': Image.BICUBIC, - 'box': Image.BOX, - 'lanczos': Image.LANCZOS, - 'hamming': Image.HAMMING - } - - def _pil_resize(src, size, resample): - pil_img = Image.fromarray(src) - pil_img = pil_img.resize(size, resample) - return np.asarray(pil_img) - - if backend.lower() == "cv2": - if isinstance(interpolation, str): - interpolation = _cv2_interp_from_str[interpolation.lower()] - # compatible with opencv < version 4.4.0 - elif interpolation is None: - interpolation = cv2.INTER_LINEAR - self.resize_func = partial(cv2.resize, interpolation=interpolation) - elif backend.lower() == "pil": - if isinstance(interpolation, str): - interpolation = _pil_interp_from_str[interpolation.lower()] - self.resize_func = partial(_pil_resize, resample=interpolation) - else: - logger.warning( - f"The backend of Resize only support \"cv2\" or \"PIL\". \"f{backend}\" is unavailable. Use \"cv2\" instead." - ) - self.resize_func = cv2.resize - - def __call__(self, src, size): - return self.resize_func(src, size) - - -class OperatorParamError(ValueError): - """ OperatorParamError - """ - pass - - -class ResizeImage(object): - """ resize image """ - - def __init__(self, - size=None, - resize_short=None, - interpolation=None, - backend="cv2"): - if resize_short is not None and resize_short > 0: - self.resize_short = resize_short - self.w = None - self.h = None - elif size is not None: - self.resize_short = None - self.w = size if type(size) is int else size[0] - self.h = size if type(size) is int else size[1] - else: - raise OperatorParamError("invalid params for ReisizeImage for '\ - 'both 'size' and 'resize_short' are None") - - self._resize_func = UnifiedResize( - interpolation=interpolation, backend=backend) - - def __call__(self, img): - img_h, img_w = img.shape[:2] - if self.resize_short is not None: - percent = float(self.resize_short) / min(img_w, img_h) - w = int(round(img_w * percent)) - h = int(round(img_h * percent)) - else: - w = self.w - h = self.h - return self._resize_func(img, (w, h)) - - -class CropImage(object): - """ crop image """ - - def __init__(self, size): - if type(size) is int: - self.size = (size, size) - else: - self.size = size # (h, w) - - def __call__(self, img): - w, h = self.size - img_h, img_w = img.shape[:2] - - if img_h < h or img_w < w: - raise Exception( - f"The size({h}, {w}) of CropImage must be greater than size({img_h}, {img_w}) of image. Please check image original size and size of ResizeImage if used." - ) - - w_start = (img_w - w) // 2 - h_start = (img_h - h) // 2 - - w_end = w_start + w - h_end = h_start + h - return img[h_start:h_end, w_start:w_end, :] - - -class NormalizeImage(object): - """ normalize image such as substract mean, divide std - """ - - def __init__(self, - scale=None, - mean=None, - std=None, - order='chw', - output_fp16=False, - channel_num=3): - if isinstance(scale, str): - scale = eval(scale) - assert channel_num in [ - 3, 4 - ], "channel number of input image should be set to 3 or 4." - self.channel_num = channel_num - self.output_dtype = 'float16' if output_fp16 else 'float32' - self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) - self.order = order - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] - - shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3) - self.mean = np.array(mean).reshape(shape).astype('float32') - self.std = np.array(std).reshape(shape).astype('float32') - - def __call__(self, img): - from PIL import Image - if isinstance(img, Image.Image): - img = np.array(img) - - assert isinstance(img, - np.ndarray), "invalid input 'img' in NormalizeImage" - - img = (img.astype('float32') * self.scale - self.mean) / self.std - - if self.channel_num == 4: - img_h = img.shape[1] if self.order == 'chw' else img.shape[0] - img_w = img.shape[2] if self.order == 'chw' else img.shape[1] - pad_zeros = np.zeros( - (1, img_h, img_w)) if self.order == 'chw' else np.zeros( - (img_h, img_w, 1)) - img = (np.concatenate( - (img, pad_zeros), axis=0) - if self.order == 'chw' else np.concatenate( - (img, pad_zeros), axis=2)) - return img.astype(self.output_dtype) - - -class ToCHWImage(object): - """ convert hwc image to chw image - """ - - def __init__(self): - pass - - def __call__(self, img): - from PIL import Image - if isinstance(img, Image.Image): - img = np.array(img) - - return img.transpose((2, 0, 1)) diff --git a/example/auto_compression/image_classification/run.py b/example/auto_compression/image_classification/run.py index c3ecf1377d9a74fcf3a6b3b9161d50067eef1a7f..448ea099e8bdc6a21174aa3820d72a0de098422a 100644 --- a/example/auto_compression/image_classification/run.py +++ b/example/auto_compression/image_classification/run.py @@ -19,13 +19,13 @@ sys.path[0] = os.path.join( import argparse import functools from functools import partial +import math import numpy as np -import math import paddle import paddle.nn as nn -from paddle.io import Dataset, BatchSampler, DataLoader -import imagenet_reader as reader +from paddle.io import DataLoader +from imagenet_reader import ImageNetDataset from paddleslim.auto_compression.config_helpers import load_config as load_slim_config from paddleslim.auto_compression import AutoCompression @@ -54,35 +54,41 @@ def argsparser(): # yapf: enable def reader_wrapper(reader, input_name): def gen(): - for i, data in enumerate(reader()): - imgs = np.float32([item[0] for item in data]) + for i, (imgs, label) in enumerate(reader()): yield {input_name: imgs} return gen -def eval_reader(data_dir, batch_size): - val_reader = paddle.batch( - reader.val(data_dir=data_dir), batch_size=batch_size) - return val_reader +def eval_reader(data_dir, batch_size, crop_size, resize_size): + val_reader = ImageNetDataset( + mode='val', + data_dir=data_dir, + crop_size=crop_size, + resize_size=resize_size) + val_loader = DataLoader( + val_reader, + batch_size=global_config['batch_size'], + shuffle=False, + drop_last=False, + num_workers=0) + return val_loader def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): - val_reader = eval_reader(data_dir, batch_size=global_config['batch_size']) - image = paddle.static.data( - name=global_config['input_name'], - shape=[None, 3, 224, 224], - dtype='float32') - label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + val_loader = eval_reader( + data_dir, + batch_size=global_config['batch_size'], + crop_size=img_size, + resize_size=resize_size) results = [] - print('Evaluating... It will take a while. Please wait...') - for batch_id, data in enumerate(val_reader()): + print('Evaluating...') + for batch_id, (image, label) in enumerate(val_loader): # top1_acc, top5_acc if len(test_feed_names) == 1: - image = np.array([[d[0]] for d in data]) - image = image.reshape((len(data), 3, 224, 224)) - label = [[d[1]] for d in data] + image = np.array(image) + label = np.array(label).astype('int64') pred = exe.run(compiled_test_program, feed={test_feed_names[0]: image}, fetch_list=test_fetch_list) @@ -100,9 +106,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): results.append([top_1, top_5]) else: # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy - image = np.array([[d[0]] for d in data]) - image = image.reshape((len(data), 3, 224, 224)) - label = [[d[1]] for d in data] + image = np.array(image) + label = np.array(label).astype('int64') result = exe.run( compiled_test_program, feed={test_feed_names[0]: image, @@ -110,6 +115,8 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): fetch_list=test_fetch_list) result = [np.mean(r) for r in result] results.append(result) + if batch_id % 100 == 0: + print('Eval iter: ', batch_id) result = np.mean(np.array(results), axis=0) return result[0] @@ -117,8 +124,10 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list): def main(): global global_config all_config = load_slim_config(args.config_path) + assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}" global_config = all_config["Global"] + gpu_num = paddle.distributed.get_world_size() if isinstance(all_config['TrainConfig']['learning_rate'], dict) and all_config['TrainConfig']['learning_rate'][ @@ -129,12 +138,28 @@ def main(): gpu_num))) all_config['TrainConfig']['learning_rate']['T_max'] = step print('total training steps:', step) + global data_dir data_dir = global_config['data_dir'] - train_reader = paddle.batch( - reader.train(data_dir=data_dir), batch_size=global_config['batch_size']) - train_dataloader = reader_wrapper(train_reader, global_config['input_name']) + global img_size, resize_size + img_size = global_config['img_size'] if 'img_size' in global_config else 224 + resize_size = global_config[ + 'resize_size'] if 'resize_size' in global_config else 256 + + train_dataset = ImageNetDataset( + mode='train', + data_dir=data_dir, + crop_size=img_size, + resize_size=resize_size) + + train_loader = DataLoader( + train_dataset, + batch_size=global_config['batch_size'], + shuffle=True, + drop_last=True, + num_workers=0) + train_dataloader = reader_wrapper(train_loader, global_config['input_name']) ac = AutoCompression( model_dir=global_config['model_dir'], @@ -145,7 +170,11 @@ def main(): train_dataloader=train_dataloader, eval_callback=eval_function, eval_dataloader=reader_wrapper( - eval_reader(data_dir, global_config['batch_size']), + eval_reader( + data_dir, + global_config['batch_size'], + crop_size=img_size, + resize_size=resize_size), global_config['input_name'])) ac.compress() diff --git a/example/auto_compression/image_classification/run.sh b/example/auto_compression/image_classification/run.sh deleted file mode 100644 index 4d8777d88694bf99a315059d8b4e0d5534a9cebe..0000000000000000000000000000000000000000 --- a/example/auto_compression/image_classification/run.sh +++ /dev/null @@ -1,8 +0,0 @@ -# 单卡启动 -export CUDA_VISIBLE_DEVICES=0 -python3.7 eval.py --save_dir='./save_quant_mobilev1/' --config_path='./configs/MobileNetV1/qat_dis.yaml' - -# 多卡启动 -export CUDA_VISIBLE_DEVICES=0,1,2,3 -python -m paddle.distributed.launch run.py --save_dir='./save_quant_mobilev1/' --config_path='./configs/MobileNetV1/qat_dis.yaml' - diff --git a/example/auto_compression/image_classification/utils.py b/example/auto_compression/image_classification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a94e67bfe37bb5961acc1bfd6460d3c3fe4022 --- /dev/null +++ b/example/auto_compression/image_classification/utils.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import base64 +import shutil +import cv2 +import numpy as np + + +def preprocess(img, args): + resize_op = ResizeImage(resize_short=args.resize_short) + img = resize_op(img) + crop_op = CropImage(size=(args.resize, args.resize)) + img = crop_op(img) + if args.normalize: + img_mean = [0.485, 0.456, 0.406] + img_std = [0.229, 0.224, 0.225] + img_scale = 1.0 / 255.0 + normalize_op = NormalizeImage( + scale=img_scale, mean=img_mean, std=img_std) + img = normalize_op(img) + tensor_op = ToTensor() + img = tensor_op(img) + return img + + +def postprocess(batch_outputs, topk=5, multilabel=False): + batch_results = [] + for probs in batch_outputs: + results = [] + if multilabel: + index = np.where(probs >= 0.5)[0].astype('int32') + else: + index = probs.argsort(axis=0)[-topk:][::-1].astype("int32") + clas_id_list = [] + score_list = [] + for i in index: + clas_id_list.append(i.item()) + score_list.append(probs[i].item()) + batch_results.append({"clas_ids": clas_id_list, "scores": score_list}) + return batch_results + + +class ResizeImage(object): + def __init__(self, resize_short=None): + self.resize_short = resize_short + + def __call__(self, img): + img_h, img_w = img.shape[:2] + percent = float(self.resize_short) / min(img_w, img_h) + w = int(round(img_w * percent)) + h = int(round(img_h * percent)) + return cv2.resize(img, (w, h)) + + +class CropImage(object): + def __init__(self, size): + if type(size) is int: + self.size = (size, size) + else: + self.size = size + + def __call__(self, img): + w, h = self.size + img_h, img_w = img.shape[:2] + w_start = (img_w - w) // 2 + h_start = (img_h - h) // 2 + + w_end = w_start + w + h_end = h_start + h + return img[h_start:h_end, w_start:w_end, :] + + +class NormalizeImage(object): + def __init__(self, scale=None, mean=None, std=None): + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + return (img.astype('float32') * self.scale - self.mean) / self.std + + +class ToTensor(object): + def __init__(self): + pass + + def __call__(self, img): + img = img.transpose((2, 0, 1)) + return img