提交 8d24e350 编写于 作者: M michaelowenliu

create paddleseg package with all apis

上级 288c09f9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import models
from . import datasets
from . import transforms
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .train import train
from .val import evaluate
from .infer import infer
__all__ = ['train', 'evaluate', 'infer']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddle.fluid.dygraph.base import to_variable
import numpy as np
import paddle.fluid as fluid
import cv2
import tqdm
from dygraph import utils
import dygraph.utils.logger as logger
def mkdir(path):
sub_dir = os.path.dirname(path)
if not os.path.exists(sub_dir):
os.makedirs(sub_dir)
def infer(model, test_dataset=None, model_dir=None, save_dir='output'):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
added_saved_dir = os.path.join(save_dir, 'added')
pred_saved_dir = os.path.join(save_dir, 'prediction')
logger.info("Start to predict...")
for im, im_info, im_path in tqdm.tqdm(test_dataset):
im = to_variable(im)
pred, _ = model(im)
pred = pred.numpy()
pred = np.squeeze(pred).astype('uint8')
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
im_file = im_path.replace(test_dataset.dataset_root, '')
if im_file[0] == '/':
im_file = im_file[1:]
# save added image
added_image = utils.visualize(im_path, pred, weight=0.6)
added_image_path = os.path.join(added_saved_dir, im_file)
mkdir(added_image_path)
cv2.imwrite(added_image_path, added_image)
# save prediction
pred_im = utils.visualize(im_path, pred, weight=0.0)
pred_saved_path = os.path.join(pred_saved_dir, im_file)
mkdir(pred_saved_path)
cv2.imwrite(pred_saved_path, pred_im)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import DataLoader
# from paddle.incubate.hapi.distributed import DistributedBatchSampler
from paddle.io import DistributedBatchSampler
import paddle.nn.functional as F
import dygraph.utils.logger as logger
from dygraph.utils import load_pretrained_model
from dygraph.utils import resume
from dygraph.utils import Timer, calculate_eta
from .val import evaluate
def check_logits_losses(logits, losses):
len_logits = len(logits)
len_losses = len(losses['types'])
if len_logits != len_losses:
raise RuntimeError(
'The length of logits should equal to the types of loss config: {} != {}.'
.format(len_logits, len_losses))
def loss_computation(logits, label, losses):
check_logits_losses(logits, losses)
loss = 0
for i in range(len(logits)):
logit = logits[i]
if logit.shape[-2:] != label.shape[-2:]:
logit = F.resize_bilinear(logit, label.shape[-2:])
loss_i = losses['types'][i](logit, label)
loss += losses['coef'][i] * loss_i
return loss
def train(model,
train_dataset,
places=None,
eval_dataset=None,
optimizer=None,
save_dir='output',
iters=10000,
batch_size=2,
resume_model=None,
save_interval_iters=1000,
log_iters=10,
num_classes=None,
num_workers=8,
use_vdl=False,
losses=None,
ignore_index=255):
nranks = ParallelEnv().nranks
start_iter = 0
if resume_model is not None:
start_iter = resume(model, optimizer, resume_model)
if not os.path.isdir(save_dir):
if os.path.exists(save_dir):
os.remove(save_dir)
os.makedirs(save_dir)
if nranks > 1:
strategy = fluid.dygraph.prepare_context()
ddp_model = fluid.dygraph.DataParallel(model, strategy)
batch_sampler = DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
places=places,
num_workers=num_workers,
return_list=True,
)
if use_vdl:
from visualdl import LogWriter
log_writer = LogWriter(save_dir)
timer = Timer()
avg_loss = 0.0
iters_per_epoch = len(batch_sampler)
best_mean_iou = -1.0
best_model_iter = -1
train_reader_cost = 0.0
train_batch_cost = 0.0
timer.start()
iter = start_iter
while iter < iters:
for data in loader:
iter += 1
if iter > iters:
break
train_reader_cost += timer.elapsed_time()
images = data[0]
labels = data[1].astype('int64')
if nranks > 1:
logits = ddp_model(images)
loss = loss_computation(logits, labels, losses)
# loss = ddp_model(images, labels)
# apply_collective_grads sum grads over multiple gpus.
loss = ddp_model.scale_loss(loss)
loss.backward()
ddp_model.apply_collective_grads()
else:
logits = model(images)
loss = loss_computation(logits, labels, losses)
# loss = model(images, labels)
loss.backward()
optimizer.minimize(loss)
model.clear_gradients()
avg_loss += loss.numpy()[0]
lr = optimizer.current_step_lr()
train_batch_cost += timer.elapsed_time()
if (iter) % log_iters == 0 and ParallelEnv().local_rank == 0:
avg_loss /= log_iters
avg_train_reader_cost = train_reader_cost / log_iters
avg_train_batch_cost = train_batch_cost / log_iters
train_reader_cost = 0.0
train_batch_cost = 0.0
remain_iters = iters - iter
eta = calculate_eta(remain_iters, avg_train_batch_cost)
logger.info(
"[TRAIN] epoch={}, iter={}/{}, loss={:.4f}, lr={:.6f}, batch_cost={:.4f}, reader_cost={:.4f} | ETA {}"
.format((iter - 1) // iters_per_epoch + 1, iter, iters,
avg_loss * nranks, lr, avg_train_batch_cost,
avg_train_reader_cost, eta))
if use_vdl:
log_writer.add_scalar('Train/loss', avg_loss * nranks, iter)
log_writer.add_scalar('Train/lr', lr, iter)
log_writer.add_scalar('Train/batch_cost',
avg_train_batch_cost, iter)
log_writer.add_scalar('Train/reader_cost',
avg_train_reader_cost, iter)
avg_loss = 0.0
if (iter % save_interval_iters == 0
or iter == iters) and ParallelEnv().local_rank == 0:
current_save_dir = os.path.join(save_dir,
"iter_{}".format(iter))
if not os.path.isdir(current_save_dir):
os.makedirs(current_save_dir)
fluid.save_dygraph(model.state_dict(),
os.path.join(current_save_dir, 'model'))
fluid.save_dygraph(optimizer.state_dict(),
os.path.join(current_save_dir, 'model'))
if eval_dataset is not None:
mean_iou, avg_acc = evaluate(
model,
eval_dataset,
model_dir=current_save_dir,
num_classes=num_classes,
ignore_index=ignore_index,
iter_id=iter)
if mean_iou > best_mean_iou:
best_mean_iou = mean_iou
best_model_iter = iter
best_model_dir = os.path.join(save_dir, "best_model")
fluid.save_dygraph(
model.state_dict(),
os.path.join(best_model_dir, 'model'))
logger.info(
'Current evaluated best model in eval_dataset is iter_{}, miou={:4f}'
.format(best_model_iter, best_mean_iou))
if use_vdl:
log_writer.add_scalar('Evaluate/mIoU', mean_iou, iter)
log_writer.add_scalar('Evaluate/aAcc', avg_acc, iter)
model.train()
timer.restart()
if use_vdl:
log_writer.close()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import tqdm
import cv2
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
import paddle.nn.functional as F
import paddle
import dygraph.utils.logger as logger
from dygraph.utils import ConfusionMatrix
from dygraph.utils import Timer, calculate_eta
def evaluate(model,
eval_dataset=None,
model_dir=None,
num_classes=None,
ignore_index=255,
iter_id=None):
ckpt_path = os.path.join(model_dir, 'model')
para_state_dict, opti_state_dict = fluid.load_dygraph(ckpt_path)
model.set_dict(para_state_dict)
model.eval()
total_iters = len(eval_dataset)
conf_mat = ConfusionMatrix(num_classes, streaming=True)
logger.info(
"Start to evaluating(total_samples={}, total_iters={})...".format(
len(eval_dataset), total_iters))
timer = Timer()
timer.start()
for iter, (im, im_info, label) in tqdm.tqdm(
enumerate(eval_dataset), total=total_iters):
im = to_variable(im)
# pred, _ = model(im)
logits = model(im)
pred = paddle.argmax(logits[0], axis=1)
pred = pred.numpy().astype('float32')
pred = np.squeeze(pred)
for info in im_info[::-1]:
if info[0] == 'resize':
h, w = info[1][0], info[1][1]
pred = cv2.resize(pred, (w, h), cv2.INTER_NEAREST)
elif info[0] == 'padding':
h, w = info[1][0], info[1][1]
pred = pred[0:h, 0:w]
else:
raise Exception("Unexpected info '{}' in im_info".format(
info[0]))
pred = pred[np.newaxis, :, :, np.newaxis]
pred = pred.astype('int64')
mask = label != ignore_index
conf_mat.calculate(pred=pred, label=label, ignore=mask)
_, iou = conf_mat.mean_iou()
time_iter = timer.elapsed_time()
remain_iter = total_iters - iter - 1
logger.debug(
"[EVAL] iter_id={}, iter={}/{}, iou={:4f}, sec/iter={:.4f} | ETA {}"
.format(iter_id, iter + 1, total_iters, iou, time_iter,
calculate_eta(remain_iter, time_iter)))
timer.restart()
category_iou, miou = conf_mat.mean_iou()
category_acc, macc = conf_mat.accuracy()
logger.info("[EVAL] #Images={} mAcc={:.4f} mIoU={:.4f}".format(
len(eval_dataset), macc, miou))
logger.info("[EVAL] Category IoU: " + str(category_iou))
logger.info("[EVAL] Category Acc: " + str(category_acc))
logger.info("[EVAL] Kappa:{:.4f} ".format(conf_mat.kappa()))
return miou, macc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import manager
from . import param_init
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
import inspect
class ComponentManager:
"""
Implement a manager class to add the new component properly.
The component can be added as either class or function type.
For example:
>>> model_manager = ComponentManager()
>>> class AlexNet: ...
>>> class ResNet: ...
>>> model_manager.add_component(AlexNet)
>>> model_manager.add_component(ResNet)
or pass a sequence alliteratively:
>>> model_manager.add_component([AlexNet, ResNet])
>>> print(model_manager.components_dict)
output: {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
Or an easier way, using it as a Python decorator, while just add it above the class declaration.
>>> model_manager = ComponentManager()
>>> @model_manager.add_component
>>> class AlexNet: ...
>>> @model_manager.add_component
>>> class ResNet: ...
>>> print(model_manager.components_dict)
output: {'AlexNet': <class '__main__.AlexNet'>, 'ResNet': <class '__main__.ResNet'>}
"""
def __init__(self):
self._components_dict = dict()
def __len__(self):
return len(self._components_dict)
def __repr__(self):
return "{}:{}".format(self.__class__.__name__,
list(self._components_dict.keys()))
def __getitem__(self, item):
if item not in self._components_dict.keys():
raise KeyError("{} does not exist in the current {}".format(
item, self))
return self._components_dict[item]
@property
def components_dict(self):
return self._components_dict
def _add_single_component(self, component):
"""
Add a single component into the corresponding manager
Args:
component (function | class): a new component
Returns:
None
"""
# Currently only support class or function type
if not (inspect.isclass(component) or inspect.isfunction(component)):
raise TypeError(
"Expect class/function type, but received {}".format(
type(component)))
# Obtain the internal name of the component
component_name = component.__name__
# Check whether the component was added already
if component_name in self._components_dict.keys():
raise KeyError("{} exists already!".format(component_name))
else:
# Take the internal name of the component as its key
self._components_dict[component_name] = component
def add_component(self, components):
"""
Add component(s) into the corresponding manager
Args:
components (function | class | list | tuple): support three types of components
Returns:
None
"""
# Check whether the type is a sequence
if isinstance(components, Sequence):
for component in components:
self._add_single_component(component)
else:
component = components
self._add_single_component(component)
return components
MODELS = ComponentManager()
BACKBONES = ComponentManager()
DATASETS = ComponentManager()
TRANSFORMS = ComponentManager()
LOSSES = ComponentManager()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
def constant_init(param, **kwargs):
initializer = fluid.initializer.Constant(**kwargs)
initializer(param, param.block)
def normal_init(param, **kwargs):
initializer = fluid.initializer.Normal(**kwargs)
initializer(param, param.block)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .dataset import Dataset
from .optic_disc_seg import OpticDiscSeg
from .cityscapes import Cityscapes
from .voc import PascalVOC
from .ade import ADE20K
DATASETS = {
"OpticDiscSeg": OpticDiscSeg,
"Cityscapes": Cityscapes,
"PascalVOC": PascalVOC,
"ADE20K": ADE20K
}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from PIL import Image
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip"
@manager.DATASETS.add_component
class ADE20K(Dataset):
"""ADE20K dataset `http://sceneparsing.csail.mit.edu/`.
Args:
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if `dataset_root` is None.
"""
def __init__(self,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 150
if mode.lower() not in ['train', 'val']:
raise Exception(
"`mode` should be one of ('train', 'val') in ADE20K dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
if self.dataset_root is None:
if not download:
raise Exception(
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='ADEChallengeData2016')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train':
img_dir = os.path.join(self.dataset_root, 'images/training')
grt_dir = os.path.join(self.dataset_root, 'annotations/training')
elif mode == 'val':
img_dir = os.path.join(self.dataset_root, 'images/validation')
grt_dir = os.path.join(self.dataset_root, 'annotations/validation')
img_files = os.listdir(img_dir)
grt_files = [i.replace('.jpg', '.png') for i in img_files]
for i in range(len(img_files)):
img_path = os.path.join(img_dir, img_files[i])
grt_path = os.path.join(grt_dir, grt_files[i])
self.file_list.append([img_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label - 1
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
else:
im, im_info, label = self.transforms(im=image_path, label=grt_path)
label = label - 1
return im, label
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import glob
from .dataset import Dataset
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
@manager.DATASETS.add_component
class Cityscapes(Dataset):
"""Cityscapes dataset `https://www.cityscapes-dataset.com/`.
The folder structure is as follow:
cityscapes
|
|--leftImg8bit
| |--train
| |--val
| |--test
|
|--gtFine
| |--train
| |--val
| |--test
Make sure there are **labelTrainIds.png in gtFine directory. If not, please run the conver_cityscapes.py in tools.
Args:
dataset_root: Cityscapes dataset directory.
mode: Which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
"""
def __init__(self, dataset_root, transforms=None, mode='train'):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 19
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
img_dir = os.path.join(self.dataset_root, 'leftImg8bit')
grt_dir = os.path.join(self.dataset_root, 'gtFine')
if self.dataset_root is None or not os.path.isdir(
self.dataset_root) or not os.path.isdir(
img_dir) or not os.path.isdir(grt_dir):
raise Exception(
"The dataset is not Found or the folder structure is nonconfoumance."
)
grt_files = sorted(
glob.glob(
os.path.join(grt_dir, mode, '*', '*_gtFine_labelTrainIds.png')))
img_files = sorted(
glob.glob(os.path.join(img_dir, mode, '*', '*_leftImg8bit.png')))
self.file_list = [[img_path, grt_path]
for img_path, grt_path in zip(img_files, grt_files)]
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
import numpy as np
from PIL import Image
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
@manager.DATASETS.add_component
class Dataset(fluid.io.Dataset):
"""Pass in a custom dataset that conforms to the format.
Args:
dataset_root: The dataset directory.
num_classes: Number of classes.
mode: which part of dataset to use. it is one of ('train', 'val', 'test'). Default: 'train'.
train_list: The train dataset file. When image_set is 'train', train_list is necessary.
The contents of train_list file are as follow:
image1.jpg ground_truth1.png
image2.jpg ground_truth2.png
val_list: The evaluation dataset file. When image_set is 'val', val_list is necessary.
The contents is the same as train_list
test_list: The test dataset file. When image_set is 'test', test_list is necessary.
The annotation file is not necessary in test_list file.
separator: The separator of dataset list. Default: ' '.
transforms: Transforms for image.
Examples:
todo
"""
def __init__(self,
dataset_root,
num_classes,
mode='train',
train_list=None,
val_list=None,
test_list=None,
separator=' ',
transforms=None):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = num_classes
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"mode should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
self.dataset_root = dataset_root
if not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train':
if train_list is None:
raise Exception(
'When `mode` is "train", `train_list` is necessary, but it is None.'
)
elif not os.path.exists(train_list):
raise Exception(
'`train_list` is not found: {}'.format(train_list))
else:
file_list = train_list
elif mode == 'val':
if val_list is None:
raise Exception(
'When `mode` is "val", `val_list` is necessary, but it is None.'
)
elif not os.path.exists(val_list):
raise Exception('`val_list` is not found: {}'.format(val_list))
else:
file_list = val_list
else:
if test_list is None:
raise Exception(
'When `mode` is "test", `test_list` is necessary, but it is None.'
)
elif not os.path.exists(test_list):
raise Exception(
'`test_list` is not found: {}'.format(test_list))
else:
file_list = test_list
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split(separator)
if len(items) != 2:
if mode == 'train' or mode == 'val':
raise Exception(
"File list format incorrect! In training or evaluation task it should be"
" image_name{}label_name\\n".format(separator))
image_path = os.path.join(self.dataset_root, items[0])
grt_path = None
else:
image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path])
def __getitem__(self, idx):
image_path, grt_path = self.file_list[idx]
if self.mode == 'test':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
return im, im_info, image_path
elif self.mode == 'val':
im, im_info, _ = self.transforms(im=image_path)
im = im[np.newaxis, ...]
label = np.asarray(Image.open(grt_path))
label = label[np.newaxis, np.newaxis, :, :]
return im, im_info, label
else:
im, im_info, label = self.transforms(im=image_path, label=grt_path)
return im, label
def __len__(self):
return len(self.file_list)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "https://paddleseg.bj.bcebos.com/dataset/optic_disc_seg.zip"
@manager.DATASETS.add_component
class OpticDiscSeg(Dataset):
def __init__(self,
dataset_root=None,
transforms=None,
mode='train',
download=True):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.file_list = list()
self.mode = mode
self.num_classes = 2
if mode.lower() not in ['train', 'val', 'test']:
raise Exception(
"`mode` should be 'train', 'val' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
if self.dataset_root is None:
if not download:
raise Exception(
"`data_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL, savepath=DATA_HOME, extrapath=DATA_HOME)
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
if mode == 'train':
file_list = os.path.join(self.dataset_root, 'train_list.txt')
elif mode == 'val':
file_list = os.path.join(self.dataset_root, 'val_list.txt')
else:
file_list = os.path.join(self.dataset_root, 'test_list.txt')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'val':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.dataset_root, items[0])
grt_path = None
else:
image_path = os.path.join(self.dataset_root, items[0])
grt_path = os.path.join(self.dataset_root, items[1])
self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from .dataset import Dataset
class Rice(Dataset):
def __init__(self, transforms=None, mode='train', download=True):
self.data_dir = "/mnt/liuyi22/PaddlePaddle/POC/rice_dataset"
self.transforms = transforms
self.file_list = list()
self.mode = mode
self.num_classes = 2
if mode.lower() not in ['train', 'eval', 'test']:
raise Exception(
"mode should be 'train', 'eval' or 'test', but got {}.".format(
mode))
if self.transforms is None:
raise Exception("transform is necessary, but it is None.")
if mode == 'train':
file_list = os.path.join(self.data_dir, 'train_list.txt')
elif mode == 'eval':
file_list = os.path.join(self.data_dir, 'val_list.txt')
else:
file_list = os.path.join(self.data_dir, 'test_list.txt')
with open(file_list, 'r') as f:
for line in f:
items = line.strip().split()
if len(items) != 2:
if mode == 'train' or mode == 'eval':
raise Exception(
"File list format incorrect! It should be"
" image_name label_name\\n")
image_path = os.path.join(self.data_dir, items[0])
grt_path = None
else:
image_path = os.path.join(self.data_dir, items[0])
grt_path = os.path.join(self.data_dir, items[1])
self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from .dataset import Dataset
from paddleseg.utils.download import download_file_and_uncompress
from paddleseg.cvlibs import manager
from paddleseg.transforms import Compose
DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')
URL = "http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"
@manager.DATASETS.add_component
class PascalVOC(Dataset):
"""Pascal VOC dataset `http://host.robots.ox.ac.uk/pascal/VOC/`. If you want to augment the dataset,
please run the voc_augment.py in tools.
Args:
dataset_root: The dataset directory.
mode: Which part of dataset to use.. it is one of ('train', 'val', 'test'). Default: 'train'.
transforms: Transforms for image.
download: Whether to download dataset if dataset_root is None.
"""
def __init__(self,
dataset_root=None,
mode='train',
transforms=None,
download=True):
self.dataset_root = dataset_root
self.transforms = Compose(transforms)
self.mode = mode
self.file_list = list()
self.num_classes = 21
if mode.lower() not in ['train', 'trainval', 'trainaug', 'val']:
raise Exception(
"`mode` should be one of ('train', 'trainval', 'trainaug', 'val') in PascalVOC dataset, but got {}."
.format(mode))
if self.transforms is None:
raise Exception("`transforms` is necessary, but it is None.")
if self.dataset_root is None:
if not download:
raise Exception(
"`dataset_root` not set and auto download disabled.")
self.dataset_root = download_file_and_uncompress(
url=URL,
savepath=DATA_HOME,
extrapath=DATA_HOME,
extraname='VOCdevkit')
elif not os.path.exists(self.dataset_root):
raise Exception('there is not `dataset_root`: {}.'.format(
self.dataset_root))
image_set_dir = os.path.join(self.dataset_root, 'VOC2012', 'ImageSets',
'Segmentation')
if mode == 'train':
file_list = os.path.join(image_set_dir, 'train.txt')
elif mode == 'val':
file_list = os.path.join(image_set_dir, 'val.txt')
elif mode == 'trainval':
file_list = os.path.join(image_set_dir, 'trainval.txt')
elif mode == 'trainaug':
file_list = os.path.join(image_set_dir, 'train.txt')
file_list_aug = os.path.join(image_set_dir, 'aug.txt')
if not os.path.exists(file_list_aug):
raise Exception(
"When `mode` is 'trainaug', Pascal Voc dataset should be augmented, "
"Please make sure voc_augment.py has been properly run when using this mode."
)
img_dir = os.path.join(self.dataset_root, 'VOC2012', 'JPEGImages')
grt_dir = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClass')
grt_dir_aug = os.path.join(self.dataset_root, 'VOC2012',
'SegmentationClassAug')
with open(file_list, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir, ''.join([line, '.png']))
self.file_list.append([image_path, grt_path])
if mode == 'trainaug':
with open(file_list_aug, 'r') as f:
for line in f:
line = line.strip()
image_path = os.path.join(img_dir, ''.join([line, '.jpg']))
grt_path = os.path.join(grt_dir_aug, ''.join([line,
'.png']))
self.file_list.append([image_path, grt_path])
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .backbones import *
from .losses import *
from .unet import UNet
from .deeplab import *
from .fcn import *
from .pspnet import *
from .ocrnet import *
from .fast_scnn import *
from .gcnet import *
from .ann import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_utils, model_utils
from paddleseg.utils import utils
@manager.MODELS.add_component
class ANN(nn.Layer):
"""
The ANN implementation based on PaddlePaddle.
The orginal artile refers to
Zhen, Zhu, et al. "Asymmetric Non-local Neural Networks for Semantic Segmentation."
(https://arxiv.org/pdf/1908.07678.pdf)
It mainly consists of AFNB and APNB modules.
Args:
num_classes (int): the unique number of target classes.
backbone (Paddle.nn.Layer): backbone network, currently support Resnet50/101.
model_pretrained (str): the path of pretrained model. Defaullt to None.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as low-level features; the second one will be
taken as high-level features in AFNB module. Usually backbone consists of four
downsampling stage, and return an output of each stage, so we set default (2, 3),
which means taking feature map of the third stage and the fourth stage in backbone.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
key_value_channels (int): the key and value channels of self-attention map in both AFNB and APNB modules.
Default to 256.
inter_channels (int): both input and output channels of APNB modules.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss. Default to True.
"""
def __init__(self,
num_classes,
backbone,
model_pretrained=None,
backbone_indices=(2, 3),
backbone_channels=(1024, 2048),
key_value_channels=256,
inter_channels=512,
psp_size=(1, 3, 6, 8),
enable_auxiliary_loss=True):
super(ANN, self).__init__()
self.backbone = backbone
low_in_channels = backbone_channels[0]
high_in_channels = backbone_channels[1]
self.fusion = AFNB(
low_in_channels=low_in_channels,
high_in_channels=high_in_channels,
out_channels=high_in_channels,
key_channels=key_value_channels,
value_channels=key_value_channels,
dropout_prob=0.05,
sizes=([1]),
psp_size=psp_size)
self.context = nn.Sequential(
layer_utils.ConvBnRelu(
in_channels=high_in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1),
APNB(
in_channels=inter_channels,
out_channels=inter_channels,
key_channels=key_value_channels,
value_channels=key_value_channels,
dropout_prob=0.05,
sizes=([1]),
psp_size=psp_size))
self.cls = nn.Conv2d(
in_channels=inter_channels,
out_channels=num_classes,
kernel_size=1)
self.auxlayer = model_utils.AuxLayer(
in_channels=low_in_channels,
inter_channels=low_in_channels // 2,
out_channels=num_classes,
dropout_prob=0.05)
self.backbone_indices = backbone_indices
self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained)
def forward(self, input, label=None):
logit_list = []
_, feat_list = self.backbone(input)
low_level_x = feat_list[self.backbone_indices[0]]
high_level_x = feat_list[self.backbone_indices[1]]
x = self.fusion(low_level_x, high_level_x)
x = self.context(x)
logit = self.cls(x)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit)
if self.enable_auxiliary_loss:
auxiliary_logit = self.auxlayer(low_level_x)
auxiliary_logit = F.resize_bilinear(auxiliary_logit, input.shape[2:])
logit_list.append(auxiliary_logit)
return logit_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the pretrained_model path of backbone. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self.backbone, pretrained_model)
class AFNB(nn.Layer):
"""
Asymmetric Fusion Non-local Block
Args:
low_in_channels (int): low-level-feature channels.
high_in_channels (int): high-level-feature channels.
out_channels (int): out channels of AFNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
dropout_prob (float): the dropout rate of output.
sizes (tuple): the number of AFNB modules. Default to ([1]).
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def __init__(self,
low_in_channels,
high_in_channels,
out_channels,
key_channels,
value_channels,
dropout_prob,
sizes=([1]),
psp_size=(1, 3, 6, 8)):
super(AFNB, self).__init__()
self.psp_size = psp_size
self.stages = nn.LayerList([
SelfAttentionBlock_AFNB(low_in_channels, high_in_channels,
key_channels, value_channels, out_channels,
size) for size in sizes
])
self.conv_bn = layer_utils.ConvBn(
in_channels=out_channels + high_in_channels,
out_channels=out_channels,
kernel_size=1)
self.dropout_prob = dropout_prob
def forward(self, low_feats, high_feats):
priors = [stage(low_feats, high_feats) for stage in self.stages]
context = priors[0]
for i in range(1, len(priors)):
context += priors[i]
output = self.conv_bn(paddle.concat([context, high_feats], axis=1))
output = F.dropout(output, p=self.dropout_prob) # dropout_prob
return output
class APNB(nn.Layer):
"""
Asymmetric Pyramid Non-local Block
Args:
in_channels (int): the input channels of APNB module.
out_channels (int): out channels of APNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
dropout_prob (float): the dropout rate of output.
sizes (tuple): the number of AFNB modules. Default to ([1]).
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def __init__(self,
in_channels,
out_channels,
key_channels,
value_channels,
dropout_prob,
sizes=([1]),
psp_size=(1, 3, 6, 8)):
super(APNB, self).__init__()
self.psp_size = psp_size
self.stages = nn.LayerList([
SelfAttentionBlock_APNB(in_channels, out_channels, key_channels,
value_channels, size) for size in sizes
])
self.conv_bn = layer_utils.ConvBnRelu(
in_channels=in_channels * 2,
out_channels=out_channels,
kernel_size=1)
self.dropout_prob = dropout_prob
def forward(self, feats):
priors = [stage(feats) for stage in self.stages]
context = priors[0]
for i in range(1, len(priors)):
context += priors[i]
output = self.conv_bn(paddle.concat([context, feats], axis=1))
output = F.dropout(output, p=self.dropout_prob) # dropout_prob
return output
def _pp_module(x, psp_size):
n, c, h, w = x.shape
priors = []
for size in psp_size:
feat = F.adaptive_pool2d(x, pool_size=size, pool_type="avg")
feat = paddle.reshape(feat, shape=(n, c, -1))
priors.append(feat)
center = paddle.concat(priors, axis=-1)
return center
class SelfAttentionBlock_AFNB(nn.Layer):
"""
Self-Attention Block for AFNB module.
Args:
low_in_channels (int): low-level-feature channels.
high_in_channels (int): high-level-feature channels.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
out_channels (int): out channels of AFNB module.
scale (int): pooling size. Defaut to 1.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def __init__(self,
low_in_channels,
high_in_channels,
key_channels,
value_channels,
out_channels=None,
scale=1,
psp_size=(1, 3, 6, 8)):
super(SelfAttentionBlock_AFNB, self).__init__()
self.scale = scale
self.in_channels = low_in_channels
self.out_channels = out_channels
self.key_channels = key_channels
self.value_channels = value_channels
if out_channels == None:
self.out_channels = high_in_channels
self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
self.f_key = layer_utils.ConvBnRelu(
in_channels=low_in_channels,
out_channels=key_channels,
kernel_size=1)
self.f_query = layer_utils.ConvBnRelu(
in_channels=high_in_channels,
out_channels=key_channels,
kernel_size=1)
self.f_value = nn.Conv2d(
in_channels=low_in_channels,
out_channels=value_channels,
kernel_size=1)
self.W = nn.Conv2d(
in_channels=value_channels,
out_channels=out_channels,
kernel_size=1)
self.psp_size = psp_size
def forward(self, low_feats, high_feats):
batch_size, _, h, w = high_feats.shape
value = self.f_value(low_feats)
value = _pp_module(value, self.psp_size)
value = paddle.transpose(value, (0, 2, 1))
query = self.f_query(high_feats)
query = paddle.reshape(query, shape=(batch_size, self.key_channels, -1))
query = paddle.transpose(query, perm=(0, 2, 1))
key = self.f_key(low_feats)
key = _pp_module(key, self.psp_size)
sim_map = paddle.matmul(query, key)
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, axis=-1)
context = paddle.matmul(sim_map, value)
context = paddle.transpose(context, perm=(0, 2, 1))
context = paddle.reshape(
context,
shape=[batch_size, self.value_channels, *high_feats.shape[2:]])
context = self.W(context)
return context
class SelfAttentionBlock_APNB(nn.Layer):
"""
Self-Attention Block for APNB module.
Args:
in_channels (int): the input channels of APNB module.
out_channels (int): out channels of APNB module.
key_channels (int): the key channels in self-attention block.
value_channels (int): the value channels in self-attention block.
scale (int): pooling size. Defaut to 1.
psp_size (tuple): the out size of pooled feature maps. Default to (1, 3, 6, 8).
"""
def __init__(self,
in_channels,
out_channels,
key_channels,
value_channels,
scale=1,
psp_size=(1, 3, 6, 8)):
super(SelfAttentionBlock_APNB, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.out_channels = out_channels
self.key_channels = key_channels
self.value_channels = value_channels
self.pool = nn.Pool2D(pool_size=(scale, scale), pool_type="max")
self.f_key = layer_utils.ConvBnRelu(
in_channels=self.in_channels,
out_channels=self.key_channels,
kernel_size=1)
self.f_query = self.f_key
self.f_value = nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.value_channels,
kernel_size=1)
self.W = nn.Conv2d(
in_channels=self.value_channels,
out_channels=self.out_channels,
kernel_size=1)
self.psp_size = psp_size
def forward(self, x):
batch_size, _, h, w = x.shape
if self.scale > 1:
x = self.pool(x)
value = self.f_value(x)
value = _pp_module(value, self.psp_size)
value = paddle.transpose(value, perm=(0, 2, 1))
query = self.f_query(x)
query = paddle.reshape(
query, shape=(batch_size, self.key_channels, -1))
query = paddle.transpose(query, perm=(0, 2, 1))
key = self.f_key(x)
key = _pp_module(key, self.psp_size)
sim_map = paddle.matmul(query, key)
sim_map = (self.key_channels ** -.5) * sim_map
sim_map = F.softmax(sim_map, axis=-1)
context = paddle.matmul(sim_map, value)
context = paddle.transpose(context, perm=(0, 2, 1))
context = paddle.reshape(
context, shape=[batch_size, self.value_channels, *x.shape[2:]])
context = self.W(context)
return context
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hrnet import *
from .resnet_vd import *
from .xception_deeplab import *
from .mobilenetv3 import *
此差异已折叠。
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.models.common import layer_utils
from paddleseg.cvlibs import manager
from paddleseg.utils import utils
__all__ = [
"MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5",
"MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0",
"MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35",
"MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75",
"MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25"
]
def make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def get_padding_same(kernel_size, dilation_rate):
"""
SAME padding implementation given kernel_size and dilation_rate.
The calculation formula as following:
(F-(k+(k -1)*(r-1))+2*p)/s + 1 = F_new
where F: a feature map
k: kernel size, r: dilation rate, p: padding value, s: stride
F_new: new feature map
Args:
kernel_size (int)
dilation_rate (int)
Returns:
padding_same (int): padding value
"""
k = kernel_size
r = dilation_rate
padding_same = (k + (k - 1) * (r - 1) - 1) // 2
return padding_same
class MobileNetV3(fluid.dygraph.Layer):
def __init__(self,
backbone_pretrained=None,
scale=1.0,
model_name="small",
class_dim=1000,
output_stride=None):
super(MobileNetV3, self).__init__()
inplanes = 16
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, "relu", 1],
[3, 64, 24, False, "relu", 2],
[3, 72, 24, False, "relu", 1], # output 1 -> out_index=2
[5, 72, 40, True, "relu", 2],
[5, 120, 40, True, "relu", 1],
[5, 120, 40, True, "relu", 1], # output 2 -> out_index=5
[3, 240, 80, False, "hard_swish", 2],
[3, 200, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 184, 80, False, "hard_swish", 1],
[3, 480, 112, True, "hard_swish", 1],
[3, 672, 112, True, "hard_swish",
1], # output 3 -> out_index=11
[5, 672, 160, True, "hard_swish", 2],
[5, 960, 160, True, "hard_swish", 1],
[5, 960, 160, True, "hard_swish",
1], # output 3 -> out_index=14
]
self.out_indices = [2, 5, 11, 14]
self.cls_ch_squeeze = 960
self.cls_ch_expand = 1280
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, "relu", 2], # output 1 -> out_index=0
[3, 72, 24, False, "relu", 2],
[3, 88, 24, False, "relu", 1], # output 2 -> out_index=3
[5, 96, 40, True, "hard_swish", 2],
[5, 240, 40, True, "hard_swish", 1],
[5, 240, 40, True, "hard_swish", 1],
[5, 120, 48, True, "hard_swish", 1],
[5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7
[5, 288, 96, True, "hard_swish", 2],
[5, 576, 96, True, "hard_swish", 1],
[5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10
]
self.out_indices = [0, 3, 7, 10]
self.cls_ch_squeeze = 576
self.cls_ch_expand = 1280
else:
raise NotImplementedError(
"mode[{}_model] is not implemented!".format(model_name))
###################################################
# modify stride and dilation based on output_stride
self.dilation_cfg = [1] * len(self.cfg)
self.modify_bottle_params(output_stride=output_stride)
###################################################
self.conv1 = ConvBNLayer(
in_c=3,
out_c=make_divisible(inplanes * scale),
filter_size=3,
stride=2,
padding=1,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv1")
self.block_list = []
inplanes = make_divisible(inplanes * scale)
for i, (k, exp, c, se, nl, s) in enumerate(self.cfg):
######################################
# add dilation rate
dilation_rate = self.dilation_cfg[i]
######################################
self.block_list.append(
ResidualUnit(
in_c=inplanes,
mid_c=make_divisible(scale * exp),
out_c=make_divisible(scale * c),
filter_size=k,
stride=s,
dilation=dilation_rate,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
self.add_sublayer(
sublayer=self.block_list[-1], name="conv" + str(i + 2))
inplanes = make_divisible(scale * c)
self.last_second_conv = ConvBNLayer(
in_c=inplanes,
out_c=make_divisible(scale * self.cls_ch_squeeze),
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
act="hard_swish",
name="conv_last")
self.pool = Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.last_conv = Conv2D(
num_channels=make_divisible(scale * self.cls_ch_squeeze),
num_filters=self.cls_ch_expand,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name="last_1x1_conv_weights"),
bias_attr=False)
self.out = Linear(
input_dim=self.cls_ch_expand,
output_dim=class_dim,
param_attr=ParamAttr("fc_weights"),
bias_attr=ParamAttr(name="fc_offset"))
self.init_weight(backbone_pretrained)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is not None:
stride = 2
rate = 1
for i, _cfg in enumerate(self.cfg):
stride = stride * _cfg[-1]
if stride > output_stride:
rate = rate * _cfg[-1]
self.cfg[i][-1] = 1
self.dilation_cfg[i] = rate
def forward(self, inputs, label=None, dropout_prob=0.2):
x = self.conv1(inputs)
# A feature list saves each downsampling feature.
feat_list = []
for i, block in enumerate(self.block_list):
x = block(x)
if i in self.out_indices:
feat_list.append(x)
#print("block {}:".format(i),x.shape, self.dilation_cfg[i])
x = self.last_second_conv(x)
x = self.pool(x)
x = self.last_conv(x)
x = fluid.layers.hard_swish(x)
x = fluid.layers.dropout(x=x, dropout_prob=dropout_prob)
x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]])
x = self.out(x)
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
in_c,
out_c,
filter_size,
stride,
padding,
dilation=1,
num_groups=1,
if_act=True,
act=None,
use_cudnn=True,
name=""):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
self.conv = fluid.dygraph.Conv2D(
num_channels=in_c,
num_filters=out_c,
filter_size=filter_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=num_groups,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False,
use_cudnn=use_cudnn,
act=None)
self.bn = BatchNorm(
num_features=out_c,
weight_attr=ParamAttr(
name=name + "_bn_scale",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)),
bias_attr=ParamAttr(
name=name + "_bn_offset",
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0)))
self._act_op = layer_utils.Activation(act=None)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.if_act:
if self.act == "relu":
x = fluid.layers.relu(x)
elif self.act == "hard_swish":
x = fluid.layers.hard_swish(x)
else:
print("The activation function is selected incorrectly.")
exit()
return x
class ResidualUnit(fluid.dygraph.Layer):
def __init__(self,
in_c,
mid_c,
out_c,
filter_size,
stride,
use_se,
dilation=1,
act=None,
name=''):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_c == out_c
self.if_se = use_se
self.expand_conv = ConvBNLayer(
in_c=in_c,
out_c=mid_c,
filter_size=1,
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer(
in_c=mid_c,
out_c=mid_c,
filter_size=filter_size,
stride=stride,
padding=get_padding_same(
filter_size,
dilation), #int((filter_size - 1) // 2) + (dilation - 1),
dilation=dilation,
num_groups=mid_c,
if_act=True,
act=act,
name=name + "_depthwise")
if self.if_se:
self.mid_se = SEModule(mid_c, name=name + "_se")
self.linear_conv = ConvBNLayer(
in_c=mid_c,
out_c=out_c,
filter_size=1,
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
self.dilation = dilation
def forward(self, inputs):
x = self.expand_conv(inputs)
x = self.bottleneck_conv(x)
if self.if_se:
x = self.mid_se(x)
x = self.linear_conv(x)
if self.if_shortcut:
x = fluid.layers.elementwise_add(inputs, x)
return x
class SEModule(fluid.dygraph.Layer):
def __init__(self, channel, reduction=4, name=""):
super(SEModule, self).__init__()
self.avg_pool = fluid.dygraph.Pool2D(
pool_type="avg", global_pooling=True, use_cudnn=False)
self.conv1 = fluid.dygraph.Conv2D(
num_channels=channel,
num_filters=channel // reduction,
filter_size=1,
stride=1,
padding=0,
act="relu",
param_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = fluid.dygraph.Conv2D(
num_channels=channel // reduction,
num_filters=channel,
filter_size=1,
stride=1,
padding=0,
act=None,
param_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = self.conv2(outputs)
outputs = fluid.layers.hard_sigmoid(outputs)
return fluid.layers.elementwise_mul(x=inputs, y=outputs, axis=0)
def MobileNetV3_small_x0_35(**kwargs):
model = MobileNetV3(model_name="small", scale=0.35, **kwargs)
return model
def MobileNetV3_small_x0_5(**kwargs):
model = MobileNetV3(model_name="small", scale=0.5, **kwargs)
return model
def MobileNetV3_small_x0_75(**kwargs):
model = MobileNetV3(model_name="small", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_small_x1_0(**kwargs):
model = MobileNetV3(model_name="small", scale=1.0, **kwargs)
return model
def MobileNetV3_small_x1_25(**kwargs):
model = MobileNetV3(model_name="small", scale=1.25, **kwargs)
return model
def MobileNetV3_large_x0_35(**kwargs):
model = MobileNetV3(model_name="large", scale=0.35, **kwargs)
return model
def MobileNetV3_large_x0_5(**kwargs):
model = MobileNetV3(model_name="large", scale=0.5, **kwargs)
return model
def MobileNetV3_large_x0_75(**kwargs):
model = MobileNetV3(model_name="large", scale=0.75, **kwargs)
return model
@manager.BACKBONES.add_component
def MobileNetV3_large_x1_0(**kwargs):
model = MobileNetV3(model_name="large", scale=1.0, **kwargs)
return model
def MobileNetV3_large_x1_25(**kwargs):
model = MobileNetV3(model_name="large", scale=1.25, **kwargs)
return model
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.utils import utils
from paddleseg.models.common import layer_utils
from paddleseg.cvlibs import manager
__all__ = [
"ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd"
]
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(
self,
num_channels,
num_filters,
filter_size,
stride=1,
dilation=1,
groups=1,
is_vd_mode=False,
act=None,
name=None,
):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = Pool2D(
pool_size=2,
pool_stride=2,
pool_padding=0,
pool_type='avg',
ceil_mode=True)
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
act=None,
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False)
if name == "conv1":
bn_name = "bn_" + name
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters,
weight_attr=ParamAttr(name=bn_name + '_scale'),
bias_attr=ParamAttr(bn_name + '_offset'))
self._act_op = layer_utils.Activation(act=act)
def forward(self, inputs):
if self.is_vd_mode:
inputs = self._pool2d_avg(inputs)
y = self._conv(inputs)
y = self._batch_norm(y)
y = self._act_op(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
dilation=1,
name=None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu',
name=name + "_branch2a")
self.dilation = dilation
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=1,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
####################################################################
# If given dilation rate > 1, using corresponding padding
if self.dilation > 1:
padding = self.dilation
y = fluid.layers.pad(
y, [0, 0, 0, 0, padding, padding, padding, padding])
#####################################################################
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class BasicBlock(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv1)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet_vd(fluid.dygraph.Layer):
def __init__(self,
backbone_pretrained=None,
layers=50,
class_dim=1000,
output_stride=None,
multi_grid=(1, 2, 4)):
super(ResNet_vd, self).__init__()
self.layers = layers
supported_layers = [18, 34, 50, 101, 152, 200]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(
supported_layers, layers)
if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
elif layers == 200:
depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024
] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
dilation_dict = None
if output_stride == 8:
dilation_dict = {2: 2, 3: 4}
elif output_stride == 16:
dilation_dict = {3: 2}
self.conv1_1 = ConvBNLayer(
num_channels=3,
num_filters=32,
filter_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = ConvBNLayer(
num_channels=32,
num_filters=32,
filter_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = ConvBNLayer(
num_channels=32,
num_filters=64,
filter_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = Pool2D(
pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
# self.block_list = []
self.stage_list = []
if layers >= 50:
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
if layers in [101, 152] and block == 2:
if i == 0:
conv_name = "res" + str(block + 2) + "a"
else:
conv_name = "res" + str(block + 2) + "b" + str(i)
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate using multi_grid, default (1, 2, 4)
if block == 3:
dilation_rate = dilation_rate * multi_grid[i]
#print("stage {}, block {}: dilation rate".format(block, i), dilation_rate)
###############################################################################
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
dilation=dilation_rate))
block_list.append(bottleneck_block)
shortcut = True
self.stage_list.append(block_list)
else:
for block in range(len(depth)):
shortcut = False
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BasicBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block],
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name))
block_list.append(basic_block)
shortcut = True
self.stage_list.append(block_list)
self.pool2d_avg = Pool2D(
pool_size=7, pool_type='avg', global_pooling=True)
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
param_attr=ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv),
name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
y = self.conv1_1(inputs)
y = self.conv1_2(y)
y = self.conv1_3(y)
y = self.pool2d_max(y)
# A feature list saves the output feature map of each stage.
feat_list = []
for i, stage in enumerate(self.stage_list):
for j, block in enumerate(stage):
y = block(y)
#print("stage {} block {}".format(i+1, j+1), y.shape)
feat_list.append(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y)
return y, feat_list
# def init_weight(self, pretrained_model=None):
# if pretrained_model is not None:
# if os.path.exists(pretrained_model):
# utils.load_pretrained_model(self, pretrained_model)
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def ResNet18_vd(**args):
model = ResNet_vd(layers=18, **args)
return model
def ResNet34_vd(**args):
model = ResNet_vd(layers=34, **args)
return model
@manager.BACKBONES.add_component
def ResNet50_vd(**args):
model = ResNet_vd(layers=50, **args)
return model
@manager.BACKBONES.add_component
def ResNet101_vd(**args):
model = ResNet_vd(layers=101, **args)
return model
def ResNet152_vd(**args):
model = ResNet_vd(layers=152, **args)
return model
def ResNet200_vd(**args):
model = ResNet_vd(layers=200, **args)
return model
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, Dropout
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.models.common import layer_utils
from paddleseg.cvlibs import manager
from paddleseg.utils import utils
__all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"]
def check_data(data, number):
if type(data) == int:
return [data] * number
assert len(data) == number
return data
def check_stride(s, os):
if s <= os:
return True
else:
return False
def check_points(count, points):
if points is None:
return False
else:
if isinstance(points, list):
return (True if count in points else False)
else:
return (True if count == points else False)
def gen_bottleneck_params(backbone='xception_65'):
if backbone == 'xception_65':
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
elif backbone == 'xception_41':
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (8, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
elif backbone == 'xception_71':
bottleneck_params = {
"entry_flow": (5, [2, 1, 2, 1, 2], [128, 256, 256, 728, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
else:
raise Exception(
"xception backbont only support xception_41/xception_65/xception_71"
)
return bottleneck_params
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
filter_size,
stride=1,
padding=0,
act=None,
name=None):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=input_channels,
num_filters=output_channels,
filter_size=filter_size,
stride=stride,
padding=padding,
param_attr=ParamAttr(name=name + "/weights"),
bias_attr=False)
self._bn = BatchNorm(
num_features=output_channels,
epsilon=1e-3,
momentum=0.99,
weight_attr=ParamAttr(name=name + "/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/BatchNorm/beta"))
self._act_op = layer_utils.Activation(act=act)
def forward(self, inputs):
return self._act_op(self._bn(self._conv(inputs)))
class Seperate_Conv(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
stride,
filter,
dilation=1,
act=None,
name=None):
super(Seperate_Conv, self).__init__()
self._conv1 = Conv2D(
num_channels=input_channels,
num_filters=input_channels,
filter_size=filter,
stride=stride,
groups=input_channels,
padding=(filter) // 2 * dilation,
dilation=dilation,
param_attr=ParamAttr(name=name + "/depthwise/weights"),
bias_attr=False)
self._bn1 = BatchNorm(
input_channels,
epsilon=1e-3,
momentum=0.99,
weight_attr=ParamAttr(name=name + "/depthwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/depthwise/BatchNorm/beta"))
self._act_op1 = layer_utils.Activation(act=act)
self._conv2 = Conv2D(
input_channels,
output_channels,
1,
stride=1,
groups=1,
padding=0,
param_attr=ParamAttr(name=name + "/pointwise/weights"),
bias_attr=False)
self._bn2 = BatchNorm(
output_channels,
epsilon=1e-3,
momentum=0.99,
weight_attr=ParamAttr(name=name + "/pointwise/BatchNorm/gamma"),
bias_attr=ParamAttr(name=name + "/pointwise/BatchNorm/beta"))
self._act_op2 = layer_utils.Activation(act=act)
def forward(self, inputs):
x = self._conv1(inputs)
x = self._bn1(x)
x = self._act_op1(x)
x = self._conv2(x)
x = self._bn2(x)
x = self._act_op2(x)
return x
class Xception_Block(fluid.dygraph.Layer):
def __init__(self,
input_channels,
output_channels,
strides=1,
filter_size=3,
dilation=1,
skip_conv=True,
has_skip=True,
activation_fn_in_separable_conv=False,
name=None):
super(Xception_Block, self).__init__()
repeat_number = 3
output_channels = check_data(output_channels, repeat_number)
filter_size = check_data(filter_size, repeat_number)
strides = check_data(strides, repeat_number)
self.has_skip = has_skip
self.skip_conv = skip_conv
self.activation_fn_in_separable_conv = activation_fn_in_separable_conv
if not activation_fn_in_separable_conv:
self._conv1 = Seperate_Conv(
input_channels,
output_channels[0],
stride=strides[0],
filter=filter_size[0],
dilation=dilation,
name=name + "/separable_conv1")
self._conv2 = Seperate_Conv(
output_channels[0],
output_channels[1],
stride=strides[1],
filter=filter_size[1],
dilation=dilation,
name=name + "/separable_conv2")
self._conv3 = Seperate_Conv(
output_channels[1],
output_channels[2],
stride=strides[2],
filter=filter_size[2],
dilation=dilation,
name=name + "/separable_conv3")
else:
self._conv1 = Seperate_Conv(
input_channels,
output_channels[0],
stride=strides[0],
filter=filter_size[0],
act="relu",
dilation=dilation,
name=name + "/separable_conv1")
self._conv2 = Seperate_Conv(
output_channels[0],
output_channels[1],
stride=strides[1],
filter=filter_size[1],
act="relu",
dilation=dilation,
name=name + "/separable_conv2")
self._conv3 = Seperate_Conv(
output_channels[1],
output_channels[2],
stride=strides[2],
filter=filter_size[2],
act="relu",
dilation=dilation,
name=name + "/separable_conv3")
if has_skip and skip_conv:
self._short = ConvBNLayer(
input_channels,
output_channels[-1],
1,
stride=strides[-1],
padding=0,
name=name + "/shortcut")
def forward(self, inputs):
layer_helper = LayerHelper(self.full_name(), act='relu')
if not self.activation_fn_in_separable_conv:
x = layer_helper.append_activation(inputs)
x = self._conv1(x)
x = layer_helper.append_activation(x)
x = self._conv2(x)
x = layer_helper.append_activation(x)
x = self._conv3(x)
else:
x = self._conv1(inputs)
x = self._conv2(x)
x = self._conv3(x)
if self.has_skip is False:
return x
if self.skip_conv:
skip = self._short(inputs)
else:
skip = inputs
return fluid.layers.elementwise_add(x, skip)
class XceptionDeeplab(fluid.dygraph.Layer):
#def __init__(self, backbone, class_dim=1000):
# add output_stride
def __init__(self,
backbone,
backbone_pretrained=None,
output_stride=16,
class_dim=1000):
super(XceptionDeeplab, self).__init__()
bottleneck_params = gen_bottleneck_params(backbone)
self.backbone = backbone
self._conv1 = ConvBNLayer(
3,
32,
3,
stride=2,
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv1")
self._conv2 = ConvBNLayer(
32,
64,
3,
stride=1,
padding=1,
act="relu",
name=self.backbone + "/entry_flow/conv2")
"""
bottleneck_params = {
"entry_flow": (3, [2, 2, 2], [128, 256, 728]),
"middle_flow": (16, 1, 728),
"exit_flow": (2, [2, 1], [[728, 1024, 1024], [1536, 1536, 2048]])
}
if output_stride == 16:
entry_block3_stride = 2
middle_block_dilation = 1
exit_block_dilations = (1, 2)
elif output_stride == 8:
entry_block3_stride = 1
middle_block_dilation = 2
exit_block_dilations = (2, 4)
"""
self.block_num = bottleneck_params["entry_flow"][0]
self.strides = bottleneck_params["entry_flow"][1]
self.chns = bottleneck_params["entry_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
self.entry_flow = []
self.middle_flow = []
self.stride = 2
self.output_stride = output_stride
s = self.stride
for i in range(self.block_num):
stride = self.strides[i] if check_stride(s * self.strides[i],
self.output_stride) else 1
xception_block = self.add_sublayer(
self.backbone + "/entry_flow/block" + str(i + 1),
Xception_Block(
input_channels=64 if i == 0 else self.chns[i - 1],
output_channels=self.chns[i],
strides=[1, 1, self.stride],
name=self.backbone + "/entry_flow/block" + str(i + 1)))
self.entry_flow.append(xception_block)
s = s * stride
self.stride = s
self.block_num = bottleneck_params["middle_flow"][0]
self.strides = bottleneck_params["middle_flow"][1]
self.chns = bottleneck_params["middle_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
s = self.stride
for i in range(self.block_num):
stride = self.strides[i] if check_stride(s * self.strides[i],
self.output_stride) else 1
xception_block = self.add_sublayer(
self.backbone + "/middle_flow/block" + str(i + 1),
Xception_Block(
input_channels=728,
output_channels=728,
strides=[1, 1, self.strides[i]],
skip_conv=False,
name=self.backbone + "/middle_flow/block" + str(i + 1)))
self.middle_flow.append(xception_block)
s = s * stride
self.stride = s
self.block_num = bottleneck_params["exit_flow"][0]
self.strides = bottleneck_params["exit_flow"][1]
self.chns = bottleneck_params["exit_flow"][2]
self.strides = check_data(self.strides, self.block_num)
self.chns = check_data(self.chns, self.block_num)
s = self.stride
stride = self.strides[0] if check_stride(s * self.strides[0],
self.output_stride) else 1
self._exit_flow_1 = Xception_Block(
728,
self.chns[0], [1, 1, stride],
name=self.backbone + "/exit_flow/block1")
s = s * stride
stride = self.strides[1] if check_stride(s * self.strides[1],
self.output_stride) else 1
self._exit_flow_2 = Xception_Block(
self.chns[0][-1],
self.chns[1], [1, 1, stride],
dilation=2,
has_skip=False,
activation_fn_in_separable_conv=True,
name=self.backbone + "/exit_flow/block2")
s = s * stride
self.stride = s
self._drop = Dropout(p=0.5)
self._pool = Pool2D(pool_type="avg", global_pooling=True)
self._fc = Linear(
self.chns[1][-1],
class_dim,
param_attr=ParamAttr(name="fc_weights"),
bias_attr=ParamAttr(name="fc_bias"))
self.init_weight(backbone_pretrained)
def forward(self, inputs):
x = self._conv1(inputs)
x = self._conv2(x)
feat_list = []
for i, ef in enumerate(self.entry_flow):
x = ef(x)
if i == 0:
feat_list.append(x)
for mf in self.middle_flow:
x = mf(x)
x = self._exit_flow_1(x)
x = self._exit_flow_2(x)
feat_list.append(x)
x = self._drop(x)
x = self._pool(x)
x = fluid.layers.squeeze(x, axes=[2, 3])
x = self._fc(x)
return x, feat_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
def Xception41_deeplab(**args):
model = XceptionDeeplab('xception_41', **args)
return model
@manager.BACKBONES.add_component
def Xception65_deeplab(**args):
model = XceptionDeeplab("xception_65", **args)
return model
def Xception71_deeplab(**args):
model = XceptionDeeplab("xception_71", **args)
return model
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import layer_utils
from . import model_utils
\ No newline at end of file
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import Conv2d
from paddle.nn import SyncBatchNorm as BatchNorm
from paddle.nn.layer import activation
class ConvBnRelu(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvBnRelu, self).__init__()
self.conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self.batch_norm = BatchNorm(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = F.relu(x)
return x
class ConvBn(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvBn, self).__init__()
self.conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self.batch_norm = BatchNorm(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
return x
class ConvReluPool(nn.Layer):
def __init__(self, in_channels, out_channels):
super(ConvReluPool, self).__init__()
self.conv = Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1)
def forward(self, x):
x = self.conv(x)
x = F.relu(x)
x = F.pool2d(x, pool_size=2, pool_type="max", pool_stride=2)
return x
# class ConvBnReluUpsample(nn.Layer):
# def __init__(self, in_channels, out_channels):
# super(ConvBnReluUpsample, self).__init__()
# self.conv_bn_relu = ConvBnRelu(in_channels, out_channels)
# def forward(self, x, upsample_scale=2):
# x = self.conv_bn_relu(x)
# new_shape = [x.shape[2] * upsample_scale, x.shape[3] * upsample_scale]
# x = F.resize_bilinear(x, new_shape)
# return x
class DepthwiseConvBnRelu(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(DepthwiseConvBnRelu, self).__init__()
self.depthwise_conv = ConvBn(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBnRelu(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x):
x = self.depthwise_conv(x)
x = self.piontwise_conv(x)
return x
class Activation(nn.Layer):
"""
The wrapper of activations
For example:
>>> relu = Activation("relu")
>>> print(relu)
<class 'paddle.nn.layer.activation.ReLU'>
>>> sigmoid = Activation("sigmoid")
>>> print(sigmoid)
<class 'paddle.nn.layer.activation.Sigmoid'>
>>> not_exit_one = Activation("not_exit_one")
KeyError: "not_exit_one does not exist in the current dict_keys(['elu', 'gelu', 'hardshrink',
'tanh', 'hardtanh', 'prelu', 'relu', 'relu6', 'selu', 'leakyrelu', 'sigmoid', 'softmax',
'softplus', 'softshrink', 'softsign', 'tanhshrink', 'logsigmoid', 'logsoftmax', 'hsigmoid'])"
Args:
act (str): the activation name in lowercase
"""
def __init__(self, act=None):
super(Activation, self).__init__()
self._act = act
upper_act_names = activation.__all__
lower_act_names = [act.lower() for act in upper_act_names]
act_dict = dict(zip(lower_act_names, upper_act_names))
if act is not None:
if act in act_dict.keys():
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
def forward(self, x):
if self._act is not None:
return self.act_func(x)
else:
return x
# -*- encoding: utf-8 -*-
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.models.common import layer_utils
class FCNHead(nn.Layer):
"""
The FCNHead implementation used in auxilary layer
Args:
in_channels (int): the number of input channels
out_channels (int): the number of output channels
"""
def __init__(self, in_channels, out_channels):
super(FCNHead, self).__init__()
inter_channels = in_channels // 4
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1)
self.conv = nn.Conv2d(
in_channels=inter_channels,
out_channels=out_channels,
kernel_size=1)
def forward(self, x):
x = self.conv_bn_relu(x)
x = F.dropout(x, p=0.1)
x = self.conv(x)
return x
class AuxLayer(nn.Layer):
"""
The auxilary layer implementation for auxilary loss
Args:
in_channels (int): the number of input channels.
inter_channels (int): intermediate channels.
out_channels (int): the number of output channels, which is usually num_classes.
"""
def __init__(self,
in_channels,
inter_channels,
out_channels,
dropout_prob=0.1):
super(AuxLayer, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
padding=1)
self.conv = nn.Conv2d(
in_channels=inter_channels,
out_channels=out_channels,
kernel_size=1)
self.dropout_prob = dropout_prob
def forward(self, x):
x = self.conv_bn_relu(x)
x = F.dropout(x, p=self.dropout_prob)
x = self.conv(x)
return x
class PPModule(nn.Layer):
"""
Pyramid pooling module
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
out_channels (int): the number of output channels after pyramid pooling module.
bin_sizes (tuple): the out size of pooled feature maps. Default to (1,2,3,6).
dim_reduction (bool): a bool value represent if reduing dimention after pooling. Default to True.
"""
def __init__(self,
in_channels,
out_channels,
bin_sizes=(1, 2, 3, 6),
dim_reduction=True):
super(PPModule, self).__init__()
self.bin_sizes = bin_sizes
inter_channels = in_channels
if dim_reduction:
inter_channels = in_channels // len(bin_sizes)
# we use dimension reduction after pooling mentioned in original implementation.
self.stages = nn.LayerList([
self._make_stage(in_channels, inter_channels, size)
for size in bin_sizes
])
self.conv_bn_relu2 = layer_utils.ConvBnRelu(
in_channels=in_channels + inter_channels * len(bin_sizes),
out_channels=out_channels,
kernel_size=3,
padding=1)
def _make_stage(self, in_channels, out_channels, size):
"""
Create one pooling layer.
In our implementation, we adopt the same dimention reduction as the original paper that might be
slightly different with other implementations.
After pooling, the channels are reduced to 1/len(bin_sizes) immediately, while some other implementations
keep the channels to be same.
Args:
in_channels (int): the number of intput channels to pyramid pooling module.
size (int): the out size of the pooled layer.
Returns:
conv (tensor): a tensor after Pyramid Pooling Module
"""
# this paddle version does not support AdaptiveAvgPool2d, so skip it here.
# prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
conv = layer_utils.ConvBnRelu(
in_channels=in_channels, out_channels=out_channels, kernel_size=1)
return conv
def forward(self, input):
cat_layers = []
for i, stage in enumerate(self.stages):
size = self.bin_sizes[i]
x = F.adaptive_pool2d(
input, pool_size=(size, size), pool_type="max")
x = stage(x)
x = F.resize_bilinear(x, out_shape=input.shape[2:])
cat_layers.append(x)
cat_layers = [input] + cat_layers[::-1]
cat = paddle.concat(cat_layers, axis=1)
out = self.conv_bn_relu2(cat)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_utils
from paddleseg.utils import utils
__all__ = ['DeepLabV3P', 'DeepLabV3']
@manager.MODELS.add_component
class DeepLabV3P(nn.Layer):
"""
The DeepLabV3Plus implementation based on PaddlePaddle.
The orginal artile refers to
"Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation"
Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam.
(https://arxiv.org/abs/1802.02611)
The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder.
Args:
num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone network, currently support Xception65, Resnet101_vd.
model_pretrained (str): the path of pretrained model.
output_stride (int): the ratio of input size and final feature size.
Support 16 or 8. Default to 16.
backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone.
the first index will be taken as a low-level feature in Deconder component;
the second one will be taken as input of ASPP component.
Usually backbone consists of four downsampling stage, and return an output of
each stage, so we set default (0, 3), which means taking feature map of the first
stage in backbone as low-level feature used in Decoder, and feature map of the fourth
stage as input of ASPP.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
"""
def __init__(self,
num_classes,
backbone,
model_pretrained=None,
backbone_indices=(0, 3),
backbone_channels=(256, 2048),
output_stride=16):
super(DeepLabV3P, self).__init__()
self.backbone = backbone
self.aspp = ASPP(output_stride, backbone_channels[1])
self.decoder = Decoder(num_classes, backbone_channels[0])
self.backbone_indices = backbone_indices
self.init_weight(model_pretrained)
def forward(self, input, label=None):
logit_list = []
_, feat_list = self.backbone(input)
low_level_feat = feat_list[self.backbone_indices[0]]
x = feat_list[self.backbone_indices[1]]
x = self.aspp(x)
logit = self.decoder(x, low_level_feat)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit)
return logit_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
@manager.MODELS.add_component
class DeepLabV3(nn.Layer):
"""
The DeepLabV3 implementation based on PaddlePaddle.
The orginal article refers to
"Rethinking Atrous Convolution for Semantic Image Segmentation"
Liang-Chieh Chen, George Papandreou, Florian Schroff, Hartwig Adam.
(https://arxiv.org/pdf/1706.05587.pdf)
Args:
Refer to DeepLabV3P above
"""
def __init__(self,
num_classes,
backbone,
model_pretrained=None,
backbone_indices=(3,),
backbone_channels=(2048,),
output_stride=16):
super(DeepLabV3, self).__init__()
self.backbone = backbone
self.aspp = ASPP(output_stride, backbone_channels[0])
self.cls = nn.Conv2d(
in_channels=backbone_channels[0],
out_channels=num_classes,
kernel_size=1)
self.backbone_indices = backbone_indices
self.init_weight(model_pretrained)
def forward(self, input, label=None):
logit_list = []
_, feat_list = self.backbone(input)
x = feat_list[self.backbone_indices[0]]
logit = self.cls(x)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit)
return logit_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class ImageAverage(nn.Layer):
"""
Global average pooling
Args:
in_channels (int): the number of input channels.
"""
def __init__(self, in_channels):
super(ImageAverage, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels, out_channels=256, kernel_size=1)
def forward(self, input):
x = paddle.reduce_mean(input, dim=[2, 3], keep_dim=True)
x = self.conv_bn_relu(x)
x = F.resize_bilinear(x, out_shape=input.shape[2:])
return x
class ASPP(nn.Layer):
"""
Decoder module of DeepLabV3P model
Args:
output_stride (int): the ratio of input size and final feature size. Support 16 or 8.
in_channels (int): the number of input channels in decoder module.
"""
def __init__(self, output_stride, in_channels):
super(ASPP, self).__init__()
if output_stride == 16:
aspp_ratios = (6, 12, 18)
elif output_stride == 8:
aspp_ratios = (12, 24, 36)
else:
raise NotImplementedError(
"Only support output_stride is 8 or 16, but received{}".format(
output_stride))
self.image_average = ImageAverage(in_channels=in_channels)
# The first aspp using 1*1 conv
self.aspp1 = layer_utils.DepthwiseConvBnRelu(
in_channels=in_channels, out_channels=256, kernel_size=1)
# The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0]
self.aspp2 = layer_utils.DepthwiseConvBnRelu(
in_channels=in_channels,
out_channels=256,
kernel_size=3,
dilation=aspp_ratios[0],
padding=aspp_ratios[0])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1]
self.aspp3 = layer_utils.DepthwiseConvBnRelu(
in_channels=in_channels,
out_channels=256,
kernel_size=3,
dilation=aspp_ratios[1],
padding=aspp_ratios[1])
# The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[2]
self.aspp4 = layer_utils.DepthwiseConvBnRelu(
in_channels=in_channels,
out_channels=256,
kernel_size=3,
dilation=aspp_ratios[2],
padding=aspp_ratios[2])
# After concat op, using 1*1 conv
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=1280, out_channels=256, kernel_size=1)
def forward(self, x):
x1 = self.image_average(x)
x2 = self.aspp1(x)
x3 = self.aspp2(x)
x4 = self.aspp3(x)
x5 = self.aspp4(x)
x = paddle.concat([x1, x2, x3, x4, x5], axis=1)
x = self.conv_bn_relu(x)
x = F.dropout(x, p=0.1) # dropout_prob
return x
class Decoder(nn.Layer):
"""
Decoder module of DeepLabV3P model
Args:
num_classes (int): the number of classes.
in_channels (int): the number of input channels in decoder module.
"""
def __init__(self, num_classes, in_channels):
super(Decoder, self).__init__()
self.conv_bn_relu1 = layer_utils.ConvBnRelu(
in_channels=in_channels, out_channels=48, kernel_size=1)
self.conv_bn_relu2 = layer_utils.DepthwiseConvBnRelu(
in_channels=304, out_channels=256, kernel_size=3, padding=1)
self.conv_bn_relu3 = layer_utils.DepthwiseConvBnRelu(
in_channels=256, out_channels=256, kernel_size=3, padding=1)
self.conv = nn.Conv2d(
in_channels=256, out_channels=num_classes, kernel_size=1)
def forward(self, x, low_level_feat):
low_level_feat = self.conv_bn_relu1(low_level_feat)
x = F.resize_bilinear(x, low_level_feat.shape[2:])
x = paddle.concat([x, low_level_feat], axis=1)
x = self.conv_bn_relu2(x)
x = self.conv_bn_relu3(x)
x = self.conv(x)
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn.functional as F
from paddle import nn
from paddleseg.cvlibs import manager
from paddleseg.models.common import layer_utils, model_utils
@manager.MODELS.add_component
class FastSCNN(nn.Layer):
"""
The FastSCNN implementation based on PaddlePaddle.
As mentioned in the original paper, FastSCNN is a real-time segmentation algorithm (123.5fps)
even for high resolution images (1024x2048).
The orginal artile refers to
Poudel, Rudra PK, et al. "Fast-scnn: Fast semantic segmentation network."
(https://arxiv.org/pdf/1902.04502.pdf)
Args:
num_classes (int): the unique number of target classes. Default to 2.
model_pretrained (str): the path of pretrained model. Defaullt to None.
enable_auxiliary_loss (bool): a bool values indictes whether adding auxiliary loss.
if true, auxiliary loss will be added after LearningToDownsample module, where the weight is 0.4. Default to False.
"""
def __init__(self,
num_classes,
model_pretrained=None,
enable_auxiliary_loss=True):
super(FastSCNN, self).__init__()
self.learning_to_downsample = LearningToDownsample(32, 48, 64)
self.global_feature_extractor = GlobalFeatureExtractor(
64, [64, 96, 128], 128, 6, [3, 3, 3])
self.feature_fusion = FeatureFusionModule(64, 128, 128)
self.classifier = Classifier(128, num_classes)
if enable_auxiliary_loss:
self.auxlayer = model_utils.AuxLayer(64, 32, num_classes)
self.enable_auxiliary_loss = enable_auxiliary_loss
self.init_weight(model_pretrained)
def forward(self, input, label=None):
logit_list = []
higher_res_features = self.learning_to_downsample(input)
x = self.global_feature_extractor(higher_res_features)
x = self.feature_fusion(higher_res_features, x)
logit = self.classifier(x)
logit = F.resize_bilinear(logit, input.shape[2:])
logit_list.append(logit)
if self.enable_auxiliary_loss:
auxiliary_logit = self.auxlayer(higher_res_features)
auxiliary_logit = F.resize_bilinear(auxiliary_logit,
input.shape[2:])
logit_list.append(auxiliary_logit)
return logit_list
def init_weight(self, pretrained_model=None):
"""
Initialize the parameters of model parts.
Args:
pretrained_model ([str], optional): the path of pretrained model. Defaults to None.
"""
if pretrained_model is not None:
if os.path.exists(pretrained_model):
utils.load_pretrained_model(self, pretrained_model)
else:
raise Exception('Pretrained model is not found: {}'.format(
pretrained_model))
class LearningToDownsample(nn.Layer):
"""
Learning to downsample module.
This module consists of three downsampling blocks (one Conv and two separable Conv)
Args:
dw_channels1 (int): the input channels of the first sep conv. Default to 32.
dw_channels2 (int): the input channels of the second sep conv. Default to 48.
out_channels (int): the output channels of LearningToDownsample module. Default to 64.
"""
def __init__(self, dw_channels1=32, dw_channels2=48, out_channels=64):
super(LearningToDownsample, self).__init__()
self.conv_bn_relu = layer_utils.ConvBnRelu(
in_channels=3, out_channels=dw_channels1, kernel_size=3, stride=2)
self.dsconv_bn_relu1 = layer_utils.DepthwiseConvBnRelu(
in_channels=dw_channels1,
out_channels=dw_channels2,
kernel_size=3,
stride=2,
padding=1)
self.dsconv_bn_relu2 = layer_utils.DepthwiseConvBnRelu(
in_channels=dw_channels2,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1)
def forward(self, x):
x = self.conv_bn_relu(x)
x = self.dsconv_bn_relu1(x)
x = self.dsconv_bn_relu2(x)
return x
class GlobalFeatureExtractor(nn.Layer):
"""
Global feature extractor module
This module consists of three LinearBottleneck blocks (like inverted residual introduced by MobileNetV2) and
a PPModule (introduced by PSPNet).
Args:
in_channels (int): the number of input channels to the module. Default to 64.
block_channels (tuple): a tuple represents output channels of each bottleneck block. Default to (64, 96, 128).
out_channels (int): the number of output channels of the module. Default to 128.
expansion (int): the expansion factor in bottleneck. Default to 6.
num_blocks (tuple): it indicates the repeat time of each bottleneck. Default to (3, 3, 3).
"""
def __init__(self,
in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
expansion=6,
num_blocks=(3, 3, 3)):
super(GlobalFeatureExtractor, self).__init__()
self.bottleneck1 = self._make_layer(LinearBottleneck, in_channels,
block_channels[0], num_blocks[0],
expansion, 2)
self.bottleneck2 = self._make_layer(LinearBottleneck, block_channels[0],
block_channels[1], num_blocks[1],
expansion, 2)
self.bottleneck3 = self._make_layer(LinearBottleneck, block_channels[1],
block_channels[2], num_blocks[2],
expansion, 1)
self.ppm = model_utils.PPModule(
block_channels[2], out_channels, dim_reduction=True)
def _make_layer(self,
block,
in_channels,
out_channels,
blocks,
expansion=6,
stride=1):
layers = []
layers.append(block(in_channels, out_channels, expansion, stride))
for i in range(1, blocks):
layers.append(block(out_channels, out_channels, expansion, 1))
return nn.Sequential(*layers)
def forward(self, x):
x = self.bottleneck1(x)
x = self.bottleneck2(x)
x = self.bottleneck3(x)
x = self.ppm(x)
return x
class LinearBottleneck(nn.Layer):
"""
Single bottleneck implementation.
Args:
in_channels (int): the number of input channels to bottleneck block.
out_channels (int): the number of output channels of bottleneck block.
expansion (int). the expansion factor in bottleneck. Default to 6.
stride (int). the stride used in depth-wise conv.
"""
def __init__(self,
in_channels,
out_channels,
expansion=6,
stride=2,
**kwargs):
super(LinearBottleneck, self).__init__()
self.use_shortcut = stride == 1 and in_channels == out_channels
expand_channels = in_channels * expansion
self.block = nn.Sequential(
# pw
layer_utils.ConvBnRelu(
in_channels=in_channels,
out_channels=expand_channels,
kernel_size=1,
bias_attr=False),
# dw
layer_utils.ConvBnRelu(
in_channels=expand_channels,
out_channels=expand_channels,
kernel_size=3,
stride=stride,
padding=1,
groups=expand_channels,
bias_attr=False),
# pw-linear
nn.Conv2d(
in_channels=expand_channels,
out_channels=out_channels,
kernel_size=1,
bias_attr=False),
nn.SyncBatchNorm(out_channels))
def forward(self, x):
out = self.block(x)
if self.use_shortcut:
out = x + out
return out
class FeatureFusionModule(nn.Layer):
"""
Feature Fusion Module Implememtation.
This module fuses high-resolution feature and low-resolution feature.
Args:
high_in_channels (int): the channels of high-resolution feature (output of LearningToDownsample).
low_in_channels (int). the channels of low-resolution feature (output of GlobalFeatureExtractor).
out_channels (int). the output channels of this module.
"""
def __init__(self, high_in_channels, low_in_channels, out_channels):
super(FeatureFusionModule, self).__init__()
# There only depth-wise conv is used WITHOUT point-wise conv
self.dwconv = layer_utils.ConvBnRelu(
in_channels=low_in_channels,
out_channels=out_channels,
kernel_size=3,
padding=1,
groups=128,
bias_attr=False)
self.conv_low_res = nn.Sequential(
nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=1), nn.SyncBatchNorm(out_channels))
self.conv_high_res = nn.Sequential(
nn.Conv2d(
in_channels=high_in_channels,
out_channels=out_channels,
kernel_size=1), nn.SyncBatchNorm(out_channels))
self.relu = nn.ReLU(True)
def forward(self, high_res_input, low_res_input):
low_res_input = F.resize_bilinear(input=low_res_input, scale=4)
low_res_input = self.dwconv(low_res_input)
low_res_input = self.conv_low_res(low_res_input)
high_res_input = self.conv_high_res(high_res_input)
x = high_res_input + low_res_input
return self.relu(x)
class Classifier(nn.Layer):
"""
The Classifier module implemetation.
This module consists of two depth-wsie conv and one conv.
Args:
input_channels (int): the input channels to this module.
num_classes (int). the unique number of target classes.
"""
def __init__(self, input_channels, num_classes):
super(Classifier, self).__init__()
self.dsconv1 = layer_utils.DepthwiseConvBnRelu(
in_channels=input_channels,
out_channels=input_channels,
kernel_size=3,
padding=1)
self.dsconv2 = layer_utils.DepthwiseConvBnRelu(
in_channels=input_channels,
out_channels=input_channels,
kernel_size=3,
padding=1)
self.conv = nn.Conv2d(
in_channels=input_channels, out_channels=num_classes, kernel_size=1)
def forward(self, x):
x = self.dsconv1(x)
x = self.dsconv2(x)
x = F.dropout(x, p=0.1) # dropout_prob
x = self.conv(x)
return x
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.initializer import Normal
from paddle.nn import SyncBatchNorm as BatchNorm
from paddleseg.cvlibs import manager
from paddleseg import utils
from paddleseg.cvlibs import param_init
from paddleseg.utils import logger
__all__ = [
"fcn_hrnet_w18_small_v1", "fcn_hrnet_w18_small_v2", "fcn_hrnet_w18",
"fcn_hrnet_w30", "fcn_hrnet_w32", "fcn_hrnet_w40", "fcn_hrnet_w44",
"fcn_hrnet_w48", "fcn_hrnet_w60", "fcn_hrnet_w64"
]
@manager.MODELS.add_component
class FCN(fluid.dygraph.Layer):
"""
Fully Convolutional Networks for Semantic Segmentation.
https://arxiv.org/abs/1411.4038
Args:
num_classes (int): the unique number of target classes.
backbone (paddle.nn.Layer): backbone networks.
model_pretrained (str): the path of pretrained model.
backbone_indices (tuple): one values in the tuple indicte the indices of output of backbone.Default -1.
backbone_channels (tuple): the same length with "backbone_indices". It indicates the channels of corresponding index.
channels (int): channels after conv layer before the last one.
"""
def __init__(self,
num_classes,
backbone,
backbone_pretrained=None,
model_pretrained=None,
backbone_indices=(-1, ),
backbone_channels=(270, ),
channels=None):
super(FCN, self).__init__()
self.num_classes = num_classes
self.backbone_pretrained = backbone_pretrained
self.model_pretrained = model_pretrained
self.backbone_indices = backbone_indices
if channels is None:
channels = backbone_channels[backbone_indices[0]]
self.backbone = backbone
self.conv_last_2 = ConvBNLayer(
num_channels=backbone_channels[backbone_indices[0]],
num_filters=channels,
filter_size=1,
stride=1)
self.conv_last_1 = Conv2D(
num_channels=channels,
num_filters=self.num_classes,
filter_size=1,
stride=1,
padding=0)
if self.training:
self.init_weight()
def forward(self, x):
input_shape = x.shape[2:]
fea_list = self.backbone(x)
x = fea_list[self.backbone_indices[0]]
x = self.conv_last_2(x)
logit = self.conv_last_1(x)
logit = fluid.layers.resize_bilinear(logit, input_shape)
return [logit]
def init_weight(self):
params = self.parameters()
for param in params:
param_name = param.name
if 'batch_norm' in param_name:
if 'w_0' in param_name:
param_init.constant_init(param, value=1.0)
elif 'b_0' in param_name:
param_init.constant_init(param, value=0.0)
if 'conv' in param_name and 'w_0' in param_name:
param_init.normal_init(param, scale=0.001)
if self.model_pretrained is not None:
if os.path.exists(self.model_pretrained):
utils.load_pretrained_model(self, self.model_pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
self.model_pretrained))
elif self.backbone_pretrained is not None:
if os.path.exists(self.backbone_pretrained):
utils.load_pretrained_model(self.backbone,
self.backbone_pretrained)
else:
raise Exception('Pretrained model is not found: {}'.format(
self.backbone_pretrained))
else:
logger.warning('No pretrained model to load, train from scratch')
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act="relu"):
super(ConvBNLayer, self).__init__()
self._conv = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
bias_attr=False)
self._batch_norm = BatchNorm(num_filters)
self.act = act
def forward(self, input):
y = self._conv(input)
y = self._batch_norm(y)
if self.act == 'relu':
y = fluid.layers.relu(y)
return y
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v1(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V1', backbone_channels=(240), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18_small_v2(*args, **kwargs):
return FCN(backbone='HRNet_W18_Small_V2', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w18(*args, **kwargs):
return FCN(backbone='HRNet_W18', backbone_channels=(270), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w30(*args, **kwargs):
return FCN(backbone='HRNet_W30', backbone_channels=(450), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w32(*args, **kwargs):
return FCN(backbone='HRNet_W32', backbone_channels=(480), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w40(*args, **kwargs):
return FCN(backbone='HRNet_W40', backbone_channels=(600), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w44(*args, **kwargs):
return FCN(backbone='HRNet_W44', backbone_channels=(660), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w48(*args, **kwargs):
return FCN(backbone='HRNet_W48', backbone_channels=(720), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w60(*args, **kwargs):
return FCN(backbone='HRNet_W60', backbone_channels=(900), **kwargs)
@manager.MODELS.add_component
def fcn_hrnet_w64(*args, **kwargs):
return FCN(backbone='HRNet_W64', backbone_channels=(960), **kwargs)
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .cross_entroy_loss import CrossEntropyLoss
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .transforms import *
from . import functional
此差异已折叠。
此差异已折叠。
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import logger
from . import download
from .metrics import ConfusionMatrix
from .utils import *
from .timer import Timer, calculate_eta
from .get_environ_info import get_environ_info
from .config import Config
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册