未验证 提交 27919efd 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #318 from wuyefeilin/dygraph

# 动态图执行 # 动态图执行
## 数据集设置
```
data_dir='data/path'
train_list='train/list/path'
val_list='val/list/path'
test_list='test/list/path'
num_classes=number/of/dataset/classes
```
## 训练 ## 训练
``` ```
python3 train.py --model_name UNet \ python3 train.py --model_name UNet \
--data_dir $data_dir \ --dataset OpticDiscSeg \
--train_list $train_list \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \ --input_size 192 192 \
--num_epochs 4 \ --num_epochs 10 \
--save_interval_epochs 1 \ --save_interval_epochs 1 \
--do_eval \
--save_dir output --save_dir output
``` ```
## 评估 ## 评估
``` ```
python3 val.py --model_name UNet \ python3 val.py --model_name UNet \
--data_dir $data_dir \ --dataset OpticDiscSeg \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \ --input_size 192 192 \
--model_dir output/epoch_1 --model_dir output/best_model
``` ```
## 预测 ## 预测
``` ```
python3 infer.py --model_name UNet \ python3 infer.py --model_name UNet \
--data_dir $data_dir \ --dataset OpticDiscSeg \
--test_list $test_list \ --model_dir output/best_model \
--num_classes $num_classes \ --input_size 192 192
--input_size 192 192 \
--model_dir output/epoch_1
``` ```
...@@ -12,5 +12,6 @@ ...@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
import os import os
from paddle.fluid.io import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
...@@ -70,16 +69,3 @@ class Cityscapes(Dataset): ...@@ -70,16 +69,3 @@ class Cityscapes(Dataset):
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
class Dataset(fluid.io.Dataset):
def __init__(self,
data_dir,
num_classes,
train_list=None,
val_list=None,
test_list=None,
separator=' ',
transforms=None,
mode='train'):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if mode == 'train':
if train_list is None:
raise Exception(
'When mode is "train", train_list is need, but it is None.')
elif not os.path.exists(train_list):
raise Exception(
'train_list is not found: {}'.format(train_list))
else:
file_list = train_list
elif mode == 'eval':
if val_list is None:
raise Exception(
'When mode is "eval", val_list is need, but it is None.')
elif not os.path.exists(val_list):
raise Exception('val_list is not found: {}'.format(val_list))
else:
file_list = val_list
else:
if test_list is None:
raise Exception(
'When mode is "test", test_list is need, but it is None.')
elif not os.path.exists(test_list):
raise Exception('test_list is not found: {}'.format(test_list))
else:
file_list = test_list
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split(separator)
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'train':
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
elif self.mode == 'eval':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
...@@ -14,8 +14,7 @@ ...@@ -14,8 +14,7 @@
import os import os
from paddle.fluid.io import Dataset from .dataset import Dataset
from utils.download import download_file_and_uncompress from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
...@@ -70,16 +69,3 @@ class OpticDiscSeg(Dataset): ...@@ -70,16 +69,3 @@ class OpticDiscSeg(Dataset):
image_path = os.path.join(self.data_dir, items[0]) image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1]) grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path]) self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
...@@ -24,7 +24,7 @@ import tqdm ...@@ -24,7 +24,7 @@ import tqdm
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils import utils
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
...@@ -37,7 +37,8 @@ def parse_args(): ...@@ -37,7 +37,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help='Model type for testing, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str, type=str,
default='UNet') default='UNet')
...@@ -97,19 +98,20 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'): ...@@ -97,19 +98,20 @@ def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
logging.info("Start to predict...") logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset): for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = im[np.newaxis, ...]
im = to_variable(im) im = to_variable(im)
pred, _ = model(im, mode='test') pred, _ = model(im, mode='test')
pred = pred.numpy() pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8') pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys()) for info in im_info[::-1]:
for k in keys[::-1]: if info[0] == 'resize':
if k == 'shape_before_resize': h, w = info[1][0], info[1][1]
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST) pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding': elif info[0] == 'padding':
h, w = im_info[k][0], im_info[k][1] h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w] pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.data_dir, '') im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/': if im_file[0] == '/':
...@@ -146,8 +148,11 @@ def main(args): ...@@ -146,8 +148,11 @@ def main(args):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test') test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet(num_classes=test_dataset.num_classes) raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=test_dataset.num_classes)
infer( infer(
model, model,
......
...@@ -13,3 +13,28 @@ ...@@ -13,3 +13,28 @@
# limitations under the License. # limitations under the License.
from .unet import UNet from .unet import UNet
from .hrnet import *
MODELS = {
"UNet": UNet,
"HRNet_W18_Small_V1": HRNet_W18_Small_V1,
"HRNet_W18_Small_V2": HRNet_W18_Small_V2,
"HRNet_W18": HRNet_W18,
"HRNet_W30": HRNet_W30,
"HRNet_W32": HRNet_W32,
"HRNet_W40": HRNet_W40,
"HRNet_W44": HRNet_W44,
"HRNet_W48": HRNet_W48,
"HRNet_W60": HRNet_W48,
"HRNet_W64": HRNet_W64,
"SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1,
"SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2,
"SE_HRNet_W18": SE_HRNet_W18,
"SE_HRNet_W30": SE_HRNet_W30,
"SE_HRNet_W32": SE_HRNet_W30,
"SE_HRNet_W40": SE_HRNet_W40,
"SE_HRNet_W44": SE_HRNet_W44,
"SE_HRNet_W48": SE_HRNet_W48,
"SE_HRNet_W60": SE_HRNet_W60,
"SE_HRNet_W64": SE_HRNet_W64
}
此差异已折叠。
...@@ -13,7 +13,11 @@ ...@@ -13,7 +13,11 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D from paddle.fluid.dygraph import Conv2D, Pool2D
try:
from paddle.fluid.dygraph import SyncBatchNorm as BatchNorm
except:
from paddle.fluid.dygraph import BatchNorm
class UNet(fluid.dygraph.Layer): class UNet(fluid.dygraph.Layer):
...@@ -39,6 +43,8 @@ class UNet(fluid.dygraph.Layer): ...@@ -39,6 +43,8 @@ class UNet(fluid.dygraph.Layer):
return pred, score_map return pred, score_map
def _get_loss(self, logit, label): def _get_loss(self, logit, label):
logit = fluid.layers.transpose(logit, [0, 2, 3, 1])
label = fluid.layers.transpose(label, [0, 2, 3, 1])
mask = label != self.ignore_index mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32') mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy( loss, probs = fluid.layers.softmax_with_cross_entropy(
......
...@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler ...@@ -22,7 +22,7 @@ from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import load_pretrained_model from utils import load_pretrained_model
...@@ -38,7 +38,8 @@ def parse_args(): ...@@ -38,7 +38,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for traing, which is one of ('UNet')", help='Model type for training, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str, type=str,
default='UNet') default='UNet')
...@@ -181,7 +182,7 @@ def train(model, ...@@ -181,7 +182,7 @@ def train(model,
total_steps = steps_per_epoch * (num_epochs - start_epoch) total_steps = steps_per_epoch * (num_epochs - start_epoch)
num_steps = 0 num_steps = 0
best_mean_iou = -1.0 best_mean_iou = -1.0
best_model_epoch = 1 best_model_epoch = -1
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
for step, data in enumerate(loader): for step, data in enumerate(loader):
images = data[0] images = data[0]
...@@ -229,10 +230,8 @@ def train(model, ...@@ -229,10 +230,8 @@ def train(model,
mean_iou, mean_acc = evaluate( mean_iou, mean_acc = evaluate(
model, model,
eval_dataset, eval_dataset,
places=places,
model_dir=current_save_dir, model_dir=current_save_dir,
num_classes=num_classes, num_classes=num_classes,
batch_size=batch_size,
ignore_index=ignore_index, ignore_index=ignore_index,
epoch_id=epoch + 1) epoch_id=epoch + 1)
if mean_iou > best_mean_iou: if mean_iou > best_mean_iou:
...@@ -241,9 +240,9 @@ def train(model, ...@@ -241,9 +240,9 @@ def train(model,
best_model_dir = os.path.join(save_dir, "best_model") best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(model.state_dict(), fluid.save_dygraph(model.state_dict(),
os.path.join(best_model_dir, 'model')) os.path.join(best_model_dir, 'model'))
logging.info( logging.info(
'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}' 'Current evaluated best model in eval_dataset is epoch_{}, miou={:4f}'
.format(best_model_epoch, best_mean_iou)) .format(best_model_epoch, best_mean_iou))
if use_vdl: if use_vdl:
log_writer.add_scalar('Evaluate/mean_iou', mean_iou, log_writer.add_scalar('Evaluate/mean_iou', mean_iou,
...@@ -286,9 +285,11 @@ def main(args): ...@@ -286,9 +285,11 @@ def main(args):
T.Normalize()]) T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet( raise Exception(
num_classes=train_dataset.num_classes, ignore_index=255) '--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=train_dataset.num_classes)
# Creat optimizer # Creat optimizer
# todo, may less one than len(loader) # todo, may less one than len(loader)
......
此差异已折叠。
...@@ -52,7 +52,11 @@ def load_pretrained_model(model, pretrained_model): ...@@ -52,7 +52,11 @@ def load_pretrained_model(model, pretrained_model):
logging.info('Load pretrained model from {}'.format(pretrained_model)) logging.info('Load pretrained model from {}'.format(pretrained_model))
if os.path.exists(pretrained_model): if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model') ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path) try:
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
except:
para_state_dict = fluid.load_program_state(pretrained_model)
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
keys = model_state_dict.keys() keys = model_state_dict.keys()
num_params_loaded = 0 num_params_loaded = 0
......
...@@ -16,8 +16,10 @@ import argparse ...@@ -16,8 +16,10 @@ import argparse
import os import os
import math import math
from paddle.fluid.dygraph.base import to_variable
import numpy as np import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
...@@ -25,7 +27,7 @@ from paddle.fluid.dataloader import BatchSampler ...@@ -25,7 +27,7 @@ from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes from datasets import OpticDiscSeg, Cityscapes
import transforms as T import transforms as T
import models from models import MODELS
import utils.logging as logging import utils.logging as logging
from utils import get_environ_info from utils import get_environ_info
from utils import ConfusionMatrix from utils import ConfusionMatrix
...@@ -39,7 +41,8 @@ def parse_args(): ...@@ -39,7 +41,8 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--model_name', '--model_name',
dest='model_name', dest='model_name',
help="Model type for evaluation, which is one of ('UNet')", help='Model type for evaluation, which is one of {}'.format(
str(list(MODELS.keys()))),
type=str, type=str,
default='UNet') default='UNet')
...@@ -60,12 +63,6 @@ def parse_args(): ...@@ -60,12 +63,6 @@ def parse_args():
nargs=2, nargs=2,
default=[512, 512], default=[512, 512],
type=int) type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument( parser.add_argument(
'--model_dir', '--model_dir',
dest='model_dir', dest='model_dir',
...@@ -78,10 +75,8 @@ def parse_args(): ...@@ -78,10 +75,8 @@ def parse_args():
def evaluate(model, def evaluate(model,
eval_dataset=None, eval_dataset=None,
places=None,
model_dir=None, model_dir=None,
num_classes=None, num_classes=None,
batch_size=2,
ignore_index=255, ignore_index=255,
epoch_id=None): epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model') ckpt_path = os.path.join(model_dir, 'model')
...@@ -89,15 +84,7 @@ def evaluate(model, ...@@ -89,15 +84,7 @@ def evaluate(model,
model.set_dict(para_state_dict) model.set_dict(para_state_dict)
model.eval() model.eval()
batch_sampler = BatchSampler( total_steps = len(eval_dataset)
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
places=places,
return_list=True,
)
total_steps = len(batch_sampler)
conf_mat = ConfusionMatrix(num_classes, streaming=True) conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info( logging.info(
...@@ -105,15 +92,26 @@ def evaluate(model, ...@@ -105,15 +92,26 @@ def evaluate(model,
len(eval_dataset), total_steps)) len(eval_dataset), total_steps))
timer = Timer() timer = Timer()
timer.start() timer.start()
for step, data in enumerate(loader): for step, (im, im_info, label) in enumerate(eval_dataset):
images = data[0] im = to_variable(im)
labels = data[1].astype('int64') pred, _ = model(im, mode='eval')
pred, _ = model(images, mode='eval') pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
pred = pred.numpy() for info in im_info[::-1]:
labels = labels.numpy() if info[0] == 'resize':
mask = labels != ignore_index h, w = info[1][0], info[1][1]
conf_mat.calculate(pred=pred, label=labels, ignore=mask) pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64')
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou() _, iou = conf_mat.mean_iou()
time_step = timer.elapsed_time() time_step = timer.elapsed_time()
...@@ -153,16 +151,17 @@ def main(args): ...@@ -153,16 +151,17 @@ def main(args):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()]) eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval') eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet': if args.model_name not in MODELS:
model = models.UNet(num_classes=eval_dataset.num_classes) raise Exception(
'--model_name is invalid. it should be one of {}'.format(
str(list(MODELS.keys()))))
model = MODELS[args.model_name](num_classes=eval_dataset.num_classes)
evaluate( evaluate(
model, model,
eval_dataset, eval_dataset,
places=places,
model_dir=args.model_dir, model_dir=args.model_dir,
num_classes=eval_dataset.num_classes, num_classes=eval_dataset.num_classes)
batch_size=args.batch_size)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册