提交 a70e2a55 编写于 作者: C chenguowei01

update im_info format

上级 3b571aec
......@@ -107,14 +107,16 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
......
......@@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .functional import *
import random
from collections import OrderedDict
import numpy as np
from PIL import Image
import cv2
from collections import OrderedDict
from .functional import *
class Compose:
......@@ -33,6 +35,7 @@ class Compose:
ValueError: transforms元素个数小于1。
"""
def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!')
......@@ -56,7 +59,7 @@ class Compose:
"""
if im_info is None:
im_info = dict()
im_info = list()
if isinstance(im, str):
im = cv2.imread(im).astype('float32')
if isinstance(label, str):
......@@ -86,6 +89,7 @@ class RandomHorizontalFlip:
prob (float): 随机水平翻转的概率。默认值为0.5。
"""
def __init__(self, prob=0.5):
self.prob = prob
......@@ -117,6 +121,7 @@ class RandomVerticalFlip:
Args:
prob (float): 随机垂直翻转的概率。默认值为0.1。
"""
def __init__(self, prob=0.1):
self.prob = prob
......@@ -207,8 +212,8 @@ class Resize:
ValueError: 数据长度不匹配。
"""
if im_info is None:
im_info = OrderedDict()
im_info['shape_before_resize'] = im.shape[:2]
im_info = list()
im_info.append(('resize', im.shape[:2]))
if not isinstance(im, np.ndarray):
raise TypeError("Resize: image type is not numpy.")
if len(im.shape) != 3:
......@@ -233,6 +238,7 @@ class ResizeByLong:
Args:
long_size (int): resize后图像的长边大小。
"""
def __init__(self, long_size):
self.long_size = long_size
......@@ -251,9 +257,9 @@ class ResizeByLong:
-shape_before_resize (tuple): 保存resize之前图像的形状(h, w)。
"""
if im_info is None:
im_info = OrderedDict()
im_info = list()
im_info['shape_before_resize'] = im.shape[:2]
im_info.append(('resize', im.shape[:2]))
im = resize_long(im, self.long_size)
if label is not None:
label = resize_long(label, self.long_size, cv2.INTER_NEAREST)
......@@ -265,7 +271,7 @@ class ResizeByLong:
class ResizeRangeScaling:
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。一般用于训练
Args:
min_value (int): 图像长边resize后的最小值。默认值400。
......@@ -274,6 +280,7 @@ class ResizeRangeScaling:
Raises:
ValueError: min_value大于max_value
"""
def __init__(self, min_value=400, max_value=600):
if min_value > max_value:
raise ValueError('min_value must be less than max_value, '
......@@ -311,7 +318,7 @@ class ResizeRangeScaling:
class ResizeStepScaling:
"""对图像按照某一个比例resize,这个比例以scale_step_size为步长
在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。
在[min_scale_factor, max_scale_factor]随机变动。当存在标注图像时,则同步进行处理。一般用于训练
Args:
min_scale_factor(float), resize最小尺度。默认值0.75。
......@@ -321,6 +328,7 @@ class ResizeStepScaling:
Raises:
ValueError: min_scale_factor大于max_scale_factor
"""
def __init__(self,
min_scale_factor=0.75,
max_scale_factor=1.25,
......@@ -386,6 +394,7 @@ class Normalize:
Raises:
ValueError: mean或std不是list对象。std包含0。
"""
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean
self.std = std
......@@ -431,6 +440,7 @@ class Padding:
TypeError: target_size不是int|list|tuple。
ValueError: target_size为list|tuple时元素个数不等于2。
"""
def __init__(self,
target_size,
im_padding_value=[127.5, 127.5, 127.5],
......@@ -466,8 +476,8 @@ class Padding:
ValueError: 输入图像im或label的形状大于目标值
"""
if im_info is None:
im_info = OrderedDict()
im_info['shape_before_padding'] = im.shape[:2]
im_info = list()
im_info.append(('padding', im.shape[:2]))
im_height, im_width = im.shape[0], im.shape[1]
if isinstance(self.target_size, int):
......@@ -483,21 +493,23 @@ class Padding:
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_width, im_height, target_width, target_height))
else:
im = cv2.copyMakeBorder(im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
im = cv2.copyMakeBorder(
im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if label is not None:
label = cv2.copyMakeBorder(label,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.label_padding_value)
label = cv2.copyMakeBorder(
label,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.label_padding_value)
if label is None:
return (im, im_info)
else:
......@@ -516,6 +528,7 @@ class RandomPaddingCrop:
TypeError: crop_size不是int/list/tuple。
ValueError: target_size为list/tuple时元素个数不等于2。
"""
def __init__(self,
crop_size=512,
im_padding_value=[127.5, 127.5, 127.5],
......@@ -564,21 +577,23 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
im = cv2.copyMakeBorder(
im,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.im_padding_value)
if label is not None:
label = cv2.copyMakeBorder(label,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.label_padding_value)
label = cv2.copyMakeBorder(
label,
0,
pad_height,
0,
pad_width,
cv2.BORDER_CONSTANT,
value=self.label_padding_value)
img_height = im.shape[0]
img_width = im.shape[1]
......@@ -586,11 +601,11 @@ class RandomPaddingCrop:
h_off = np.random.randint(img_height - crop_height + 1)
w_off = np.random.randint(img_width - crop_width + 1)
im = im[h_off:(crop_height + h_off), w_off:(w_off +
crop_width), :]
im = im[h_off:(crop_height + h_off), w_off:(
w_off + crop_width), :]
if label is not None:
label = label[h_off:(crop_height +
h_off), w_off:(w_off + crop_width)]
label = label[h_off:(crop_height + h_off), w_off:(
w_off + crop_width)]
if label is None:
return (im, im_info)
else:
......@@ -603,6 +618,7 @@ class RandomBlur:
Args:
prob (float): 图像模糊概率。默认为0.1。
"""
def __init__(self, prob=0.1):
self.prob = prob
......@@ -650,6 +666,7 @@ class RandomRotation:
label_padding_value (int): 标注图像padding的值。默认为255。
"""
def __init__(self,
max_rotation=15,
im_padding_value=[127.5, 127.5, 127.5],
......@@ -686,18 +703,20 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy
dsize = (nw, nh)
im = cv2.warpAffine(im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
label = cv2.warpAffine(label,
r,
dsize=dsize,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.label_padding_value)
im = cv2.warpAffine(
im,
r,
dsize=dsize,
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value)
label = cv2.warpAffine(
label,
r,
dsize=dsize,
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.label_padding_value)
if label is None:
return (im, im_info)
......@@ -713,6 +732,7 @@ class RandomScaleAspect:
min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
"""
def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale
self.aspect_ratio = aspect_ratio
......@@ -751,10 +771,12 @@ class RandomScaleAspect:
im = im[h1:(h1 + dh), w1:(w1 + dw), :]
label = label[h1:(h1 + dh), w1:(w1 + dw)]
im = cv2.resize(im, (img_width, img_height),
interpolation=cv2.INTER_LINEAR)
label = cv2.resize(label, (img_width, img_height),
interpolation=cv2.INTER_NEAREST)
im = cv2.resize(
im, (img_width, img_height),
interpolation=cv2.INTER_LINEAR)
label = cv2.resize(
label, (img_width, img_height),
interpolation=cv2.INTER_NEAREST)
break
if label is None:
return (im, im_info)
......@@ -778,6 +800,7 @@ class RandomDistort:
hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。
"""
def __init__(self,
brightness_range=0.5,
brightness_prob=0.5,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册