Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
f3a3c81c
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f3a3c81c
编写于
6月 12, 2020
作者:
J
Jason
提交者:
GitHub
6月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #149 from FlyingQianMM/develop_qh
add fastscnn for segmentation
上级
8391440b
5a169e44
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
750 addition
and
43 deletion
+750
-43
docs/apis/models/semantic_segmentation.md
docs/apis/models/semantic_segmentation.md
+90
-5
new_tutorials/train/README.md
new_tutorials/train/README.md
+3
-0
new_tutorials/train/segmentation/fast_scnn.py
new_tutorials/train/segmentation/fast_scnn.py
+48
-0
paddlex/cv/models/__init__.py
paddlex/cv/models/__init__.py
+1
-0
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+18
-14
paddlex/cv/models/deeplabv3p.py
paddlex/cv/models/deeplabv3p.py
+1
-1
paddlex/cv/models/fast_scnn.py
paddlex/cv/models/fast_scnn.py
+169
-0
paddlex/cv/models/hrnet.py
paddlex/cv/models/hrnet.py
+1
-6
paddlex/cv/models/unet.py
paddlex/cv/models/unet.py
+1
-6
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+15
-4
paddlex/cv/nets/__init__.py
paddlex/cv/nets/__init__.py
+1
-0
paddlex/cv/nets/segmentation/__init__.py
paddlex/cv/nets/segmentation/__init__.py
+1
-0
paddlex/cv/nets/segmentation/deeplabv3p.py
paddlex/cv/nets/segmentation/deeplabv3p.py
+0
-1
paddlex/cv/nets/segmentation/fast_scnn.py
paddlex/cv/nets/segmentation/fast_scnn.py
+395
-0
paddlex/cv/nets/segmentation/hrnet.py
paddlex/cv/nets/segmentation/hrnet.py
+0
-1
paddlex/cv/nets/segmentation/unet.py
paddlex/cv/nets/segmentation/unet.py
+5
-5
paddlex/seg.py
paddlex/seg.py
+1
-0
未找到文件。
docs/apis/models/semantic_segmentation.md
浏览文件 @
f3a3c81c
...
...
@@ -40,12 +40,12 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认'output'
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',则自动下载在COCO数据集上预训练的模型权重(注意:暂未提供Xception41、MobileNetV2_x0.25、MobileNetV2_x0.5、MobileNetV2_x1.5、MobileNetV2_x2.0的COCO预训练模型);若为字符串'CITYSCAPES',则自动下载在CITYSCAPES数据集上预训练的模型权重(注意:暂未提供Xception41、MobileNetV2_x0.25、MobileNetV2_x0.5、MobileNetV2_x1.5、MobileNetV2_x2.0的CITYSCAPES预训练模型);为None,则不使用预训练模型。默认'IMAGENET'。
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',则自动下载在COCO数据集上预训练的模型权重(注意:暂未提供Xception41、MobileNetV2_x0.25、MobileNetV2_x0.5、MobileNetV2_x1.5、MobileNetV2_x2.0的COCO预训练模型);若为字符串'CITYSCAPES',则自动下载在CITYSCAPES数据集上预训练的模型权重(注意:暂未提供Xception41、MobileNetV2_x0.25、MobileNetV2_x0.5、MobileNetV2_x1.5、MobileNetV2_x2.0的CITYSCAPES预训练模型);
若
为None,则不使用预训练模型。默认'IMAGENET'。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (bool): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
...
...
@@ -129,7 +129,7 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
...
...
@@ -209,12 +209,12 @@ train(self, num_epochs, train_dataset, train_batch_size=2, eval_dataset=None, ev
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认'output'
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet数据集上预训练的模型权重;若为字符串'CITYSCAPES',则自动下载在CITYSCAPES图片数据上预训练的模型权重(注意:目前仅提供`width`取值为18的CITYSCAPES预训练模型);为None,则不使用预训练模型。默认'IMAGENET'。
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet数据集上预训练的模型权重;若为字符串'CITYSCAPES',则自动下载在CITYSCAPES图片数据上预训练的模型权重(注意:目前仅提供`width`取值为18的CITYSCAPES预训练模型);
若
为None,则不使用预训练模型。默认'IMAGENET'。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
...
...
@@ -258,3 +258,88 @@ predict(self, im_file, transforms=None):
> **返回值**
> >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
## FastSCNN类
```
python
paddlex
.
seg
.
FastSCNN
(
num_classes
=
2
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
ignore_index
=
255
,
multi_loss_weight
=
[
1.0
])
```
> 构建FastSCNN分割器。
> **参数**
> > - **num_classes** (int): 类别数。
> > - **use_bce_loss** (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
> > - **multi_loss_weight** (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
### train 训练接口
```
python
train
(
self
,
num_epochs
,
train_dataset
,
train_batch_size
=
2
,
eval_dataset
=
None
,
eval_batch_size
=
1
,
save_interval_epochs
=
1
,
log_interval_steps
=
2
,
save_dir
=
'output'
,
pretrain_weights
=
'CITYSCAPES'
,
optimizer
=
None
,
learning_rate
=
0.01
,
lr_decay_power
=
0.9
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
```
> FastSCNN模型训练接口。
> **参数**
> >
> > - **num_epochs** (int): 训练迭代轮数。
> > - **train_dataset** (paddlex.datasets): 训练数据读取器。
> > - **train_batch_size** (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
> > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
> > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
> > - **save_dir** (str): 模型保存路径。默认'output'
> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'CITYSCAPES',则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认'CITYSCAPES'。
> > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认的优化器:使用fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
> > - **learning_rate** (float): 默认优化器的初始学习率。默认0.01。
> > - **lr_decay_power** (float): 默认优化器学习率衰减指数。默认0.9。
> > - **use_vdl** (bool): 是否使用VisualDL进行可视化。默认False。
> > - **sensitivities_file** (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
> > - **eval_metric_loss** (float): 可容忍的精度损失。默认为0.05。
> > - **early_stop** (float): 是否使用提前终止训练策略。默认值为False。
> > - **early_stop_patience** (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内连续下降或持平,则终止训练。默认值为5。
> > - **resume_checkpoint** (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
#### evaluate 评估接口
```
evaluate(self, eval_dataset, batch_size=1, epoch_id=None, return_details=False):
```
> FastSCNN模型评估接口。
> **参数**
> >
> > - **eval_dataset** (paddlex.datasets): 评估数据读取器。
> > - **batch_size** (int): 评估时的batch大小。默认1。
> > - **epoch_id** (int): 当前评估模型所在的训练轮数。
> > - **return_details** (bool): 是否返回详细信息。默认False。
> **返回值**
> >
> > - **dict**: 当return_details为False时,返回dict。包含关键字:'miou'、'category_iou'、'macc'、
> > 'category_acc'和'kappa',分别表示平均iou、各类别iou、平均准确率、各类别准确率和kappa系数。
> > - **tuple** (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
> > 包含关键字:'confusion_matrix',表示评估的混淆矩阵。
#### predict 预测接口
```
predict(self, im_file, transforms=None):
```
> FastSCNN模型预测接口。需要注意的是,只有在训练过程中定义了eval_dataset,模型在保存时才会将预测时的图像处理流程保存在`UNet.test_transforms`和`UNet.eval_transforms`中。如未在训练时定义eval_dataset,那在调用预测`predict`接口时,用户需要再重新定义test_transforms传入给`predict`接口。
> **参数**
> >
> > - **img_file** (str): 预测图像路径。
> > - **transforms** (paddlex.seg.transforms): 数据预处理操作。
> **返回值**
> >
> > - **dict**: 包含关键字'label_map'和'score_map', 'label_map'存储预测结果灰度图,像素值表示对应的类别,'score_map'存储各类别的概率,shape=(h, w, num_classes)。
new_tutorials/train/README.md
浏览文件 @
f3a3c81c
...
...
@@ -10,6 +10,9 @@
|detection/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 垃圾分拣 |
|segmentation/deeplabv3p.py | 语义分割DeepLabV3| 视盘分割 |
|segmentation/unet.py | 语义分割UNet | 视盘分割 |
|segmentation/hrnet.py | 语义分割HRNet | 视盘分割 |
|segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 |
## 开始训练
在安装PaddleX后,使用如下命令开始训练
...
...
new_tutorials/train/segmentation/fast_scnn.py
0 → 100644
浏览文件 @
f3a3c81c
import
os
# 选择使用0号卡
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
# 下载和解压视盘分割数据集
optic_dataset
=
'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx
.
utils
.
download_and_decompress
(
optic_dataset
,
path
=
'./'
)
# 定义训练和验证时的transforms
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms
train_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'train'
,
train_crop_size
=
[
769
,
769
])
eval_transforms
=
transforms
.
ComposedSegTransforms
(
mode
=
'eval'
)
# 定义训练和验证所用的数据集
# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset
train_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/train_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
'optic_disc_seg'
,
file_list
=
'optic_disc_seg/val_list.txt'
,
label_list
=
'optic_disc_seg/labels.txt'
,
transforms
=
eval_transforms
)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标
# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001
# 浏览器打开 https://0.0.0.0:8001即可
# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP
# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet
num_classes
=
len
(
train_dataset
.
labels
)
model
=
pdx
.
seg
.
FastSCNN
(
num_classes
=
num_classes
)
model
.
train
(
num_epochs
=
20
,
train_dataset
=
train_dataset
,
train_batch_size
=
4
,
eval_dataset
=
eval_dataset
,
learning_rate
=
0.01
,
save_dir
=
'output/fastscnn'
,
use_vdl
=
True
)
paddlex/cv/models/__init__.py
浏览文件 @
f3a3c81c
...
...
@@ -43,5 +43,6 @@ from .mask_rcnn import MaskRCNN
from
.unet
import
UNet
from
.deeplabv3p
import
DeepLabv3p
from
.hrnet
import
HRNet
from
.fast_scnn
import
FastSCNN
from
.load_model
import
load_model
from
.slim
import
prune
paddlex/cv/models/base.py
浏览文件 @
f3a3c81c
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -194,9 +194,8 @@ class BaseAPI:
if
os
.
path
.
exists
(
pretrain_dir
):
os
.
remove
(
pretrain_dir
)
os
.
makedirs
(
pretrain_dir
)
if
pretrain_weights
is
not
None
and
\
not
os
.
path
.
isdir
(
pretrain_weights
)
\
and
not
os
.
path
.
isfile
(
pretrain_weights
):
if
pretrain_weights
is
not
None
and
not
os
.
path
.
exists
(
pretrain_weights
):
if
self
.
model_type
==
'classifier'
:
if
pretrain_weights
not
in
[
'IMAGENET'
]:
logging
.
warning
(
...
...
@@ -245,8 +244,8 @@ class BaseAPI:
logging
.
info
(
"Load pretrain weights from {}."
.
format
(
pretrain_weights
),
use_color
=
True
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
if
sensitivities_file
is
not
None
:
import
paddleslim
...
...
@@ -350,7 +349,9 @@ class BaseAPI:
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
def
export_inference_model
(
self
,
save_dir
):
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())]
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
from
paddlex.utils.save
import
save_mask_inference_model
...
...
@@ -387,7 +388,8 @@ class BaseAPI:
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
num_epochs
,
...
...
@@ -511,11 +513,13 @@ class BaseAPI:
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
if
time_eval_one_epoch
is
not
None
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
else
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eta_str
=
seconds_to_hms
(
eta
+
eval_eta
)
logging
.
info
(
...
...
paddlex/cv/models/deeplabv3p.py
浏览文件 @
f3a3c81c
...
...
@@ -251,7 +251,7 @@ class DeepLabv3p(BaseAPI):
lr_decay_power (float): 默认优化器学习率衰减指数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
...
...
paddlex/cv/models/fast_scnn.py
0 → 100644
浏览文件 @
f3a3c81c
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
import
paddle.fluid
as
fluid
import
paddlex
from
collections
import
OrderedDict
from
.deeplabv3p
import
DeepLabv3p
class
FastSCNN
(
DeepLabv3p
):
"""实现Fast SCNN网络的构建并进行训练、评估、预测和模型导出。
Args:
num_classes (int): 类别数。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
multi_loss_weight (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。
也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列,
fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重,
lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。
Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。
TypeError: multi_loss_weight不为list。
ValueError: multi_loss_weight为list但长度小于0或者大于3。
"""
def
__init__
(
self
,
num_classes
=
2
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
ignore_index
=
255
,
multi_loss_weight
=
[
1.0
]):
self
.
init_params
=
locals
()
super
(
DeepLabv3p
,
self
).
__init__
(
'segmenter'
)
# dice_loss或bce_loss只适用两类分割中
if
num_classes
>
2
and
(
use_bce_loss
or
use_dice_loss
):
raise
ValueError
(
"dice loss and bce loss is only applicable to binary classfication"
)
if
class_weight
is
not
None
:
if
isinstance
(
class_weight
,
list
):
if
len
(
class_weight
)
!=
num_classes
:
raise
ValueError
(
"Length of class_weight should be equal to number of classes"
)
elif
isinstance
(
class_weight
,
str
):
if
class_weight
.
lower
()
!=
'dynamic'
:
raise
ValueError
(
"if class_weight is string, must be dynamic!"
)
else
:
raise
TypeError
(
'Expect class_weight is a list or string but receive {}'
.
format
(
type
(
class_weight
)))
if
not
isinstance
(
multi_loss_weight
,
list
):
raise
TypeError
(
'Expect multi_loss_weight is a list but receive {}'
.
format
(
type
(
multi_loss_weight
)))
if
len
(
multi_loss_weight
)
>
3
or
len
(
multi_loss_weight
)
<
0
:
raise
ValueError
(
"Length of multi_loss_weight should be lower than or equal to 3 but greater than 0."
)
self
.
num_classes
=
num_classes
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
multi_loss_weight
=
multi_loss_weight
self
.
ignore_index
=
ignore_index
self
.
labels
=
None
self
.
fixed_input_shape
=
None
def
build_net
(
self
,
mode
=
'train'
):
model
=
paddlex
.
cv
.
nets
.
segmentation
.
FastSCNN
(
self
.
num_classes
,
mode
=
mode
,
use_bce_loss
=
self
.
use_bce_loss
,
use_dice_loss
=
self
.
use_dice_loss
,
class_weight
=
self
.
class_weight
,
ignore_index
=
self
.
ignore_index
,
multi_loss_weight
=
self
.
multi_loss_weight
,
fixed_input_shape
=
self
.
fixed_input_shape
)
inputs
=
model
.
generate_inputs
()
model_out
=
model
.
build_net
(
inputs
)
outputs
=
OrderedDict
()
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
[
'loss'
]
=
model_out
else
:
outputs
[
'pred'
]
=
model_out
[
0
]
outputs
[
'logit'
]
=
model_out
[
1
]
return
inputs
,
outputs
def
train
(
self
,
num_epochs
,
train_dataset
,
train_batch_size
=
2
,
eval_dataset
=
None
,
save_interval_epochs
=
1
,
log_interval_steps
=
2
,
save_dir
=
'output'
,
pretrain_weights
=
'CITYSCAPES'
,
optimizer
=
None
,
learning_rate
=
0.01
,
lr_decay_power
=
0.9
,
use_vdl
=
False
,
sensitivities_file
=
None
,
eval_metric_loss
=
0.05
,
early_stop
=
False
,
early_stop_patience
=
5
,
resume_checkpoint
=
None
):
"""训练。
Args:
num_epochs (int): 训练迭代轮数。
train_dataset (paddlex.datasets): 训练数据读取器。
train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。
eval_dataset (paddlex.datasets): 评估数据读取器。
save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。
save_dir (str): 模型保存路径。默认'output'。
pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'CITYSCAPES'
则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'CITYSCAPES'。
optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用
fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。
learning_rate (float): 默认优化器的初始学习率。默认0.01。
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
连续下降或持平,则终止训练。默认值为5。
resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
Raises:
ValueError: 模型从inference model进行加载。
"""
return
super
(
FastSCNN
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
,
resume_checkpoint
)
paddlex/cv/models/hrnet.py
浏览文件 @
f3a3c81c
...
...
@@ -95,11 +95,6 @@ class HRNet(DeepLabv3p):
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
[
'loss'
]
=
model_out
elif
mode
==
'eval'
:
outputs
[
'loss'
]
=
model_out
[
0
]
outputs
[
'pred'
]
=
model_out
[
1
]
outputs
[
'label'
]
=
model_out
[
2
]
outputs
[
'mask'
]
=
model_out
[
3
]
else
:
outputs
[
'pred'
]
=
model_out
[
0
]
outputs
[
'logit'
]
=
model_out
[
1
]
...
...
@@ -160,7 +155,7 @@ class HRNet(DeepLabv3p):
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
...
...
paddlex/cv/models/unet.py
浏览文件 @
f3a3c81c
...
...
@@ -95,11 +95,6 @@ class UNet(DeepLabv3p):
if
mode
==
'train'
:
self
.
optimizer
.
minimize
(
model_out
)
outputs
[
'loss'
]
=
model_out
elif
mode
==
'eval'
:
outputs
[
'loss'
]
=
model_out
[
0
]
outputs
[
'pred'
]
=
model_out
[
1
]
outputs
[
'label'
]
=
model_out
[
2
]
outputs
[
'mask'
]
=
model_out
[
3
]
else
:
outputs
[
'pred'
]
=
model_out
[
0
]
outputs
[
'logit'
]
=
model_out
[
1
]
...
...
@@ -141,7 +136,7 @@ class UNet(DeepLabv3p):
lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。
use_vdl (bool): 是否使用VisualDL进行可视化。默认False。
sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
则自动下载在
ImageNet
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
则自动下载在
Cityscapes
图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
early_stop (bool): 是否使用提前终止训练策略。默认值为False。
early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
...
...
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
f3a3c81c
...
...
@@ -117,7 +117,9 @@ cityscapes_pretrain = {
'DeepLabv3p_Xception65_CITYSCAPES'
:
'https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz'
,
'HRNet_W18_CITYSCAPES'
:
'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz'
'https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz'
,
'FastSCNN_CITYSCAPES'
:
'https://paddleseg.bj.bcebos.com/models/fast_scnn_cityscape.tar'
}
...
...
@@ -139,6 +141,10 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
elif
class_name
==
'HRNet'
:
logging
.
warning
(
warning_info
.
format
(
class_name
,
flag
,
'IMAGENET'
))
flag
=
'IMAGENET'
elif
class_name
==
'FastSCNN'
:
logging
.
warning
(
warning_info
.
format
(
class_name
,
flag
,
'CITYSCAPES'
))
flag
=
'CITYSCAPES'
elif
flag
==
'CITYSCAPES'
:
model_name
=
'{}_{}'
.
format
(
class_name
,
backbone
)
if
class_name
==
'UNet'
:
...
...
@@ -155,9 +161,14 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
model_name
=
'{}_{}'
.
format
(
class_name
,
backbone
)
logging
.
warning
(
warning_info
.
format
(
model_name
,
flag
,
'IMAGENET'
))
flag
=
'IMAGENET'
elif
flag
==
'IMAGENET'
and
class_name
==
'UNet'
:
logging
.
warning
(
warning_info
.
format
(
class_name
,
flag
,
'COCO'
))
flag
=
'COCO'
elif
flag
==
'IMAGENET'
:
if
class_name
==
'UNet'
:
logging
.
warning
(
warning_info
.
format
(
class_name
,
flag
,
'COCO'
))
flag
=
'COCO'
elif
class_name
==
'FastSCNN'
:
logging
.
warning
(
warning_info
.
format
(
class_name
,
flag
,
'CITYSCAPES'
))
flag
=
'CITYSCAPES'
if
flag
==
'IMAGENET'
:
new_save_dir
=
save_dir
...
...
paddlex/cv/nets/__init__.py
浏览文件 @
f3a3c81c
...
...
@@ -20,6 +20,7 @@ from .mobilenet_v2 import MobileNetV2
from
.mobilenet_v3
import
MobileNetV3
from
.segmentation
import
UNet
from
.segmentation
import
DeepLabv3p
from
.segmentation
import
FastSCNN
from
.xception
import
Xception
from
.densenet
import
DenseNet
from
.shufflenet_v2
import
ShuffleNetV2
...
...
paddlex/cv/nets/segmentation/__init__.py
浏览文件 @
f3a3c81c
...
...
@@ -15,5 +15,6 @@
from
.unet
import
UNet
from
.deeplabv3p
import
DeepLabv3p
from
.hrnet
import
HRNet
from
.fast_scnn
import
FastSCNN
from
.model_utils
import
libs
from
.model_utils
import
loss
paddlex/cv/nets/segmentation/deeplabv3p.py
浏览文件 @
f3a3c81c
...
...
@@ -28,7 +28,6 @@ from .model_utils.libs import sigmoid_to_softmax
from
.model_utils.loss
import
softmax_with_loss
from
.model_utils.loss
import
dice_loss
from
.model_utils.loss
import
bce_loss
import
paddlex.utils.logging
as
logging
from
paddlex.cv.nets.xception
import
Xception
from
paddlex.cv.nets.mobilenet_v2
import
MobileNetV2
...
...
paddlex/cv/nets/segmentation/fast_scnn.py
0 → 100644
浏览文件 @
f3a3c81c
# coding: utf8
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
collections
import
OrderedDict
import
paddle.fluid
as
fluid
from
.model_utils.libs
import
scope
from
.model_utils.libs
import
bn
,
bn_relu
,
relu
,
conv_bn_layer
from
.model_utils.libs
import
conv
,
avg_pool
from
.model_utils.libs
import
separate_conv
from
.model_utils.libs
import
sigmoid_to_softmax
from
.model_utils.loss
import
softmax_with_loss
from
.model_utils.loss
import
dice_loss
from
.model_utils.loss
import
bce_loss
class
FastSCNN
(
object
):
def
__init__
(
self
,
num_classes
,
mode
=
'train'
,
use_bce_loss
=
False
,
use_dice_loss
=
False
,
class_weight
=
None
,
multi_loss_weight
=
[
1.0
],
ignore_index
=
255
,
fixed_input_shape
=
None
):
# dice_loss或bce_loss只适用两类分割中
if
num_classes
>
2
and
(
use_bce_loss
or
use_dice_loss
):
raise
ValueError
(
"dice loss and bce loss is only applicable to binary classfication"
)
if
class_weight
is
not
None
:
if
isinstance
(
class_weight
,
list
):
if
len
(
class_weight
)
!=
num_classes
:
raise
ValueError
(
"Length of class_weight should be equal to number of classes"
)
elif
isinstance
(
class_weight
,
str
):
if
class_weight
.
lower
()
!=
'dynamic'
:
raise
ValueError
(
"if class_weight is string, must be dynamic!"
)
else
:
raise
TypeError
(
'Expect class_weight is a list or string but receive {}'
.
format
(
type
(
class_weight
)))
self
.
num_classes
=
num_classes
self
.
mode
=
mode
self
.
use_bce_loss
=
use_bce_loss
self
.
use_dice_loss
=
use_dice_loss
self
.
class_weight
=
class_weight
self
.
ignore_index
=
ignore_index
self
.
multi_loss_weight
=
multi_loss_weight
self
.
fixed_input_shape
=
fixed_input_shape
def
build_net
(
self
,
inputs
):
if
self
.
use_dice_loss
or
self
.
use_bce_loss
:
self
.
num_classes
=
1
image
=
inputs
[
'image'
]
size
=
fluid
.
layers
.
shape
(
image
)[
2
:]
with
scope
(
'learning_to_downsample'
):
higher_res_features
=
self
.
_learning_to_downsample
(
image
,
32
,
48
,
64
)
with
scope
(
'global_feature_extractor'
):
lower_res_feature
=
self
.
_global_feature_extractor
(
higher_res_features
,
64
,
[
64
,
96
,
128
],
128
,
6
,
[
3
,
3
,
3
])
with
scope
(
'feature_fusion'
):
x
=
self
.
_feature_fusion
(
higher_res_features
,
lower_res_feature
,
64
,
128
,
128
)
with
scope
(
'classifier'
):
logit
=
self
.
_classifier
(
x
,
128
)
logit
=
fluid
.
layers
.
resize_bilinear
(
logit
,
size
,
align_mode
=
0
)
if
len
(
self
.
multi_loss_weight
)
==
3
:
with
scope
(
'aux_layer_higher'
):
higher_logit
=
self
.
_aux_layer
(
higher_res_features
,
self
.
num_classes
)
higher_logit
=
fluid
.
layers
.
resize_bilinear
(
higher_logit
,
size
,
align_mode
=
0
)
with
scope
(
'aux_layer_lower'
):
lower_logit
=
self
.
_aux_layer
(
lower_res_feature
,
self
.
num_classes
)
lower_logit
=
fluid
.
layers
.
resize_bilinear
(
lower_logit
,
size
,
align_mode
=
0
)
logit
=
(
logit
,
higher_logit
,
lower_logit
)
elif
len
(
self
.
multi_loss_weight
)
==
2
:
with
scope
(
'aux_layer_higher'
):
higher_logit
=
self
.
_aux_layer
(
higher_res_features
,
self
.
num_classes
)
higher_logit
=
fluid
.
layers
.
resize_bilinear
(
higher_logit
,
size
,
align_mode
=
0
)
logit
=
(
logit
,
higher_logit
)
else
:
logit
=
(
logit
,
)
if
self
.
num_classes
==
1
:
out
=
sigmoid_to_softmax
(
logit
[
0
])
out
=
fluid
.
layers
.
transpose
(
out
,
[
0
,
2
,
3
,
1
])
else
:
out
=
fluid
.
layers
.
transpose
(
logit
[
0
],
[
0
,
2
,
3
,
1
])
pred
=
fluid
.
layers
.
argmax
(
out
,
axis
=
3
)
pred
=
fluid
.
layers
.
unsqueeze
(
pred
,
axes
=
[
3
])
if
self
.
mode
==
'train'
:
label
=
inputs
[
'label'
]
return
self
.
_get_loss
(
logit
,
label
)
elif
self
.
mode
==
'eval'
:
label
=
inputs
[
'label'
]
loss
=
self
.
_get_loss
(
logit
,
label
)
return
loss
,
pred
,
label
,
mask
else
:
if
self
.
num_classes
==
1
:
logit
=
sigmoid_to_softmax
(
logit
[
0
])
else
:
logit
=
fluid
.
layers
.
softmax
(
logit
[
0
],
axis
=
1
)
return
pred
,
logit
def
generate_inputs
(
self
):
inputs
=
OrderedDict
()
if
self
.
fixed_input_shape
is
not
None
:
input_shape
=
[
None
,
3
,
self
.
fixed_input_shape
[
1
],
self
.
fixed_input_shape
[
0
]
]
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
input_shape
,
name
=
'image'
)
else
:
inputs
[
'image'
]
=
fluid
.
data
(
dtype
=
'float32'
,
shape
=
[
None
,
3
,
None
,
None
],
name
=
'image'
)
if
self
.
mode
==
'train'
:
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
elif
self
.
mode
==
'eval'
:
inputs
[
'label'
]
=
fluid
.
data
(
dtype
=
'int32'
,
shape
=
[
None
,
1
,
None
,
None
],
name
=
'label'
)
return
inputs
def
_get_loss
(
self
,
logits
,
label
):
avg_loss
=
0
if
not
(
self
.
use_dice_loss
or
self
.
use_bce_loss
):
for
i
,
logit
in
enumerate
(
logits
):
logit_mask
=
(
label
.
astype
(
'int32'
)
!=
self
.
ignore_index
).
astype
(
'int32'
)
loss
=
softmax_with_loss
(
logit
,
label
,
logit_mask
,
num_classes
=
self
.
num_classes
,
weight
=
self
.
class_weight
,
ignore_index
=
self
.
ignore_index
)
avg_loss
+=
self
.
multi_loss_weight
[
i
]
*
loss
else
:
if
self
.
use_dice_loss
:
for
i
,
logit
in
enumerate
(
logits
):
logit_mask
=
(
label
.
astype
(
'int32'
)
!=
self
.
ignore_index
).
astype
(
'int32'
)
loss
=
dice_loss
(
logit
,
label
,
logit_mask
)
avg_loss
+=
self
.
multi_loss_weight
[
i
]
*
loss
if
self
.
use_bce_loss
:
for
i
,
logit
in
enumerate
(
logits
):
#logit_label = fluid.layers.resize_nearest(label, logit_shape[2:])
logit_mask
=
(
label
.
astype
(
'int32'
)
!=
self
.
ignore_index
).
astype
(
'int32'
)
loss
=
bce_loss
(
logit
,
label
,
logit_mask
,
ignore_index
=
self
.
ignore_index
)
avg_loss
+=
self
.
multi_loss_weight
[
i
]
*
loss
return
avg_loss
def
_learning_to_downsample
(
self
,
x
,
dw_channels1
=
32
,
dw_channels2
=
48
,
out_channels
=
64
):
x
=
relu
(
bn
(
conv
(
x
,
dw_channels1
,
3
,
2
)))
with
scope
(
'dsconv1'
):
x
=
separate_conv
(
x
,
dw_channels2
,
stride
=
2
,
filter
=
3
,
act
=
fluid
.
layers
.
relu
)
with
scope
(
'dsconv2'
):
x
=
separate_conv
(
x
,
out_channels
,
stride
=
2
,
filter
=
3
,
act
=
fluid
.
layers
.
relu
)
return
x
def
_shortcut
(
self
,
input
,
data_residual
):
return
fluid
.
layers
.
elementwise_add
(
input
,
data_residual
)
def
_dropout2d
(
self
,
input
,
prob
,
is_train
=
False
):
if
not
is_train
:
return
input
keep_prob
=
1.0
-
prob
shape
=
fluid
.
layers
.
shape
(
input
)
channels
=
shape
[
1
]
random_tensor
=
keep_prob
+
fluid
.
layers
.
uniform_random
(
[
shape
[
0
],
channels
,
1
,
1
],
min
=
0.
,
max
=
1.
)
binary_tensor
=
fluid
.
layers
.
floor
(
random_tensor
)
output
=
input
/
keep_prob
*
binary_tensor
return
output
def
_inverted_residual_unit
(
self
,
input
,
num_in_filter
,
num_filters
,
ifshortcut
,
stride
,
filter_size
,
padding
,
expansion_factor
,
name
=
None
):
num_expfilter
=
int
(
round
(
num_in_filter
*
expansion_factor
))
channel_expand
=
conv_bn_layer
(
input
=
input
,
num_filters
=
num_expfilter
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
True
,
name
=
name
+
'_expand'
)
bottleneck_conv
=
conv_bn_layer
(
input
=
channel_expand
,
num_filters
=
num_expfilter
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
padding
,
num_groups
=
num_expfilter
,
if_act
=
True
,
name
=
name
+
'_dwise'
,
use_cudnn
=
False
)
depthwise_output
=
bottleneck_conv
linear_out
=
conv_bn_layer
(
input
=
bottleneck_conv
,
num_filters
=
num_filters
,
filter_size
=
1
,
stride
=
1
,
padding
=
0
,
num_groups
=
1
,
if_act
=
False
,
name
=
name
+
'_linear'
)
if
ifshortcut
:
out
=
self
.
_shortcut
(
input
=
input
,
data_residual
=
linear_out
)
return
out
,
depthwise_output
else
:
return
linear_out
,
depthwise_output
def
_inverted_blocks
(
self
,
input
,
in_c
,
t
,
c
,
n
,
s
,
name
=
None
):
first_block
,
depthwise_output
=
self
.
_inverted_residual_unit
(
input
=
input
,
num_in_filter
=
in_c
,
num_filters
=
c
,
ifshortcut
=
False
,
stride
=
s
,
filter_size
=
3
,
padding
=
1
,
expansion_factor
=
t
,
name
=
name
+
'_1'
)
last_residual_block
=
first_block
last_c
=
c
for
i
in
range
(
1
,
n
):
last_residual_block
,
depthwise_output
=
self
.
_inverted_residual_unit
(
input
=
last_residual_block
,
num_in_filter
=
last_c
,
num_filters
=
c
,
ifshortcut
=
True
,
stride
=
1
,
filter_size
=
3
,
padding
=
1
,
expansion_factor
=
t
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
last_residual_block
,
depthwise_output
def
_psp_module
(
self
,
input
,
out_features
):
cat_layers
=
[]
sizes
=
(
1
,
2
,
3
,
6
)
for
size
in
sizes
:
psp_name
=
"psp"
+
str
(
size
)
with
scope
(
psp_name
):
pool
=
fluid
.
layers
.
adaptive_pool2d
(
input
,
pool_size
=
[
size
,
size
],
pool_type
=
'avg'
,
name
=
psp_name
+
'_adapool'
)
data
=
conv
(
pool
,
out_features
,
filter_size
=
1
,
bias_attr
=
False
,
name
=
psp_name
+
'_conv'
)
data_bn
=
bn
(
data
,
act
=
'relu'
)
interp
=
fluid
.
layers
.
resize_bilinear
(
data_bn
,
out_shape
=
fluid
.
layers
.
shape
(
input
)[
2
:],
name
=
psp_name
+
'_interp'
,
align_mode
=
0
)
cat_layers
.
append
(
interp
)
cat_layers
=
[
input
]
+
cat_layers
out
=
fluid
.
layers
.
concat
(
cat_layers
,
axis
=
1
,
name
=
'psp_cat'
)
return
out
def
_aux_layer
(
self
,
x
,
num_classes
):
x
=
relu
(
bn
(
conv
(
x
,
32
,
3
,
padding
=
1
)))
x
=
self
.
_dropout2d
(
x
,
0.1
,
is_train
=
(
self
.
mode
==
'train'
))
with
scope
(
'logit'
):
x
=
conv
(
x
,
num_classes
,
1
,
bias_attr
=
True
)
return
x
def
_feature_fusion
(
self
,
higher_res_feature
,
lower_res_feature
,
higher_in_channels
,
lower_in_channels
,
out_channels
,
scale_factor
=
4
):
shape
=
fluid
.
layers
.
shape
(
higher_res_feature
)
w
=
shape
[
-
1
]
h
=
shape
[
-
2
]
lower_res_feature
=
fluid
.
layers
.
resize_bilinear
(
lower_res_feature
,
[
h
,
w
],
align_mode
=
0
)
with
scope
(
'dwconv'
):
lower_res_feature
=
relu
(
bn
(
conv
(
lower_res_feature
,
out_channels
,
1
)))
#(lower_res_feature)
with
scope
(
'conv_lower_res'
):
lower_res_feature
=
bn
(
conv
(
lower_res_feature
,
out_channels
,
1
,
bias_attr
=
True
))
with
scope
(
'conv_higher_res'
):
higher_res_feature
=
bn
(
conv
(
higher_res_feature
,
out_channels
,
1
,
bias_attr
=
True
))
out
=
higher_res_feature
+
lower_res_feature
return
relu
(
out
)
def
_global_feature_extractor
(
self
,
x
,
in_channels
=
64
,
block_channels
=
(
64
,
96
,
128
),
out_channels
=
128
,
t
=
6
,
num_blocks
=
(
3
,
3
,
3
)):
x
,
_
=
self
.
_inverted_blocks
(
x
,
in_channels
,
t
,
block_channels
[
0
],
num_blocks
[
0
],
2
,
'inverted_block_1'
)
x
,
_
=
self
.
_inverted_blocks
(
x
,
block_channels
[
0
],
t
,
block_channels
[
1
],
num_blocks
[
1
],
2
,
'inverted_block_2'
)
x
,
_
=
self
.
_inverted_blocks
(
x
,
block_channels
[
1
],
t
,
block_channels
[
2
],
num_blocks
[
2
],
1
,
'inverted_block_3'
)
x
=
self
.
_psp_module
(
x
,
block_channels
[
2
]
//
4
)
with
scope
(
'out'
):
x
=
relu
(
bn
(
conv
(
x
,
out_channels
,
1
)))
return
x
def
_classifier
(
self
,
x
,
dw_channels
,
stride
=
1
):
with
scope
(
'dsconv1'
):
x
=
separate_conv
(
x
,
dw_channels
,
stride
=
stride
,
filter
=
3
,
act
=
fluid
.
layers
.
relu
)
with
scope
(
'dsconv2'
):
x
=
separate_conv
(
x
,
dw_channels
,
stride
=
stride
,
filter
=
3
,
act
=
fluid
.
layers
.
relu
)
x
=
self
.
_dropout2d
(
x
,
0.1
,
is_train
=
self
.
mode
==
'train'
)
x
=
conv
(
x
,
self
.
num_classes
,
1
,
bias_attr
=
True
)
return
x
paddlex/cv/nets/segmentation/hrnet.py
浏览文件 @
f3a3c81c
...
...
@@ -27,7 +27,6 @@ from .model_utils.loss import softmax_with_loss
from
.model_utils.loss
import
dice_loss
from
.model_utils.loss
import
bce_loss
import
paddlex
import
paddlex.utils.logging
as
logging
class
HRNet
(
object
):
...
...
paddlex/cv/nets/segmentation/unet.py
浏览文件 @
f3a3c81c
...
...
@@ -27,7 +27,6 @@ from .model_utils.libs import sigmoid_to_softmax
from
.model_utils.loss
import
softmax_with_loss
from
.model_utils.loss
import
dice_loss
from
.model_utils.loss
import
bce_loss
import
paddlex.utils.logging
as
logging
class
UNet
(
object
):
...
...
@@ -106,7 +105,8 @@ class UNet(object):
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.33
))
with
scope
(
"conv0"
):
data
=
bn_relu
(
conv
(
...
...
@@ -140,8 +140,7 @@ class UNet(object):
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
initializer
=
fluid
.
initializer
.
XavierInitializer
(),
)
with
scope
(
"up"
):
if
self
.
upsample_mode
==
'bilinear'
:
short_cut_shape
=
fluid
.
layers
.
shape
(
short_cut
)
...
...
@@ -197,7 +196,8 @@ class UNet(object):
name
=
'weights'
,
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0
),
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
initializer
=
fluid
.
initializer
.
TruncatedNormal
(
loc
=
0.0
,
scale
=
0.01
))
with
scope
(
"logit"
):
data
=
conv
(
data
,
...
...
paddlex/seg.py
浏览文件 @
f3a3c81c
...
...
@@ -18,5 +18,6 @@ from . import cv
UNet
=
cv
.
models
.
UNet
DeepLabv3p
=
cv
.
models
.
DeepLabv3p
HRNet
=
cv
.
models
.
HRNet
FastSCNN
=
cv
.
models
.
FastSCNN
transforms
=
cv
.
transforms
.
seg_transforms
visualize
=
cv
.
models
.
utils
.
visualize
.
visualize_segmentation
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录