提交 76fef60e 编写于 作者: C chenguowei01

add dataloader

上级 9ccba793
...@@ -18,6 +18,7 @@ import os ...@@ -18,6 +18,7 @@ import os
import random import random
from paddle.fluid.io import Dataset from paddle.fluid.io import Dataset
import cv2
from utils.download import download_file_and_uncompress from utils.download import download_file_and_uncompress
...@@ -37,6 +38,7 @@ class OpticDiscSeg(Dataset): ...@@ -37,6 +38,7 @@ class OpticDiscSeg(Dataset):
self.data_dir = data_dir self.data_dir = data_dir
self.transforms = transforms self.transforms = transforms
self.file_list = list() self.file_list = list()
self.mode = mode
if mode.lower() not in ['train', 'eval', 'test']: if mode.lower() not in ['train', 'eval', 'test']:
raise Exception( raise Exception(
...@@ -50,9 +52,8 @@ class OpticDiscSeg(Dataset): ...@@ -50,9 +52,8 @@ class OpticDiscSeg(Dataset):
if self.data_dir is None: if self.data_dir is None:
if not download: if not download:
raise Exception("data_file not set and auto download disabled.") raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(url=URL, self.data_dir = download_file_and_uncompress(
savepath=LOCAL_PATH, url=URL, savepath=LOCAL_PATH, extrapath=LOCAL_PATH)
extrapath=LOCAL_PATH)
if mode == 'train': if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt') file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval': elif mode == 'eval':
...@@ -83,9 +84,14 @@ class OpticDiscSeg(Dataset): ...@@ -83,9 +84,14 @@ class OpticDiscSeg(Dataset):
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
def __getitem__(self, idx): def __getitem__(self, idx):
print(idx)
image_path, grt_path = self.file_list[idx] image_path, grt_path = self.file_list[idx]
return self.transforms(im=image_path, label=grt_path) im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info
def __len__(self): def __len__(self):
return len(self.file_list) return len(self.file_list)
...@@ -35,74 +35,87 @@ def parse_args(): ...@@ -35,74 +35,87 @@ def parse_args():
parser = argparse.ArgumentParser(description='Model training') parser = argparse.ArgumentParser(description='Model training')
# params of model # params of model
parser.add_argument('--model_name', parser.add_argument(
dest='model_name', '--model_name',
help="Model type for traing, which is one of ('UNet')", dest='model_name',
type=str, help="Model type for traing, which is one of ('UNet')",
default='UNet') type=str,
default='UNet')
# params of dataset # params of dataset
parser.add_argument('--data_dir', parser.add_argument(
dest='data_dir', '--data_dir',
help='The root directory of dataset', dest='data_dir',
type=str) help='The root directory of dataset',
parser.add_argument('--train_list', type=str)
dest='train_list', parser.add_argument(
help='Train list file of dataset', '--train_list',
type=str) dest='train_list',
parser.add_argument('--val_list', help='Train list file of dataset',
dest='val_list', type=str)
help='Val list file of dataset', parser.add_argument(
type=str, '--val_list',
default=None) dest='val_list',
parser.add_argument('--num_classes', help='Val list file of dataset',
dest='num_classes', type=str,
help='Number of classes', default=None)
type=int, parser.add_argument(
default=2) '--num_classes',
dest='num_classes',
help='Number of classes',
type=int,
default=2)
# params of training # params of training
parser.add_argument("--input_size", parser.add_argument(
dest="input_size", "--input_size",
help="The image size for net inputs.", dest="input_size",
nargs=2, help="The image size for net inputs.",
default=[512, 512], nargs=2,
type=int) default=[512, 512],
parser.add_argument('--num_epochs', type=int)
dest='num_epochs', parser.add_argument(
help='Number epochs for training', '--num_epochs',
type=int, dest='num_epochs',
default=100) help='Number epochs for training',
parser.add_argument('--batch_size', type=int,
dest='batch_size', default=100)
help='Mini batch size', parser.add_argument(
type=int, '--batch_size',
default=2) dest='batch_size',
parser.add_argument('--learning_rate', help='Mini batch size',
dest='learning_rate', type=int,
help='Learning rate', default=2)
type=float, parser.add_argument(
default=0.01) '--learning_rate',
parser.add_argument('--pretrained_model', dest='learning_rate',
dest='pretrained_model', help='Learning rate',
help='The path of pretrianed weight', type=float,
type=str, default=0.01)
default=None) parser.add_argument(
parser.add_argument('--save_interval_epochs', '--pretrained_model',
dest='save_interval_epochs', dest='pretrained_model',
help='The interval epochs for save a model snapshot', help='The path of pretrianed weight',
type=int, type=str,
default=5) default=None)
parser.add_argument('--save_dir', parser.add_argument(
dest='save_dir', '--save_interval_epochs',
help='The directory for saving the model snapshot', dest='save_interval_epochs',
type=str, help='The interval epochs for save a model snapshot',
default='./output') type=int,
parser.add_argument('--num_workers', default=5)
dest='num_workers', parser.add_argument(
help='Num workers for data loader', '--save_dir',
type=int, dest='save_dir',
default=0) help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
return parser.parse_args() return parser.parse_args()
...@@ -126,10 +139,8 @@ def train(model, ...@@ -126,10 +139,8 @@ def train(model,
load_pretrained_model(model, pretrained_model) load_pretrained_model(model, pretrained_model)
batch_sampler = DistributedBatchSampler(train_dataset, batch_sampler = DistributedBatchSampler(
batch_size=batch_size, train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
shuffle=True,
drop_last=True)
loader = DataLoader( loader = DataLoader(
train_dataset, train_dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
...@@ -142,10 +153,8 @@ def train(model, ...@@ -142,10 +153,8 @@ def train(model,
for epoch in range(num_epochs): for epoch in range(num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = np.array([d[0] for d in data]) images = data[0]
labels = np.array([d[2] for d in data]).astype('int64') labels = data[1].astype('int64')
images = to_variable(images)
labels = to_variable(labels)
loss = model(images, labels, mode='train') loss = model(images, labels, mode='train')
loss.backward() loss.backward()
optimizer.minimize(loss) optimizer.minimize(loss)
...@@ -165,13 +174,14 @@ def train(model, ...@@ -165,13 +174,14 @@ def train(model,
if eval_dataset is not None: if eval_dataset is not None:
model.eval() model.eval()
evaluate(model, evaluate(
eval_dataset, model,
model_dir=current_save_dir, eval_dataset,
num_classes=num_classes, model_dir=current_save_dir,
batch_size=batch_size, num_classes=num_classes,
ignore_index=model.ignore_index, batch_size=batch_size,
epoch_id=epoch + 1) ignore_index=model.ignore_index,
epoch_id=epoch + 1)
model.train() model.train()
...@@ -194,13 +204,14 @@ def main(args): ...@@ -194,13 +204,14 @@ def main(args):
eval_transforms = T.Compose( eval_transforms = T.Compose(
[T.Resize(args.input_size), [T.Resize(args.input_size),
T.Normalize()]) T.Normalize()])
eval_dataset = Dataset(data_dir=args.data_dir, eval_dataset = Dataset(
file_list=args.val_list, data_dir=args.data_dir,
transforms=eval_transforms, file_list=args.val_list,
num_workers='auto', transforms=eval_transforms,
buffer_size=100, num_workers='auto',
parallel_method='thread', buffer_size=100,
shuffle=False) parallel_method='thread',
shuffle=False)
if args.model_name == 'UNet': if args.model_name == 'UNet':
model = models.UNet(num_classes=args.num_classes, ignore_index=255) model = models.UNet(num_classes=args.num_classes, ignore_index=255)
...@@ -208,28 +219,27 @@ def main(args): ...@@ -208,28 +219,27 @@ def main(args):
# Creat optimizer # Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size num_steps_each_epoch = len(train_dataset) // args.batch_size
decay_step = args.num_epochs * num_steps_each_epoch decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(args.learning_rate, lr_decay = fluid.layers.polynomial_decay(
decay_step, args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
end_learning_rate=0,
power=0.9)
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
lr_decay, lr_decay,
momentum=0.9, momentum=0.9,
parameter_list=model.parameters(), parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5)) regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(model, train(
train_dataset, model,
places=places, train_dataset,
eval_dataset=eval_dataset, places=places,
optimizer=optimizer, eval_dataset=eval_dataset,
save_dir=args.save_dir, optimizer=optimizer,
num_epochs=args.num_epochs, save_dir=args.save_dir,
batch_size=args.batch_size, num_epochs=args.num_epochs,
pretrained_model=args.pretrained_model, batch_size=args.batch_size,
save_interval_epochs=args.save_interval_epochs, pretrained_model=args.pretrained_model,
num_classes=args.num_classes, save_interval_epochs=args.save_interval_epochs,
num_workers=args.num_workers) num_classes=args.num_classes,
num_workers=args.num_workers)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册