diff --git a/dygraph/infer.py b/dygraph/infer.py index 1cc15d319f09e86693eb35006fd6d7efc3f5becc..6eaebfa94c5ba8f0733fc4b3c9474d1fa58bca49 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -107,14 +107,16 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): pred, _ = model(im, mode='test') pred = pred.numpy() pred = np.squeeze(pred).astype('uint8') - keys = list(im_info.keys()) - for k in keys[::-1]: - if k == 'shape_before_resize': - h, w = im_info[k][0], im_info[k][1] + for info in im_info[::-1]: + if info[0] == 'resize': + h, w = info[1][0], info[1][1] pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) - elif k == 'shape_before_padding': - h, w = im_info[k][0], im_info[k][1] + elif info[0] == 'padding': + h, w = info[1][0], info[1][1] pred = pred[0:h, 0:w] + else: + raise Exception("Unexpected info '{}' in im_info".format( + info[0])) im_file = im_path.replace(test_dataset.data_dir, '') if im_file[0] == '/': diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py index 38c3be18a2ae885bfa6238304a614935401a6330..44746658f6672067427feec7f1c88e5a42fa92b7 100644 --- a/dygraph/transforms/transforms.py +++ b/dygraph/transforms/transforms.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .functional import * import random +from collections import OrderedDict + import numpy as np from PIL import Image import cv2 -from collections import OrderedDict + +from .functional import * class Compose: @@ -33,6 +35,7 @@ class Compose: ValueError: transforms元素个数小于1。 """ + def __init__(self, transforms, to_rgb=True): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') @@ -56,7 +59,7 @@ class Compose: """ if im_info is None: - im_info = dict() + im_info = list() if isinstance(im, str): im = cv2.imread(im).astype('float32') if isinstance(label, str): @@ -86,6 +89,7 @@ class RandomHorizontalFlip: prob (float): 随机水平翻转的概率。默认值为0.5。 """ + def __init__(self, prob=0.5): self.prob = prob @@ -117,6 +121,7 @@ class RandomVerticalFlip: Args: prob (float): 随机垂直翻转的概率。默认值为0.1。 """ + def __init__(self, prob=0.1): self.prob = prob @@ -207,8 +212,8 @@ class Resize: ValueError: 数据长度不匹配。 """ if im_info is None: - im_info = OrderedDict() - im_info['shape_before_resize'] = im.shape[:2] + im_info = list() + im_info.append(('resize', im.shape[:2])) if not isinstance(im, np.ndarray): raise TypeError("Resize: image type is not numpy.") if len(im.shape) != 3: @@ -233,6 +238,7 @@ class ResizeByLong: Args: long_size (int): resize后图像的长边大小。 """ + def __init__(self, long_size): self.long_size = long_size @@ -251,9 +257,9 @@ class ResizeByLong: -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 """ if im_info is None: - im_info = OrderedDict() + im_info = list() - im_info['shape_before_resize'] = im.shape[:2] + im_info.append(('resize', im.shape[:2])) im = resize_long(im, self.long_size) if label is not None: label = resize_long(label, self.long_size, cv2.INTER_NEAREST) @@ -265,7 +271,7 @@ class ResizeByLong: class ResizeRangeScaling: - """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 + """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。一般用于训练 Args: min_value (int): 图像长边resize后的最小值。默认值400。 @@ -274,6 +280,7 @@ class ResizeRangeScaling: Raises: ValueError: min_value大于max_value """ + def __init__(self, min_value=400, max_value=600): if min_value > max_value: raise ValueError('min_value must be less than max_value, ' @@ -311,7 +318,7 @@ class ResizeRangeScaling: class ResizeStepScaling: """对图像按照某一个比例resize,这个比例以scale_step_size为步长 - 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 + 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。一般用于训练 Args: min_scale_factor(float), resize最小尺度。默认值0.75。 @@ -321,6 +328,7 @@ class ResizeStepScaling: Raises: ValueError: min_scale_factor大于max_scale_factor """ + def __init__(self, min_scale_factor=0.75, max_scale_factor=1.25, @@ -386,6 +394,7 @@ class Normalize: Raises: ValueError: mean或std不是list对象。std包含0。 """ + def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): self.mean = mean self.std = std @@ -431,6 +440,7 @@ class Padding: TypeError: target_size不是int|list|tuple。 ValueError: target_size为list|tuple时元素个数不等于2。 """ + def __init__(self, target_size, im_padding_value=[127.5, 127.5, 127.5], @@ -466,8 +476,8 @@ class Padding: ValueError: 输入图像im或label的形状大于目标值 """ if im_info is None: - im_info = OrderedDict() - im_info['shape_before_padding'] = im.shape[:2] + im_info = list() + im_info.append(('padding', im.shape[:2])) im_height, im_width = im.shape[0], im.shape[1] if isinstance(self.target_size, int): @@ -483,21 +493,23 @@ class Padding: 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' .format(im_width, im_height, target_width, target_height)) else: - im = cv2.copyMakeBorder(im, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.im_padding_value) + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) if label is not None: - label = cv2.copyMakeBorder(label, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.label_padding_value) + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) if label is None: return (im, im_info) else: @@ -516,6 +528,7 @@ class RandomPaddingCrop: TypeError: crop_size不是int/list/tuple。 ValueError: target_size为list/tuple时元素个数不等于2。 """ + def __init__(self, crop_size=512, im_padding_value=[127.5, 127.5, 127.5], @@ -564,21 +577,23 @@ class RandomPaddingCrop: pad_height = max(crop_height - img_height, 0) pad_width = max(crop_width - img_width, 0) if (pad_height > 0 or pad_width > 0): - im = cv2.copyMakeBorder(im, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.im_padding_value) + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) if label is not None: - label = cv2.copyMakeBorder(label, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.label_padding_value) + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) img_height = im.shape[0] img_width = im.shape[1] @@ -586,11 +601,11 @@ class RandomPaddingCrop: h_off = np.random.randint(img_height - crop_height + 1) w_off = np.random.randint(img_width - crop_width + 1) - im = im[h_off:(crop_height + h_off), w_off:(w_off + - crop_width), :] + im = im[h_off:(crop_height + h_off), w_off:( + w_off + crop_width), :] if label is not None: - label = label[h_off:(crop_height + - h_off), w_off:(w_off + crop_width)] + label = label[h_off:(crop_height + h_off), w_off:( + w_off + crop_width)] if label is None: return (im, im_info) else: @@ -603,6 +618,7 @@ class RandomBlur: Args: prob (float): 图像模糊概率。默认为0.1。 """ + def __init__(self, prob=0.1): self.prob = prob @@ -650,6 +666,7 @@ class RandomRotation: label_padding_value (int): 标注图像padding的值。默认为255。 """ + def __init__(self, max_rotation=15, im_padding_value=[127.5, 127.5, 127.5], @@ -686,18 +703,20 @@ class RandomRotation: r[0, 2] += (nw / 2) - cx r[1, 2] += (nh / 2) - cy dsize = (nw, nh) - im = cv2.warpAffine(im, - r, - dsize=dsize, - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.im_padding_value) - label = cv2.warpAffine(label, - r, - dsize=dsize, - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.label_padding_value) + im = cv2.warpAffine( + im, + r, + dsize=dsize, + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.im_padding_value) + label = cv2.warpAffine( + label, + r, + dsize=dsize, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.label_padding_value) if label is None: return (im, im_info) @@ -713,6 +732,7 @@ class RandomScaleAspect: min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 """ + def __init__(self, min_scale=0.5, aspect_ratio=0.33): self.min_scale = min_scale self.aspect_ratio = aspect_ratio @@ -751,10 +771,12 @@ class RandomScaleAspect: im = im[h1:(h1 + dh), w1:(w1 + dw), :] label = label[h1:(h1 + dh), w1:(w1 + dw)] - im = cv2.resize(im, (img_width, img_height), - interpolation=cv2.INTER_LINEAR) - label = cv2.resize(label, (img_width, img_height), - interpolation=cv2.INTER_NEAREST) + im = cv2.resize( + im, (img_width, img_height), + interpolation=cv2.INTER_LINEAR) + label = cv2.resize( + label, (img_width, img_height), + interpolation=cv2.INTER_NEAREST) break if label is None: return (im, im_info) @@ -778,6 +800,7 @@ class RandomDistort: hue_range (int): 色调因子的范围。默认为18。 hue_prob (float): 随机调整色调的概率。默认为0.5。 """ + def __init__(self, brightness_range=0.5, brightness_prob=0.5,