baseseg.py 9.5 KB
Newer Older
R
Rosun 已提交

from __future__ import print_function
import sys
import os
import math
import random
import functools
import io
import time
import codecs
import numpy as np
import paddle
import paddle.fluid as fluid
import cv2
import copy
from PIL import Image, ImageOps, ImageFilter, ImageEnhance

from src.models.model_builder import ModelPhase
from src.utils.config import cfg
from .data_utils import GeneratorEnqueuer


class BaseSeg(object):
    def __init__(self, file_list, data_dir, shuffle=False, mode=ModelPhase.TRAIN, base_size=1024, crop_size=769, rand_scale=True):
        self.mode = mode
        self.shuffle = shuffle
        self.data_dir = data_dir
        self.shuffle_seed = 0

        self.crop_size = crop_size  
        self.base_size = base_size  # short edge when training
        self.rand_scale = rand_scale

        # NOTE: Please ensure file list was save in UTF-8 coding format
        with codecs.open(file_list, 'r', 'utf-8') as flist:
            self.lines = [line.strip() for line in flist]
            self.all_lines = copy.deepcopy(self.lines)
            if shuffle and cfg.NUM_TRAINERS > 1:
                np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            elif shuffle:
                np.random.shuffle(self.lines)
        self.num_trainers= cfg.NUM_TRAINERS
        self.trainer_id=cfg.TRAINER_ID

    def generator(self):
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
            self.lines = self.all_lines[num_lines * cfg.TRAINER_ID: num_lines * (cfg.TRAINER_ID + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        for line in self.lines:
            yield self.process_image(line, self.data_dir, self.mode)

    def sharding_generator(self, pid=0, num_processes=1):
        """
        Use line id as shard key for multiprocess io
        It's a normal generator if pid=0, num_processes=1
        """
        for index, line in enumerate(self.lines):
            # Use index and pid to shard file list
                if index % num_processes == pid:
                    yield self.process_image(line, self.data_dir, self.mode)

    def batch_reader(self, batch_size):
        br = self.batch(self.reader, batch_size)
        for batch in br:
            yield batch[0], batch[1], batch[2]

    def multiprocess_generator(self, max_queue_size=32, num_processes=8):
        # Re-shuffle file list
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // self.num_trainers
            self.lines = self.all_lines[num_lines * self.trainer_id: num_lines * (self.trainer_id + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        # Create multiple sharding generators according to num_processes for multiple processes
        generators = []
        for pid in range(num_processes):
            generators.append(self.sharding_generator(pid, num_processes))

        try:
            enqueuer = GeneratorEnqueuer(generators)
            enqueuer.start(max_queue_size=max_queue_size, workers=num_processes)
            while True:
                generator_out = None
                while enqueuer.is_running():
                    if not enqueuer.queue.empty():
                        generator_out = enqueuer.queue.get(timeout=5)
                        break
                    else:
                        time.sleep(0.01)
                if generator_out is None:
                    break
                yield generator_out
        finally:
            if enqueuer is not None:
                enqueuer.stop()

    def batch(self, reader, batch_size, is_test=False, drop_last=False):
        def batch_reader(is_test=False, drop_last=drop_last):
            if is_test:
                imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []
                for img, grt, img_name, valid_shape, org_shape in reader():
                    imgs.append(img)
                    grts.append(grt)
                    img_names.append(img_name)
                    valid_shapes.append(valid_shape)
                    org_shapes.append(org_shape)
                    if len(imgs) == batch_size:
                        yield np.array(imgs), np.array(
                            grts), img_names, np.array(valid_shapes), np.array(
                                org_shapes)
                        imgs, grts, img_names, valid_shapes, org_shapes = [], [], [], [], []

                if not drop_last and len(imgs) > 0:
                    yield np.array(imgs), np.array(grts), img_names, np.array(
                        valid_shapes), np.array(org_shapes)
            else:
                imgs, labs, ignore = [], [], []
                bs = 0
                for img, lab, ig in reader():
                    imgs.append(img)
                    labs.append(lab)
                    ignore.append(ig)
                    bs += 1
                    if bs == batch_size:
                        yield np.array(imgs), np.array(labs), np.array(ignore)
                        bs = 0
                        imgs, labs, ignore = [], [], []

                if not drop_last and bs > 0:
                    yield np.array(imgs), np.array(labs), np.array(ignore)

        return batch_reader(is_test, drop_last)

    def load_image(self, line, src_dir, mode=ModelPhase.TRAIN):
        raise NotImplemented

    def pil_imread(self, file_path):
        """read pseudo-color label"""
        im = Image.open(file_path)
        return np.asarray(im)

    def cv2_imread(self, file_path, flag=cv2.IMREAD_COLOR):
        # resolve cv2.imread open Chinese file path issues on Windows Platform.
        return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)

    def normalize_image(self, img):
        img = img.transpose((2, 0, 1)).astype('float32') / 255.0
        img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1))
        img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1))
        img -= img_mean
        img /= img_std

        return img

    def process_image(self, line, data_dir, mode):
        """ process_image """
        img, grt, img_name, grt_name = self.load_image( line, data_dir, mode=mode)  # img.type: numpy.array, grt.type: numpy.array
        if mode == ModelPhase.TRAIN:
            # numpy.array convert to  PIL.Image 
            img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
            grt = Image.fromarray(grt.astype('uint8')).convert('L')
            
            crop_size = self.crop_size
            # random scale 
            if self.rand_scale:
                short_size = random.randint(int(self.base_size * cfg.DATAAUG.RAND_SCALE_MIN), int(self.base_size * cfg.DATAAUG.RAND_SCALE_MAX))
            else:
                short_size = self.base_size
            w, h = img.size
            if h > w:
                out_w = short_size
                out_h = int(1.0 * h / w * out_w)
            else:
                out_h = short_size
                out_w = int(1.0 * w / h * out_h)
            img = img.resize((out_w, out_h), Image.BILINEAR)
            grt = grt.resize((out_w, out_h), Image.NEAREST)

            # rand flip
            if random.random() > 0.5:
                img = img.transpose(Image.FLIP_LEFT_RIGHT)
                grt = grt.transpose(Image.FLIP_LEFT_RIGHT)

            # padding
            if short_size < crop_size:
                pad_h = crop_size - out_h if out_h < crop_size else 0
                pad_w = crop_size - out_w if out_w < crop_size else 0
                img = ImageOps.expand(img, border=(pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=0)
                grt = ImageOps.expand(grt, border=(pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2), fill=cfg.DATASET.IGNORE_INDEX)

            # random crop
            w, h = img.size
            x = random.randint(0, w - crop_size)
            y = random.randint(0, h - crop_size)
            img = img.crop((x, y, x + crop_size, y + crop_size))
            grt = grt.crop((x, y, x + crop_size, y + crop_size))


            # gaussian blur
            if cfg.DATAAUG_EXTRA:
                if random.random() > 0.7:
                    img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))

            # PIL.Image -> cv2
            img = cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
            grt = np.array(grt)
            
        elif ModelPhase.is_eval(mode):
            org_shape = [img.shape[0], img.shape[1]]  # 1024 x 2048 for cityscapes

        elif ModelPhase.is_visual(mode):
            org_shape = [img.shape[0], img.shape[1]]
            #img, grt = resize(img, grt, mode=mode)
            valid_shape = [img.shape[0], img.shape[1]]
            #img, grt = rand_crop(img, grt, mode=mode)
        else:
            raise ValueError("Dataset mode={} Error!".format(mode))

        # Normalize image
        img = self.normalize_image(img)

        if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
            grt = np.expand_dims(np.array(grt).astype('int32'), axis=0)
            ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32')


        if ModelPhase.is_train(mode):
            return (img, grt, ignore)
        elif ModelPhase.is_eval(mode):
            return (img, grt, ignore)
        elif ModelPhase.is_visual(mode):
            return (img, grt, img_name, valid_shape, org_shape)