run.py 6.1 KB
Newer Older
C
ceci3 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
ceci3 已提交
15 16 17 18 19
import os
import sys
import argparse
import functools
from functools import partial
C
Chang Xu 已提交
20
import math
C
ceci3 已提交
21 22 23 24

import numpy as np
import paddle
import paddle.nn as nn
C
Chang Xu 已提交
25 26
from paddle.io import DataLoader
from imagenet_reader import ImageNetDataset
27
from paddleslim.common import load_config as load_slim_config
C
ceci3 已提交
28 29
from paddleslim.auto_compression import AutoCompression

30

31 32 33 34 35 36 37 38 39 40 41 42 43
def argsparser():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--config_path',
        type=str,
        default=None,
        help="path of compression strategy config.",
        required=True)
    parser.add_argument(
        '--save_dir',
        type=str,
        default='output',
        help="directory to save compressed model.")
44 45 46 47 48
    parser.add_argument(
        '--total_images',
        type=int,
        default=1281167,
        help="the number of total training images.")
49 50 51
    return parser


C
Chang Xu 已提交
52
# yapf: enable
53
def reader_wrapper(reader, input_name):
C
ceci3 已提交
54
    def gen():
C
Chang Xu 已提交
55
        for i, (imgs, label) in enumerate(reader()):
C
Chang Xu 已提交
56
            yield {input_name: imgs}
C
ceci3 已提交
57 58 59

    return gen

C
ceci3 已提交
60

W
whs 已提交
61
def eval_reader(data_dir, batch_size, crop_size, resize_size, place=None):
C
Chang Xu 已提交
62 63 64 65 66 67 68
    val_reader = ImageNetDataset(
        mode='val',
        data_dir=data_dir,
        crop_size=crop_size,
        resize_size=resize_size)
    val_loader = DataLoader(
        val_reader,
W
whs 已提交
69
        places=[place] if place is not None else None,
C
Chang Xu 已提交
70 71 72 73 74
        batch_size=global_config['batch_size'],
        shuffle=False,
        drop_last=False,
        num_workers=0)
    return val_loader
C
ceci3 已提交
75

C
ceci3 已提交
76

C
ceci3 已提交
77
def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
C
Chang Xu 已提交
78 79 80 81 82
    val_loader = eval_reader(
        data_dir,
        batch_size=global_config['batch_size'],
        crop_size=img_size,
        resize_size=resize_size)
C
ceci3 已提交
83 84

    results = []
C
Chang Xu 已提交
85 86
    print('Evaluating...')
    for batch_id, (image, label) in enumerate(val_loader):
C
ceci3 已提交
87 88
        # top1_acc, top5_acc
        if len(test_feed_names) == 1:
C
Chang Xu 已提交
89 90
            image = np.array(image)
            label = np.array(label).astype('int64')
C
ceci3 已提交
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            pred = exe.run(compiled_test_program,
                           feed={test_feed_names[0]: image},
                           fetch_list=test_fetch_list)
            pred = np.array(pred[0])
            label = np.array(label)
            sort_array = pred.argsort(axis=1)
            top_1_pred = sort_array[:, -1:][:, ::-1]
            top_1 = np.mean(label == top_1_pred)
            top_5_pred = sort_array[:, -5:][:, ::-1]
            acc_num = 0
            for i in range(len(label)):
                if label[i][0] in top_5_pred[i]:
                    acc_num += 1
            top_5 = float(acc_num) / len(label)
            results.append([top_1, top_5])
        else:
            # eval "eval model", which inputs are image and label, output is top1 and top5 accuracy
C
Chang Xu 已提交
108 109
            image = np.array(image)
            label = np.array(label).astype('int64')
C
ceci3 已提交
110 111 112 113 114 115 116
            result = exe.run(
                compiled_test_program,
                feed={test_feed_names[0]: image,
                      test_feed_names[1]: label},
                fetch_list=test_fetch_list)
            result = [np.mean(r) for r in result]
            results.append(result)
C
Chang Xu 已提交
117 118
        if batch_id % 100 == 0:
            print('Eval iter: ', batch_id)
C
ceci3 已提交
119 120 121 122
    result = np.mean(np.array(results), axis=0)
    return result[0]


123
def main():
124 125
    rank_id = paddle.distributed.get_rank()
    place = paddle.CUDAPlace(rank_id)
126 127
    global global_config
    all_config = load_slim_config(args.config_path)
C
Chang Xu 已提交
128

129 130
    assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
    global_config = all_config["Global"]
C
Chang Xu 已提交
131

132
    gpu_num = paddle.distributed.get_world_size()
C
Chang Xu 已提交
133 134 135
    if isinstance(all_config['TrainConfig']['learning_rate'],
                  dict) and all_config['TrainConfig']['learning_rate'][
                      'type'] == 'CosineAnnealingDecay':
136 137 138 139 140 141
        step = int(
            math.ceil(
                float(args.total_images) / (global_config['batch_size'] *
                                            gpu_num)))
        all_config['TrainConfig']['learning_rate']['T_max'] = step
        print('total training steps:', step)
C
Chang Xu 已提交
142

143 144
    global data_dir
    data_dir = global_config['data_dir']
C
ceci3 已提交
145

C
Chang Xu 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158
    global img_size, resize_size
    img_size = global_config['img_size'] if 'img_size' in global_config else 224
    resize_size = global_config[
        'resize_size'] if 'resize_size' in global_config else 256

    train_dataset = ImageNetDataset(
        mode='train',
        data_dir=data_dir,
        crop_size=img_size,
        resize_size=resize_size)

    train_loader = DataLoader(
        train_dataset,
159
        places=[place],
C
Chang Xu 已提交
160 161 162 163 164
        batch_size=global_config['batch_size'],
        shuffle=True,
        drop_last=True,
        num_workers=0)
    train_dataloader = reader_wrapper(train_loader, global_config['input_name'])
C
ceci3 已提交
165 166

    ac = AutoCompression(
167 168 169
        model_dir=global_config['model_dir'],
        model_filename=global_config['model_filename'],
        params_filename=global_config['params_filename'],
C
ceci3 已提交
170
        save_dir=args.save_dir,
171
        config=all_config,
C
ceci3 已提交
172
        train_dataloader=train_dataloader,
W
whs 已提交
173
        eval_callback=eval_function if rank_id == 0 else None,
174
        eval_dataloader=reader_wrapper(
C
Chang Xu 已提交
175 176 177 178
            eval_reader(
                data_dir,
                global_config['batch_size'],
                crop_size=img_size,
W
whs 已提交
179 180
                resize_size=resize_size,
                place=place),
181
            global_config['input_name']))
C
ceci3 已提交
182 183

    ac.compress()
184 185


186 187 188 189 190
if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    args = parser.parse_args()
    main()