未验证 提交 1c06b49d 编写于 作者: W wuzewu 提交者: GitHub

Merge pull request #277 from wuyefeilin/dygraph

add dygraph
# 动态图执行
## 数据集设置
```
data_dir='data/path'
train_list='train/list/path'
val_list='val/list/path'
test_list='test/list/path'
num_classes=number/of/dataset/classes
```
## 训练
```
python3 train.py --model_name UNet \
--data_dir $data_dir \
--train_list $train_list \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \
--num_epochs 4 \
--save_interval_epochs 1 \
--save_dir output
```
## 评估
```
python3 val.py --model_name UNet \
--data_dir $data_dir \
--val_list $val_list \
--num_classes $num_classes \
--input_size 192 192 \
--model_dir output/epoch_1
```
## 预测
```
python3 infer.py --model_name UNet \
--data_dir $data_dir \
--test_list $test_list \
--num_classes $num_classes \
--input_size 192 192 \
--model_dir output/epoch_1
```
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.io import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/cityscapes.tar"
class Cityscapes(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
mode='train',
download=True):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 19
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train.list')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val.list')
else:
file_list = os.path.join(self.data_dir, 'test.list')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.io import Dataset
from utils.download import download_file_and_uncompress
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
class OpticDiscSeg(Dataset):
def __init__(self,
data_dir=None,
transforms=None,
mode='train',
download=True):
self.data_dir = data_dir
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 2
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
self.data_dir = data_dir
if self.data_dir is None:
if not download:
raise Exception("data_file not set and auto download disabled.")
self.data_dir = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
im, im_info, label = self.transforms(im=image_path, label=grt_path)
if self.mode == 'train':
return im, label
elif self.mode == 'eval':
return im, label
if self.mode == 'test':
return im, im_info, image_path
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
import cv2
import tqdm
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils
import utils.logging as logging
from utils import get_environ_info
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of prediction
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
help='The path of model for evaluation',
type=str,
default=None)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the inference results',
type=str,
default='./output/result')
return parser.parse_args()
def mkdir(path):
sub_dir = os.path.dirname(path)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = os.path.join(save_dir, 'added')
pred_saved_dir = os.path.join(save_dir, 'prediction')
logging.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = im[np.newaxis, ...]
im = to_variable(im)
pred, _ = model(im, mode='test')
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
keys = list(im_info.keys())
for k in keys[::-1]:
if k == 'shape_before_resize':
h, w = im_info[k][0], im_info[k][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif k == 'shape_before_padding':
h, w = im_info[k][0], im_info[k][1]
pred = pred[0:h, 0:w]
im_file = im_path.replace(test_dataset.data_dir, '')
if im_file[0] == '/':
im_file = im_file[1:]
# save added image
added_image = utils.visualize(im_path, pred, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_path, pred, weight=0.0)
pred_saved_path = os.path.join(pred_saved_dir, im_file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
test_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
test_dataset = dataset(transforms=test_transforms, mode='test')
if args.model_name == 'UNet':
model = models.UNet(num_classes=test_dataset.num_classes)
infer(
model,
model_dir=args.model_dir,
test_dataset=test_dataset,
save_dir=args.save_dir)
if __name__ == '__main__':
args = parse_args()
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .unet import UNet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, BatchNorm, Pool2D
class UNet(fluid.dygraph.Layer):
def __init__(self, num_classes, ignore_index=255):
super().__init__()
self.encode = UnetEncoder()
self.decode = UnetDecode()
self.get_logit = GetLogit(64, num_classes)
self.ignore_index = ignore_index
self.EPS = 1e-5
def forward(self, x, label=None, mode='train'):
encode_data, short_cuts = self.encode(x)
decode_data = self.decode(encode_data, short_cuts)
logit = self.get_logit(decode_data)
if mode == 'train':
return self._get_loss(logit, label)
else:
score_map = fluid.layers.softmax(logit, axis=1)
score_map = fluid.layers.transpose(score_map, [0, 2, 3, 1])
pred = fluid.layers.argmax(score_map, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
return pred, score_map
def _get_loss(self, logit, label):
mask = label != self.ignore_index
mask = fluid.layers.cast(mask, 'float32')
loss, probs = fluid.layers.softmax_with_cross_entropy(
logit,
label,
ignore_index=self.ignore_index,
return_softmax=True,
axis=1)
loss = loss * mask
avg_loss = fluid.layers.mean(loss) / (
fluid.layers.mean(mask) + self.EPS)
label.stop_gradient = True
mask.stop_gradient = True
return avg_loss
class UnetEncoder(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.double_conv = DoubleConv(3, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
def forward(self, x):
short_cuts = []
x = self.double_conv(x)
short_cuts.append(x)
x = self.down1(x)
short_cuts.append(x)
x = self.down2(x)
short_cuts.append(x)
x = self.down3(x)
short_cuts.append(x)
x = self.down4(x)
return x, short_cuts
class UnetDecode(fluid.dygraph.Layer):
def __init__(self):
super().__init__()
self.up1 = Up(512, 256)
self.up2 = Up(256, 128)
self.up3 = Up(128, 64)
self.up4 = Up(64, 64)
def forward(self, x, short_cuts):
x = self.up1(x, short_cuts[3])
x = self.up2(x, short_cuts[2])
x = self.up3(x, short_cuts[1])
x = self.up4(x, short_cuts[0])
return x
class DoubleConv(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.conv0 = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn0 = BatchNorm(num_channels=num_filters)
self.conv1 = Conv2D(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=1,
padding=1)
self.bn1 = BatchNorm(num_channels=num_filters)
def forward(self, x):
x = self.conv0(x)
x = self.bn0(x)
x = fluid.layers.relu(x)
x = self.conv1(x)
x = self.bn1(x)
x = fluid.layers.relu(x)
return x
class Down(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.max_pool = Pool2D(
pool_size=2, pool_type='max', pool_stride=2, pool_padding=0)
self.double_conv = DoubleConv(num_channels, num_filters)
def forward(self, x):
x = self.max_pool(x)
x = self.double_conv(x)
return x
class Up(fluid.dygraph.Layer):
def __init__(self, num_channels, num_filters):
super().__init__()
self.double_conv = DoubleConv(2 * num_channels, num_filters)
def forward(self, x, short_cut):
short_cut_shape = fluid.layers.shape(short_cut)
x = fluid.layers.resize_bilinear(x, short_cut_shape[2:])
x = fluid.layers.concat([x, short_cut], axis=1)
x = self.double_conv(x)
return x
class GetLogit(fluid.dygraph.Layer):
def __init__(self, num_channels, num_classes):
super().__init__()
self.conv = Conv2D(
num_channels=num_channels,
num_filters=num_classes,
filter_size=3,
stride=1,
padding=1)
def forward(self, x):
x = self.conv(x)
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.incubate.hapi.distributed import DistributedBatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
from utils import load_pretrained_model
from val import evaluate
def parse_args():
parser = argparse.ArgumentParser(description='Model training')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for traing, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to train, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of training
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--num_epochs',
dest='num_epochs',
help='Number epochs for training',
type=int,
default=100)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size of one gpu or cpu',
type=int,
default=2)
parser.add_argument(
'--learning_rate',
dest='learning_rate',
help='Learning rate',
type=float,
default=0.01)
parser.add_argument(
'--pretrained_model',
dest='pretrained_model',
help='The path of pretrained weight',
type=str,
default=None)
parser.add_argument(
'--save_interval_epochs',
dest='save_interval_epochs',
help='The interval epochs for save a model snapshot',
type=int,
default=5)
parser.add_argument(
'--save_dir',
dest='save_dir',
help='The directory for saving the model snapshot',
type=str,
default='./output')
parser.add_argument(
'--num_workers',
dest='num_workers',
help='Num workers for data loader',
type=int,
default=0)
parser.add_argument(
'--do_eval',
dest='do_eval',
help='Eval while training',
action='store_true')
return parser.parse_args()
def train(model,
train_dataset,
places=None,
eval_dataset=None,
optimizer=None,
save_dir='output',
num_epochs=100,
batch_size=2,
pretrained_model=None,
save_interval_epochs=1,
num_classes=None,
num_workers=8):
ignore_index = model.ignore_index
nranks = ParallelEnv().nranks
load_pretrained_model(model, pretrained_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
model_parallel = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
num_steps_each_epoch = len(train_dataset) // batch_size
for epoch in range(num_epochs):
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
loss = model_parallel(images, labels, mode='train')
loss = model_parallel.scale_loss(loss)
loss.backward()
model_parallel.apply_collective_grads()
else:
loss = model(images, labels, mode='train')
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
logging.info("[TRAIN] Epoch={}/{}, Step={}/{}, loss={}".format(
epoch + 1, num_epochs, step + 1, len(batch_sampler),
loss.numpy()))
if ((epoch + 1) % save_interval_epochs == 0
or num_steps_each_epoch == num_epochs - 1
) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir,
"epoch_{}".format(epoch + 1))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
evaluate(
model,
eval_dataset,
places=places,
model_dir=current_save_dir,
num_classes=num_classes,
batch_size=batch_size,
ignore_index=ignore_index,
epoch_id=epoch + 1)
model.train()
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
# Creat dataset reader
train_transforms = T.Compose([
T.Resize(args.input_size),
T.RandomHorizontalFlip(),
T.Normalize()
])
train_dataset = dataset(transforms=train_transforms, mode='train')
eval_dataset = None
if args.do_eval:
eval_transforms = T.Compose(
[T.Resize(args.input_size),
T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(
num_classes=train_dataset.num_classes, ignore_index=255)
# Creat optimizer
num_steps_each_epoch = len(train_dataset) // args.batch_size
decay_step = args.num_epochs * num_steps_each_epoch
lr_decay = fluid.layers.polynomial_decay(
args.learning_rate, decay_step, end_learning_rate=0, power=0.9)
optimizer = fluid.optimizer.Momentum(
lr_decay,
momentum=0.9,
parameter_list=model.parameters(),
regularization=fluid.regularizer.L2Decay(regularization_coeff=4e-5))
train(
model,
train_dataset,
places=places,
eval_dataset=eval_dataset,
optimizer=optimizer,
save_dir=args.save_dir,
num_epochs=args.num_epochs,
batch_size=args.batch_size,
pretrained_model=args.pretrained_model,
save_interval_epochs=args.save_interval_epochs,
num_classes=train_dataset.num_classes,
num_workers=args.num_workers)
if __name__ == '__main__':
args = parse_args()
main(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
from . import functional
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
from PIL import Image, ImageEnhance
def normalize(im, mean, std):
im = im.astype(np.float32, copy=False) / 255.0
im -= mean
im /= std
return im
def permute(im):
im = np.transpose(im, (2, 0, 1))
return im
def resize(im, target_size=608, interp=cv2.INTER_LINEAR):
if isinstance(target_size, list) or isinstance(target_size, tuple):
w = target_size[0]
h = target_size[1]
else:
w = target_size
h = target_size
im = cv2.resize(im, (w, h), interpolation=interp)
return im
def resize_long(im, long_size=224, interpolation=cv2.INTER_LINEAR):
value = max(im.shape[0], im.shape[1])
scale = float(long_size) / float(value)
resized_width = int(round(im.shape[1] * scale))
resized_height = int(round(im.shape[0] * scale))
im = cv2.resize(
im, (resized_width, resized_height), interpolation=interpolation)
return im
def horizontal_flip(im):
if len(im.shape) == 3:
im = im[:, ::-1, :]
elif len(im.shape) == 2:
im = im[:, ::-1]
return im
def vertical_flip(im):
if len(im.shape) == 3:
im = im[::-1, :, :]
elif len(im.shape) == 2:
im = im[::-1, :]
return im
def brightness(im, brightness_lower, brightness_upper):
brightness_delta = np.random.uniform(brightness_lower, brightness_upper)
im = ImageEnhance.Brightness(im).enhance(brightness_delta)
return im
def contrast(im, contrast_lower, contrast_upper):
contrast_delta = np.random.uniform(contrast_lower, contrast_upper)
im = ImageEnhance.Contrast(im).enhance(contrast_delta)
return im
def saturation(im, saturation_lower, saturation_upper):
saturation_delta = np.random.uniform(saturation_lower, saturation_upper)
im = ImageEnhance.Color(im).enhance(saturation_delta)
return im
def hue(im, hue_lower, hue_upper):
hue_delta = np.random.uniform(hue_lower, hue_upper)
im = np.array(im.convert('HSV'))
im[:, :, 0] = im[:, :, 0] + hue_delta
im = Image.fromarray(im, mode='HSV').convert('RGB')
return im
def rotate(im, rotate_lower, rotate_upper):
rotate_delta = np.random.uniform(rotate_lower, rotate_upper)
im = im.rotate(int(rotate_delta))
return im
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import logging
from . import download
from .metrics import ConfusionMatrix
from .utils import *
import os
import sys
import time
import requests
import tarfile
import zipfile
import shutil
import functools
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
def _uncompress_file_zip(filepath, extrapath):
files = zipfile.ZipFile(filepath, 'r')
filelist = files.namelist()
rootpath = filelist[0]
total_num = len(filelist)
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file_tar(filepath, extrapath, mode="r:gz"):
files = tarfile.open(filepath, mode)
filelist = files.getnames()
total_num = len(filelist)
rootpath = filelist[0]
for index, file in enumerate(filelist):
files.extract(file, extrapath)
yield total_num, index, rootpath
files.close()
yield total_num, index, rootpath
def _uncompress_file(filepath, extrapath, delete_file, print_progress):
if print_progress:
print("Uncompress %s" % os.path.basename(filepath))
if filepath.endswith("zip"):
handler = _uncompress_file_zip
elif filepath.endswith("tgz"):
handler = _uncompress_file_tar
else:
handler = functools.partial(_uncompress_file_tar, mode="r")
for total_num, index, rootpath in handler(filepath, extrapath):
if print_progress:
done = int(50 * float(index) / total_num)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * index) / total_num))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
if delete_file:
os.remove(filepath)
return rootpath
def download_file_and_uncompress(url,
savepath=None,
extrapath=None,
extraname=None,
print_progress=True,
cover=False,
delete_file=True):
if savepath is None:
savepath = "."
if extrapath is None:
extrapath = "."
savename = url.split("/")[-1]
savepath = os.path.join(savepath, savename)
savename = ".".join(savename.split(".")[:-1])
savename = os.path.join(extrapath, savename)
extraname = savename if extraname is None else os.path.join(
extrapath, extraname)
if cover:
if os.path.exists(savepath):
shutil.rmtree(savepath)
if os.path.exists(savename):
shutil.rmtree(savename)
if os.path.exists(extraname):
shutil.rmtree(extraname)
if not os.path.exists(extraname):
if not os.path.exists(savename):
if not os.path.exists(savepath):
_download_file(url, savepath, print_progress)
savename = _uncompress_file(savepath, extrapath, delete_file,
print_progress)
savename = os.path.join(extrapath, savename)
shutil.move(savename, extraname)
return savename
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import os
import sys
from paddle.fluid.dygraph.parallel import ParallelEnv
levels = {0: 'ERROR', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'}
log_level = 2
def log(level=2, message=""):
if ParallelEnv().local_rank == 0:
current_time = time.time()
time_array = time.localtime(current_time)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time_array)
if log_level >= level:
print(
"{} [{}]\t{}".format(current_time, levels[level],
message).encode("utf-8").decode("latin1"))
sys.stdout.flush()
def debug(message=""):
log(level=3, message=message)
def info(message=""):
log(level=2, message=message)
def warning(message=""):
log(level=1, message=message)
def error(message=""):
log(level=0, message=message)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import numpy as np
from scipy.sparse import csr_matrix
class ConfusionMatrix(object):
"""
Confusion Matrix for segmentation evaluation
"""
def __init__(self, num_classes=2, streaming=False):
self.confusion_matrix = np.zeros([num_classes, num_classes],
dtype='int64')
self.num_classes = num_classes
self.streaming = streaming
def calculate(self, pred, label, ignore=None):
# If not in streaming mode, clear matrix everytime when call `calculate`
if not self.streaming:
self.zero_matrix()
label = np.transpose(label, (0, 2, 3, 1))
ignore = np.transpose(ignore, (0, 2, 3, 1))
mask = np.array(ignore) == 1
label = np.asarray(label)[mask]
pred = np.asarray(pred)[mask]
one = np.ones_like(pred)
# Accumuate ([row=label, col=pred], 1) into sparse matrix
spm = csr_matrix((one, (label, pred)),
shape=(self.num_classes, self.num_classes))
spm = spm.todense()
self.confusion_matrix += spm
def zero_matrix(self):
""" Clear confusion matrix """
self.confusion_matrix = np.zeros([self.num_classes, self.num_classes],
dtype='int64')
def mean_iou(self):
iou_list = []
avg_iou = 0
# TODO: use numpy sum axis api to simpliy
vji = np.zeros(self.num_classes, dtype=int)
vij = np.zeros(self.num_classes, dtype=int)
for j in range(self.num_classes):
v_j = 0
for i in range(self.num_classes):
v_j += self.confusion_matrix[j][i]
vji[j] = v_j
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
for c in range(self.num_classes):
total = vji[c] + vij[c] - self.confusion_matrix[c][c]
if total == 0:
iou = 0
else:
iou = float(self.confusion_matrix[c][c]) / total
avg_iou += iou
iou_list.append(iou)
avg_iou = float(avg_iou) / float(self.num_classes)
return np.array(iou_list), avg_iou
def accuracy(self):
total = self.confusion_matrix.sum()
total_right = 0
for c in range(self.num_classes):
total_right += self.confusion_matrix[c][c]
if total == 0:
avg_acc = 0
else:
avg_acc = float(total_right) / total
vij = np.zeros(self.num_classes, dtype=int)
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
acc_list = []
for c in range(self.num_classes):
if vij[c] == 0:
acc = 0
else:
acc = self.confusion_matrix[c][c] / float(vij[c])
acc_list.append(acc)
return np.array(acc_list), avg_acc
def kappa(self):
vji = np.zeros(self.num_classes)
vij = np.zeros(self.num_classes)
for j in range(self.num_classes):
v_j = 0
for i in range(self.num_classes):
v_j += self.confusion_matrix[j][i]
vji[j] = v_j
for i in range(self.num_classes):
v_i = 0
for j in range(self.num_classes):
v_i += self.confusion_matrix[j][i]
vij[i] = v_i
total = self.confusion_matrix.sum()
# avoid spillovers
# TODO: is it reasonable to hard code 10000.0?
total = float(total) / 10000.0
vji = vji / 10000.0
vij = vij / 10000.0
tp = 0
tc = 0
for c in range(self.num_classes):
tp += vji[c] * vij[c]
tc += self.confusion_matrix[c][c]
tc = tc / 10000.0
pe = tp / (total * total)
po = tc / total
kappa = (po - pe) / (1 - pe)
return kappa
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import math
import cv2
import paddle.fluid as fluid
from . import logging
def seconds_to_hms(seconds):
h = math.floor(seconds / 3600)
m = math.floor((seconds - h * 3600) / 60)
s = int(seconds - h * 3600 - m * 60)
hms_str = "{}:{}:{}".format(h, m, s)
return hms_str
def get_environ_info():
info = dict()
info['place'] = 'cpu'
info['num'] = int(os.environ.get('CPU_NUM', 1))
if os.environ.get('CUDA_VISIBLE_DEVICES', None) != "":
if hasattr(fluid.core, 'get_cuda_device_count'):
gpu_num = 0
try:
gpu_num = fluid.core.get_cuda_device_count()
except:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
pass
if gpu_num > 0:
info['place'] = 'cuda'
info['num'] = fluid.core.get_cuda_device_count()
return info
def load_pretrained_model(model, pretrained_model):
if pretrained_model is not None:
logging.info('Load pretrained model!')
if os.path.exists(pretrained_model):
ckpt_path = os.path.join(pretrained_model, 'model')
para_state_dict, _ = fluid.load_dygraph(ckpt_path)
model_state_dict = model.state_dict()
keys = model_state_dict.keys()
num_params_loaded = 0
for k in keys:
if k not in para_state_dict:
logging.warning("{} is not in pretrained model".format(k))
elif list(para_state_dict[k].shape) != list(
model_state_dict[k].shape):
logging.warning(
"[SKIP] Shape of pretrained params {} doesn't match.(Pretrained: {}, Actual: {})"
.format(k, para_state_dict[k].shape,
model_state_dict[k].shape))
else:
model_state_dict[k] = para_state_dict[k]
num_params_loaded += 1
model.set_dict(model_state_dict)
logging.info("There are {}/{} varaibles are loaded.".format(
num_params_loaded, len(model_state_dict)))
else:
raise ValueError(
'The pretrained model directory is not Found: {}'.formnat(
pretrained_model))
def visualize(image, result, save_dir=None, weight=0.6):
"""
Convert segment result to color image, and save added image.
Args:
image: the path of origin image
result: the predict result of image
save_dir: the directory for saving visual image
weight: the image weight of visual image, and the result weight is (1 - weight)
"""
color_map = get_color_map_list(256)
color_map = np.array(color_map).astype("uint8")
# Use OpenCV LUT for color mapping
c1 = cv2.LUT(result, color_map[:, 0])
c2 = cv2.LUT(result, color_map[:, 1])
c3 = cv2.LUT(result, color_map[:, 2])
pseudo_img = np.dstack((c1, c2, c3))
im = cv2.imread(image)
vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
if save_dir is not None:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
image_name = os.path.split(image)[-1]
out_path = os.path.join(save_dir, image_name)
cv2.imwrite(out_path, vis_result)
else:
return vis_result
def get_color_map_list(num_classes):
""" Returns the color map for visualizing the segmentation mask,
which can support arbitrary number of classes.
Args:
num_classes: Number of classes
Returns:
The color map
"""
num_classes += 1
color_map = num_classes * [0, 0, 0]
for i in range(0, num_classes):
j = 0
lab = i
while lab:
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
j += 1
lab >>= 3
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
color_map = color_map[1:]
return color_map
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import math
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
from paddle.fluid.dataloader import BatchSampler
from datasets import OpticDiscSeg, Cityscapes
import transforms as T
import models
import utils.logging as logging
from utils import get_environ_info
from utils import ConfusionMatrix
def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')
# params of model
parser.add_argument(
'--model_name',
dest='model_name',
help="Model type for evaluation, which is one of ('UNet')",
type=str,
default='UNet')
# params of dataset
parser.add_argument(
'--dataset',
dest='dataset',
help=
"The dataset you want to evaluation, which is one of ('OpticDiscSeg', 'Cityscapes')",
type=str,
default='OpticDiscSeg')
# params of evaluate
parser.add_argument(
"--input_size",
dest="input_size",
help="The image size for net inputs.",
nargs=2,
default=[512, 512],
type=int)
parser.add_argument(
'--batch_size',
dest='batch_size',
help='Mini batch size',
type=int,
default=2)
parser.add_argument(
'--model_dir',
dest='model_dir',
help='The path of model for evaluation',
type=str,
default=None)
return parser.parse_args()
def evaluate(model,
eval_dataset=None,
places=None,
model_dir=None,
num_classes=None,
batch_size=2,
ignore_index=255,
epoch_id=None):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
batch_sampler = BatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
loader = DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
places=places,
return_list=True,
)
total_steps = math.ceil(len(eval_dataset) * 1.0 / batch_size)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logging.info(
"Start to evaluating(total_samples={}, total_steps={})...".format(
len(eval_dataset), total_steps))
for step, data in enumerate(loader):
images = data[0]
labels = data[1].astype('int64')
pred, _ = model(images, mode='eval')
pred = pred.numpy()
labels = labels.numpy()
mask = labels != ignore_index
conf_mat.calculate(pred=pred, label=labels, ignore=mask)
_, iou = conf_mat.mean_iou()
logging.info("[EVAL] Epoch={}, Step={}/{}, iou={}".format(
epoch_id, step + 1, total_steps, iou))
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logging.info("[EVAL] #image={} acc={:.4f} IoU={:.4f}".format(
len(eval_dataset), macc, miou))
logging.info("[EVAL] Category IoU: " + str(category_iou))
logging.info("[EVAL] Category Acc: " + str(category_acc))
logging.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
def main(args):
env_info = get_environ_info()
places = fluid.CUDAPlace(ParallelEnv().dev_id) \
if env_info['place'] == 'cuda' and fluid.is_compiled_with_cuda() \
else fluid.CPUPlace()
if args.dataset.lower() == 'opticdiscseg':
dataset = OpticDiscSeg
elif args.dataset.lower() == 'cityscapes':
dataset = Cityscapes
else:
raise Exception(
"The --dataset set wrong. It should be one of ('OpticDiscSeg', 'Cityscapes')"
)
with fluid.dygraph.guard(places):
eval_transforms = T.Compose([T.Resize(args.input_size), T.Normalize()])
eval_dataset = dataset(transforms=eval_transforms, mode='eval')
if args.model_name == 'UNet':
model = models.UNet(num_classes=eval_dataset.num_classes)
evaluate(
model,
eval_dataset,
places=places,
model_dir=args.model_dir,
num_classes=eval_dataset.num_classes,
batch_size=args.batch_size)
if __name__ == '__main__':
args = parse_args()
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册