# coding: utf8 # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import sys import os import math import random import functools import io import time import codecs import numpy as np import paddle import paddle.fluid as fluid import cv2 import data_aug as aug from utils.config import cfg from data_utils import GeneratorEnqueuer from models.model_builder import ModelPhase def cv2_imread(file_path, flag=cv2.IMREAD_COLOR): # resolve cv2.imread open Chinese file path issues on Windows Platform. return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag) class SegDataset(object): def __init__(self, file_list, data_dir, shuffle=False, mode=ModelPhase.TRAIN): self.mode = mode self.shuffle = shuffle self.data_dir = data_dir # NOTE: Please ensure file list was save in UTF-8 coding format with codecs.open(file_list, 'r', 'utf-8') as flist: self.lines = [line.strip() for line in flist] if shuffle: np.random.shuffle(self.lines) def generator(self): if self.shuffle: np.random.shuffle(self.lines) for line in self.lines: yield self.process_image(line, self.data_dir, self.mode) def sharding_generator(self, pid=0, num_processes=1): """ Use line id as shard key for multiprocess io It's a normal generator if pid=0, num_processes=1 """ for index, line in enumerate(self.lines): # Use index and pid to shard file list if index % num_processes == pid: yield self.process_image(line, self.data_dir, self.mode) def batch_reader(self, batch_size): br = self.batch(self.reader, batch_size) for batch in br: yield batch[0], batch[1], batch[2] def multiprocess_generator(self, max_queue_size=32, num_processes=8): # Re-shuffle file list if self.shuffle: np.random.shuffle(self.lines) # Create multiple sharding generators according to num_processes for multiple processes generators = [] for pid in range(num_processes): generators.append(self.sharding_generator(pid, num_processes)) try: enqueuer = GeneratorEnqueuer(generators) enqueuer.start(max_queue_size=max_queue_size, workers=num_processes) while True: generator_out = None while enqueuer.is_running(): if not enqueuer.queue.empty(): generator_out = enqueuer.queue.get(timeout=5) break else: time.sleep(0.01) if generator_out is None: break yield generator_out finally: if enqueuer is not None: enqueuer.stop() def batch(self, reader, batch_size, is_test=False, drop_last=False): def batch_reader(is_test=False, drop_last=drop_last): if is_test: imgs, img_names, valid_shapes, org_shapes = [], [], [], [] for img, img_name, valid_shape, org_shape in reader(): imgs.append(img) img_names.append(img_name) valid_shapes.append(valid_shape) org_shapes.append(org_shape) if len(imgs) == batch_size: yield np.array(imgs), img_names, np.array( valid_shapes), np.array(org_shapes) imgs, img_names, valid_shapes, org_shapes = [], [], [], [] if not drop_last and len(imgs) > 0: yield np.array(imgs), img_names, np.array( valid_shapes), np.array(org_shapes) else: imgs, labs, ignore = [], [], [] bs = 0 for img, lab, ig in reader(): imgs.append(img) labs.append(lab) ignore.append(ig) bs += 1 if bs == batch_size: yield np.array(imgs), np.array(labs), np.array(ignore) bs = 0 imgs, labs, ignore = [], [], [] if not drop_last and bs > 0: yield np.array(imgs), np.array(labs), np.array(ignore) return batch_reader(is_test, drop_last) def load_image(self, line, src_dir, mode=ModelPhase.TRAIN): # original image cv2.imread flag setting cv2_imread_flag = cv2.IMREAD_COLOR if cfg.DATASET.IMAGE_TYPE == "rgba": # If use RBGA 4 channel ImageType, use IMREAD_UNCHANGED flags to # reserver alpha channel cv2_imread_flag = cv2.IMREAD_UNCHANGED if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL: parts = line.strip().split(cfg.DATASET.SEPARATOR) if len(parts) != 2: raise Exception("File list format incorrect! It should be" " image_name{}label_name\\n".format( cfg.DATASET.SEPARATOR)) img_name, grt_name = parts[0], parts[1] img_path = os.path.join(src_dir, img_name) grt_path = os.path.join(src_dir, grt_name) img = cv2_imread(img_path, cv2_imread_flag) grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE) if img is None or grt is None: raise Exception( "Empty image, src_dir: {}, img: {} & lab: {}".format( src_dir, img_path, grt_path)) img_height = img.shape[0] img_width = img.shape[1] grt_height = grt.shape[0] grt_width = grt.shape[1] if img_height != grt_height or img_width != grt_width: raise Exception( "source img and label img must has the same size") if len(img.shape) < 3: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img_channels = img.shape[2] if img_channels < 3: raise Exception( "PaddleSeg only supports gray, rgb or rgba image") if img_channels != cfg.DATASET.DATA_DIM: raise Exception( "Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}" .format(img_channels, cfg.DATASET.DATADIM, img_name)) if img_channels != len(cfg.MEAN): raise Exception( "img name {}, img chns {} mean size {}, size unequal". format(img_name, img_channels, len(cfg.MEAN))) if img_channels != len(cfg.STD): raise Exception( "img name {}, img chns {} std size {}, size unequal".format( img_name, img_channels, len(cfg.STD))) # visualization mode elif mode == ModelPhase.VISUAL: if cfg.DATASET.SEPARATOR in line: parts = line.strip().split(cfg.DATASET.SEPARATOR) img_name = parts[0] else: img_name = line.strip() img_path = os.path.join(src_dir, img_name) img = cv2_imread(img_path, cv2_imread_flag) if img is None: raise Exception("empty image, src_dir:{}, img: {}".format( src_dir, img_name)) # Convert grayscale image to BGR 3 channel image if len(img.shape) < 3: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img_height = img.shape[0] img_width = img.shape[1] img_channels = img.shape[2] if img_channels < 3: raise Exception("this repo only recept gray, rgb or rgba image") if img_channels != cfg.DATASET.DATA_DIM: raise Exception("data dim must equal to image channels") if img_channels != len(cfg.MEAN): raise Exception( "img name {}, img chns {} mean size {}, size unequal". format(img_name, img_channels, len(cfg.MEAN))) if img_channels != len(cfg.STD): raise Exception( "img name {}, img chns {} std size {}, size unequal".format( img_name, img_channels, len(cfg.STD))) grt = None grt_name = None else: raise ValueError("mode error: {}".format(mode)) return img, grt, img_name, grt_name def normalize_image(self, img): """ 像素归一化后减均值除方差 """ img = img.transpose((2, 0, 1)).astype('float32') / 255.0 img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1)) img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1)) img -= img_mean img /= img_std return img def process_image(self, line, data_dir, mode): """ process_image """ img, grt, img_name, grt_name = self.load_image( line, data_dir, mode=mode) if mode == ModelPhase.TRAIN: img, grt = aug.resize(img, grt, mode) if cfg.AUG.RICH_CROP.ENABLE: if cfg.AUG.RICH_CROP.BLUR: if cfg.AUG.RICH_CROP.BLUR_RATIO <= 0: n = 0 elif cfg.AUG.RICH_CROP.BLUR_RATIO >= 1: n = 1 else: n = int(1.0 / cfg.AUG.RICH_CROP.BLUR_RATIO) if n > 0: if np.random.randint(0, n) == 0: radius = np.random.randint(3, 10) if radius % 2 != 1: radius = radius + 1 if radius > 9: radius = 9 img = cv2.GaussianBlur(img, (radius, radius), 0, 0) img, grt = aug.random_rotation( img, grt, rich_crop_max_rotation=cfg.AUG.RICH_CROP.MAX_ROTATION, mean_value=cfg.MEAN) img, grt = aug.rand_scale_aspect( img, grt, rich_crop_min_scale=cfg.AUG.RICH_CROP.MIN_AREA_RATIO, rich_crop_aspect_ratio=cfg.AUG.RICH_CROP.ASPECT_RATIO) img = aug.hsv_color_jitter( img, brightness_jitter_ratio=cfg.AUG.RICH_CROP. BRIGHTNESS_JITTER_RATIO, saturation_jitter_ratio=cfg.AUG.RICH_CROP. SATURATION_JITTER_RATIO, contrast_jitter_ratio=cfg.AUG.RICH_CROP. CONTRAST_JITTER_RATIO) if cfg.AUG.RICH_CROP.FLIP: if cfg.AUG.RICH_CROP.FLIP_RATIO <= 0: n = 0 elif cfg.AUG.RICH_CROP.FLIP_RATIO >= 1: n = 1 else: n = int(1.0 / cfg.AUG.RICH_CROP.FLIP_RATIO) if n > 0: if np.random.randint(0, n) == 0: img = img[::-1, :, :] grt = grt[::-1, :] if cfg.AUG.MIRROR: if np.random.randint(0, 2) == 1: img = img[:, ::-1, :] grt = grt[:, ::-1] img, grt = aug.rand_crop(img, grt, mode=mode) elif ModelPhase.is_eval(mode): img, grt = aug.resize(img, grt, mode=mode) img, grt = aug.rand_crop(img, grt, mode=mode) elif ModelPhase.is_visual(mode): org_shape = [img.shape[0], img.shape[1]] img, grt = aug.resize(img, grt, mode=mode) valid_shape = [img.shape[0], img.shape[1]] img, grt = aug.rand_crop(img, grt, mode=mode) else: raise ValueError("Dataset mode={} Error!".format(mode)) # Normalize image img = self.normalize_image(img) if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode): grt = np.expand_dims(np.array(grt).astype('int32'), axis=0) ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32') if ModelPhase.is_train(mode): return (img, grt, ignore) elif ModelPhase.is_eval(mode): return (img, grt, ignore) elif ModelPhase.is_visual(mode): return (img, img_name, valid_shape, org_shape)