提交 5433fc46 编写于 作者: C Channingss

fix some bug

上级 d1b71369
...@@ -125,7 +125,6 @@ def fix_input_shape(info, fixed_input_shape=None): ...@@ -125,7 +125,6 @@ def fix_input_shape(info, fixed_input_shape=None):
logging.warning( logging.warning(
"fixed_input_shape must == input shape when trainning") "fixed_input_shape must == input shape when trainning")
else: else:
print("*" * 10)
resize['ResizeByShort']['short_size'] = min(fixed_input_shape) resize['ResizeByShort']['short_size'] = min(fixed_input_shape)
resize['ResizeByShort']['max_size'] = max(fixed_input_shape) resize['ResizeByShort']['max_size'] = max(fixed_input_shape)
padding['Padding']['target_size'] = list(fixed_input_shape) padding['Padding']['target_size'] = list(fixed_input_shape)
......
...@@ -208,10 +208,10 @@ class Padding: ...@@ -208,10 +208,10 @@ class Padding:
Args: Args:
coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。 coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
target_size (int|list): 填充后的图像长、宽,默认为1 target_size (int|list): 填充后的图像长、宽,默认为None
""" """
def __init__(self, coarsest_stride=1, target_size=1): def __init__(self, coarsest_stride=1, target_size=None):
self.coarsest_stride = coarsest_stride self.coarsest_stride = coarsest_stride
self.target_size = target_size self.target_size = target_size
...@@ -230,11 +230,11 @@ class Padding: ...@@ -230,11 +230,11 @@ class Padding:
Raises: Raises:
TypeError: 形参数据类型不满足需求。 TypeError: 形参数据类型不满足需求。
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
ValueError: coarsest_stride,target_size需有且只有一个被指定,coarset_stride优先级更高。
ValueError: target_size小于原图的大小。 ValueError: target_size小于原图的大小。
""" """
if self.coarsest_stride == 1: if self.coarsest_stride == 1 and self.target_size is None:
if isinstance(self.target_size, int) and self.target_size == 1:
if label_info is None: if label_info is None:
return (im, im_info) return (im, im_info)
else: else:
...@@ -251,13 +251,16 @@ class Padding: ...@@ -251,13 +251,16 @@ class Padding:
np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
padding_im_w = int( padding_im_w = int(
np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride) np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
elif isinstance(self.target_size, int):
if isinstance(self.target_size, int) and self.target_size != 1:
padding_im_h = self.target_size padding_im_h = self.target_size
padding_im_w = self.target_size padding_im_w = self.target_size
elif isinstance(self.target_size, list): elif isinstance(self.target_size, list):
padding_im_w = self.target_size[0] padding_im_w = self.target_size[0]
padding_im_h = self.target_size[1] padding_im_h = self.target_size[1]
else:
raise ValueError(
"coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
)
pad_height = padding_im_h - im_h pad_height = padding_im_h - im_h
pad_width = padding_im_w - im_w pad_width = padding_im_w - im_w
if pad_height < 0 or pad_width < 0: if pad_height < 0 or pad_width < 0:
......
...@@ -287,6 +287,7 @@ class ResizeByLong: ...@@ -287,6 +287,7 @@ class ResizeByLong:
else: else:
return (im, im_info, label) return (im, im_info, label)
class ResizeByShort: class ResizeByShort:
"""根据图像的短边调整图像大小(resize)。 """根据图像的短边调整图像大小(resize)。
...@@ -315,12 +316,12 @@ class ResizeByShort: ...@@ -315,12 +316,12 @@ class ResizeByShort:
if not (isinstance(self.max_size, int)): if not (isinstance(self.max_size, int)):
raise TypeError("max_size: input type is invalid.") raise TypeError("max_size: input type is invalid.")
def __call__(self, im, im_info=None, label_info=None): def __call__(self, im, im_info=None, label=None):
""" """
Args: Args:
im (numnp.ndarraypy): 图像np.ndarray数据。 im (numnp.ndarraypy): 图像np.ndarray数据。
im_info (dict, 可选): 存储与图像相关的信息。 im_info (dict, 可选): 存储与图像相关的信息。
label_info (dict, 可选): 存储与标注框相关的信息 label (np.ndarray): 标注图像np.ndarray数据
Returns: Returns:
tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典; tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
...@@ -335,11 +336,12 @@ class ResizeByShort: ...@@ -335,11 +336,12 @@ class ResizeByShort:
ValueError: 数据长度不匹配。 ValueError: 数据长度不匹配。
""" """
if im_info is None: if im_info is None:
im_info = dict() im_info = OrderedDict()
if not isinstance(im, np.ndarray): if not isinstance(im, np.ndarray):
raise TypeError("ResizeByShort: image type is not numpy.") raise TypeError("ResizeByShort: image type is not numpy.")
if len(im.shape) != 3: if len(im.shape) != 3:
raise ValueError('ResizeByShort: image is not 3-dimensional.') raise ValueError('ResizeByShort: image is not 3-dimensional.')
im_info['shape_before_resize'] = im.shape[:2]
im_short_size = min(im.shape[0], im.shape[1]) im_short_size = min(im.shape[0], im.shape[1])
im_long_size = max(im.shape[0], im.shape[1]) im_long_size = max(im.shape[0], im.shape[1])
scale = float(self.short_size) / im_short_size scale = float(self.short_size) / im_short_size
...@@ -348,15 +350,18 @@ class ResizeByShort: ...@@ -348,15 +350,18 @@ class ResizeByShort:
scale = float(self.max_size) / float(im_long_size) scale = float(self.max_size) / float(im_long_size)
resized_width = int(round(im.shape[1] * scale)) resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale)) resized_height = int(round(im.shape[0] * scale))
im_resize_info = [resized_height, resized_width, scale]
im = cv2.resize( im = cv2.resize(
im, (resized_width, resized_height), im, (resized_width, resized_height),
interpolation=cv2.INTER_LINEAR) interpolation=cv2.INTER_NEAREST)
im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32) if label is not None:
if label_info is None: im = cv2.resize(
label, (resized_width, resized_height),
interpolation=cv2.INTER_NEAREST)
if label is None:
return (im, im_info) return (im, im_info)
else: else:
return (im, im_info, label_info) return (im, im_info, label)
class ResizeRangeScaling: class ResizeRangeScaling:
"""对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。 """对图像长边随机resize到指定范围内,短边按比例进行缩放。当存在标注图像时,则同步进行处理。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册