# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import numpy as np
import argparse
import paddle
from tqdm import tqdm
from ppdet.core.workspace import load_config, merge_config
from ppdet.core.workspace import create
from ppdet.metrics import COCOMetric, VOCMetric, KeyPointTopDownCOCOEval
from paddleslim.common import load_config as load_slim_config
from paddleslim.auto_compression import AutoCompression

from post_process import PicoDetPostProcess


def argsparser():
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument(
        '--config_path',
        type=str,
        default=None,
        help="path of compression strategy config.",
        required=True)
    parser.add_argument(
        '--save_dir',
        type=str,
        default='output',
        help="directory to save compressed model.")
    parser.add_argument(
        '--devices',
        type=str,
        default='gpu',
        help="which device used to compress.")

    return parser


def reader_wrapper(reader, input_list):
    def gen():
        for data in reader:
            in_dict = {}
            if isinstance(input_list, list):
                for input_name in input_list:
                    in_dict[input_name] = data[input_name]
            elif isinstance(input_list, dict):
                for input_name in input_list.keys():
                    in_dict[input_list[input_name]] = data[input_name]
            yield in_dict

    return gen


def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
    metric = global_config['metric']
    with tqdm(
            total=len(val_loader),
            bar_format='Evaluation stage, Run batch:|{bar}| {n_fmt}/{total_fmt}',
            ncols=80) as t:
        for data in val_loader:
            data_all = {k: np.array(v) for k, v in data.items()}
            batch_size = data_all['image'].shape[0]
            data_input = {}
            for k, v in data.items():
                if k in test_feed_names:
                    data_input[k] = np.array(v)

            outs = exe.run(compiled_test_program,
                           feed=data_input,
                           fetch_list=test_fetch_list,
                           return_numpy=False)
            if not global_config['include_post_process']:
                np_score_list, np_boxes_list = [], []
                for i, out in enumerate(outs):
                    if i < 4:
                        np_score_list.append(
                            np.array(out).reshape(batch_size, -1, num_classes))
                    else:
                        np_boxes_list.append(
                            np.array(out).reshape(batch_size, -1, 32))
                post_processor = PicoDetPostProcess(
                    data_all['image'].shape[2:],
                    data_all['im_shape'],
                    data_all['scale_factor'],
                    score_threshold=0.01,
                    nms_threshold=0.6)
                res = post_processor(np_score_list, np_boxes_list)
            else:
                res = {}
                for out in outs:
                    v = np.array(out)
                    if len(v.shape) > 1:
                        res['bbox'] = v
                    else:
                        res['bbox_num'] = v

            metric.update(data_all, res)
            t.update()
    metric.accumulate()
    metric.log()
    map_res = metric.get_results()
    metric.reset()
    return map_res['bbox'][0]


def main():
    global global_config
    all_config = load_slim_config(FLAGS.config_path)
    assert "Global" in all_config, f"Key 'Global' not found in config file. \n{all_config}"
    global_config = all_config["Global"]
    reader_cfg = load_config(global_config['reader_config'])

    train_loader = create('TrainReader')(reader_cfg['TrainDataset'],
                                         reader_cfg['worker_num'],
                                         return_list=True)
    train_loader = reader_wrapper(train_loader, global_config['input_list'])

    if 'Evaluation' in global_config.keys() and global_config[
            'Evaluation'] and paddle.distributed.get_rank() == 0:
        eval_func = eval_function
        dataset = reader_cfg['EvalDataset']
        global val_loader
        _eval_batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=reader_cfg['EvalReader']['batch_size'])
        val_loader = create('EvalReader')(dataset,
                                          reader_cfg['worker_num'],
                                          batch_sampler=_eval_batch_sampler,
                                          return_list=True)
        global num_classes
        num_classes = reader_cfg['num_classes']
        clsid2catid = {v: k for k, v in dataset.catid2clsid.items()}
        anno_file = dataset.get_anno()
        metric = COCOMetric(
            anno_file=anno_file, clsid2catid=clsid2catid, IouType='bbox')
        global_config['metric'] = metric
    else:
        eval_func = None

    ac = AutoCompression(
        model_dir=global_config["model_dir"],
        model_filename=global_config["model_filename"],
        params_filename=global_config["params_filename"],
        save_dir=FLAGS.save_dir,
        config=all_config,
        train_dataloader=train_loader,
        eval_callback=eval_func)
    ac.compress()


if __name__ == '__main__':
    paddle.enable_static()
    parser = argsparser()
    FLAGS = parser.parse_args()
    assert FLAGS.devices in ['cpu', 'gpu', 'xpu', 'npu']
    paddle.set_device(FLAGS.devices)

    main()
