run.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

C
ceci3 已提交
15 16
import os
import sys
17
import numpy as np
18
import argparse
19
import paddle
C
ceci3 已提交
20 21
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
G
Guanghua Yu 已提交
22
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
23 24
from paddleslim.auto_compression.config_helpers import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression
G
Guanghua Yu 已提交
25
from keypoint_utils import keypoint_post_process
G
Guanghua Yu 已提交
26
from post_process import PPYOLOEPostProcess
27 28


29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
def argsparser():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--config_path',
        type=str,
        default=None,
        help="path of compression strategy config.",
        required=True)
    parser.add_argument(
        '--save_dir',
        type=str,
        default='output',
        help="directory to save compressed model.")
    parser.add_argument(
        '--devices',
        type=str,
        default='gpu',
        help="which device used to compress.")

    return parser


51
def reader_wrapper(reader, input_list):
52 53
    def gen():
        for data in reader:
54
            in_dict = {}
G
Guanghua Yu 已提交
55 56 57 58 59 60
            if isinstance(input_list, list):
                for input_name in input_list:
                    in_dict[input_name] = data[input_name]
            elif isinstance(input_list, dict):
                for input_name in input_list.keys():
                    in_dict[input_list[input_name]] = data[input_name]
61
            yield in_dict
62 63 64 65

    return gen


G
Guanghua Yu 已提交
66 67 68 69 70 71 72 73 74 75 76
def convert_numpy_data(data, metric):
    data_all = {}
    data_all = {k: np.array(v) for k, v in data.items()}
    if isinstance(metric, VOCMetric):
        for k, v in data_all.items():
            if not isinstance(v[0], np.ndarray):
                tmp_list = []
                for t in v:
                    tmp_list.append(np.array(t))
                data_all[k] = np.array(tmp_list)
    else:
77
        data_all = {k: np.array(v) for k, v in data.items()}
G
Guanghua Yu 已提交
78
    return data_all
79 80 81


def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
G
Guanghua Yu 已提交
82
    metric = global_config['metric']
83
    for batch_id, data in enumerate(val_loader):
G
Guanghua Yu 已提交
84
        data_all = convert_numpy_data(data, metric)
85 86
        data_input = {}
        for k, v in data.items():
G
Guanghua Yu 已提交
87 88 89 90 91 92
            if isinstance(global_config['input_list'], list):
                if k in test_feed_names:
                    data_input[k] = np.array(v)
            elif isinstance(global_config['input_list'], dict):
                if k in global_config['input_list'].keys():
                    data_input[global_config['input_list'][k]] = np.array(v)
93
        outs = exe.run(compiled_test_program,
94
                       feed=data_input,
95 96 97
                       fetch_list=test_fetch_list,
                       return_numpy=False)
        res = {}
G
Guanghua Yu 已提交
98 99 100 101
        if 'arch' in global_config and global_config['arch'] == 'keypoint':
            res = keypoint_post_process(data, data_input, exe,
                                        compiled_test_program, test_fetch_list,
                                        outs)
G
Guanghua Yu 已提交
102 103 104 105
        if 'arch' in global_config and global_config['arch'] == 'PPYOLOE':
            postprocess = PPYOLOEPostProcess(
                score_threshold=0.01, nms_threshold=0.6)
            res = postprocess(np.array(outs[0]), data_all['scale_factor'])
G
Guanghua Yu 已提交
106 107 108 109 110 111 112
        else:
            for out in outs:
                v = np.array(out)
                if len(v.shape) > 1:
                    res['bbox'] = v
                else:
                    res['bbox_num'] = v
113

114
        metric.update(data_all, res)
115 116 117 118 119 120
        if batch_id % 100 == 0:
            print('Eval iter:', batch_id)
    metric.accumulate()
    metric.log()
    map_res = metric.get_results()
    metric.reset()
G
Guanghua Yu 已提交
121 122 123
    map_key = 'keypoint' if 'arch' in global_config and global_config[
        'arch'] == 'keypoint' else 'bbox'
    return map_res[map_key][0]
124 125


126
def main():
G
Guanghua Yu 已提交
127
    global global_config
W
whs 已提交
128 129 130
    all_config = load_slim_config(FLAGS.config_path)
    assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
    global_config = all_config["Global"]
C
ceci3 已提交
131
    reader_cfg = load_config(global_config['reader_config'])
C
ceci3 已提交
132

133
    train_loader = create('EvalReader')(reader_cfg['TrainDataset'],
134 135
                                        reader_cfg['worker_num'],
                                        return_list=True)
C
ceci3 已提交
136
    train_loader = reader_wrapper(train_loader, global_config['input_list'])
137

138 139
    if 'Evaluation' in global_config.keys() and global_config[
            'Evaluation'] and paddle.distributed.get_rank() == 0:
140
        eval_func = eval_function
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
        dataset = reader_cfg['EvalDataset']
        global val_loader
        _eval_batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
        val_loader = create('EvalReader')(dataset,
                                          reader_cfg['worker_num'],
                                          batch_sampler=_eval_batch_sampler,
                                          return_list=True)
        metric = None
        if reader_cfg['metric'] == 'COCO':
            clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
            anno_file = dataset.get_anno()
            metric = COCOMetric(
                anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
        elif reader_cfg['metric'] == 'VOC':
            metric = VOCMetric(
                label_list=dataset.get_label_list(),
                class_num=reader_cfg['num_classes'],
                map_type=reader_cfg['map_type'])
G
Guanghua Yu 已提交
160 161 162 163
        elif reader_cfg['metric'] == 'KeyPointTopDownCOCOEval':
            anno_file = dataset.get_anno()
            metric = KeyPointTopDownCOCOEval(anno_file,
                                             len(dataset), 17, 'output_eval')
164 165 166
        else:
            raise ValueError("metric currently only supports COCO and VOC.")
        global_config['metric'] = metric
167 168
    else:
        eval_func = None
C
ceci3 已提交
169

170
    ac = AutoCompression(
C
ceci3 已提交
171 172 173
        model_dir=global_config["model_dir"],
        model_filename=global_config["model_filename"],
        params_filename=global_config["params_filename"],
174
        save_dir=FLAGS.save_dir,
W
whs 已提交
175
        config=all_config,
176
        train_dataloader=train_loader,
C
ceci3 已提交
177
        eval_callback=eval_func)
178
    ac.compress()
179 180 181 182


if __name__ == '__main__':
    paddle.enable_static()
183 184 185 186 187 188
    parser = argsparser()
    FLAGS = parser.parse_args()
    assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
    paddle.set_device(FLAGS.devices)

    main()