From 897d86ac4697b0268f6d038346602f02b78079fa Mon Sep 17 00:00:00 2001 From: sunxl1988 <47514455+sunxl1988@users.noreply.github.com> Date: Tue, 21 Jul 2020 16:38:50 +0800 Subject: [PATCH] test=dygraph sync reader from static ppdet (#1084) sync reader from static ppdet --- ppdet/data/reader.py | 34 +- ppdet/data/source/widerface.py | 125 +++++--- ppdet/data/transform/batch_operators.py | 168 ++++++---- ppdet/data/transform/gridmask_utils.py | 83 +++++ ppdet/data/transform/op_helper.py | 28 +- ppdet/data/transform/operators.py | 406 ++++++++++++++++++++++-- 6 files changed, 701 insertions(+), 143 deletions(-) create mode 100644 ppdet/data/transform/gridmask_utils.py diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 5ed4a3f3d..7d808b589 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import copy import functools import collections @@ -167,6 +168,8 @@ class Reader(object): Default True. mixup_epoch (int): mixup epoc number. Default is -1, meaning not use mixup. + cutmix_epoch (int): cutmix epoc number. Default is -1, meaning + not use cutmix. class_aware_sampling (bool): whether use class-aware sampling or not. Default False. worker_num (int): number of working threads/processes. @@ -191,6 +194,7 @@ class Reader(object): drop_last=False, drop_empty=True, mixup_epoch=-1, + cutmix_epoch=-1, class_aware_sampling=False, worker_num=-1, use_process=False, @@ -241,6 +245,7 @@ class Reader(object): # sampling self._mixup_epoch = mixup_epoch + self._cutmix_epoch = cutmix_epoch self._class_aware_sampling = class_aware_sampling self._load_img = False @@ -253,6 +258,8 @@ class Reader(object): self._pos = -1 self._epoch = -1 + self._curr_iter = 0 + # multi-process self._worker_num = worker_num self._parallel = None @@ -274,6 +281,11 @@ class Reader(object): def reset(self): """implementation of Dataset.reset """ + if self._epoch < 0: + self._epoch = 0 + else: + self._epoch += 1 + self.indexes = [i for i in range(self.size())] if self._class_aware_sampling: self.indexes = np.random.choice( @@ -283,17 +295,18 @@ class Reader(object): p=self.img_weights) if self._shuffle: + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0)) + np.random.seed(self._epoch + trainer_id) np.random.shuffle(self.indexes) if self._mixup_epoch > 0 and len(self.indexes) < 2: logger.debug("Disable mixup for dataset samples " "less than 2 samples") self._mixup_epoch = -1 - - if self._epoch < 0: - self._epoch = 0 - else: - self._epoch += 1 + if self._cutmix_epoch > 0 and len(self.indexes) < 2: + logger.info("Disable cutmix for dataset samples " + "less than 2 samples") + self._cutmix_epoch = -1 self._pos = 0 @@ -306,6 +319,7 @@ class Reader(object): if self.drained(): raise StopIteration batch = self._load_batch() + self._curr_iter += 1 if self._drop_last and len(batch) < self._batch_size: raise StopIteration if self._worker_num > -1: @@ -321,6 +335,7 @@ class Reader(object): break pos = self.indexes[self._pos] sample = copy.deepcopy(self._roidbs[pos]) + sample["curr_iter"] = self._curr_iter self._pos += 1 if self._drop_empty and self._fields and 'gt_mask' in self._fields: @@ -343,9 +358,18 @@ class Reader(object): mix_idx = np.random.randint(1, num) mix_idx = self.indexes[(mix_idx + self._pos - 1) % num] sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx]) + sample['mixup']["curr_iter"] = self._curr_iter if self._load_img: sample['mixup']['image'] = self._load_image(sample['mixup'][ 'im_file']) + if self._epoch < self._cutmix_epoch: + num = len(self.indexes) + mix_idx = np.random.randint(1, num) + sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx]) + sample['cutmix']["curr_iter"] = self._curr_iter + if self._load_img: + sample['cutmix']['image'] = self._load_image(sample[ + 'cutmix']['im_file']) batch.append(sample) bs += 1 diff --git a/ppdet/data/source/widerface.py b/ppdet/data/source/widerface.py index 311430559..7aab15337 100644 --- a/ppdet/data/source/widerface.py +++ b/ppdet/data/source/widerface.py @@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet): image_dir=None, anno_path=None, sample_num=-1, - with_background=True): + with_background=True, + with_lmk=False): super(WIDERFaceDataSet, self).__init__( image_dir=image_dir, anno_path=anno_path, @@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet): self.with_background = with_background self.roidbs = None self.cname2cid = None + self.with_lmk = with_lmk def load_roidb_and_cname2cid(self): anno_path = os.path.join(self.dataset_dir, self.anno_path) @@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet): records = [] ct = 0 - file_lists = _load_file_list(txt_file) + file_lists = self._load_file_list(txt_file) cname2cid = widerface_label(self.with_background) for item in file_lists: im_fname = item[0] im_id = np.array([ct]) - gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32) - gt_class = np.ones((len(item) - 2, 1), dtype=np.int32) + gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32) + gt_class = np.ones((len(item) - 1, 1), dtype=np.int32) + gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32) + lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32) for index_box in range(len(item)): - if index_box >= 2: - temp_info_box = item[index_box].split(' ') - xmin = float(temp_info_box[0]) - ymin = float(temp_info_box[1]) - w = float(temp_info_box[2]) - h = float(temp_info_box[3]) - # Filter out wrong labels - if w < 0 or h < 0: - logger.warn('Illegal box with w: {}, h: {} in ' - 'img: {}, and it will be ignored'.format( - w, h, im_fname)) - continue - xmin = max(0, xmin) - ymin = max(0, ymin) - xmax = xmin + w - ymax = ymin + h - gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax] - + if index_box < 1: + continue + gt_bbox[index_box - 1] = item[index_box][0] + if self.with_lmk: + gt_lmk_labels[index_box - 1] = item[index_box][1] + lmk_ignore_flag[index_box - 1] = item[index_box][2] im_fname = os.path.join(image_dir, im_fname) if image_dir else im_fname widerface_rec = { @@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet): 'gt_bbox': gt_bbox, 'gt_class': gt_class, } - # logger.debug + if self.with_lmk: + widerface_rec['gt_keypoint'] = gt_lmk_labels + widerface_rec['keypoint_ignore'] = lmk_ignore_flag + if len(item) != 0: records.append(widerface_rec) @@ -108,34 +103,64 @@ class WIDERFaceDataSet(DataSet): logger.debug('{} samples in file {}'.format(ct, anno_path)) self.roidbs, self.cname2cid = records, cname2cid - -def _load_file_list(input_txt): - with open(input_txt, 'r') as f_dir: - lines_input_txt = f_dir.readlines() - - file_dict = {} - num_class = 0 - for i in range(len(lines_input_txt)): - line_txt = lines_input_txt[i].strip('\n\t\r') - if '.jpg' in line_txt: - if i != 0: - num_class += 1 - file_dict[num_class] = [] - file_dict[num_class].append(line_txt) - if '.jpg' not in line_txt: - if len(line_txt) > 6: - split_str = line_txt.split(' ') - x1_min = float(split_str[0]) - y1_min = float(split_str[1]) - x2_max = float(split_str[2]) - y2_max = float(split_str[3]) - line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str( - x2_max) + ' ' + str(y2_max) + def _load_file_list(self, input_txt): + with open(input_txt, 'r') as f_dir: + lines_input_txt = f_dir.readlines() + + file_dict = {} + num_class = 0 + for i in range(len(lines_input_txt)): + line_txt = lines_input_txt[i].strip('\n\t\r') + if '.jpg' in line_txt: + if i != 0: + num_class += 1 + file_dict[num_class] = [] file_dict[num_class].append(line_txt) - else: - file_dict[num_class].append(line_txt) - - return list(file_dict.values()) + if '.jpg' not in line_txt: + if len(line_txt) <= 6: + continue + result_boxs = [] + split_str = line_txt.split(' ') + xmin = float(split_str[0]) + ymin = float(split_str[1]) + w = float(split_str[2]) + h = float(split_str[3]) + # Filter out wrong labels + if w < 0 or h < 0: + logger.warn('Illegal box with w: {}, h: {} in ' + 'img: {}, and it will be ignored'.format( + w, h, file_dict[num_class][0])) + continue + xmin = max(0, xmin) + ymin = max(0, ymin) + xmax = xmin + w + ymax = ymin + h + gt_bbox = [xmin, ymin, xmax, ymax] + result_boxs.append(gt_bbox) + if self.with_lmk: + assert len(split_str) > 18, 'When `with_lmk=True`, the number' \ + 'of characters per line in the annotation file should' \ + 'exceed 18.' + lmk0_x = float(split_str[5]) + lmk0_y = float(split_str[6]) + lmk1_x = float(split_str[8]) + lmk1_y = float(split_str[9]) + lmk2_x = float(split_str[11]) + lmk2_y = float(split_str[12]) + lmk3_x = float(split_str[14]) + lmk3_y = float(split_str[15]) + lmk4_x = float(split_str[17]) + lmk4_y = float(split_str[18]) + lmk_ignore_flag = 0 if lmk0_x == -1 else 1 + gt_lmk_label = [ + lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x, + lmk3_y, lmk4_x, lmk4_y + ] + result_boxs.append(gt_lmk_label) + result_boxs.append(lmk_ignore_flag) + file_dict[num_class].append(result_boxs) + + return list(file_dict.values()) def widerface_label(with_background=True): diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 068614ab5..1bed5edaf 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -26,13 +26,17 @@ import cv2 import numpy as np from .operators import register_op, BaseOperator -from .op_helper import jaccard_overlap +from .op_helper import jaccard_overlap, gaussian2D logger = logging.getLogger(__name__) __all__ = [ - 'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget', - 'Gt2FCOSTarget' + 'PadBatch', + 'RandomShape', + 'PadMultiScaleTest', + 'Gt2YoloTarget', + 'Gt2FCOSTarget', + 'Gt2TTFTarget', ] @@ -41,17 +45,15 @@ class PadBatch(BaseOperator): """ Pad a batch of samples so they can be divisible by a stride. The layout of each image should be 'CHW'. - Args: pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure height and width is divisible by `pad_to_stride`. """ - def __init__(self, pad_to_stride=0, use_padded_im_info=True, pad_gt=False): + def __init__(self, pad_to_stride=0, use_padded_im_info=True): super(PadBatch, self).__init__() self.pad_to_stride = pad_to_stride self.use_padded_im_info = use_padded_im_info - self.pad_gt = pad_gt def __call__(self, samples, context=None): """ @@ -61,9 +63,9 @@ class PadBatch(BaseOperator): coarsest_stride = self.pad_to_stride if coarsest_stride == 0: return samples - max_shape = np.array([data['image'].shape for data in samples]).max( axis=0) + if coarsest_stride > 0: max_shape[1] = int( np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) @@ -80,52 +82,6 @@ class PadBatch(BaseOperator): data['image'] = padding_im if self.use_padded_im_info: data['im_info'][:2] = max_shape[1:3] - - if self.pad_gt: - gt_num = [] - if data['gt_poly'] is not None and len(data['gt_poly']) > 0: - pad_mask = True - else: - pad_mask = False - - if pad_mask: - poly_num = [] - poly_part_num = [] - point_num = [] - for data in samples: - gt_num.append(data['gt_bbox'].shape[0]) - if pad_mask: - poly_num.append(len(data['gt_poly'])) - for poly in data['gt_poly']: - poly_part_num.append(int(len(poly))) - for p_p in poly: - point_num.append(int(len(p_p) / 2)) - gt_num_max = max(gt_num) - gt_box_data = np.zeros([gt_num_max, 4]) - gt_class_data = np.zeros([gt_num_max]) - is_crowd_data = np.ones([gt_num_max]) - - if pad_mask: - poly_num_max = max(poly_num) - poly_part_num_max = max(poly_part_num) - point_num_max = max(point_num) - gt_masks_data = -np.ones( - [poly_num_max, poly_part_num_max, point_num_max, 2]) - - for i, data in enumerate(samples): - gt_num = data['gt_bbox'].shape[0] - gt_box_data[0:gt_num, :] = data['gt_bbox'] - gt_class_data[0:gt_num] = np.squeeze(data['gt_class']) - is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd']) - if pad_mask: - for j, poly in enumerate(data['gt_poly']): - for k, p_p in enumerate(poly): - pp_np = np.array(p_p).reshape(-1, 2) - gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np - data['gt_poly'] = gt_masks_data - data['gt_bbox'] = gt_box_data - data['gt_class'] = gt_class_data - data['is_crowd_data'] = is_crowd_data return samples @@ -136,13 +92,12 @@ class RandomShape(BaseOperator): select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is False, use cv2.INTER_NEAREST. - Args: sizes (list): list of int, random choose a size from these random_inter (bool): whether to randomly interpolation, defalut true. """ - def __init__(self, sizes=[], random_inter=False): + def __init__(self, sizes=[], random_inter=False, resize_box=False): super(RandomShape, self).__init__() self.sizes = sizes self.random_inter = random_inter @@ -153,6 +108,7 @@ class RandomShape(BaseOperator): cv2.INTER_CUBIC, cv2.INTER_LANCZOS4, ] if random_inter else [] + self.resize_box = resize_box def __call__(self, samples, context=None): shape = np.random.choice(self.sizes) @@ -166,6 +122,12 @@ class RandomShape(BaseOperator): im = cv2.resize( im, None, None, fx=scale_x, fy=scale_y, interpolation=method) samples[i]['image'] = im + if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][ + 'gt_bbox']) > 0: + scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) + samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] * + scale_array, 0, + float(shape) - 1) return samples @@ -525,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator): sample['centerness{}'.format(lvl)] = np.reshape( ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) return samples + + +@register_op +class Gt2TTFTarget(BaseOperator): + """ + Gt2TTFTarget + Generate TTFNet targets by ground truth data + + Args: + num_classes(int): the number of classes. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + alpha(float): the alpha parameter to generate gaussian target. + 0.54 by default. + """ + + def __init__(self, num_classes, down_ratio=4, alpha=0.54): + super(Gt2TTFTarget, self).__init__() + self.down_ratio = down_ratio + self.num_classes = num_classes + self.alpha = alpha + + def __call__(self, samples, context=None): + output_size = samples[0]['image'].shape[1] + feat_size = output_size // self.down_ratio + for sample in samples: + heatmap = np.zeros( + (self.num_classes, feat_size, feat_size), dtype='float32') + box_target = np.ones( + (4, feat_size, feat_size), dtype='float32') * -1 + reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32') + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + + bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1 + bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1 + area = bbox_w * bbox_h + boxes_areas_log = np.log(area) + boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1] + boxes_area_topk_log = boxes_areas_log[boxes_ind] + gt_bbox = gt_bbox[boxes_ind] + gt_class = gt_class[boxes_ind] + + feat_gt_bbox = gt_bbox / self.down_ratio + feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1) + feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1], + feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0]) + + ct_inds = np.stack( + [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2, + (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2], + axis=1) / self.down_ratio + + h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32') + w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32') + + for k in range(len(gt_bbox)): + cls_id = gt_class[k] + fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32') + self.draw_truncate_gaussian(fake_heatmap, ct_inds[k], + h_radiuses_alpha[k], + w_radiuses_alpha[k]) + + heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap) + box_target_inds = fake_heatmap > 0 + box_target[:, box_target_inds] = gt_bbox[k][:, None] + + local_heatmap = fake_heatmap[box_target_inds] + ct_div = np.sum(local_heatmap) + local_heatmap *= boxes_area_topk_log[k] + reg_weight[0, box_target_inds] = local_heatmap / ct_div + sample['ttf_heatmap'] = heatmap + sample['ttf_box_target'] = box_target + sample['ttf_reg_weight'] = reg_weight + return samples + + def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): + h, w = 2 * h_radius + 1, 2 * w_radius + 1 + sigma_x = w / 6 + sigma_y = h / 6 + gaussian = gaussian2D((h, w), sigma_x, sigma_y) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, w_radius), min(width - x, w_radius + 1) + top, bottom = min(y, h_radius), min(height - y, h_radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius - + left:w_radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + heatmap[y - top:y + bottom, x - left:x + right] = np.maximum( + masked_heatmap, masked_gaussian) + return heatmap diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py new file mode 100644 index 000000000..a23e69b20 --- /dev/null +++ b/ppdet/data/transform/gridmask_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np +from PIL import Image + + +class GridMask(object): + def __init__(self, + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7, + upper_iter=360000): + super(GridMask, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + self.st_prob = prob + self.upper_iter = upper_iter + + def __call__(self, x, curr_iter): + self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter) + if np.random.rand() > self.prob: + return x + _, h, w = x.shape + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2 + + w].astype(np.float32) + + if self.mode == 1: + mask = 1 - mask + mask = np.expand_dims(mask, axis=0) + if self.offset: + offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32) + x = (x * mask + offset * (1 - mask)).astype(x.dtype) + else: + x = (x * mask).astype(x.dtype) + + return x diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index d41efd934..02d219546 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox): return True -def filter_and_process(sample_bbox, bboxes, labels, scores=None): +def filter_and_process(sample_bbox, bboxes, labels, scores=None, + keypoints=None): new_bboxes = [] new_labels = [] new_scores = [] + new_keypoints = [] + new_kp_ignore = [] for i in range(len(bboxes)): new_bbox = [0, 0, 0, 0] obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] @@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None): new_labels.append([labels[i][0]]) if scores is not None: new_scores.append([scores[i][0]]) + if keypoints is not None: + sample_keypoint = keypoints[0][i] + for j in range(len(sample_keypoint)): + kp_len = sample_height if j % 2 else sample_width + sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0] + sample_keypoint[j] = ( + sample_keypoint[j] - sample_coord) / kp_len + sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0) + new_keypoints.append(sample_keypoint) + new_kp_ignore.append(keypoints[1][i]) + bboxes = np.array(new_bboxes) labels = np.array(new_labels) scores = np.array(new_scores) + if keypoints is not None: + keypoints = np.array(new_keypoints) + new_kp_ignore = np.array(new_kp_ignore) + return bboxes, labels, scores, (keypoints, new_kp_ignore) return bboxes, labels, scores @@ -420,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap): def draw_gaussian(heatmap, center, radius, k=1, delte=6): diameter = 2 * radius + 1 - gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte) + sigma = diameter / delte + gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma) x, y = center @@ -435,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6): np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) -def gaussian2D(shape, sigma=1): +def gaussian2D(shape, sigma_x=1, sigma_y=1): m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m + 1, -n:n + 1] - h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * + sigma_y))) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 6d04454e5..db73e4174 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -32,9 +32,10 @@ import logging import random import math import numpy as np +import os import cv2 -from PIL import Image, ImageEnhance +from PIL import Image, ImageEnhance, ImageDraw from ppdet.core.workspace import serializable from ppdet.modeling.ops import AnchorGrid @@ -89,21 +90,24 @@ class BaseOperator(object): @register_op class DecodeImage(BaseOperator): - def __init__(self, to_rgb=True, with_mixup=False): + def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False): """ Transform the image data to numpy format. - Args: to_rgb (bool): whether to convert BGR to RGB with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score + with_cutmix (bool): whether or not to cutmix image and gt_bbbox/gt_score """ super(DecodeImage, self).__init__() self.to_rgb = to_rgb self.with_mixup = with_mixup + self.with_cutmix = with_cutmix if not isinstance(self.to_rgb, bool): raise TypeError("{}: input type is invalid.".format(self)) if not isinstance(self.with_mixup, bool): raise TypeError("{}: input type is invalid.".format(self)) + if not isinstance(self.with_cutmix, bool): + raise TypeError("{}: input type is invalid.".format(self)) def __call__(self, sample, context=None): """ load image if 'im_file' field is not empty but 'image' is""" @@ -142,6 +146,10 @@ class DecodeImage(BaseOperator): # decode mixup image if self.with_mixup and 'mixup' in sample: self.__call__(sample['mixup'], context) + # decode cutmix image + if self.with_cutmix and 'cutmix' in sample: + self.__call__(sample['cutmix'], context) + return sample @@ -156,7 +164,6 @@ class MultiscaleTestResize(BaseOperator): use_flip=True): """ Rescale image to the each size in target size, and capped at max_size. - Args: origin_target_size(int): original target size of image's short side. origin_max_size(int): original max size of image. @@ -265,7 +272,6 @@ class ResizeImage(BaseOperator): if max_size != 0. If target_size is list, selected a scale randomly as the specified target size. - Args: target_size (int|list): the target size of image's short side, multi-scale training is adopted when type is list. @@ -392,6 +398,16 @@ class RandomFlipImage(BaseOperator): flipped_segms.append(_flip_rle(segm, height, width)) return flipped_segms + def flip_keypoint(self, gt_keypoint, width): + for i in range(gt_keypoint.shape[1]): + if i % 2 == 0: + old_x = gt_keypoint[:, i].copy() + if self.is_normalized: + gt_keypoint[:, i] = 1 - old_x + else: + gt_keypoint[:, i] = width - old_x - 1 + return gt_keypoint + def __call__(self, sample, context=None): """Filp the image and bounding box. Operators: @@ -439,12 +455,130 @@ class RandomFlipImage(BaseOperator): if self.is_mask_flip and len(sample['gt_poly']) != 0: sample['gt_poly'] = self.flip_segms(sample['gt_poly'], height, width) + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = self.flip_keypoint( + sample['gt_keypoint'], width) sample['flipped'] = True sample['image'] = im sample = samples if batch_input else samples[0] return sample +@register_op +class RandomErasingImage(BaseOperator): + def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3): + """ + Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896 + Args: + prob (float): probability to carry out random erasing + sl (float): lower limit of the erasing area ratio + sh (float): upper limit of the erasing area ratio + r1 (float): aspect ratio of the erasing region + """ + super(RandomErasingImage, self).__init__() + self.prob = prob + self.sl = sl + self.sh = sh + self.r1 = r1 + + def __call__(self, sample, context=None): + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + gt_bbox = sample['gt_bbox'] + im = sample['image'] + if not isinstance(im, np.ndarray): + raise TypeError("{}: image is not a numpy array.".format(self)) + if len(im.shape) != 3: + raise ImageError("{}: image is not 3-dimensional.".format(self)) + + for idx in range(gt_bbox.shape[0]): + if self.prob <= np.random.rand(): + continue + + x1, y1, x2, y2 = gt_bbox[idx, :] + w_bbox = x2 - x1 + 1 + h_bbox = y2 - y1 + 1 + area = w_bbox * h_bbox + + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.r1, 1 / self.r1) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + + if w < w_bbox and h < h_bbox: + off_y1 = random.randint(0, int(h_bbox - h)) + off_x1 = random.randint(0, int(w_bbox - w)) + im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1): + int(x1 + off_x1 + w), :] = 0 + sample['image'] = im + + sample = samples if batch_input else samples[0] + return sample + + +@register_op +class GridMaskOp(BaseOperator): + def __init__(self, + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7, + upper_iter=360000): + """ + GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086 + Args: + use_h (bool): whether to mask vertically + use_w (boo;): whether to mask horizontally + rotate (float): angle for the mask to rotate + offset (float): mask offset + ratio (float): mask ratio + mode (int): gridmask mode + prob (float): max probability to carry out gridmask + upper_iter (int): suggested to be equal to global max_iter + """ + super(GridMaskOp, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + self.upper_iter = upper_iter + + from .gridmask_utils import GridMask + self.gridmask_op = GridMask( + use_h, + use_w, + rotate=rotate, + offset=offset, + ratio=ratio, + mode=mode, + prob=prob, + upper_iter=upper_iter) + + def __call__(self, sample, context=None): + samples = sample + batch_input = True + if not isinstance(samples, Sequence): + batch_input = False + samples = [samples] + for sample in samples: + sample['image'] = self.gridmask_op(sample['image'], + sample['curr_iter']) + if not batch_input: + samples = samples[0] + return sample + + @register_op class AutoAugmentImage(BaseOperator): def __init__(self, is_normalized=False, autoaug_type="v1"): @@ -733,8 +867,17 @@ class ExpandImage(BaseOperator): im = Image.fromarray(im) expand_im.paste(im, (int(w_off), int(h_off))) expand_im = np.asarray(expand_im) - gt_bbox, gt_class, _ = filter_and_process(expand_bbox, gt_bbox, - gt_class) + if 'gt_keypoint' in sample.keys( + ) and 'keypoint_ignore' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + gt_bbox, gt_class, _, gt_keypoints = filter_and_process( + expand_bbox, gt_bbox, gt_class, keypoints=keypoints) + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] + else: + gt_bbox, gt_class, _ = filter_and_process(expand_bbox, + gt_bbox, gt_class) sample['image'] = expand_im sample['gt_bbox'] = gt_bbox sample['gt_class'] = gt_class @@ -808,7 +951,7 @@ class CropImage(BaseOperator): sample_bbox = sampled_bbox.pop(idx) sample_bbox = clip_bbox(sample_bbox) crop_bbox, crop_class, crop_score = \ - filter_and_process(sample_bbox, gt_bbox, gt_class, gt_score) + filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score) if self.avoid_no_bbox: if len(crop_bbox) < 1: continue @@ -911,8 +1054,16 @@ class CropImageWithDataAchorSampling(BaseOperator): idx = int(np.random.uniform(0, len(sampled_bbox))) sample_bbox = sampled_bbox.pop(idx) - crop_bbox, crop_class, crop_score = filter_and_process( - sample_bbox, gt_bbox, gt_class, gt_score) + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score, self.target_size, self.min_size) @@ -926,6 +1077,9 @@ class CropImageWithDataAchorSampling(BaseOperator): sample['gt_bbox'] = crop_bbox sample['gt_class'] = crop_class sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] return sample return sample @@ -947,8 +1101,16 @@ class CropImageWithDataAchorSampling(BaseOperator): sample_bbox = sampled_bbox.pop(idx) sample_bbox = clip_bbox(sample_bbox) - crop_bbox, crop_class, crop_score = filter_and_process( - sample_bbox, gt_bbox, gt_class, gt_score) + if 'gt_keypoint' in sample.keys(): + keypoints = (sample['gt_keypoint'], + sample['keypoint_ignore']) + crop_bbox, crop_class, crop_score, gt_keypoints = \ + filter_and_process(sample_bbox, gt_bbox, gt_class, + scores=gt_score, + keypoints=keypoints) + else: + crop_bbox, crop_class, crop_score = filter_and_process( + sample_bbox, gt_bbox, gt_class, scores=gt_score) # sampling bbox according the bbox area crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score, self.target_size, @@ -966,6 +1128,9 @@ class CropImageWithDataAchorSampling(BaseOperator): sample['gt_bbox'] = crop_bbox sample['gt_class'] = crop_class sample['gt_score'] = crop_score + if 'gt_keypoint' in sample.keys(): + sample['gt_keypoint'] = gt_keypoints[0] + sample['keypoint_ignore'] = gt_keypoints[1] return sample return sample @@ -987,6 +1152,17 @@ class NormalizeBox(BaseOperator): gt_bbox[i][2] = gt_bbox[i][2] / width gt_bbox[i][3] = gt_bbox[i][3] / height sample['gt_bbox'] = gt_bbox + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] / height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] / width + sample['gt_keypoint'] = gt_keypoint + return sample @@ -998,7 +1174,6 @@ class Permute(BaseOperator): Args: to_bgr (bool): confirm whether to convert RGB to BGR channel_first (bool): confirm whether to change channel - """ super(Permute, self).__init__() self.to_bgr = to_bgr @@ -1094,6 +1269,84 @@ class MixupImage(BaseOperator): return sample +@register_op +class CutmixImage(BaseOperator): + def __init__(self, alpha=1.5, beta=1.5): + """ + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://https://arxiv.org/abs/1905.04899 + Cutmix image and gt_bbbox/gt_score + Args: + alpha (float): alpha parameter of beta distribute + beta (float): beta parameter of beta distribute + """ + super(CutmixImage, self).__init__() + self.alpha = alpha + self.beta = beta + if self.alpha <= 0.0: + raise ValueError("alpha shold be positive in {}".format(self)) + if self.beta <= 0.0: + raise ValueError("beta shold be positive in {}".format(self)) + + def _rand_bbox(self, img1, img2, factor): + """ _rand_bbox """ + h = max(img1.shape[0], img2.shape[0]) + w = max(img1.shape[1], img2.shape[1]) + cut_rat = np.sqrt(1. - factor) + + cut_w = np.int(w * cut_rat) + cut_h = np.int(h * cut_rat) + + # uniform + cx = np.random.randint(w) + cy = np.random.randint(h) + + bbx1 = np.clip(cx - cut_w // 2, 0, w) + bby1 = np.clip(cy - cut_h // 2, 0, h) + bbx2 = np.clip(cx + cut_w // 2, 0, w) + bby2 = np.clip(cy + cut_h // 2, 0, h) + + img_1 = np.zeros((h, w, img1.shape[2]), 'float32') + img_1[:img1.shape[0], :img1.shape[1], :] = \ + img1.astype('float32') + img_2 = np.zeros((h, w, img2.shape[2]), 'float32') + img_2[:img2.shape[0], :img2.shape[1], :] = \ + img2.astype('float32') + img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :] + return img_1 + + def __call__(self, sample, context=None): + if 'cutmix' not in sample: + return sample + factor = np.random.beta(self.alpha, self.beta) + factor = max(0.0, min(1.0, factor)) + if factor >= 1.0: + sample.pop('cutmix') + return sample + if factor <= 0.0: + return sample['cutmix'] + img1 = sample['image'] + img2 = sample['cutmix']['image'] + img = self._rand_bbox(img1, img2, factor) + gt_bbox1 = sample['gt_bbox'] + gt_bbox2 = sample['cutmix']['gt_bbox'] + gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0) + gt_class1 = sample['gt_class'] + gt_class2 = sample['cutmix']['gt_class'] + gt_class = np.concatenate((gt_class1, gt_class2), axis=0) + gt_score1 = sample['gt_score'] + gt_score2 = sample['cutmix']['gt_score'] + gt_score = np.concatenate( + (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) + sample['image'] = img + sample['gt_bbox'] = gt_bbox + sample['gt_score'] = gt_score + sample['gt_class'] = gt_class + sample['h'] = img.shape[0] + sample['w'] = img.shape[1] + sample.pop('cutmix') + return sample + + @register_op class RandomInterpImage(BaseOperator): def __init__(self, target_size=0, max_size=0): @@ -1129,7 +1382,6 @@ class RandomInterpImage(BaseOperator): @register_op class Resize(BaseOperator): """Resize image and bbox. - Args: target_dim (int or list): target size, can be a single number or a list (for random shape). @@ -1162,6 +1414,7 @@ class Resize(BaseOperator): scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0, dim - 1) + sample['scale_factor'] = [scale_x, scale_y] * 2 sample['h'] = resize_h sample['w'] = resize_w @@ -1173,7 +1426,6 @@ class Resize(BaseOperator): @register_op class ColorDistort(BaseOperator): """Random color distortion. - Args: hue (list): hue settings. in [lower, upper, probability] format. @@ -1185,6 +1437,8 @@ class ColorDistort(BaseOperator): in [lower, upper, probability] format. random_apply (bool): whether to apply in random (yolo) or fixed (SSD) order. + hsv_format (bool): whether to convert color from BGR to HSV + random_channel (bool): whether to swap channels randomly """ def __init__(self, @@ -1192,13 +1446,17 @@ class ColorDistort(BaseOperator): saturation=[0.5, 1.5, 0.5], contrast=[0.5, 1.5, 0.5], brightness=[0.5, 1.5, 0.5], - random_apply=True): + random_apply=True, + hsv_format=False, + random_channel=False): super(ColorDistort, self).__init__() self.hue = hue self.saturation = saturation self.contrast = contrast self.brightness = brightness self.random_apply = random_apply + self.hsv_format = hsv_format + self.random_channel = random_channel def apply_hue(self, img): low, high, prob = self.hue @@ -1206,6 +1464,11 @@ class ColorDistort(BaseOperator): return img img = img.astype(np.float32) + if self.hsv_format: + img[..., 0] += random.uniform(low, high) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + return img # XXX works, but result differ from HSV version delta = np.random.uniform(low, high) @@ -1225,8 +1488,10 @@ class ColorDistort(BaseOperator): if np.random.uniform(0., 1.) < prob: return img delta = np.random.uniform(low, high) - img = img.astype(np.float32) + if self.hsv_format: + img[..., 1] *= delta + return img gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) gray = gray.sum(axis=2, keepdims=True) gray *= (1.0 - delta) @@ -1273,12 +1538,24 @@ class ColorDistort(BaseOperator): if np.random.randint(0, 2): img = self.apply_contrast(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img = self.apply_saturation(img) img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) else: + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img = self.apply_saturation(img) img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) img = self.apply_contrast(img) + + if self.random_channel: + if np.random.randint(0, 2): + img = img[..., np.random.permutation(3)] sample['image'] = img return sample @@ -1346,7 +1623,6 @@ class CornerRandColor(ColorDistort): @register_op class NormalizePermute(BaseOperator): """Normalize and permute channel order. - Args: mean (list): mean values in RGB order. std (list): std values in RGB order. @@ -1376,7 +1652,6 @@ class NormalizePermute(BaseOperator): @register_op class RandomExpand(BaseOperator): """Random expand the canvas. - Args: ratio (float): maximum expansion ratio. prob (float): probability to expand. @@ -1468,7 +1743,6 @@ class RandomExpand(BaseOperator): @register_op class RandomCrop(BaseOperator): """Random crop image and bboxes. - Args: aspect_ratio (list): aspect ratio of cropped region. in [min, max] format. @@ -1595,11 +1869,23 @@ class RandomCrop(BaseOperator): found = False for i in range(self.num_attempts): scale = np.random.uniform(*self.scaling) - min_ar, max_ar = self.aspect_ratio - aspect_ratio = np.random.uniform( - max(min_ar, scale**2), min(max_ar, scale**-2)) - crop_h = int(h * scale / np.sqrt(aspect_ratio)) - crop_w = int(w * scale * np.sqrt(aspect_ratio)) + if self.aspect_ratio is not None: + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + h_scale = scale / np.sqrt(aspect_ratio) + w_scale = scale * np.sqrt(aspect_ratio) + else: + h_scale = np.random.uniform(*self.scaling) + w_scale = np.random.uniform(*self.scaling) + crop_h = h * h_scale + crop_w = w * w_scale + if self.aspect_ratio is None: + if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0: + continue + + crop_h = int(crop_h) + crop_w = int(crop_w) crop_y = np.random.randint(0, h - crop_h) crop_x = np.random.randint(0, w - crop_w) crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] @@ -1751,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator): return sample -@register_op class Lighting(BaseOperator): """ Lighting the imagen by eigenvalues and eigenvectors @@ -1991,7 +2276,6 @@ class CornerRatio(BaseOperator): class RandomScaledCrop(BaseOperator): """Resize image and bbox based on long side (with optional random scaling), then crop or pad image to target size. - Args: target_dim (int): target size. scale_range (list): random scale range. @@ -2046,7 +2330,6 @@ class RandomScaledCrop(BaseOperator): @register_op class ResizeAndPad(BaseOperator): """Resize image and bbox, then pad image to target size. - Args: target_dim (int): target size interp (int): interpolation method, default to `cv2.INTER_LINEAR`. @@ -2085,7 +2368,6 @@ class ResizeAndPad(BaseOperator): @register_op class TargetAssign(BaseOperator): """Assign regression target and labels. - Args: image_size (int or list): input image size, a single integer or list of [h, w]. Default: 512 @@ -2184,3 +2466,69 @@ class TargetAssign(BaseOperator): targets[matched_indices] = matched_targets sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32) return sample + + +@register_op +class DebugVisibleImage(BaseOperator): + """ + In debug mode, visualize images according to `gt_box`. + (Currently only supported when not cropping and flipping image.) + """ + + def __init__(self, output_dir='output/debug', is_normalized=False): + super(DebugVisibleImage, self).__init__() + self.is_normalized = is_normalized + self.output_dir = output_dir + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + if not isinstance(self.is_normalized, bool): + raise TypeError("{}: input type is invalid.".format(self)) + + def __call__(self, sample, context=None): + image = Image.open(sample['im_file']).convert('RGB') + out_file_name = sample['im_file'].split('/')[-1] + width = sample['w'] + height = sample['h'] + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + draw = ImageDraw.Draw(image) + for i in range(gt_bbox.shape[0]): + if self.is_normalized: + gt_bbox[i][0] = gt_bbox[i][0] * width + gt_bbox[i][1] = gt_bbox[i][1] * height + gt_bbox[i][2] = gt_bbox[i][2] * width + gt_bbox[i][3] = gt_bbox[i][3] * height + + xmin, ymin, xmax, ymax = gt_bbox[i] + draw.line( + [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), + (xmin, ymin)], + width=2, + fill='green') + # draw label + text = str(gt_class[i][0]) + tw, th = draw.textsize(text) + draw.rectangle( + [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green') + draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255)) + + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] + if self.is_normalized: + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] * height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] * width + for i in range(gt_keypoint.shape[0]): + keypoint = gt_keypoint[i] + for j in range(int(keypoint.shape[0] / 2)): + x1 = round(keypoint[2 * j]).astype(np.int32) + y1 = round(keypoint[2 * j + 1]).astype(np.int32) + draw.ellipse( + (x1, y1, x1 + 5, y1i + 5), + fill='green', + outline='green') + save_path = os.path.join(self.output_dir, out_file_name) + image.save(save_path, quality=95) + return sample -- GitLab