diff --git a/dygraph/infer.py b/dygraph/infer.py index 1cc15d319f09e86693eb35006fd6d7efc3f5becc..f5caf7a435d3083f7d84106024096684a9d4f3b8 100644 --- a/dygraph/infer.py +++ b/dygraph/infer.py @@ -37,12 +37,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for testing, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for testing, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet') diff --git a/dygraph/models/hrnet.py b/dygraph/models/hrnet.py index fac8a929be40acce2d801c3cdbbe89bb634bead3..bccc303bb435e48554991a21b4fd72dd90a3cb37 100644 --- a/dygraph/models/hrnet.py +++ b/dygraph/models/hrnet.py @@ -18,7 +18,8 @@ import paddle import paddle.fluid as fluid from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm __all__ = [ "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", "HRNet_W18", "HRNet_W30", diff --git a/dygraph/models/unet.py b/dygraph/models/unet.py index b55e3614b6988a0102eb3e6f17093e59673eae70..6e04c8b2f17aeca763dc9653b6b2da73835979c7 100644 --- a/dygraph/models/unet.py +++ b/dygraph/models/unet.py @@ -13,7 +13,8 @@ # limitations under the License. import paddle.fluid as fluid -from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D +from paddle.fluid.dygraph import Conv2D, Pool2D +from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm class UNet(fluid.dygraph.Layer): diff --git a/dygraph/train.py b/dygraph/train.py index 8573591e25f2964610bd3da33b224c52bcfe1da9..709a66bb8c7f55dac0a83e5435a42893eb4d2e9a 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -38,12 +38,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for training, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for training, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet') @@ -186,6 +182,7 @@ def train(model, total_steps = steps_per_epoch * (num_epochs - start_epoch) num_steps = 0 best_mean_iou = -1.0 + best_model_epoch = -1 for epoch in range(start_epoch, num_epochs): for step, data in enumerate(loader): images = data[0] @@ -245,9 +242,9 @@ def train(model, best_model_dir = os.path.join(save_dir, "best_model") fluid.save_dygraph(model.state_dict(), os.path.join(best_model_dir, 'model')) - logging.info( - 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' - .format(best_model_epoch, best_mean_iou)) + logging.info( + 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' + .format(best_model_epoch, best_mean_iou)) if use_vdl: log_writer.add_scalar('Evaluate/mean_iou', mean_iou, diff --git a/dygraph/transforms/transforms.py b/dygraph/transforms/transforms.py index 38c3be18a2ae885bfa6238304a614935401a6330..f2b24fbad48b53930d4ba1b16a9a08ee6ae3c10b 100644 --- a/dygraph/transforms/transforms.py +++ b/dygraph/transforms/transforms.py @@ -1,3 +1,4 @@ +# coding: utf8 # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -33,6 +34,7 @@ class Compose: ValueError: transforms元素个数小于1。 """ + def __init__(self, transforms, to_rgb=True): if not isinstance(transforms, list): raise TypeError('The transforms must be a list!') @@ -86,6 +88,7 @@ class RandomHorizontalFlip: prob (float): 随机水平翻转的概率。默认值为0.5。 """ + def __init__(self, prob=0.5): self.prob = prob @@ -117,6 +120,7 @@ class RandomVerticalFlip: Args: prob (float): 随机垂直翻转的概率。默认值为0.1。 """ + def __init__(self, prob=0.1): self.prob = prob @@ -233,6 +237,7 @@ class ResizeByLong: Args: long_size (int): resize后图像的长边大小。 """ + def __init__(self, long_size): self.long_size = long_size @@ -274,6 +279,7 @@ class ResizeRangeScaling: Raises: ValueError: min_value大于max_value """ + def __init__(self, min_value=400, max_value=600): if min_value > max_value: raise ValueError('min_value must be less than max_value, ' @@ -321,6 +327,7 @@ class ResizeStepScaling: Raises: ValueError: min_scale_factor大于max_scale_factor """ + def __init__(self, min_scale_factor=0.75, max_scale_factor=1.25, @@ -386,6 +393,7 @@ class Normalize: Raises: ValueError: mean或std不是list对象。std包含0。 """ + def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): self.mean = mean self.std = std @@ -431,6 +439,7 @@ class Padding: TypeError: target_size不是int|list|tuple。 ValueError: target_size为list|tuple时元素个数不等于2。 """ + def __init__(self, target_size, im_padding_value=[127.5, 127.5, 127.5], @@ -483,21 +492,23 @@ class Padding: 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' .format(im_width, im_height, target_width, target_height)) else: - im = cv2.copyMakeBorder(im, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.im_padding_value) + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) if label is not None: - label = cv2.copyMakeBorder(label, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.label_padding_value) + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) if label is None: return (im, im_info) else: @@ -516,6 +527,7 @@ class RandomPaddingCrop: TypeError: crop_size不是int/list/tuple。 ValueError: target_size为list/tuple时元素个数不等于2。 """ + def __init__(self, crop_size=512, im_padding_value=[127.5, 127.5, 127.5], @@ -564,21 +576,23 @@ class RandomPaddingCrop: pad_height = max(crop_height - img_height, 0) pad_width = max(crop_width - img_width, 0) if (pad_height > 0 or pad_width > 0): - im = cv2.copyMakeBorder(im, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.im_padding_value) + im = cv2.copyMakeBorder( + im, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.im_padding_value) if label is not None: - label = cv2.copyMakeBorder(label, - 0, - pad_height, - 0, - pad_width, - cv2.BORDER_CONSTANT, - value=self.label_padding_value) + label = cv2.copyMakeBorder( + label, + 0, + pad_height, + 0, + pad_width, + cv2.BORDER_CONSTANT, + value=self.label_padding_value) img_height = im.shape[0] img_width = im.shape[1] @@ -586,11 +600,11 @@ class RandomPaddingCrop: h_off = np.random.randint(img_height - crop_height + 1) w_off = np.random.randint(img_width - crop_width + 1) - im = im[h_off:(crop_height + h_off), w_off:(w_off + - crop_width), :] + im = im[h_off:(crop_height + h_off), w_off:( + w_off + crop_width), :] if label is not None: - label = label[h_off:(crop_height + - h_off), w_off:(w_off + crop_width)] + label = label[h_off:(crop_height + h_off), w_off:( + w_off + crop_width)] if label is None: return (im, im_info) else: @@ -603,6 +617,7 @@ class RandomBlur: Args: prob (float): 图像模糊概率。默认为0.1。 """ + def __init__(self, prob=0.1): self.prob = prob @@ -650,6 +665,7 @@ class RandomRotation: label_padding_value (int): 标注图像padding的值。默认为255。 """ + def __init__(self, max_rotation=15, im_padding_value=[127.5, 127.5, 127.5], @@ -686,18 +702,20 @@ class RandomRotation: r[0, 2] += (nw / 2) - cx r[1, 2] += (nh / 2) - cy dsize = (nw, nh) - im = cv2.warpAffine(im, - r, - dsize=dsize, - flags=cv2.INTER_LINEAR, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.im_padding_value) - label = cv2.warpAffine(label, - r, - dsize=dsize, - flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, - borderValue=self.label_padding_value) + im = cv2.warpAffine( + im, + r, + dsize=dsize, + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.im_padding_value) + label = cv2.warpAffine( + label, + r, + dsize=dsize, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=self.label_padding_value) if label is None: return (im, im_info) @@ -713,6 +731,7 @@ class RandomScaleAspect: min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 """ + def __init__(self, min_scale=0.5, aspect_ratio=0.33): self.min_scale = min_scale self.aspect_ratio = aspect_ratio @@ -751,10 +770,12 @@ class RandomScaleAspect: im = im[h1:(h1 + dh), w1:(w1 + dw), :] label = label[h1:(h1 + dh), w1:(w1 + dw)] - im = cv2.resize(im, (img_width, img_height), - interpolation=cv2.INTER_LINEAR) - label = cv2.resize(label, (img_width, img_height), - interpolation=cv2.INTER_NEAREST) + im = cv2.resize( + im, (img_width, img_height), + interpolation=cv2.INTER_LINEAR) + label = cv2.resize( + label, (img_width, img_height), + interpolation=cv2.INTER_NEAREST) break if label is None: return (im, im_info) @@ -778,6 +799,7 @@ class RandomDistort: hue_range (int): 色调因子的范围。默认为18。 hue_prob (float): 随机调整色调的概率。默认为0.5。 """ + def __init__(self, brightness_range=0.5, brightness_prob=0.5, diff --git a/dygraph/val.py b/dygraph/val.py index 36d4242966f0e98e381130c533d67c46b31aefe1..41d0d33485d1052bef3b1c4d70b546cdf89d3922 100644 --- a/dygraph/val.py +++ b/dygraph/val.py @@ -39,12 +39,8 @@ def parse_args(): parser.add_argument( '--model_name', dest='model_name', - help= - 'Model type for evaluation, which is one of ("UNet", "HRNet_W18_Small_V1", "HRNet_W18_Small_V2", ' - '"HRNet_W18", "HRNet_W30", "HRNet_W32", "HRNet_W40", "HRNet_W44", "HRNet_W48", ' - '"HRNet_W60", "HRNet_W64", "SE_HRNet_W18_Small_V1", "SE_HRNet_W18_Small_V2", "SE_HRNet_W18", ' - '"SE_HRNet_W30", "SE_HRNet_W32", "SE_HRNet_W40","SE_HRNet_W44", "SE_HRNet_W48", ' - '"SE_HRNet_W60", "SE_HRNet_W64")', + help='Model type for evaluation, which is one of {}'.format( + str(list(MODELS.keys()))), type=str, default='UNet')