run.py 6.1 KB
Newer Older
Z
zhouzj 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
import os
import argparse
import random
import paddle
import numpy as np
import paddleseg.transforms as T
from paddleseg.datasets import Dataset
from paddleseg.utils import worker_init_fn
from paddleslim.auto_compression.config_helpers import load_config
from paddleslim.auto_compression import AutoCompression

from paddleseg.core.infer import reverse_transform
from paddleseg.utils import metrics
import paddle.nn.functional as F
import cv2
import paddle.fluid as fluid


def parse_args():
    parser = argparse.ArgumentParser(description='Model training')
    parser.add_argument(
        '--model_dir',
        type=str,
        default=None,
        help="inference model directory.")
    parser.add_argument(
        '--model_filename',
        type=str,
        default=None,
        help="inference model filename.")
    parser.add_argument(
        '--params_filename',
        type=str,
        default=None,
        help="inference params filename.")
    parser.add_argument(
        '--save_dir',
        type=str,
        default=None,
        help="directory to save compressed model.")
    parser.add_argument(
        '--config_path',
        type=str,
        default=None,
        help="path of compression strategy config.")
    return parser.parse_args()


def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):

    nranks = paddle.distributed.ParallelEnv().local_rank

    batch_sampler = paddle.io.DistributedBatchSampler(
        eval_dataset, batch_size=1, shuffle=False, drop_last=False)
    loader = paddle.io.DataLoader(
        eval_dataset,
        batch_sampler=batch_sampler,
        num_workers=1,
        return_list=True, )

    total_iters = len(loader)
    intersect_area_all = 0
    pred_area_all = 0
    label_area_all = 0

    print("Start evaluating (total_samples: {}, total_iters: {})...".format(
        len(eval_dataset), total_iters))
    print("nranks:", nranks)

    for iter, (image, label) in enumerate(loader):
        paddle.enable_static()

        label = np.array(label).astype('int64')
        ori_shape = np.array(label).shape[-2:]

        image = np.array(image)
        logits = exe.run(compiled_test_program,
                         feed={test_feed_names[0]: image},
                         fetch_list=test_fetch_list,
                         return_numpy=True)

        paddle.disable_static()
        logit = logits[0]

        logit = reverse_transform(
            paddle.to_tensor(logit),
            ori_shape,
            eval_dataset.transforms.transforms,
            mode='bilinear')

        pred = paddle.argmax(
            paddle.to_tensor(logit), axis=1, keepdim=True, dtype='int32')

        intersect_area, pred_area, label_area = metrics.calculate_area(
            pred,
            paddle.to_tensor(label),
            eval_dataset.num_classes,
            ignore_index=eval_dataset.ignore_index)

        if nranks > 1:
            intersect_area_list = []
            pred_area_list = []
            label_area_list = []
            paddle.distributed.all_gather(intersect_area_list, intersect_area)
            paddle.distributed.all_gather(pred_area_list, pred_area)
            paddle.distributed.all_gather(label_area_list, label_area)

            # Some image has been evaluated and should be eliminated in last iter
            if (iter + 1) * nranks > len(eval_dataset):
                valid = len(eval_dataset) - iter * nranks
                intersect_area_list = intersect_area_list[:valid]
                pred_area_list = pred_area_list[:valid]
                label_area_list = label_area_list[:valid]

            for i in range(len(intersect_area_list)):
                intersect_area_all = intersect_area_all + intersect_area_list[i]
                pred_area_all = pred_area_all + pred_area_list[i]
                label_area_all = label_area_all + label_area_list[i]
        else:
            intersect_area_all = intersect_area_all + intersect_area
            pred_area_all = pred_area_all + pred_area
            label_area_all = label_area_all + label_area

    class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
                                       label_area_all)
    class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all)
    kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all)
    class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all,
                                     label_area_all)

    infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
        len(eval_dataset), miou, acc, kappa, mdice)
    print(infor)

    paddle.enable_static()
    return miou


def reader_wrapper(reader):
    def gen():
        for i, data in enumerate(reader()):
            imgs = np.array(data[0])
            yield {"x": imgs}

    return gen


if __name__ == '__main__':

    args = parse_args()

    transforms = [T.RandomPaddingCrop(crop_size=(512, 512)), T.Normalize()]
    train_dataset = Dataset(
        transforms=transforms,
        dataset_root='dataset_root',  # Need to fill in
        num_classes=2,
        train_path='train_path',  # Need to fill in
        mode='train')
    eval_dataset = Dataset(
        transforms=transforms,
        dataset_root='dataset_root',  # Need to fill in
        num_classes=2,
        train_path='val_path',  # Need to fill in
        mode='val')

    batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=128, shuffle=True, drop_last=True)

    train_loader = paddle.io.DataLoader(
        train_dataset,
        batch_sampler=batch_sampler,
        num_workers=2,
        return_list=True,
        worker_init_fn=worker_init_fn, )
    train_dataloader = reader_wrapper(train_loader)

    # set auto_compression
    compress_config, train_config = load_config(args.config_path)

    ac = AutoCompression(
        model_dir=args.model_dir,
        model_filename=args.model_filename,
        params_filename=args.param_filename,
        save_dir=args.save_dir,
        strategy_config=compress_config,
        train_config=train_config,
        train_dataloader=train_dataloader,
        eval_callback=eval_function)

    ac.compress()