提交 76fef60e 编写于 作者: C chenguowei01

add dataloader

上级 9ccba793
......@@ -18,6 +18,7 @@ import os
import random
from paddle.fluid.io import Dataset
import cv2
from utils.download import download_file_and_uncompress
......@@ -37,6 +38,7 @@ class OpticDiscSeg(Dataset):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
......@@ -50,9 +52,8 @@ class OpticDiscSeg(Dataset):
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(url=URL,
savepath=LOCAL_PATH,
extrapath=LOCAL_PATH)
self.data_dir = download_file_and_uncompress(
url=URL, savepath=LOCAL_PATH, extrapath=LOCAL_PATH)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
......@@ -83,9 +84,14 @@ class OpticDiscSeg(Dataset):
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
print(idx)
image_path, grt_path = self.file_list[idx]
return self.transforms(im=image_path, label=grt_path)
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info
def __len__(self):
return len(self.file_list)
......@@ -35,74 +35,87 @@ def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument('--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument('--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument('--train_list',
dest='train_list',
help='Train list file of dataset',
type=str)
parser.add_argument('--val_list',
dest='val_list',
help='Val list file of dataset',
type=str,
default=None)
parser.add_argument('--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
parser.add_argument(
'--data_dir',
dest='data_dir',
help='The root directory of dataset',
type=str)
parser.add_argument(
'--train_list',
dest='train_list',
help='Train list file of dataset',
type=str)
parser.add_argument(
'--val_list',
dest='val_list',
help='Val list file of dataset',
type=str,
default=None)
parser.add_argument(
'--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of training
parser.add_argument("--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument('--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=100)
parser.add_argument('--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument('--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument('--pretrained_model',
dest='pretrained_model',
help='The path of pretrianed weight',
type=str,
default=None)
parser.add_argument('--save_interval_epochs',
dest='save_interval_epochs',
help='The interval epochs for save a model snapshot',
type=int,
default=5)
parser.add_argument('--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument('--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=100)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrianed weight',
type=str,
default=None)
parser.add_argument(
'--save_interval_epochs',
dest='save_interval_epochs',
help='The interval epochs for save a model snapshot',
type=int,
default=5)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
return parser.parse_args()
......@@ -126,10 +139,8 @@ def train(model,
load_pretrained_model(model, pretrained_model)
batch_sampler = DistributedBatchSampler(train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
......@@ -142,10 +153,8 @@ def train(model,
for epoch in range(num_epochs):
for step, data in enumerate(loader):
images = np.array([d[0] for d in data])
labels = np.array([d[2] for d in data]).astype('int64')
images = to_variable(images)
labels = to_variable(labels)
images = data[0]
labels = data[1].astype('int64')
loss = model(images, labels, mode='train')
loss.backward()
optimizer.minimize(loss)
......@@ -165,13 +174,14 @@ def train(model,
if eval_dataset is not None:
model.eval()
evaluate(model,
eval_dataset,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=model.ignore_index,
epoch_id=epoch + 1)
evaluate(
model,
eval_dataset,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=model.ignore_index,
epoch_id=epoch + 1)
model.train()
......@@ -194,13 +204,14 @@ def main(args):
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = Dataset(data_dir=args.data_dir,
file_list=args.val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
eval_dataset = Dataset(
data_dir=args.data_dir,
file_list=args.val_list,
transforms=eval_transforms,
num_workers='auto',
buffer_size=100,
parallel_method='thread',
shuffle=False)
if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes, ignore_index=255)
......@@ -208,28 +219,27 @@ def main(args):
# Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size
decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(args.learning_rate,
decay_step,
end_learning_rate=0,
power=0.9)
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
save_dir=args.save_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
save_interval_epochs=args.save_interval_epochs,
num_classes=args.num_classes,
num_workers=args.num_workers)
train(
model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
save_dir=args.save_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
save_interval_epochs=args.save_interval_epochs,
num_classes=args.num_classes,
num_workers=args.num_workers)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册