未验证 提交 27919efd 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #318 from wuyefeilin/dygraph

# 动态图执行
## 数据集设置
```
data_dir='data/path'
train_list='train/list/path'
val_list='val/list/path'
test_list='test/list/path'
num_classes=number/of/dataset/classes
```
## 训练
```
python3 train.py --model_name UNet \
--data_dir $data_dir \
--train_list $train_list \
--val_list $val_list \
--num_classes $num_classes \
--dataset OpticDiscSeg \
--input_size 192 192 \
--num_epochs 4 \
--num_epochs 10 \
--save_interval_epochs 1 \
--do_eval \
--save_dir output
```
## 评估
```
python3 val.py --model_name UNet \
--data_dir $data_dir \
--val_list $val_list \
--num_classes $num_classes \
--dataset OpticDiscSeg \
--input_size 192 192 \
--model_dir output/epoch_1
--model_dir output/best_model
```
## 预测
```
python3 infer.py --model_name UNet \
--data_dir $data_dir \
--test_list $test_list \
--num_classes $num_classes \
--input_size 192 192 \
--model_dir output/epoch_1
--dataset OpticDiscSeg \
--model_dir output/best_model \
--input_size 192 192
```
......@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
......@@ -14,8 +14,7 @@
import os
from paddle.fluid.io import Dataset
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
......@@ -70,16 +69,3 @@ class Cityscapes(Dataset):
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
class Dataset(fluid.io.Dataset):
def __init__(self,
data_dir,
num_classes,
train_list=None,
val_list=None,
test_list=None,
separator=' ',
transforms=None,
mode='train'):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if mode == 'train':
if train_list is None:
raise Exception(
'When mode is "train", train_list is need, but it is None.')
elif not os.path.exists(train_list):
raise Exception(
'train_list is not found: {}'.format(train_list))
else:
file_list = train_list
elif mode == 'eval':
if val_list is None:
raise Exception(
'When mode is "eval", val_list is need, but it is None.')
elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list))
else:
file_list = val_list
else:
if test_list is None:
raise Exception(
'When mode is "test", test_list is need, but it is None.')
elif not os.path.exists(test_list):
raise Exception('test_list is not found: {}'.format(test_list))
else:
file_list = test_list
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split(separator)
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'train':
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
elif self.mode == 'eval':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
......@@ -14,8 +14,7 @@
import os
from paddle.fluid.io import Dataset
from .dataset import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
......@@ -70,16 +69,3 @@ class OpticDiscSeg(Dataset):
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
......@@ -24,7 +24,7 @@ import tqdm
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils
import utils.logging as logging
from utils import get_environ_info
......@@ -37,7 +37,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
help='Model type for testing, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......@@ -97,19 +98,20 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = im[np.newaxis, ...]
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
......@@ -146,8 +148,11 @@ def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=test_dataset.num_classes)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=test_dataset.num_classes)
infer(
model,
......
......@@ -13,3 +13,28 @@
# limitations under the License.
from .unet import UNet
from .hrnet import *
MODELS = {
"UNet": UNet,
"HRNet_W18_Small_V1": HRNet_W18_Small_V1,
"HRNet_W18_Small_V2": HRNet_W18_Small_V2,
"HRNet_W18": HRNet_W18,
"HRNet_W30": HRNet_W30,
"HRNet_W32": HRNet_W32,
"HRNet_W40": HRNet_W40,
"HRNet_W44": HRNet_W44,
"HRNet_W48": HRNet_W48,
"HRNet_W60": HRNet_W48,
"HRNet_W64": HRNet_W64,
"SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
"SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
"SE_HRNet_W18": SE_HRNet_W18,
"SE_HRNet_W30": SE_HRNet_W30,
"SE_HRNet_W32": SE_HRNet_W30,
"SE_HRNet_W40": SE_HRNet_W40,
"SE_HRNet_W44": SE_HRNet_W44,
"SE_HRNet_W48": SE_HRNet_W48,
"SE_HRNet_W60": SE_HRNet_W60,
"SE_HRNet_W64": SE_HRNet_W64
}
此差异已折叠。
......@@ -13,7 +13,11 @@
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D
from paddle.fluid.dygraph import Conv2D, Pool2D
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
except:
from paddle.fluid.dygraph import BatchNorm
class UNet(fluid.dygraph.Layer):
......@@ -39,6 +43,8 @@ class UNet(fluid.dygraph.Layer):
return pred, score_map
def _get_loss(self, logit, label):
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
......
......@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
......@@ -38,7 +38,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......@@ -181,7 +182,7 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0
best_mean_iou = -1.0
best_model_epoch = 1
best_model_epoch = -1
for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader):
images = data[0]
......@@ -229,10 +230,8 @@ def train(model,
mean_iou, mean_acc = evaluate(
model,
eval_dataset,
places=places,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=ignore_index,
epoch_id=epoch + 1)
if mean_iou > best_mean_iou:
......@@ -241,9 +240,9 @@ def train(model,
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model'))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
......@@ -286,9 +285,11 @@ def main(args):
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer
# todo, may less one than len(loader)
......
此差异已折叠。
......@@ -52,7 +52,11 @@ def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
try:
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
except:
para_state_dict = fluid.load_program_state(pretrained_model)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
......
......@@ -16,8 +16,10 @@ import argparse
import os
import math
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
......@@ -25,7 +27,7 @@ from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
from models import MODELS
import utils.logging as logging
from utils import get_environ_info
from utils import ConfusionMatrix
......@@ -39,7 +41,8 @@ def parse_args():
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for evaluation, which is one of ('UNet')",
help='Model type for evaluation, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str,
default='UNet')
......@@ -60,12 +63,6 @@ def parse_args():
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
......@@ -78,10 +75,8 @@ def parse_args():
def evaluate(model,
eval_dataset=None,
places=None,
model_dir=None,
num_classes=None,
batch_size=2,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
......@@ -89,15 +84,7 @@ def evaluate(model,
model.set_dict(para_state_dict)
model.eval()
batch_sampler = BatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
places=places,
return_list=True,
)
total_steps = len(batch_sampler)
total_steps = len(eval_dataset)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
......@@ -105,15 +92,26 @@ def evaluate(model,
len(eval_dataset), total_steps))
timer = Timer()
timer.start()
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
pred, _ = model(images, mode='eval')
pred = pred.numpy()
labels = labels.numpy()
mask = labels != ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
for step, (im, im_info, label) in enumerate(eval_dataset):
im = to_variable(im)
pred, _ = model(im, mode='eval')
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64')
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time()
......@@ -153,16 +151,17 @@ def main(args):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=eval_dataset.num_classes)
if args.model_name not in MODELS:
raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=eval_dataset.num_classes)
evaluate(
model,
eval_dataset,
places=places,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes,
batch_size=args.batch_size)
num_classes=eval_dataset.num_classes)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册