提交 70f3fab4 编写于 作者: C chenguowei01

Merge branch 'dygraph' of https://github.com/wuyefeilin/PaddleSeg into dygraph

...@@ -37,12 +37,8 @@ def parse_args(): ...@@ -37,12 +37,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for testing, which is one of {}'.format(
'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
......
...@@ -18,7 +18,8 @@ import paddle ...@@ -18,7 +18,8 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
__all__ = [ __all__ = [
"HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30",
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D from paddle.fluid.dygraph import Conv2D, Pool2D
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
class UNet(fluid.dygraph.Layer): class UNet(fluid.dygraph.Layer):
......
...@@ -38,12 +38,8 @@ def parse_args(): ...@@ -38,12 +38,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for training, which is one of {}'.format(
'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
...@@ -186,6 +182,7 @@ def train(model, ...@@ -186,6 +182,7 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch) total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0 num_steps = 0
best_mean_iou = -1.0 best_mean_iou = -1.0
best_model_epoch = -1
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
......
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -33,6 +34,7 @@ class Compose: ...@@ -33,6 +34,7 @@ class Compose:
ValueError: transforms元素个数小于1。 ValueError: transforms元素个数小于1。
""" """
def __init__(self, transforms, to_rgb=True): def __init__(self, transforms, to_rgb=True):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
...@@ -86,6 +88,7 @@ class RandomHorizontalFlip: ...@@ -86,6 +88,7 @@ class RandomHorizontalFlip:
prob (float): 随机水平翻转的概率。默认值为0.5。 prob (float): 随机水平翻转的概率。默认值为0.5。
""" """
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -117,6 +120,7 @@ class RandomVerticalFlip: ...@@ -117,6 +120,7 @@ class RandomVerticalFlip:
Args: Args:
prob (float): 随机垂直翻转的概率。默认值为0.1。 prob (float): 随机垂直翻转的概率。默认值为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -233,6 +237,7 @@ class ResizeByLong: ...@@ -233,6 +237,7 @@ class ResizeByLong:
Args: Args:
long_size (int): resize后图像的长边大小。 long_size (int): resize后图像的长边大小。
""" """
def __init__(self, long_size): def __init__(self, long_size):
self.long_size = long_size self.long_size = long_size
...@@ -274,6 +279,7 @@ class ResizeRangeScaling: ...@@ -274,6 +279,7 @@ class ResizeRangeScaling:
Raises: Raises:
ValueError: min_value大于max_value ValueError: min_value大于max_value
""" """
def __init__(self, min_value=400, max_value=600): def __init__(self, min_value=400, max_value=600):
if min_value > max_value: if min_value > max_value:
raise ValueError('min_value must be less than max_value, ' raise ValueError('min_value must be less than max_value, '
...@@ -321,6 +327,7 @@ class ResizeStepScaling: ...@@ -321,6 +327,7 @@ class ResizeStepScaling:
Raises: Raises:
ValueError: min_scale_factor大于max_scale_factor ValueError: min_scale_factor大于max_scale_factor
""" """
def __init__(self, def __init__(self,
min_scale_factor=0.75, min_scale_factor=0.75,
max_scale_factor=1.25, max_scale_factor=1.25,
...@@ -386,6 +393,7 @@ class Normalize: ...@@ -386,6 +393,7 @@ class Normalize:
Raises: Raises:
ValueError: mean或std不是list对象。std包含0。 ValueError: mean或std不是list对象。std包含0。
""" """
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean self.mean = mean
self.std = std self.std = std
...@@ -431,6 +439,7 @@ class Padding: ...@@ -431,6 +439,7 @@ class Padding:
TypeError: target_size不是int|list|tuple。 TypeError: target_size不是int|list|tuple。
ValueError: target_size为list|tuple时元素个数不等于2。 ValueError: target_size为list|tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
target_size, target_size,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -483,7 +492,8 @@ class Padding: ...@@ -483,7 +492,8 @@ class Padding:
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_width, im_height, target_width, target_height)) .format(im_width, im_height, target_width, target_height))
else: else:
im = cv2.copyMakeBorder(im, im = cv2.copyMakeBorder(
im,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -491,7 +501,8 @@ class Padding: ...@@ -491,7 +501,8 @@ class Padding:
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=self.im_padding_value) value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder(label, label = cv2.copyMakeBorder(
label,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -516,6 +527,7 @@ class RandomPaddingCrop: ...@@ -516,6 +527,7 @@ class RandomPaddingCrop:
TypeError: crop_size不是int/list/tuple。 TypeError: crop_size不是int/list/tuple。
ValueError: target_size为list/tuple时元素个数不等于2。 ValueError: target_size为list/tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
crop_size=512, crop_size=512,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -564,7 +576,8 @@ class RandomPaddingCrop: ...@@ -564,7 +576,8 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder(im, im = cv2.copyMakeBorder(
im,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -572,7 +585,8 @@ class RandomPaddingCrop: ...@@ -572,7 +585,8 @@ class RandomPaddingCrop:
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=self.im_padding_value) value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder(label, label = cv2.copyMakeBorder(
label,
0, 0,
pad_height, pad_height,
0, 0,
...@@ -586,11 +600,11 @@ class RandomPaddingCrop: ...@@ -586,11 +600,11 @@ class RandomPaddingCrop:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
w_off = np.random.randint(img_width - crop_width + 1) w_off = np.random.randint(img_width - crop_width + 1)
im = im[h_off:(crop_height + h_off), w_off:(w_off + im = im[h_off:(crop_height + h_off), w_off:(
crop_width), :] w_off + crop_width), :]
if label is not None: if label is not None:
label = label[h_off:(crop_height + label = label[h_off:(crop_height + h_off), w_off:(
h_off), w_off:(w_off + crop_width)] w_off + crop_width)]
if label is None: if label is None:
return (im, im_info) return (im, im_info)
else: else:
...@@ -603,6 +617,7 @@ class RandomBlur: ...@@ -603,6 +617,7 @@ class RandomBlur:
Args: Args:
prob (float): 图像模糊概率。默认为0.1。 prob (float): 图像模糊概率。默认为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -650,6 +665,7 @@ class RandomRotation: ...@@ -650,6 +665,7 @@ class RandomRotation:
label_padding_value (int): 标注图像padding的值。默认为255。 label_padding_value (int): 标注图像padding的值。默认为255。
""" """
def __init__(self, def __init__(self,
max_rotation=15, max_rotation=15,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -686,13 +702,15 @@ class RandomRotation: ...@@ -686,13 +702,15 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine(im, im = cv2.warpAffine(
im,
r, r,
dsize=dsize, dsize=dsize,
flags=cv2.INTER_LINEAR, flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderMode=cv2.BORDER_CONSTANT,
borderValue=self.im_padding_value) borderValue=self.im_padding_value)
label = cv2.warpAffine(label, label = cv2.warpAffine(
label,
r, r,
dsize=dsize, dsize=dsize,
flags=cv2.INTER_NEAREST, flags=cv2.INTER_NEAREST,
...@@ -713,6 +731,7 @@ class RandomScaleAspect: ...@@ -713,6 +731,7 @@ class RandomScaleAspect:
min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
""" """
def __init__(self, min_scale=0.5, aspect_ratio=0.33): def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale self.min_scale = min_scale
self.aspect_ratio = aspect_ratio self.aspect_ratio = aspect_ratio
...@@ -751,9 +770,11 @@ class RandomScaleAspect: ...@@ -751,9 +770,11 @@ class RandomScaleAspect:
im = im[h1:(h1 + dh), w1:(w1 + dw), :] im = im[h1:(h1 + dh), w1:(w1 + dw), :]
label = label[h1:(h1 + dh), w1:(w1 + dw)] label = label[h1:(h1 + dh), w1:(w1 + dw)]
im = cv2.resize(im, (img_width, img_height), im = cv2.resize(
im, (img_width, img_height),
interpolation=cv2.INTER_LINEAR) interpolation=cv2.INTER_LINEAR)
label = cv2.resize(label, (img_width, img_height), label = cv2.resize(
label, (img_width, img_height),
interpolation=cv2.INTER_NEAREST) interpolation=cv2.INTER_NEAREST)
break break
if label is None: if label is None:
...@@ -778,6 +799,7 @@ class RandomDistort: ...@@ -778,6 +799,7 @@ class RandomDistort:
hue_range (int): 色调因子的范围。默认为18。 hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。 hue_prob (float): 随机调整色调的概率。默认为0.5。
""" """
def __init__(self, def __init__(self,
brightness_range=0.5, brightness_range=0.5,
brightness_prob=0.5, brightness_prob=0.5,
......
...@@ -39,12 +39,8 @@ def parse_args(): ...@@ -39,12 +39,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help= help='Model type for evaluation, which is one of {}'.format(
'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' str(list(MODELS.keys()))),
'"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", '
'"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", '
'"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", '
'"SE_HRNet_W60", "SE_HRNet_W64")',
type=str, type=str,
default='UNet') default='UNet')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册