提交 9452bf66 编写于 作者: C chenguowei01

add optic_disc_seg dataset

上级 d73b5914
...@@ -13,3 +13,4 @@ ...@@ -13,3 +13,4 @@
# limitations under the License. # limitations under the License.
from .dataset import Dataset from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import random
from paddle.fluid.io import Dataset
from utils.download import download_file_and_uncompress
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
train_list=None,
val_list=None,
test_list=None,
shuffle='False',
mode='train',
transform=None,
download=True):
self.data_dir = data_dir
self.shuffle = shuffle
self.transform = transform
self.file_list = list()
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if transform is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(url=URL,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
file_list = os.paht.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
else:
if mode == 'train':
file_list = train_list
elif mode == 'eval':
file_list = val_list
else:
file_list = test_list
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
if shuffle:
random.shuffle(self.file_list)
def __getitem__(self, idx):
print(idx)
image_path, grt_path = self.file_list[idx]
return self.transform(im=image_path, label=grt_path)
def __len__(self):
return len(self.file_list)
...@@ -33,8 +33,7 @@ class Compose: ...@@ -33,8 +33,7 @@ class Compose:
ValueError: transforms元素个数小于1。 ValueError: transforms元素个数小于1。
""" """
def __init__(self, transforms, to_rgb=True):
def __init__(self, transforms, to_rgb=False):
if not isinstance(transforms, list): if not isinstance(transforms, list):
raise TypeError('The transforms must be a list!') raise TypeError('The transforms must be a list!')
if len(transforms) < 1: if len(transforms) < 1:
...@@ -87,7 +86,6 @@ class RandomHorizontalFlip: ...@@ -87,7 +86,6 @@ class RandomHorizontalFlip:
prob (float): 随机水平翻转的概率。默认值为0.5。 prob (float): 随机水平翻转的概率。默认值为0.5。
""" """
def __init__(self, prob=0.5): def __init__(self, prob=0.5):
self.prob = prob self.prob = prob
...@@ -119,7 +117,6 @@ class RandomVerticalFlip: ...@@ -119,7 +117,6 @@ class RandomVerticalFlip:
Args: Args:
prob (float): 随机垂直翻转的概率。默认值为0.1。 prob (float): 随机垂直翻转的概率。默认值为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -236,7 +233,6 @@ class ResizeByLong: ...@@ -236,7 +233,6 @@ class ResizeByLong:
Args: Args:
long_size (int): resize后图像的长边大小。 long_size (int): resize后图像的长边大小。
""" """
def __init__(self, long_size): def __init__(self, long_size):
self.long_size = long_size self.long_size = long_size
...@@ -278,7 +274,6 @@ class ResizeRangeScaling: ...@@ -278,7 +274,6 @@ class ResizeRangeScaling:
Raises: Raises:
ValueError: min_value大于max_value ValueError: min_value大于max_value
""" """
def __init__(self, min_value=400, max_value=600): def __init__(self, min_value=400, max_value=600):
if min_value > max_value: if min_value > max_value:
raise ValueError('min_value must be less than max_value, ' raise ValueError('min_value must be less than max_value, '
...@@ -326,7 +321,6 @@ class ResizeStepScaling: ...@@ -326,7 +321,6 @@ class ResizeStepScaling:
Raises: Raises:
ValueError: min_scale_factor大于max_scale_factor ValueError: min_scale_factor大于max_scale_factor
""" """
def __init__(self, def __init__(self,
min_scale_factor=0.75, min_scale_factor=0.75,
max_scale_factor=1.25, max_scale_factor=1.25,
...@@ -392,7 +386,6 @@ class Normalize: ...@@ -392,7 +386,6 @@ class Normalize:
Raises: Raises:
ValueError: mean或std不是list对象。std包含0。 ValueError: mean或std不是list对象。std包含0。
""" """
def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): def __init__(self, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
self.mean = mean self.mean = mean
self.std = std self.std = std
...@@ -438,7 +431,6 @@ class Padding: ...@@ -438,7 +431,6 @@ class Padding:
TypeError: target_size不是int|list|tuple。 TypeError: target_size不是int|list|tuple。
ValueError: target_size为list|tuple时元素个数不等于2。 ValueError: target_size为list|tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
target_size, target_size,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -491,23 +483,21 @@ class Padding: ...@@ -491,23 +483,21 @@ class Padding:
'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})' 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
.format(im_width, im_height, target_width, target_height)) .format(im_width, im_height, target_width, target_height))
else: else:
im = cv2.copyMakeBorder( im = cv2.copyMakeBorder(im,
im, 0,
0, pad_height,
pad_height, 0,
0, pad_width,
pad_width, cv2.BORDER_CONSTANT,
cv2.BORDER_CONSTANT, value=self.im_padding_value)
value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder( label = cv2.copyMakeBorder(label,
label, 0,
0, pad_height,
pad_height, 0,
0, pad_width,
pad_width, cv2.BORDER_CONSTANT,
cv2.BORDER_CONSTANT, value=self.label_padding_value)
value=self.label_padding_value)
if label is None: if label is None:
return (im, im_info) return (im, im_info)
else: else:
...@@ -526,7 +516,6 @@ class RandomPaddingCrop: ...@@ -526,7 +516,6 @@ class RandomPaddingCrop:
TypeError: crop_size不是int/list/tuple。 TypeError: crop_size不是int/list/tuple。
ValueError: target_size为list/tuple时元素个数不等于2。 ValueError: target_size为list/tuple时元素个数不等于2。
""" """
def __init__(self, def __init__(self,
crop_size=512, crop_size=512,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -575,23 +564,21 @@ class RandomPaddingCrop: ...@@ -575,23 +564,21 @@ class RandomPaddingCrop:
pad_height = max(crop_height - img_height, 0) pad_height = max(crop_height - img_height, 0)
pad_width = max(crop_width - img_width, 0) pad_width = max(crop_width - img_width, 0)
if (pad_height > 0 or pad_width > 0): if (pad_height > 0 or pad_width > 0):
im = cv2.copyMakeBorder( im = cv2.copyMakeBorder(im,
im, 0,
0, pad_height,
pad_height, 0,
0, pad_width,
pad_width, cv2.BORDER_CONSTANT,
cv2.BORDER_CONSTANT, value=self.im_padding_value)
value=self.im_padding_value)
if label is not None: if label is not None:
label = cv2.copyMakeBorder( label = cv2.copyMakeBorder(label,
label, 0,
0, pad_height,
pad_height, 0,
0, pad_width,
pad_width, cv2.BORDER_CONSTANT,
cv2.BORDER_CONSTANT, value=self.label_padding_value)
value=self.label_padding_value)
img_height = im.shape[0] img_height = im.shape[0]
img_width = im.shape[1] img_width = im.shape[1]
...@@ -599,11 +586,11 @@ class RandomPaddingCrop: ...@@ -599,11 +586,11 @@ class RandomPaddingCrop:
h_off = np.random.randint(img_height - crop_height + 1) h_off = np.random.randint(img_height - crop_height + 1)
w_off = np.random.randint(img_width - crop_width + 1) w_off = np.random.randint(img_width - crop_width + 1)
im = im[h_off:(crop_height + h_off), w_off:( im = im[h_off:(crop_height + h_off), w_off:(w_off +
w_off + crop_width), :] crop_width), :]
if label is not None: if label is not None:
label = label[h_off:(crop_height + h_off), w_off:( label = label[h_off:(crop_height +
w_off + crop_width)] h_off), w_off:(w_off + crop_width)]
if label is None: if label is None:
return (im, im_info) return (im, im_info)
else: else:
...@@ -616,7 +603,6 @@ class RandomBlur: ...@@ -616,7 +603,6 @@ class RandomBlur:
Args: Args:
prob (float): 图像模糊概率。默认为0.1。 prob (float): 图像模糊概率。默认为0.1。
""" """
def __init__(self, prob=0.1): def __init__(self, prob=0.1):
self.prob = prob self.prob = prob
...@@ -664,7 +650,6 @@ class RandomRotation: ...@@ -664,7 +650,6 @@ class RandomRotation:
label_padding_value (int): 标注图像padding的值。默认为255。 label_padding_value (int): 标注图像padding的值。默认为255。
""" """
def __init__(self, def __init__(self,
max_rotation=15, max_rotation=15,
im_padding_value=[127.5, 127.5, 127.5], im_padding_value=[127.5, 127.5, 127.5],
...@@ -701,20 +686,18 @@ class RandomRotation: ...@@ -701,20 +686,18 @@ class RandomRotation:
r[0, 2] += (nw / 2) - cx r[0, 2] += (nw / 2) - cx
r[1, 2] += (nh / 2) - cy r[1, 2] += (nh / 2) - cy
dsize = (nw, nh) dsize = (nw, nh)
im = cv2.warpAffine( im = cv2.warpAffine(im,
im, r,
r, dsize=dsize,
dsize=dsize, flags=cv2.INTER_LINEAR,
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT,
borderMode=cv2.BORDER_CONSTANT, borderValue=self.im_padding_value)
borderValue=self.im_padding_value) label = cv2.warpAffine(label,
label = cv2.warpAffine( r,
label, dsize=dsize,
r, flags=cv2.INTER_NEAREST,
dsize=dsize, borderMode=cv2.BORDER_CONSTANT,
flags=cv2.INTER_NEAREST, borderValue=self.label_padding_value)
borderMode=cv2.BORDER_CONSTANT,
borderValue=self.label_padding_value)
if label is None: if label is None:
return (im, im_info) return (im, im_info)
...@@ -730,7 +713,6 @@ class RandomScaleAspect: ...@@ -730,7 +713,6 @@ class RandomScaleAspect:
min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。 min_scale (float):裁取图像占原始图像的面积比,取值[0,1],为0时则返回原图。默认为0.5。
aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。 aspect_ratio (float): 裁取图像的宽高比范围,非负值,为0时返回原图。默认为0.33。
""" """
def __init__(self, min_scale=0.5, aspect_ratio=0.33): def __init__(self, min_scale=0.5, aspect_ratio=0.33):
self.min_scale = min_scale self.min_scale = min_scale
self.aspect_ratio = aspect_ratio self.aspect_ratio = aspect_ratio
...@@ -769,12 +751,10 @@ class RandomScaleAspect: ...@@ -769,12 +751,10 @@ class RandomScaleAspect:
im = im[h1:(h1 + dh), w1:(w1 + dw), :] im = im[h1:(h1 + dh), w1:(w1 + dw), :]
label = label[h1:(h1 + dh), w1:(w1 + dw)] label = label[h1:(h1 + dh), w1:(w1 + dw)]
im = cv2.resize( im = cv2.resize(im, (img_width, img_height),
im, (img_width, img_height), interpolation=cv2.INTER_LINEAR)
interpolation=cv2.INTER_LINEAR) label = cv2.resize(label, (img_width, img_height),
label = cv2.resize( interpolation=cv2.INTER_NEAREST)
label, (img_width, img_height),
interpolation=cv2.INTER_NEAREST)
break break
if label is None: if label is None:
return (im, im_info) return (im, im_info)
...@@ -798,7 +778,6 @@ class RandomDistort: ...@@ -798,7 +778,6 @@ class RandomDistort:
hue_range (int): 色调因子的范围。默认为18。 hue_range (int): 色调因子的范围。默认为18。
hue_prob (float): 随机调整色调的概率。默认为0.5。 hue_prob (float): 随机调整色调的概率。默认为0.5。
""" """
def __init__(self, def __init__(self,
brightness_range=0.5, brightness_range=0.5,
brightness_prob=0.5, brightness_prob=0.5,
......
...@@ -13,6 +13,6 @@ ...@@ -13,6 +13,6 @@
# limitations under the License. # limitations under the License.
from . import logging from . import logging
from . import download
from .metrics import ConfusionMatrix from .metrics import ConfusionMatrix
from .download import download_file_and_uncompress
from .utils import * from .utils import *
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册