baseseg.py 9.5 KB
Newer Older
R
Rosun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239
from __future__ import print_function
import sys
import os
import math
import random
import functools
import io
import time
import codecs
import numpy as np
import paddle
import paddle.fluid as fluid
import cv2
import copy
from PIL import Image, ImageOps, ImageFilter, ImageEnhance

from src.models.model_builder import ModelPhase
from src.utils.config import cfg
from .data_utils import GeneratorEnqueuer


class BaseSeg(object):
    def __init__(self, file_list, data_dir, shuffle=False, mode=ModelPhase.TRAIN, base_size=1024, crop_size=769, rand_scale=True):
        self.mode = mode
        self.shuffle = shuffle
        self.data_dir = data_dir
        self.shuffle_seed = 0

        self.crop_size = crop_size  
        self.base_size = base_size  # short edge when training
        self.rand_scale = rand_scale

        # NOTE: Please ensure file list was save in UTF-8 coding format
        with codecs.open(file_list, 'r', 'utf-8') as flist:
            self.lines = [line.strip() for line in flist]
            self.all_lines = copy.deepcopy(self.lines)
            if shuffle and cfg.NUM_TRAINERS > 1:
                np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            elif shuffle:
                np.random.shuffle(self.lines)
        self.num_trainers= cfg.NUM_TRAINERS
        self.trainer_id=cfg.TRAINER_ID

    def generator(self):
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
            self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        for line in self.lines:
            yield self.process_image(line, self.data_dir, self.mode)

    def sharding_generator(self, pid=0, num_processes=1):
        """
        Use line id as shard key for multiprocess io
        It's a normal generator if pid=0, num_processes=1
        """
        for index, line in enumerate(self.lines):
            # Use index and pid to shard file list
                if index % num_processes == pid:
                    yield self.process_image(line, self.data_dir, self.mode)

    def batch_reader(self, batch_size):
        br = self.batch(self.reader, batch_size)
        for batch in br:
            yield batch[0], batch[1], batch[2]

    def multiprocess_generator(self, max_queue_size=32, num_processes=8):
        # Re-shuffle file list
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // self.num_trainers
            self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        # Create multiple sharding generators according to num_processes for multiple processes
        generators = []
        for pid in range(num_processes):
            generators.append(self.sharding_generator(pid, num_processes))

        try:
            enqueuer = GeneratorEnqueuer(generators)
            enqueuer.start(max_queue_size=max_queue_size, workers=num_processes)
            while True:
                generator_out = None
                while enqueuer.is_running():
                    if not enqueuer.queue.empty():
                        generator_out = enqueuer.queue.get(timeout=5)
                        break
                    else:
                        time.sleep(0.01)
                if generator_out is None:
                    break
                yield generator_out
        finally:
            if enqueuer is not None:
                enqueuer.stop()

    def batch(self, reader, batch_size, is_test=False, drop_last=False):
        def batch_reader(is_test=False, drop_last=drop_last):
            if is_test:
                imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
                for img, grt, img_name, valid_shape, org_shape in reader():
                    imgs.append(img)
                    grts.append(grt)
                    img_names.append(img_name)
                    valid_shapes.append(valid_shape)
                    org_shapes.append(org_shape)
                    if len(imgs) == batch_size:
                        yield np.array(imgs), np.array(
                            grts), img_names, np.array(valid_shapes), np.array(
                                org_shapes)
                        imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []

                if not drop_last and len(imgs) > 0:
                    yield np.array(imgs), np.array(grts), img_names, np.array(
                        valid_shapes), np.array(org_shapes)
            else:
                imgs, labs, ignore = [], [], []
                bs = 0
                for img, lab, ig in reader():
                    imgs.append(img)
                    labs.append(lab)
                    ignore.append(ig)
                    bs += 1
                    if bs == batch_size:
                        yield np.array(imgs), np.array(labs), np.array(ignore)
                        bs = 0
                        imgs, labs, ignore = [], [], []

                if not drop_last and bs > 0:
                    yield np.array(imgs), np.array(labs), np.array(ignore)

        return batch_reader(is_test, drop_last)

    def load_image(self, line, src_dir, mode=ModelPhase.TRAIN):
        raise NotImplemented

    def pil_imread(self, file_path):
        """read pseudo-color label"""
        im = Image.open(file_path)
        return np.asarray(im)

    def cv2_imread(self, file_path, flag=cv2.IMREAD_COLOR):
        # resolve cv2.imread open Chinese file path issues on Windows Platform.
        return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)

    def normalize_image(self, img):
        img = img.transpose((2, 0, 1)).astype('float32') / 255.0
        img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1))
        img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1))
        img -= img_mean
        img /= img_std

        return img

    def process_image(self, line, data_dir, mode):
        """ process_image """
        img, grt, img_name, grt_name = self.load_image( line, data_dir, mode=mode)  # img.type: numpy.array, grt.type: numpy.array
        if mode == ModelPhase.TRAIN:
            # numpy.array convert to  PIL.Image 
            img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            grt = Image.fromarray(grt.astype('uint8')).convert('L')
            
            crop_size = self.crop_size
            # random scale 
            if self.rand_scale:
                short_size = random.randint(int(self.base_size * cfg.DATAAUG.RAND_SCALE_MIN), int(self.base_size * cfg.DATAAUG.RAND_SCALE_MAX))
            else:
                short_size = self.base_size
            w, h = img.size
            if h > w:
                out_w = short_size
                out_h = int(1.0 * h / w * out_w)
            else:
                out_h = short_size
                out_w = int(1.0 * w / h * out_h)
            img = img.resize((out_w, out_h), Image.BILINEAR)
            grt = grt.resize((out_w, out_h), Image.NEAREST)

            # rand flip
            if random.random() > 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
                grt = grt.transpose(Image.FLIP_LEFT_RIGHT)

            # padding
            if short_size < crop_size:
                pad_h = crop_size - out_h if out_h < crop_size else 0
                pad_w = crop_size - out_w if out_w < crop_size else 0
                img = ImageOps.expand(img, border=(pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=0)
                grt = ImageOps.expand(grt, border=(pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=cfg.DATASET.IGNORE_INDEX)

            # random crop
            w, h = img.size
            x = random.randint(0, w - crop_size)
            y = random.randint(0, h - crop_size)
            img = img.crop((x, y, x + crop_size, y + crop_size))
            grt = grt.crop((x, y, x + crop_size, y + crop_size))


            # gaussian blur
            if cfg.DATAAUG_EXTRA:
                if random.random() > 0.7:
                    img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))

            # PIL.Image -> cv2
            img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
            grt = np.array(grt)
            
        elif ModelPhase.is_eval(mode):
            org_shape = [img.shape[0], img.shape[1]]  # 1024 x 2048 for cityscapes

        elif ModelPhase.is_visual(mode):
            org_shape = [img.shape[0], img.shape[1]]
            #img, grt = resize(img, grt, mode=mode)
            valid_shape = [img.shape[0], img.shape[1]]
            #img, grt = rand_crop(img, grt, mode=mode)
        else:
            raise ValueError("Dataset mode={} Error!".format(mode))

        # Normalize image
        img = self.normalize_image(img)

        if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
            grt = np.expand_dims(np.array(grt).astype('int32'), axis=0)
            ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32')


        if ModelPhase.is_train(mode):
            return (img, grt, ignore)
        elif ModelPhase.is_eval(mode):
            return (img, grt, ignore)
        elif ModelPhase.is_visual(mode):
            return (img, grt, img_name, valid_shape, org_shape)