提交 a70e2a55 编写于 作者: C chenguowei01

update im_info format

上级 3b571aec
...@@ -107,14 +107,16 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): ...@@ -107,14 +107,16 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
pred, _ = model(im, mode='test') pred, _ = model(im, mode='test')
pred = pred.numpy() pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys()) for info in im_info[::-1]:
for k in keys[::-1]: if info[0] == 'resize':
if k == 'shape_before_resize': h, w = info[1][0], info[1][1]
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding': elif info[0] == 'padding':
h, w = im_info[k][0], im_info[k][1] h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '') im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/': if im_file[0] == '/':
......
...@@ -12,12 +12,14 @@ ...@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .functional import *
import random import random
from collections import OrderedDict
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import cv2 import cv2
from collections import OrderedDict
from .functional import *
class Compose: class Compose:
...@@ -33,6 +35,7 @@ class Compose: ...@@ -33,6 +35,7 @@ class Compose:
ValueError: transforms元素个数小于1。 ValueError: transforms元素个数小于1。
""" """
def __init__(self, transforms, to_rgb=True): def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
...@@ -56,7 +59,7 @@ class Compose: ...@@ -56,7 +59,7 @@ class Compose:
""" """
if im_info is None: if im_info is None:
im_info = dict() im_info = list()
if isinstance(im, str): if isinstance(im, str):
im = cv2.imread(im).astype('float32') im = cv2.imread(im).astype('float32')
if isinstance(label, str): if isinstance(label, str):
...@@ -86,6 +89,7 @@ class RandomHorizontalFlip: ...@@ -86,6 +89,7 @@ class RandomHorizontalFlip:
prob (float): 随机水平翻转的概率。默认值为0.5。 prob (float): 随机水平翻转的概率。默认值为0.5。
""" """
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -117,6 +121,7 @@ class RandomVerticalFlip: ...@@ -117,6 +121,7 @@ class RandomVerticalFlip:
Args: Args:
prob (float): 随机垂直翻转的概率。默认值为0.1。 prob (float): 随机垂直翻转的概率。默认值为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -207,8 +212,8 @@ class Resize: ...@@ -207,8 +212,8 @@ class Resize:
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
""" """
if im_info is None: if im_info is None:
im_info = OrderedDict() im_info = list()
im_info['shape_before_resize'] = im.shape[:2] im_info.append(('resize', im.shape[:2]))
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
raise TypeError("Resize: image type is not numpy.") raise TypeError("Resize: image type is not numpy.")
if len(im.shape) != 3: if len(im.shape) != 3:
...@@ -233,6 +238,7 @@ class ResizeByLong: ...@@ -233,6 +238,7 @@ class ResizeByLong:
Args: Args:
long_size (int): resize后图像的长边大小。 long_size (int): resize后图像的长边大小。
""" """
def __init__(self, long_size): def __init__(self, long_size):
self.long_size = long_size self.long_size = long_size
...@@ -251,9 +257,9 @@ class ResizeByLong: ...@@ -251,9 +257,9 @@ class ResizeByLong:
-shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。 -shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
""" """
if im_info is None: if im_info is None:
im_info = OrderedDict() im_info = list()
im_info['shape_before_resize'] = im.shape[:2] im_info.append(('resize', im.shape[:2]))
im = resize_long(im, self.long_size) im = resize_long(im, self.long_size)
if label is not None: if label is not None:
label = resize_long(label, self.long_size, cv2.INTER_NEAREST) label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
...@@ -265,7 +271,7 @@ class ResizeByLong: ...@@ -265,7 +271,7 @@ class ResizeByLong:
class ResizeRangeScaling: class ResizeRangeScaling:
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。一般用于训练
Args: Args:
min_value (int): 图像长边resize后的最小值。默认值400。 min_value (int): 图像长边resize后的最小值。默认值400。
...@@ -274,6 +280,7 @@ class ResizeRangeScaling: ...@@ -274,6 +280,7 @@ class ResizeRangeScaling:
Raises: Raises:
ValueError: min_value大于max_value ValueError: min_value大于max_value
""" """
def __init__(self, min_value=400, max_value=600): def __init__(self, min_value=400, max_value=600):
if min_value > max_value: if min_value > max_value:
raise ValueError('min_value must be less than max_value, ' raise ValueError('min_value must be less than max_value, '
...@@ -311,7 +318,7 @@ class ResizeRangeScaling: ...@@ -311,7 +318,7 @@ class ResizeRangeScaling:
class ResizeStepScaling: class ResizeStepScaling:
"""对图像按照某一个比例resize,这个比例以scale_step_size为步长 """对图像按照某一个比例resize,这个比例以scale_step_size为步长
在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。 在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。一般用于训练
Args: Args:
min_scale_factor(float), resize最小尺度。默认值0.75。 min_scale_factor(float), resize最小尺度。默认值0.75。
...@@ -321,6 +328,7 @@ class ResizeStepScaling: ...@@ -321,6 +328,7 @@ class ResizeStepScaling:
Raises: Raises:
ValueError: min_scale_factor大于max_scale_factor ValueError: min_scale_factor大于max_scale_factor
""" """
def __init__(self, def __init__(self,
min_scale_factor=0.75, min_scale_factor=0.75,
max_scale_factor=1.25, max_scale_factor=1.25,
...@@ -386,6 +394,7 @@ class Normalize: ...@@ -386,6 +394,7 @@ class Normalize:
Raises: Raises:
ValueError: mean或std不是list对象。std包含0。 ValueError: mean或std不是list对象。std包含0。
""" """
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean self.mean = mean
self.std = std self.std = std
...@@ -431,6 +440,7 @@ class Padding: ...@@ -431,6 +440,7 @@ class Padding:
TypeError: target_size不是int|list|tuple。 TypeError: target_size不是int|list|tuple。
ValueError: target_size为list|tuple时元素个数不等于2。 ValueError: target_size为list|tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
target_size, target_size,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -466,8 +476,8 @@ class Padding: ...@@ -466,8 +476,8 @@ class Padding:
ValueError: 输入图像im或label的形状大于目标值 ValueError: 输入图像im或label的形状大于目标值
""" """
if im_info is None: if im_info is None:
im_info = OrderedDict() im_info = list()
im_info['shape_before_padding'] = im.shape[:2] im_info.append(('padding', im.shape[:2]))
im_height, im_width = im.shape[0], im.shape[1] im_height, im_width = im.shape[0], im.shape[1]
if isinstance(self.target_size, int): if isinstance(self.target_size, int):
...@@ -483,7 +493,8 @@ class Padding: ...@@ -483,7 +493,8 @@ class Padding:
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_width, im_height, target_width, target_height)) .format(im_width, im_height, target_width, target_height))
else: else:
im = cv2.copyMakeBorder(im, im = cv2.copyMakeBorder(
im,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -491,7 +502,8 @@ class Padding: ...@@ -491,7 +502,8 @@ class Padding:
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=self.im_padding_value) value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder(label, label = cv2.copyMakeBorder(
label,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -516,6 +528,7 @@ class RandomPaddingCrop: ...@@ -516,6 +528,7 @@ class RandomPaddingCrop:
TypeError: crop_size不是int/list/tuple。 TypeError: crop_size不是int/list/tuple。
ValueError: target_size为list/tuple时元素个数不等于2。 ValueError: target_size为list/tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
crop_size=512, crop_size=512,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -564,7 +577,8 @@ class RandomPaddingCrop: ...@@ -564,7 +577,8 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(im, im = cv2.copyMakeBorder(
im,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -572,7 +586,8 @@ class RandomPaddingCrop: ...@@ -572,7 +586,8 @@ class RandomPaddingCrop:
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=self.im_padding_value) value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder(label, label = cv2.copyMakeBorder(
label,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -586,11 +601,11 @@ class RandomPaddingCrop: ...@@ -586,11 +601,11 @@ class RandomPaddingCrop:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
w_off = np.random.randint(img_width - crop_width + 1) w_off = np.random.randint(img_width - crop_width + 1)
im = im[h_off:(crop_height + h_off), w_off:(w_off + im = im[h_off:(crop_height + h_off), w_off:(
crop_width), :] w_off + crop_width), :]
if label is not None: if label is not None:
label = label[h_off:(crop_height + label = label[h_off:(crop_height + h_off), w_off:(
h_off), w_off:(w_off + crop_width)] w_off + crop_width)]
if label is None: if label is None:
return (im, im_info) return (im, im_info)
else: else:
...@@ -603,6 +618,7 @@ class RandomBlur: ...@@ -603,6 +618,7 @@ class RandomBlur:
Args: Args:
prob (float): 图像模糊概率。默认为0.1。 prob (float): 图像模糊概率。默认为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -650,6 +666,7 @@ class RandomRotation: ...@@ -650,6 +666,7 @@ class RandomRotation:
label_padding_value (int): 标注图像padding的值。默认为255。 label_padding_value (int): 标注图像padding的值。默认为255。
""" """
def __init__(self, def __init__(self,
max_rotation=15, max_rotation=15,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -686,13 +703,15 @@ class RandomRotation: ...@@ -686,13 +703,15 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine(im, im = cv2.warpAffine(
im,
r, r,
dsize=dsize, dsize=dsize,
flags=cv2.INTER_LINEAR, flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value) borderValue=self.im_padding_value)
label = cv2.warpAffine(label, label = cv2.warpAffine(
label,
r, r,
dsize=dsize, dsize=dsize,
flags=cv2.INTER_NEAREST, flags=cv2.INTER_NEAREST,
...@@ -713,6 +732,7 @@ class RandomScaleAspect: ...@@ -713,6 +732,7 @@ class RandomScaleAspect:
min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
""" """
def __init__(self, min_scale=0.5, aspect_ratio=0.33): def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale self.min_scale = min_scale
self.aspect_ratio = aspect_ratio self.aspect_ratio = aspect_ratio
...@@ -751,9 +771,11 @@ class RandomScaleAspect: ...@@ -751,9 +771,11 @@ class RandomScaleAspect:
im = im[h1:(h1 + dh), w1:(w1 + dw), :] im = im[h1:(h1 + dh), w1:(w1 + dw), :]
label = label[h1:(h1 + dh), w1:(w1 + dw)] label = label[h1:(h1 + dh), w1:(w1 + dw)]
im = cv2.resize(im, (img_width, img_height), im = cv2.resize(
im, (img_width, img_height),
interpolation=cv2.INTER_LINEAR) interpolation=cv2.INTER_LINEAR)
label = cv2.resize(label, (img_width, img_height), label = cv2.resize(
label, (img_width, img_height),
interpolation=cv2.INTER_NEAREST) interpolation=cv2.INTER_NEAREST)
break break
if label is None: if label is None:
...@@ -778,6 +800,7 @@ class RandomDistort: ...@@ -778,6 +800,7 @@ class RandomDistort:
hue_range (int): 色调因子的范围。默认为18。 hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。 hue_prob (float): 随机调整色调的概率。默认为0.5。
""" """
def __init__(self, def __init__(self,
brightness_range=0.5, brightness_range=0.5,
brightness_prob=0.5, brightness_prob=0.5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册