提交 5a169e44 编写于 作者: F FlyingQianMM

add fastscnn tutorial

上级 a755889f
import os
# 选择使用0号卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms = transforms.ComposedSegTransforms(
mode='train', train_crop_size=[769, 769])
eval_transforms = transforms.ComposedSegTransforms(mode='eval')
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/val_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes = len(train_dataset.labels)
model = pdx.seg.FastSCNN(num_classes=num_classes)
model.train(
num_epochs=20,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.01,
save_dir='output/fastscnn',
use_vdl=True)
......@@ -194,9 +194,8 @@ class BaseAPI:
if os.path.exists(pretrain_dir):
os.remove(pretrain_dir)
os.makedirs(pretrain_dir)
if pretrain_weights is not None and \
not os.path.isdir(pretrain_weights) \
and not os.path.isfile(pretrain_weights):
if pretrain_weights is not None and not os.path.exists(
pretrain_weights):
if self.model_type == 'classifier':
if pretrain_weights not in ['IMAGENET']:
logging.warning(
......@@ -245,8 +244,8 @@ class BaseAPI:
logging.info(
"Load pretrain weights from {}.".format(pretrain_weights),
use_color=True)
paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog,
pretrain_weights, fuse_bn)
paddlex.utils.utils.load_pretrain_weights(
self.exe, self.train_prog, pretrain_weights, fuse_bn)
# 进行裁剪
if sensitivities_file is not None:
import paddleslim
......@@ -350,7 +349,9 @@ class BaseAPI:
logging.info("Model saved in {}.".format(save_dir))
def export_inference_model(self, save_dir):
test_input_names = [var.name for var in list(self.test_inputs.values())]
test_input_names = [
var.name for var in list(self.test_inputs.values())
]
test_outputs = list(self.test_outputs.values())
if self.__class__.__name__ == 'MaskRCNN':
from paddlex.utils.save import save_mask_inference_model
......@@ -387,7 +388,8 @@ class BaseAPI:
# 模型保存成功的标志
open(osp.join(save_dir, '.success'), 'w').close()
logging.info("Model for inference deploy saved in {}.".format(save_dir))
logging.info("Model for inference deploy saved in {}.".format(
save_dir))
def train_loop(self,
num_epochs,
......@@ -511,10 +513,12 @@ class BaseAPI:
eta = ((num_epochs - i) * total_num_steps - step - 1
) * avg_step_time
if time_eval_one_epoch is not None:
eval_eta = (total_eval_times - i // save_interval_epochs
eval_eta = (
total_eval_times - i // save_interval_epochs
) * time_eval_one_epoch
else:
eval_eta = (total_eval_times - i // save_interval_epochs
eval_eta = (
total_eval_times - i // save_interval_epochs
) * total_num_steps_eval * avg_step_time
eta_str = seconds_to_hms(eta + eval_eta)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import paddle.fluid as fluid
import paddlex
from collections import OrderedDict
from .deeplabv3p import DeepLabv3p
class FastSCNN(DeepLabv3p):
"""实现Fast SCNN网络的构建并进行训练、评估、预测和模型导出。
Args:
num_classes (int): 类别数。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
multi_loss_weight (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。
也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,
fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,
lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
TypeError: multi_loss_weight不为list。
ValueError: multi_loss_weight为list但长度小于0或者大于3。
"""
def __init__(self,
num_classes=2,
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
ignore_index=255,
multi_loss_weight=[1.0]):
self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
if not isinstance(multi_loss_weight, list):
raise TypeError(
'Expect multi_loss_weight is a list but receive {}'.format(
type(multi_loss_weight)))
if len(multi_loss_weight) > 3 or len(multi_loss_weight) < 0:
raise ValueError(
"Length of multi_loss_weight should be lower than or equal to 3 but greater than 0."
)
self.num_classes = num_classes
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.multi_loss_weight = multi_loss_weight
self.ignore_index = ignore_index
self.labels = None
self.fixed_input_shape = None
def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.FastSCNN(
self.num_classes,
mode=mode,
use_bce_loss=self.use_bce_loss,
use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight,
ignore_index=self.ignore_index,
multi_loss_weight=self.multi_loss_weight,
fixed_input_shape=self.fixed_input_shape)
inputs = model.generate_inputs()
model_out = model.build_net(inputs)
outputs = OrderedDict()
if mode == 'train':
self.optimizer.minimize(model_out)
outputs['loss'] = model_out
else:
outputs['pred'] = model_out[0]
outputs['logit'] = model_out[1]
return inputs, outputs
def train(self,
num_epochs,
train_dataset,
train_batch_size=2,
eval_dataset=None,
save_interval_epochs=1,
log_interval_steps=2,
save_dir='output',
pretrain_weights='CITYSCAPES',
optimizer=None,
learning_rate=0.01,
lr_decay_power=0.9,
use_vdl=False,
sensitivities_file=None,
eval_metric_loss=0.05,
early_stop=False,
early_stop_patience=5,
resume_checkpoint=None):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
eval_dataset (paddlex.datasets): 评估数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'CITYSCAPES'
则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'CITYSCAPES'。
optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
learning_rate (float): 默认优化器的初始学习率。默认0.01。
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
"""
return super(FastSCNN, self).train(
num_epochs, train_dataset, train_batch_size, eval_dataset,
save_interval_epochs, log_interval_steps, save_dir,
pretrain_weights, optimizer, learning_rate, lr_decay_power,
use_vdl, sensitivities_file, eval_metric_loss, early_stop,
early_stop_patience, resume_checkpoint)
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import paddle.fluid as fluid
from .model_utils.libs import scope
from .model_utils.libs import bn, bn_relu, relu, conv_bn_layer
from .model_utils.libs import conv, avg_pool
from .model_utils.libs import separate_conv
from .model_utils.libs import sigmoid_to_softmax
from .model_utils.loss import softmax_with_loss
from .model_utils.loss import dice_loss
from .model_utils.loss import bce_loss
class FastSCNN(object):
def __init__(self,
num_classes,
mode='train',
use_bce_loss=False,
use_dice_loss=False,
class_weight=None,
multi_loss_weight=[1.0],
ignore_index=255,
fixed_input_shape=None):
# dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError(
"dice loss and bce loss is only applicable to binary classfication"
)
if class_weight is not None:
if isinstance(class_weight, list):
if len(class_weight) != num_classes:
raise ValueError(
"Length of class_weight should be equal to number of classes"
)
elif isinstance(class_weight, str):
if class_weight.lower() != 'dynamic':
raise ValueError(
"if class_weight is string, must be dynamic!")
else:
raise TypeError(
'Expect class_weight is a list or string but receive {}'.
format(type(class_weight)))
self.num_classes = num_classes
self.mode = mode
self.use_bce_loss = use_bce_loss
self.use_dice_loss = use_dice_loss
self.class_weight = class_weight
self.ignore_index = ignore_index
self.multi_loss_weight = multi_loss_weight
self.fixed_input_shape = fixed_input_shape
def build_net(self, inputs):
if self.use_dice_loss or self.use_bce_loss:
self.num_classes = 1
image = inputs['image']
size = fluid.layers.shape(image)[2:]
with scope('learning_to_downsample'):
higher_res_features = self._learning_to_downsample(image, 32, 48,
64)
with scope('global_feature_extractor'):
lower_res_feature = self._global_feature_extractor(
higher_res_features, 64, [64, 96, 128], 128, 6, [3, 3, 3])
with scope('feature_fusion'):
x = self._feature_fusion(higher_res_features, lower_res_feature,
64, 128, 128)
with scope('classifier'):
logit = self._classifier(x, 128)
logit = fluid.layers.resize_bilinear(logit, size, align_mode=0)
if len(self.multi_loss_weight) == 3:
with scope('aux_layer_higher'):
higher_logit = self._aux_layer(higher_res_features,
self.num_classes)
higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
with scope('aux_layer_lower'):
lower_logit = self._aux_layer(lower_res_feature,
self.num_classes)
lower_logit = fluid.layers.resize_bilinear(
lower_logit, size, align_mode=0)
logit = (logit, higher_logit, lower_logit)
elif len(self.multi_loss_weight) == 2:
with scope('aux_layer_higher'):
higher_logit = self._aux_layer(higher_res_features,
self.num_classes)
higher_logit = fluid.layers.resize_bilinear(
higher_logit, size, align_mode=0)
logit = (logit, higher_logit)
else:
logit = (logit, )
if self.num_classes == 1:
out = sigmoid_to_softmax(logit[0])
out = fluid.layers.transpose(out, [0, 2, 3, 1])
else:
out = fluid.layers.transpose(logit[0], [0, 2, 3, 1])
pred = fluid.layers.argmax(out, axis=3)
pred = fluid.layers.unsqueeze(pred, axes=[3])
if self.mode == 'train':
label = inputs['label']
return self._get_loss(logit, label)
elif self.mode == 'eval':
label = inputs['label']
loss = self._get_loss(logit, label)
return loss, pred, label, mask
else:
if self.num_classes == 1:
logit = sigmoid_to_softmax(logit[0])
else:
logit = fluid.layers.softmax(logit[0], axis=1)
return pred, logit
def generate_inputs(self):
inputs = OrderedDict()
if self.fixed_input_shape is not None:
input_shape = [
None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0]
]
inputs['image'] = fluid.data(
dtype='float32', shape=input_shape, name='image')
else:
inputs['image'] = fluid.data(
dtype='float32', shape=[None, 3, None, None], name='image')
if self.mode == 'train':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
elif self.mode == 'eval':
inputs['label'] = fluid.data(
dtype='int32', shape=[None, 1, None, None], name='label')
return inputs
def _get_loss(self, logits, label):
avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss):
for i, logit in enumerate(logits):
logit_mask = (
label.astype('int32') != self.ignore_index).astype('int32')
loss = softmax_with_loss(
logit,
label,
logit_mask,
num_classes=self.num_classes,
weight=self.class_weight,
ignore_index=self.ignore_index)
avg_loss += self.multi_loss_weight[i] * loss
else:
if self.use_dice_loss:
for i, logit in enumerate(logits):
logit_mask = (label.astype('int32') != self.ignore_index
).astype('int32')
loss = dice_loss(logit, label, logit_mask)
avg_loss += self.multi_loss_weight[i] * loss
if self.use_bce_loss:
for i, logit in enumerate(logits):
#logit_label = fluid.layers.resize_nearest(label, logit_shape[2:])
logit_mask = (label.astype('int32') != self.ignore_index
).astype('int32')
loss = bce_loss(
logit,
label,
logit_mask,
ignore_index=self.ignore_index)
avg_loss += self.multi_loss_weight[i] * loss
return avg_loss
def _learning_to_downsample(self,
x,
dw_channels1=32,
dw_channels2=48,
out_channels=64):
x = relu(bn(conv(x, dw_channels1, 3, 2)))
with scope('dsconv1'):
x = separate_conv(
x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu)
with scope('dsconv2'):
x = separate_conv(
x, out_channels, stride=2, filter=3, act=fluid.layers.relu)
return x
def _shortcut(self, input, data_residual):
return fluid.layers.elementwise_add(input, data_residual)
def _dropout2d(self, input, prob, is_train=False):
if not is_train:
return input
keep_prob = 1.0 - prob
shape = fluid.layers.shape(input)
channels = shape[1]
random_tensor = keep_prob + fluid.layers.uniform_random(
[shape[0], channels, 1, 1], min=0., max=1.)
binary_tensor = fluid.layers.floor(random_tensor)
output = input / keep_prob * binary_tensor
return output
def _inverted_residual_unit(self,
input,
num_in_filter,
num_filters,
ifshortcut,
stride,
filter_size,
padding,
expansion_factor,
name=None):
num_expfilter = int(round(num_in_filter * expansion_factor))
channel_expand = conv_bn_layer(
input=input,
num_filters=num_expfilter,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=True,
name=name + '_expand')
bottleneck_conv = conv_bn_layer(
input=channel_expand,
num_filters=num_expfilter,
filter_size=filter_size,
stride=stride,
padding=padding,
num_groups=num_expfilter,
if_act=True,
name=name + '_dwise',
use_cudnn=False)
depthwise_output = bottleneck_conv
linear_out = conv_bn_layer(
input=bottleneck_conv,
num_filters=num_filters,
filter_size=1,
stride=1,
padding=0,
num_groups=1,
if_act=False,
name=name + '_linear')
if ifshortcut:
out = self._shortcut(input=input, data_residual=linear_out)
return out, depthwise_output
else:
return linear_out, depthwise_output
def _inverted_blocks(self, input, in_c, t, c, n, s, name=None):
first_block, depthwise_output = self._inverted_residual_unit(
input=input,
num_in_filter=in_c,
num_filters=c,
ifshortcut=False,
stride=s,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_1')
last_residual_block = first_block
last_c = c
for i in range(1, n):
last_residual_block, depthwise_output = self._inverted_residual_unit(
input=last_residual_block,
num_in_filter=last_c,
num_filters=c,
ifshortcut=True,
stride=1,
filter_size=3,
padding=1,
expansion_factor=t,
name=name + '_' + str(i + 1))
return last_residual_block, depthwise_output
def _psp_module(self, input, out_features):
cat_layers = []
sizes = (1, 2, 3, 6)
for size in sizes:
psp_name = "psp" + str(size)
with scope(psp_name):
pool = fluid.layers.adaptive_pool2d(
input,
pool_size=[size, size],
pool_type='avg',
name=psp_name + '_adapool')
data = conv(
pool,
out_features,
filter_size=1,
bias_attr=False,
name=psp_name + '_conv')
data_bn = bn(data, act='relu')
interp = fluid.layers.resize_bilinear(
data_bn,
out_shape=fluid.layers.shape(input)[2:],
name=psp_name + '_interp',
align_mode=0)
cat_layers.append(interp)
cat_layers = [input] + cat_layers
out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat')
return out
def _aux_layer(self, x, num_classes):
x = relu(bn(conv(x, 32, 3, padding=1)))
x = self._dropout2d(x, 0.1, is_train=(self.mode == 'train'))
with scope('logit'):
x = conv(x, num_classes, 1, bias_attr=True)
return x
def _feature_fusion(self,
higher_res_feature,
lower_res_feature,
higher_in_channels,
lower_in_channels,
out_channels,
scale_factor=4):
shape = fluid.layers.shape(higher_res_feature)
w = shape[-1]
h = shape[-2]
lower_res_feature = fluid.layers.resize_bilinear(
lower_res_feature, [h, w], align_mode=0)
with scope('dwconv'):
lower_res_feature = relu(
bn(conv(lower_res_feature, out_channels,
1))) #(lower_res_feature)
with scope('conv_lower_res'):
lower_res_feature = bn(
conv(
lower_res_feature, out_channels, 1, bias_attr=True))
with scope('conv_higher_res'):
higher_res_feature = bn(
conv(
higher_res_feature, out_channels, 1, bias_attr=True))
out = higher_res_feature + lower_res_feature
return relu(out)
def _global_feature_extractor(self,
x,
in_channels=64,
block_channels=(64, 96, 128),
out_channels=128,
t=6,
num_blocks=(3, 3, 3)):
x, _ = self._inverted_blocks(x, in_channels, t, block_channels[0],
num_blocks[0], 2, 'inverted_block_1')
x, _ = self._inverted_blocks(x, block_channels[0], t,
block_channels[1], num_blocks[1], 2,
'inverted_block_2')
x, _ = self._inverted_blocks(x, block_channels[1], t,
block_channels[2], num_blocks[2], 1,
'inverted_block_3')
x = self._psp_module(x, block_channels[2] // 4)
with scope('out'):
x = relu(bn(conv(x, out_channels, 1)))
return x
def _classifier(self, x, dw_channels, stride=1):
with scope('dsconv1'):
x = separate_conv(
x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
with scope('dsconv2'):
x = separate_conv(
x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu)
x = self._dropout2d(x, 0.1, is_train=self.mode == 'train')
x = conv(x, self.num_classes, 1, bias_attr=True)
return x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册