Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
be77e22f
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
be77e22f
编写于
6月 23, 2020
作者:
J
Jason
提交者:
GitHub
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #150 from FlyingQianMM/develop_draw
add hrnet_w18_small_v1 for segmentation
上级
7bc6fe62
41183ea7
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
1393 addition
and
42 deletion
+1393
-42
docs/apis/models/semantic_segmentation.md
docs/apis/models/semantic_segmentation.md
+2
-2
examples/human_segmentation/README.md
examples/human_segmentation/README.md
+181
-0
examples/human_segmentation/bg_replace.py
examples/human_segmentation/bg_replace.py
+314
-0
examples/human_segmentation/data/download_data.py
examples/human_segmentation/data/download_data.py
+33
-0
examples/human_segmentation/eval.py
examples/human_segmentation/eval.py
+85
-0
examples/human_segmentation/infer.py
examples/human_segmentation/infer.py
+109
-0
examples/human_segmentation/postprocess.py
examples/human_segmentation/postprocess.py
+125
-0
examples/human_segmentation/pretrain_weights/download_pretrain_weights.py
...egmentation/pretrain_weights/download_pretrain_weights.py
+40
-0
examples/human_segmentation/quant_offline.py
examples/human_segmentation/quant_offline.py
+85
-0
examples/human_segmentation/train.py
examples/human_segmentation/train.py
+156
-0
examples/human_segmentation/video_infer.py
examples/human_segmentation/video_infer.py
+187
-0
paddlex/cv/datasets/seg_dataset.py
paddlex/cv/datasets/seg_dataset.py
+10
-9
paddlex/cv/models/hrnet.py
paddlex/cv/models/hrnet.py
+8
-7
paddlex/cv/models/utils/pretrain_weights.py
paddlex/cv/models/utils/pretrain_weights.py
+1
-1
paddlex/cv/nets/hrnet.py
paddlex/cv/nets/hrnet.py
+55
-22
paddlex/cv/nets/segmentation/hrnet.py
paddlex/cv/nets/segmentation/hrnet.py
+2
-1
未找到文件。
docs/apis/models/semantic_segmentation.md
浏览文件 @
be77e22f
...
...
@@ -186,10 +186,10 @@ paddlex.seg.HRNet(num_classes=2, width=18, use_bce_loss=False, use_dice_loss=Fal
> **参数**
> > - **num_classes** (int): 类别数。
> > - **width** (int
): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]
。
> > - **width** (int
|str): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64, '18_small_v1']。'18_small_v1'是18的轻量级版本
。
> > - **use_bce_loss** (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list
/
str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **class_weight** (list
|
str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
### train 训练接口
...
...
examples/human_segmentation/README.md
0 → 100644
浏览文件 @
be77e22f
# HumanSeg人像分割模型
本教程基于PaddleX核心分割网络,提供针对人像分割场景从预训练模型、Fine-tune、视频分割预测部署的全流程应用指南。
## 安装
**前置依赖**
*
paddlepaddle >= 1.8.0
*
python >= 3.5
```
pip install paddlex -i https://mirror.baidu.com/pypi/simple
```
安装的相关问题参考
[
PaddleX安装
](
https://paddlex.readthedocs.io/zh_CN/latest/install.html
)
## 预训练模型
HumanSeg开放了在大规模人像数据上训练的两个预训练模型,满足多种使用场景的需求
| 模型类型 | Checkpoint Parameter | Inference Model | Quant Inference Model | 备注 |
| --- | --- | --- | ---| --- |
| HumanSeg-server |
[
humanseg_server_params
](
https://paddlex.bj.bcebos.com/humanseg/models/humanseg_server.pdparams
)
|
[
humanseg_server_inference
](
https://paddlex.bj.bcebos.com/humanseg/models/humanseg_server_inference.zip
)
| -- | 高精度模型,适用于服务端GPU且背景复杂的人像场景, 模型结构为Deeplabv3+/Xcetion65, 输入大小(512, 512) |
| HumanSeg-mobile |
[
humanseg_mobile_params
](
https://paddlex.bj.bcebos.com/humanseg/models/humanseg_mobile.pdparams
)
|
[
humanseg_mobile_inference
](
https://paddlex.bj.bcebos.com/humanseg/models/humanseg_mobile_inference.zip
)
|
[
humanseg_mobile_quant
](
https://paddlex.bj.bcebos.com/humanseg/models/humanseg_mobile_quant.zip
)
| 轻量级模型, 适用于移动端或服务端CPU的前置摄像头场景,模型结构为HRNet_w18_samll_v1,输入大小(192, 192) |
模型性能
| 模型 | 模型大小 | 计算耗时 |
| --- | --- | --- |
|humanseg_server_inference| 158M | - |
|humanseg_mobile_inference | 5.8 M | 42.35ms |
|humanseg_mobile_quant | 1.6M | 24.93ms |
计算耗时运行环境: 小米,cpu:骁龙855, 内存:6GB, 图片大小:192
*
192
**NOTE:**
其中Checkpoint Parameter为模型权重,用于Fine-tuning场景。
*
Inference Model和Quant Inference Model为预测部署模型,包含
`__model__`
计算图结构、
`__params__`
模型参数和
`model.yaml`
基础的模型配置信息。
*
其中Inference Model适用于服务端的CPU和GPU预测部署,Qunat Inference Model为量化版本,适用于通过Paddle Lite进行移动端等端侧设备部署。
执行以下脚本进行HumanSeg预训练模型的下载
```
bash
python pretrain_weights/download_pretrain_weights.py
```
## 下载测试数据
我们提供了
[
supervise.ly
](
https://supervise.ly/
)
发布人像分割数据集
**Supervisely Persons**
, 从中随机抽取一小部分并转化成PaddleX可直接加载数据格式。通过运行以下代码进行快速下载,其中包含手机前置摄像头的人像测试视频
`video_test.mp4`
.
```
bash
python data/download_data.py
```
## 快速体验视频流人像分割
结合DIS(Dense Inverse Search-basedmethod)光流算法预测结果与分割结果,改善视频流人像分割
```
bash
# 通过电脑摄像头进行实时分割处理
python video_infer.py
--model_dir
pretrain_weights/humanseg_mobile_inference
# 对人像视频进行分割处理
python video_infer.py
--model_dir
pretrain_weights/humanseg_mobile_inference
--video_path
data/video_test.mp4
```
视频分割结果如下:
<img
src=
"https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif"
width=
"20%"
height=
"20%"
><img
src=
"https://paddleseg.bj.bcebos.com/humanseg/data/result.gif"
width=
"20%"
height=
"20%"
>
根据所选背景进行背景替换,背景可以是一张图片,也可以是一段视频。
```
bash
# 通过电脑摄像头进行实时背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py
--model_dir
pretrain_weights/humanseg_mobile_inference
--background_image_path
data/background.jpg
# 对人像视频进行背景替换处理, 也可通过'--background_video_path'传入背景视频
python bg_replace.py
--model_dir
pretrain_weights/humanseg_mobile_inference
--video_path
data/video_test.mp4
--background_image_path
data/background.jpg
# 对单张图像进行背景替换
python bg_replace.py
--model_dir
pretrain_weights/humanseg_mobile_inference
--image_path
data/human_image.jpg
--background_image_path
data/background.jpg
```
背景替换结果如下:
<img
src=
"https://paddleseg.bj.bcebos.com/humanseg/data/video_test.gif"
width=
"20%"
height=
"20%"
><img
src=
"https://paddleseg.bj.bcebos.com/humanseg/data/bg_replace.gif"
width=
"20%"
height=
"20%"
>
**NOTE**
:
视频分割处理时间需要几分钟,请耐心等待。
提供的模型适用于手机摄像头竖屏拍摄场景,宽屏效果会略差一些。
## 训练
使用下述命令基于与训练模型进行Fine-tuning,请确保选用的模型结构
`model_type`
与模型参数
`pretrain_weights`
匹配。
```
bash
# 指定GPU卡号(以0号卡为例)
export
CUDA_VISIBLE_DEVICES
=
0
# 若不使用GPU,则将CUDA_VISIBLE_DEVICES指定为空
# export CUDA_VISIBLE_DEVICES=
python train.py
--model_type
HumanSegMobile
\
--save_dir
output/
\
--data_dir
data/mini_supervisely
\
--train_list
data/mini_supervisely/train.txt
\
--val_list
data/mini_supervisely/val.txt
\
--pretrain_weights
pretrain_weights/humanseg_mobile_params
\
--batch_size
8
\
--learning_rate
0.001
\
--num_epochs
10
\
--image_shape
192 192
```
其中参数含义如下:
*
`--model_type`
: 模型类型,可选项为:HumanSegServer和HumanSegMobile
*
`--save_dir`
: 模型保存路径
*
`--data_dir`
: 数据集路径
*
`--train_list`
: 训练集列表路径
*
`--val_list`
: 验证集列表路径
*
`--pretrain_weights`
: 预训练模型路径
*
`--batch_size`
: 批大小
*
`--learning_rate`
: 初始学习率
*
`--num_epochs`
: 训练轮数
*
`--image_shape`
: 网络输入图像大小(w, h)
更多命令行帮助可运行下述命令进行查看:
```
bash
python train.py
--help
```
**NOTE**
可通过更换
`--model_type`
变量与对应的
`--pretrain_weights`
使用不同的模型快速尝试。
## 评估
使用下述命令进行评估
```
bash
python eval.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--val_list
data/mini_supervisely/val.txt
\
--image_shape
192 192
```
其中参数含义如下:
*
`--model_dir`
: 模型路径
*
`--data_dir`
: 数据集路径
*
`--val_list`
: 验证集列表路径
*
`--image_shape`
: 网络输入图像大小(w, h)
## 预测
使用下述命令进行预测, 预测结果默认保存在
`./output/result/`
文件夹中。
```
bash
python infer.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--test_list
data/mini_supervisely/test.txt
\
--save_dir
output/result
\
--image_shape
192 192
```
其中参数含义如下:
*
`--model_dir`
: 模型路径
*
`--data_dir`
: 数据集路径
*
`--test_list`
: 测试集列表路径
*
`--image_shape`
: 网络输入图像大小(w, h)
## 模型导出
```
bash
paddlex
--export_inference
--model_dir
output/best_model
\
--save_dir
output/export
```
其中参数含义如下:
*
`--model_dir`
: 模型路径
*
`--save_dir`
: 导出模型保存路径
## 离线量化
```
bash
python quant_offline.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--quant_list
data/mini_supervisely/val.txt
\
--save_dir
output/quant_offline
\
--image_shape
192 192
```
其中参数含义如下:
*
`--model_dir`
: 待量化模型路径
*
`--data_dir`
: 数据集路径
*
`--quant_list`
: 量化数据集列表路径,一般直接选择训练集或验证集
*
`--save_dir`
: 量化模型保存路径
*
`--image_shape`
: 网络输入图像大小(w, h)
examples/human_segmentation/bg_replace.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
import
cv2
import
numpy
as
np
from
postprocess
import
postprocess
,
threshold_mask
import
paddlex
as
pdx
import
paddlex.utils.logging
as
logging
from
paddlex.seg
import
transforms
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg inference for video'
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'Model path for inference'
,
type
=
str
)
parser
.
add_argument
(
'--image_path'
,
dest
=
'image_path'
,
help
=
'Image including human'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--background_image_path'
,
dest
=
'background_image_path'
,
help
=
'Background image for replacing'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--video_path'
,
dest
=
'video_path'
,
help
=
'Video path for inference'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--background_video_path'
,
dest
=
'background_video_path'
,
help
=
'Background video path for replacing'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
default
=
'./output'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
def
bg_replace
(
label_map
,
img
,
bg
):
h
,
w
,
_
=
img
.
shape
bg
=
cv2
.
resize
(
bg
,
(
w
,
h
))
label_map
=
np
.
repeat
(
label_map
[:,
:,
np
.
newaxis
],
3
,
axis
=
2
)
comb
=
(
label_map
*
img
+
(
1
-
label_map
)
*
bg
).
astype
(
np
.
uint8
)
return
comb
def
recover
(
img
,
im_info
):
if
im_info
[
0
]
==
'resize'
:
w
,
h
=
im_info
[
1
][
1
],
im_info
[
1
][
0
]
img
=
cv2
.
resize
(
img
,
(
w
,
h
),
cv2
.
INTER_LINEAR
)
elif
im_info
[
0
]
==
'padding'
:
w
,
h
=
im_info
[
1
][
0
],
im_info
[
1
][
0
]
img
=
img
[
0
:
h
,
0
:
w
,
:]
return
img
def
infer
(
args
):
resize_h
=
args
.
image_shape
[
1
]
resize_w
=
args
.
image_shape
[
0
]
test_transforms
=
transforms
.
Compose
([
transforms
.
Normalize
()])
model
=
pdx
.
load_model
(
args
.
model_dir
)
if
not
osp
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
# 图像背景替换
if
args
.
image_path
is
not
None
:
if
not
osp
.
exists
(
args
.
image_path
):
raise
Exception
(
'The --image_path is not existed: {}'
.
format
(
args
.
image_path
))
if
args
.
background_image_path
is
None
:
raise
Exception
(
'The --background_image_path is not set. Please set it'
)
else
:
if
not
osp
.
exists
(
args
.
background_image_path
):
raise
Exception
(
'The --background_image_path is not existed: {}'
.
format
(
args
.
background_image_path
))
img
=
cv2
.
imread
(
args
.
image_path
)
im_shape
=
img
.
shape
im_scale_x
=
float
(
resize_w
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
resize_h
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
image
=
im
.
astype
(
'float32'
)
im_info
=
(
'resize'
,
im_shape
[
0
:
2
])
pred
=
model
.
predict
(
image
,
test_transforms
)
label_map
=
pred
[
'label_map'
]
label_map
=
recover
(
label_map
,
im_info
)
bg
=
cv2
.
imread
(
args
.
background_image_path
)
save_name
=
osp
.
basename
(
args
.
image_path
)
save_path
=
osp
.
join
(
args
.
save_dir
,
save_name
)
result
=
bg_replace
(
label_map
,
img
,
bg
)
cv2
.
imwrite
(
save_path
,
result
)
# 视频背景替换,如果提供背景视频则以背景视频作为背景,否则采用提供的背景图片
else
:
is_video_bg
=
False
if
args
.
background_video_path
is
not
None
:
if
not
osp
.
exists
(
args
.
background_video_path
):
raise
Exception
(
'The --background_video_path is not existed: {}'
.
format
(
args
.
background_video_path
))
is_video_bg
=
True
elif
args
.
background_image_path
is
not
None
:
if
not
osp
.
exists
(
args
.
background_image_path
):
raise
Exception
(
'The --background_image_path is not existed: {}'
.
format
(
args
.
background_image_path
))
else
:
raise
Exception
(
'Please offer backgound image or video. You should set --backbground_iamge_paht or --background_video_path'
)
disflow
=
cv2
.
DISOpticalFlow_create
(
cv2
.
DISOPTICAL_FLOW_PRESET_ULTRAFAST
)
prev_gray
=
np
.
zeros
((
resize_h
,
resize_w
),
np
.
uint8
)
prev_cfd
=
np
.
zeros
((
resize_h
,
resize_w
),
np
.
float32
)
is_init
=
True
if
args
.
video_path
is
not
None
:
logging
.
info
(
'Please wait. It is computing......'
)
if
not
osp
.
exists
(
args
.
video_path
):
raise
Exception
(
'The --video_path is not existed: {}'
.
format
(
args
.
video_path
))
cap_video
=
cv2
.
VideoCapture
(
args
.
video_path
)
fps
=
cap_video
.
get
(
cv2
.
CAP_PROP_FPS
)
width
=
int
(
cap_video
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
cap_video
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
save_name
=
osp
.
basename
(
args
.
video_path
)
save_name
=
save_name
.
split
(
'.'
)[
0
]
save_path
=
osp
.
join
(
args
.
save_dir
,
save_name
+
'.avi'
)
cap_out
=
cv2
.
VideoWriter
(
save_path
,
cv2
.
VideoWriter_fourcc
(
'M'
,
'J'
,
'P'
,
'G'
),
fps
,
(
width
,
height
))
if
is_video_bg
:
cap_bg
=
cv2
.
VideoCapture
(
args
.
background_video_path
)
frames_bg
=
cap_bg
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
)
current_frame_bg
=
1
else
:
img_bg
=
cv2
.
imread
(
args
.
background_image_path
)
while
cap_video
.
isOpened
():
ret
,
frame
=
cap_video
.
read
()
if
ret
:
im_shape
=
frame
.
shape
im_scale_x
=
float
(
resize_w
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
resize_h
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
frame
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
image
=
im
.
astype
(
'float32'
)
im_info
=
(
'resize'
,
im_shape
[
0
:
2
])
pred
=
model
.
predict
(
image
,
test_transforms
)
score_map
=
pred
[
'score_map'
]
cur_gray
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2GRAY
)
cur_gray
=
cv2
.
resize
(
cur_gray
,
(
resize_w
,
resize_h
))
score_map
=
255
*
score_map
[:,
:,
1
]
optflow_map
=
postprocess
(
cur_gray
,
score_map
,
prev_gray
,
prev_cfd
,
\
disflow
,
is_init
)
prev_gray
=
cur_gray
.
copy
()
prev_cfd
=
optflow_map
.
copy
()
is_init
=
False
optflow_map
=
cv2
.
GaussianBlur
(
optflow_map
,
(
3
,
3
),
0
)
optflow_map
=
threshold_mask
(
optflow_map
,
thresh_bg
=
0.2
,
thresh_fg
=
0.8
)
score_map
=
recover
(
optflow_map
,
im_info
)
#循环读取背景帧
if
is_video_bg
:
ret_bg
,
frame_bg
=
cap_bg
.
read
()
if
ret_bg
:
if
current_frame_bg
==
frames_bg
:
current_frame_bg
=
1
cap_bg
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
0
)
else
:
break
current_frame_bg
+=
1
comb
=
bg_replace
(
score_map
,
frame
,
frame_bg
)
else
:
comb
=
bg_replace
(
score_map
,
frame
,
img_bg
)
cap_out
.
write
(
comb
)
else
:
break
if
is_video_bg
:
cap_bg
.
release
()
cap_video
.
release
()
cap_out
.
release
()
# 当没有输入预测图像和视频的时候,则打开摄像头
else
:
cap_video
=
cv2
.
VideoCapture
(
0
)
if
not
cap_video
.
isOpened
():
raise
IOError
(
"Error opening video stream or file, "
"--video_path whether existing: {}"
" or camera whether working"
.
format
(
args
.
video_path
))
return
if
is_video_bg
:
cap_bg
=
cv2
.
VideoCapture
(
args
.
background_video_path
)
frames_bg
=
cap_bg
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
)
current_frame_bg
=
1
else
:
img_bg
=
cv2
.
imread
(
args
.
background_image_path
)
while
cap_video
.
isOpened
():
ret
,
frame
=
cap_video
.
read
()
if
ret
:
im_shape
=
frame
.
shape
im_scale_x
=
float
(
resize_w
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
resize_h
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
frame
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
image
=
im
.
astype
(
'float32'
)
im_info
=
(
'resize'
,
im_shape
[
0
:
2
])
pred
=
model
.
predict
(
image
,
test_transforms
)
score_map
=
pred
[
'score_map'
]
cur_gray
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2GRAY
)
cur_gray
=
cv2
.
resize
(
cur_gray
,
(
resize_w
,
resize_h
))
score_map
=
255
*
score_map
[:,
:,
1
]
optflow_map
=
postprocess
(
cur_gray
,
score_map
,
prev_gray
,
prev_cfd
,
\
disflow
,
is_init
)
prev_gray
=
cur_gray
.
copy
()
prev_cfd
=
optflow_map
.
copy
()
is_init
=
False
optflow_map
=
cv2
.
GaussianBlur
(
optflow_map
,
(
3
,
3
),
0
)
optflow_map
=
threshold_mask
(
optflow_map
,
thresh_bg
=
0.2
,
thresh_fg
=
0.8
)
score_map
=
recover
(
optflow_map
,
im_info
)
#循环读取背景帧
if
is_video_bg
:
ret_bg
,
frame_bg
=
cap_bg
.
read
()
if
ret_bg
:
if
current_frame_bg
==
frames_bg
:
current_frame_bg
=
1
cap_bg
.
set
(
cv2
.
CAP_PROP_POS_FRAMES
,
0
)
else
:
break
current_frame_bg
+=
1
comb
=
bg_replace
(
score_map
,
frame
,
frame_bg
)
else
:
comb
=
bg_replace
(
score_map
,
frame
,
img_bg
)
cv2
.
imshow
(
'HumanSegmentation'
,
comb
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
else
:
break
if
is_video_bg
:
cap_bg
.
release
()
cap_video
.
release
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
infer
(
args
)
examples/human_segmentation/data/download_data.py
0 → 100644
浏览文件 @
be77e22f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
import
paddlex
as
pdx
def
download_data
(
savepath
):
url
=
"https://paddleseg.bj.bcebos.com/humanseg/data/mini_supervisely.zip"
pdx
.
utils
.
download_and_decompress
(
url
=
url
,
path
=
savepath
)
url
=
"https://paddleseg.bj.bcebos.com/humanseg/data/video_test.zip"
pdx
.
utils
.
download_and_decompress
(
url
=
url
,
path
=
savepath
)
if
__name__
==
"__main__"
:
download_data
(
LOCAL_PATH
)
print
(
"Data download finish!"
)
examples/human_segmentation/eval.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
paddlex
as
pdx
import
paddlex.utils.logging
as
logging
from
paddlex.seg
import
transforms
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg training'
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'Model path for evaluating'
,
type
=
str
,
default
=
'output/best_model'
)
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--val_list'
,
dest
=
'val_list'
,
help
=
'Val list file of dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
def
dict2str
(
dict_input
):
out
=
''
for
k
,
v
in
dict_input
.
items
():
try
:
v
=
round
(
float
(
v
),
6
)
except
:
pass
out
=
out
+
'{}={}, '
.
format
(
k
,
v
)
return
out
.
strip
(
', '
)
def
evaluate
(
args
):
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
transforms
=
eval_transforms
)
model
=
pdx
.
load_model
(
args
.
model_dir
)
metrics
=
model
.
evaluate
(
eval_dataset
,
args
.
batch_size
)
logging
.
info
(
'[EVAL] Finished, {} .'
.
format
(
dict2str
(
metrics
)))
if
__name__
==
'__main__'
:
args
=
parse_args
()
evaluate
(
args
)
examples/human_segmentation/infer.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
import
cv2
import
numpy
as
np
import
tqdm
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg prediction and visualization'
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'Model path for prediction'
,
type
=
str
)
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--test_list'
,
dest
=
'test_list'
,
help
=
'Test list file of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
default
=
'./output/result'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
def
infer
(
args
):
def
makedir
(
path
):
sub_dir
=
osp
.
dirname
(
path
)
if
not
osp
.
exists
(
sub_dir
):
os
.
makedirs
(
sub_dir
)
test_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
model
=
pdx
.
load_model
(
args
.
model_dir
)
added_saved_path
=
osp
.
join
(
args
.
save_dir
,
'added'
)
mat_saved_path
=
osp
.
join
(
args
.
save_dir
,
'mat'
)
scoremap_saved_path
=
osp
.
join
(
args
.
save_dir
,
'scoremap'
)
with
open
(
args
.
test_list
,
'r'
)
as
f
:
files
=
f
.
readlines
()
for
file
in
tqdm
.
tqdm
(
files
):
file
=
file
.
strip
()
im_file
=
osp
.
join
(
args
.
data_dir
,
file
)
im
=
cv2
.
imread
(
im_file
)
result
=
model
.
predict
(
im_file
,
transforms
=
test_transforms
)
# save added image
added_image
=
pdx
.
seg
.
visualize
(
im_file
,
result
,
weight
=
0.6
,
save_dir
=
None
)
added_image_file
=
osp
.
join
(
added_saved_path
,
file
)
makedir
(
added_image_file
)
cv2
.
imwrite
(
added_image_file
,
added_image
)
# save score map
score_map
=
result
[
'score_map'
][:,
:,
1
]
score_map
=
(
score_map
*
255
).
astype
(
np
.
uint8
)
score_map_file
=
osp
.
join
(
scoremap_saved_path
,
file
)
makedir
(
score_map_file
)
cv2
.
imwrite
(
score_map_file
,
score_map
)
# save mat image
score_map
=
np
.
expand_dims
(
score_map
,
axis
=-
1
)
mat_image
=
np
.
concatenate
([
im
,
score_map
],
axis
=
2
)
mat_file
=
osp
.
join
(
mat_saved_path
,
file
)
ext
=
osp
.
splitext
(
mat_file
)[
-
1
]
mat_file
=
mat_file
.
replace
(
ext
,
'.png'
)
makedir
(
mat_file
)
cv2
.
imwrite
(
mat_file
,
mat_image
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
infer
(
args
)
examples/human_segmentation/postprocess.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
def
cal_optical_flow_tracking
(
pre_gray
,
cur_gray
,
prev_cfd
,
dl_weights
,
disflow
):
"""计算光流跟踪匹配点和光流图
输入参数:
pre_gray: 上一帧灰度图
cur_gray: 当前帧灰度图
prev_cfd: 上一帧光流图
dl_weights: 融合权重图
disflow: 光流数据结构
返回值:
is_track: 光流点跟踪二值图,即是否具有光流点匹配
track_cfd: 光流跟踪图
"""
check_thres
=
8
h
,
w
=
pre_gray
.
shape
[:
2
]
track_cfd
=
np
.
zeros_like
(
prev_cfd
)
is_track
=
np
.
zeros_like
(
pre_gray
)
flow_fw
=
disflow
.
calc
(
pre_gray
,
cur_gray
,
None
)
flow_bw
=
disflow
.
calc
(
cur_gray
,
pre_gray
,
None
)
flow_fw
=
np
.
round
(
flow_fw
).
astype
(
np
.
int
)
flow_bw
=
np
.
round
(
flow_bw
).
astype
(
np
.
int
)
y_list
=
np
.
array
(
range
(
h
))
x_list
=
np
.
array
(
range
(
w
))
yv
,
xv
=
np
.
meshgrid
(
y_list
,
x_list
)
yv
,
xv
=
yv
.
T
,
xv
.
T
cur_x
=
xv
+
flow_fw
[:,
:,
0
]
cur_y
=
yv
+
flow_fw
[:,
:,
1
]
# 超出边界不跟踪
not_track
=
(
cur_x
<
0
)
+
(
cur_x
>=
w
)
+
(
cur_y
<
0
)
+
(
cur_y
>=
h
)
flow_bw
[
~
not_track
]
=
flow_bw
[
cur_y
[
~
not_track
],
cur_x
[
~
not_track
]]
not_track
+=
(
np
.
square
(
flow_fw
[:,
:,
0
]
+
flow_bw
[:,
:,
0
])
+
np
.
square
(
flow_fw
[:,
:,
1
]
+
flow_bw
[:,
:,
1
])
)
>=
check_thres
track_cfd
[
cur_y
[
~
not_track
],
cur_x
[
~
not_track
]]
=
prev_cfd
[
~
not_track
]
is_track
[
cur_y
[
~
not_track
],
cur_x
[
~
not_track
]]
=
1
not_flow
=
np
.
all
(
np
.
abs
(
flow_fw
)
==
0
,
axis
=-
1
)
*
np
.
all
(
np
.
abs
(
flow_bw
)
==
0
,
axis
=-
1
)
dl_weights
[
cur_y
[
not_flow
],
cur_x
[
not_flow
]]
=
0.05
return
track_cfd
,
is_track
,
dl_weights
def
fuse_optical_flow_tracking
(
track_cfd
,
dl_cfd
,
dl_weights
,
is_track
):
"""光流追踪图和人像分割结构融合
输入参数:
track_cfd: 光流追踪图
dl_cfd: 当前帧分割结果
dl_weights: 融合权重图
is_track: 光流点匹配二值图
返回
cur_cfd: 光流跟踪图和人像分割结果融合图
"""
fusion_cfd
=
dl_cfd
.
copy
()
is_track
=
is_track
.
astype
(
np
.
bool
)
fusion_cfd
[
is_track
]
=
dl_weights
[
is_track
]
*
dl_cfd
[
is_track
]
+
(
1
-
dl_weights
[
is_track
])
*
track_cfd
[
is_track
]
# 确定区域
index_certain
=
((
dl_cfd
>
0.9
)
+
(
dl_cfd
<
0.1
))
*
is_track
index_less01
=
(
dl_weights
<
0.1
)
*
index_certain
fusion_cfd
[
index_less01
]
=
0.3
*
dl_cfd
[
index_less01
]
+
0.7
*
track_cfd
[
index_less01
]
index_larger09
=
(
dl_weights
>=
0.1
)
*
index_certain
fusion_cfd
[
index_larger09
]
=
0.4
*
dl_cfd
[
index_larger09
]
+
0.6
*
track_cfd
[
index_larger09
]
return
fusion_cfd
def
threshold_mask
(
img
,
thresh_bg
,
thresh_fg
):
dst
=
(
img
/
255.0
-
thresh_bg
)
/
(
thresh_fg
-
thresh_bg
)
dst
[
np
.
where
(
dst
>
1
)]
=
1
dst
[
np
.
where
(
dst
<
0
)]
=
0
return
dst
.
astype
(
np
.
float32
)
def
postprocess
(
cur_gray
,
scoremap
,
prev_gray
,
pre_cfd
,
disflow
,
is_init
):
"""光流优化
Args:
cur_gray : 当前帧灰度图
pre_gray : 前一帧灰度图
pre_cfd :前一帧融合结果
scoremap : 当前帧分割结果
difflow : 光流
is_init : 是否第一帧
Returns:
fusion_cfd : 光流追踪图和预测结果融合图
"""
h
,
w
=
scoremap
.
shape
cur_cfd
=
scoremap
.
copy
()
if
is_init
:
if
h
<=
64
or
w
<=
64
:
disflow
.
setFinestScale
(
1
)
elif
h
<=
160
or
w
<=
160
:
disflow
.
setFinestScale
(
2
)
else
:
disflow
.
setFinestScale
(
3
)
fusion_cfd
=
cur_cfd
else
:
weights
=
np
.
ones
((
h
,
w
),
np
.
float32
)
*
0.3
track_cfd
,
is_track
,
weights
=
cal_optical_flow_tracking
(
prev_gray
,
cur_gray
,
pre_cfd
,
weights
,
disflow
)
fusion_cfd
=
fuse_optical_flow_tracking
(
track_cfd
,
cur_cfd
,
weights
,
is_track
)
return
fusion_cfd
examples/human_segmentation/pretrain_weights/download_pretrain_weights.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
sys
import
os
LOCAL_PATH
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
import
paddlex
as
pdx
import
paddlehub
as
hub
model_urls
=
{
"PaddleX_HumanSeg_Server_Params"
:
"https://bj.bcebos.com/paddlex/models/humanseg/humanseg_server_params.tar"
,
"PaddleX_HumanSeg_Server_Inference"
:
"https://bj.bcebos.com/paddlex/models/humanseg/humanseg_server_inference.tar"
,
"PaddleX_HumanSeg_Mobile_Params"
:
"https://bj.bcebos.com/paddlex/models/humanseg/humanseg_mobile_params.tar"
,
"PaddleX_HumanSeg_Mobile_Inference"
:
"https://bj.bcebos.com/paddlex/models/humanseg/humanseg_mobile_inference.tar"
,
"PaddleX_HumanSeg_Mobile_Quant"
:
"https://bj.bcebos.com/paddlex/models/humanseg/humanseg_mobile_quant.tar"
}
if
__name__
==
"__main__"
:
for
model_name
,
url
in
model_urls
.
items
():
pdx
.
utils
.
download_and_decompress
(
url
=
url
,
path
=
LOCAL_PATH
)
print
(
"Pretrained Model download success!"
)
examples/human_segmentation/quant_offline.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg training'
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'Model path for quant'
,
type
=
str
,
default
=
'output/best_model'
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--batch_nums'
,
dest
=
'batch_nums'
,
help
=
'Batch number for quant'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'the root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--quant_list'
,
dest
=
'quant_list'
,
help
=
'Image file list for model quantization, it can be vat.txt or train.txt'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the quant model'
,
type
=
str
,
default
=
'./output/quant_offline'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
def
evaluate
(
args
):
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
quant_list
,
transforms
=
eval_transforms
)
model
=
pdx
.
load_model
(
args
.
model_dir
)
pdx
.
slim
.
export_quant_model
(
model
,
eval_dataset
,
args
.
batch_size
,
args
.
batch_nums
,
args
.
save_dir
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
evaluate
(
args
)
examples/human_segmentation/train.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
paddlex
as
pdx
from
paddlex.seg
import
transforms
MODEL_TYPE
=
[
'HumanSegMobile'
,
'HumanSegServer'
]
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg training'
)
parser
.
add_argument
(
'--model_type'
,
dest
=
'model_type'
,
help
=
"Model type for traing, which is one of ('HumanSegMobile', 'HumanSegServer')"
,
type
=
str
,
default
=
'HumanSegMobile'
)
parser
.
add_argument
(
'--data_dir'
,
dest
=
'data_dir'
,
help
=
'The root directory of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--train_list'
,
dest
=
'train_list'
,
help
=
'Train list file of dataset'
,
type
=
str
)
parser
.
add_argument
(
'--val_list'
,
dest
=
'val_list'
,
help
=
'Val list file of dataset'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the model snapshot'
,
type
=
str
,
default
=
'./output'
)
parser
.
add_argument
(
'--num_classes'
,
dest
=
'num_classes'
,
help
=
'Number of classes'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
parser
.
add_argument
(
'--num_epochs'
,
dest
=
'num_epochs'
,
help
=
'Number epochs for training'
,
type
=
int
,
default
=
100
)
parser
.
add_argument
(
'--batch_size'
,
dest
=
'batch_size'
,
help
=
'Mini batch size'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--learning_rate'
,
dest
=
'learning_rate'
,
help
=
'Learning rate'
,
type
=
float
,
default
=
0.01
)
parser
.
add_argument
(
'--pretrain_weights'
,
dest
=
'pretrain_weights'
,
help
=
'The path of pretrianed weight'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--resume_checkpoint'
,
dest
=
'resume_checkpoint'
,
help
=
'The path of resume checkpoint'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--use_vdl'
,
dest
=
'use_vdl'
,
help
=
'Whether to use visualdl'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--save_interval_epochs'
,
dest
=
'save_interval_epochs'
,
help
=
'The interval epochs for save a model snapshot'
,
type
=
int
,
default
=
5
)
return
parser
.
parse_args
()
def
train
(
args
):
train_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Normalize
()
])
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
train_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
train_list
,
transforms
=
train_transforms
,
shuffle
=
True
)
eval_dataset
=
pdx
.
datasets
.
SegDataset
(
data_dir
=
args
.
data_dir
,
file_list
=
args
.
val_list
,
transforms
=
eval_transforms
)
if
args
.
model_type
==
'HumanSegMobile'
:
model
=
pdx
.
seg
.
HRNet
(
num_classes
=
args
.
num_classes
,
width
=
'18_small_v1'
)
elif
args
.
model_type
==
'HumanSegServer'
:
model
=
pdx
.
seg
.
DeepLabv3p
(
num_classes
=
args
.
num_classes
,
backbone
=
'Xception65'
)
else
:
raise
ValueError
(
"--model_type: {} is set wrong, it shold be one of ('HumanSegMobile', "
"'HumanSegLite', 'HumanSegServer')"
.
format
(
args
.
model_type
))
model
.
train
(
num_epochs
=
args
.
num_epochs
,
train_dataset
=
train_dataset
,
train_batch_size
=
args
.
batch_size
,
eval_dataset
=
eval_dataset
,
save_interval_epochs
=
args
.
save_interval_epochs
,
learning_rate
=
args
.
learning_rate
,
pretrain_weights
=
args
.
pretrain_weights
,
resume_checkpoint
=
args
.
resume_checkpoint
,
save_dir
=
args
.
save_dir
,
use_vdl
=
args
.
use_vdl
)
if
__name__
==
'__main__'
:
args
=
parse_args
()
train
(
args
)
examples/human_segmentation/video_infer.py
0 → 100644
浏览文件 @
be77e22f
# coding: utf8
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
os.path
as
osp
import
cv2
import
numpy
as
np
from
postprocess
import
postprocess
,
threshold_mask
import
paddlex
as
pdx
import
paddlex.utils.logging
as
logging
from
paddlex.seg
import
transforms
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'HumanSeg inference for video'
)
parser
.
add_argument
(
'--model_dir'
,
dest
=
'model_dir'
,
help
=
'Model path for inference'
,
type
=
str
)
parser
.
add_argument
(
'--video_path'
,
dest
=
'video_path'
,
help
=
'Video path for inference, camera will be used if the path not existing'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--save_dir'
,
dest
=
'save_dir'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
default
=
'./output'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
def
recover
(
img
,
im_info
):
if
im_info
[
0
]
==
'resize'
:
w
,
h
=
im_info
[
1
][
1
],
im_info
[
1
][
0
]
img
=
cv2
.
resize
(
img
,
(
w
,
h
),
cv2
.
INTER_LINEAR
)
elif
im_info
[
0
]
==
'padding'
:
w
,
h
=
im_info
[
1
][
0
],
im_info
[
1
][
0
]
img
=
img
[
0
:
h
,
0
:
w
,
:]
return
img
def
video_infer
(
args
):
resize_h
=
args
.
image_shape
[
1
]
resize_w
=
args
.
image_shape
[
0
]
model
=
pdx
.
load_model
(
args
.
model_dir
)
test_transforms
=
transforms
.
Compose
([
transforms
.
Normalize
()])
if
not
args
.
video_path
:
cap
=
cv2
.
VideoCapture
(
0
)
else
:
cap
=
cv2
.
VideoCapture
(
args
.
video_path
)
if
not
cap
.
isOpened
():
raise
IOError
(
"Error opening video stream or file, "
"--video_path whether existing: {}"
" or camera whether working"
.
format
(
args
.
video_path
))
return
width
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_WIDTH
))
height
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_HEIGHT
))
disflow
=
cv2
.
DISOpticalFlow_create
(
cv2
.
DISOPTICAL_FLOW_PRESET_ULTRAFAST
)
prev_gray
=
np
.
zeros
((
resize_h
,
resize_w
),
np
.
uint8
)
prev_cfd
=
np
.
zeros
((
resize_h
,
resize_w
),
np
.
float32
)
is_init
=
True
fps
=
cap
.
get
(
cv2
.
CAP_PROP_FPS
)
if
args
.
video_path
:
logging
.
info
(
"Please wait. It is computing......"
)
# 用于保存预测结果视频
if
not
osp
.
exists
(
args
.
save_dir
):
os
.
makedirs
(
args
.
save_dir
)
out
=
cv2
.
VideoWriter
(
osp
.
join
(
args
.
save_dir
,
'result.avi'
),
cv2
.
VideoWriter_fourcc
(
'M'
,
'J'
,
'P'
,
'G'
),
fps
,
(
width
,
height
))
# 开始获取视频帧
while
cap
.
isOpened
():
ret
,
frame
=
cap
.
read
()
if
ret
:
im_shape
=
frame
.
shape
im_scale_x
=
float
(
resize_w
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
resize_h
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
frame
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
image
=
im
.
astype
(
'float32'
)
im_info
=
(
'resize'
,
im_shape
[
0
:
2
])
pred
=
model
.
predict
(
image
,
test_transforms
)
score_map
=
pred
[
'score_map'
]
cur_gray
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2GRAY
)
score_map
=
255
*
score_map
[:,
:,
1
]
optflow_map
=
postprocess
(
cur_gray
,
score_map
,
prev_gray
,
prev_cfd
,
\
disflow
,
is_init
)
prev_gray
=
cur_gray
.
copy
()
prev_cfd
=
optflow_map
.
copy
()
is_init
=
False
optflow_map
=
cv2
.
GaussianBlur
(
optflow_map
,
(
3
,
3
),
0
)
optflow_map
=
threshold_mask
(
optflow_map
,
thresh_bg
=
0.2
,
thresh_fg
=
0.8
)
img_matting
=
np
.
repeat
(
optflow_map
[:,
:,
np
.
newaxis
],
3
,
axis
=
2
)
img_matting
=
recover
(
img_matting
,
im_info
)
bg_im
=
np
.
ones_like
(
img_matting
)
*
255
comb
=
(
img_matting
*
frame
+
(
1
-
img_matting
)
*
bg_im
).
astype
(
np
.
uint8
)
out
.
write
(
comb
)
else
:
break
cap
.
release
()
out
.
release
()
else
:
while
cap
.
isOpened
():
ret
,
frame
=
cap
.
read
()
if
ret
:
im_shape
=
frame
.
shape
im_scale_x
=
float
(
resize_w
)
/
float
(
im_shape
[
1
])
im_scale_y
=
float
(
resize_h
)
/
float
(
im_shape
[
0
])
im
=
cv2
.
resize
(
frame
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
image
=
im
.
astype
(
'float32'
)
im_info
=
(
'resize'
,
im_shape
[
0
:
2
])
pred
=
model
.
predict
(
image
,
test_transforms
)
score_map
=
pred
[
'score_map'
]
cur_gray
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2GRAY
)
cur_gray
=
cv2
.
resize
(
cur_gray
,
(
resize_w
,
resize_h
))
score_map
=
255
*
score_map
[:,
:,
1
]
optflow_map
=
postprocess
(
cur_gray
,
score_map
,
prev_gray
,
prev_cfd
,
\
disflow
,
is_init
)
prev_gray
=
cur_gray
.
copy
()
prev_cfd
=
optflow_map
.
copy
()
is_init
=
False
optflow_map
=
cv2
.
GaussianBlur
(
optflow_map
,
(
3
,
3
),
0
)
optflow_map
=
threshold_mask
(
optflow_map
,
thresh_bg
=
0.2
,
thresh_fg
=
0.8
)
img_matting
=
np
.
repeat
(
optflow_map
[:,
:,
np
.
newaxis
],
3
,
axis
=
2
)
img_matting
=
recover
(
img_matting
,
im_info
)
bg_im
=
np
.
ones_like
(
img_matting
)
*
255
comb
=
(
img_matting
*
frame
+
(
1
-
img_matting
)
*
bg_im
).
astype
(
np
.
uint8
)
cv2
.
imshow
(
'HumanSegmentation'
,
comb
)
if
cv2
.
waitKey
(
1
)
&
0xFF
==
ord
(
'q'
):
break
else
:
break
cap
.
release
()
if
__name__
==
"__main__"
:
args
=
parse_args
()
video_infer
(
args
)
paddlex/cv/datasets/seg_dataset.py
浏览文件 @
be77e22f
#
copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve
.
#
Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -28,7 +28,7 @@ class SegDataset(Dataset):
Args:
data_dir (str): 数据集所在的目录路径。
file_list (str): 描述数据集图片文件和对应标注文件的文件路径(文本内每行路径为相对data_dir的相对路)。
label_list (str): 描述数据集包含的类别信息文件路径。
label_list (str): 描述数据集包含的类别信息文件路径。
默认值为None。
transforms (list): 数据集中每个样本的预处理/增强算子。
num_workers (int): 数据集中样本在预处理过程中的线程或进程数。默认为4。
buffer_size (int): 数据集中样本在预处理过程中队列的缓存长度,以样本数为单位。默认为100。
...
...
@@ -40,7 +40,7 @@ class SegDataset(Dataset):
def
__init__
(
self
,
data_dir
,
file_list
,
label_list
,
label_list
=
None
,
transforms
=
None
,
num_workers
=
'auto'
,
buffer_size
=
100
,
...
...
@@ -56,10 +56,11 @@ class SegDataset(Dataset):
self
.
labels
=
list
()
self
.
_epoch
=
0
with
open
(
label_list
,
encoding
=
get_encoding
(
label_list
))
as
f
:
for
line
in
f
:
item
=
line
.
strip
()
self
.
labels
.
append
(
item
)
if
label_list
is
not
None
:
with
open
(
label_list
,
encoding
=
get_encoding
(
label_list
))
as
f
:
for
line
in
f
:
item
=
line
.
strip
()
self
.
labels
.
append
(
item
)
with
open
(
file_list
,
encoding
=
get_encoding
(
file_list
))
as
f
:
for
line
in
f
:
...
...
@@ -69,8 +70,8 @@ class SegDataset(Dataset):
full_path_im
=
osp
.
join
(
data_dir
,
items
[
0
])
full_path_label
=
osp
.
join
(
data_dir
,
items
[
1
])
if
not
osp
.
exists
(
full_path_im
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_im
))
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_im
))
if
not
osp
.
exists
(
full_path_label
):
raise
IOError
(
'The image file {} is not exist!'
.
format
(
full_path_label
))
...
...
paddlex/cv/models/hrnet.py
浏览文件 @
be77e22f
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -24,11 +24,12 @@ class HRNet(DeepLabv3p):
Args:
num_classes (int): 类别数。
width (int): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64]。
width (int|str): 高分辨率分支中特征层的通道数量。默认值为18。可选择取值为[18, 30, 32, 40, 44, 48, 60, 64, '18_small_v1']。
'18_small_v1'是18的轻量级版本。
use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。
use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。
当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。
class_weight (list
/
str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
class_weight (list
|
str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为
num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,
即平时使用的交叉熵损失函数。
...
...
@@ -168,6 +169,6 @@ class HRNet(DeepLabv3p):
return
super
(
HRNet
,
self
).
train
(
num_epochs
,
train_dataset
,
train_batch_size
,
eval_dataset
,
save_interval_epochs
,
log_interval_steps
,
save_dir
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
pretrain_weights
,
optimizer
,
learning_rate
,
lr_decay_power
,
use_vdl
,
sensitivities_file
,
eval_metric_loss
,
early_stop
,
early_stop_patience
,
resume_checkpoint
)
paddlex/cv/models/utils/pretrain_weights.py
浏览文件 @
be77e22f
...
...
@@ -81,7 +81,7 @@ coco_pretrain = {
'YOLOv3_MobileNetV1_COCO'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v1.tar'
,
'YOLOv3_MobileNetV3_large_COCO'
:
'https://
paddlemodels.bj.bcebos.com/object_detection/yolov3_mobilenet_v3.pdparams
'
,
'https://
bj.bcebos.com/paddlex/models/yolov3_mobilenet_v3.tar
'
,
'YOLOv3_ResNet34_COCO'
:
'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_r34.tar'
,
'YOLOv3_ResNet50_vd_COCO'
:
...
...
paddlex/cv/nets/hrnet.py
浏览文件 @
be77e22f
...
...
@@ -51,15 +51,38 @@ class HRNet(object):
self
.
width
=
width
self
.
has_se
=
has_se
self
.
num_modules
=
{
'18_small_v1'
:
[
1
,
1
,
1
,
1
],
'18'
:
[
1
,
1
,
4
,
3
],
'30'
:
[
1
,
1
,
4
,
3
],
'32'
:
[
1
,
1
,
4
,
3
],
'40'
:
[
1
,
1
,
4
,
3
],
'44'
:
[
1
,
1
,
4
,
3
],
'48'
:
[
1
,
1
,
4
,
3
],
'60'
:
[
1
,
1
,
4
,
3
],
'64'
:
[
1
,
1
,
4
,
3
]
}
self
.
num_blocks
=
{
'18_small_v1'
:
[[
1
],
[
2
,
2
],
[
2
,
2
,
2
],
[
2
,
2
,
2
,
2
]],
'18'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'30'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'32'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'40'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'44'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'48'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'60'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]],
'64'
:
[[
4
],
[
4
,
4
],
[
4
,
4
,
4
],
[
4
,
4
,
4
,
4
]]
}
self
.
channels
=
{
18
:
[[
18
,
36
],
[
18
,
36
,
72
],
[
18
,
36
,
72
,
144
]],
30
:
[[
30
,
60
],
[
30
,
60
,
120
],
[
30
,
60
,
120
,
240
]],
32
:
[[
32
,
64
],
[
32
,
64
,
128
],
[
32
,
64
,
128
,
256
]],
40
:
[[
40
,
80
],
[
40
,
80
,
160
],
[
40
,
80
,
160
,
320
]],
44
:
[[
44
,
88
],
[
44
,
88
,
176
],
[
44
,
88
,
176
,
352
]],
48
:
[[
48
,
96
],
[
48
,
96
,
192
],
[
48
,
96
,
192
,
384
]],
60
:
[[
60
,
120
],
[
60
,
120
,
240
],
[
60
,
120
,
240
,
480
]],
64
:
[[
64
,
128
],
[
64
,
128
,
256
],
[
64
,
128
,
256
,
512
]],
'18_small_v1'
:
[[
32
],
[
16
,
32
],
[
16
,
32
,
64
],
[
16
,
32
,
64
,
128
]],
'18'
:
[[
64
],
[
18
,
36
],
[
18
,
36
,
72
],
[
18
,
36
,
72
,
144
]],
'30'
:
[[
64
],
[
30
,
60
],
[
30
,
60
,
120
],
[
30
,
60
,
120
,
240
]],
'32'
:
[[
64
],
[
32
,
64
],
[
32
,
64
,
128
],
[
32
,
64
,
128
,
256
]],
'40'
:
[[
64
],
[
40
,
80
],
[
40
,
80
,
160
],
[
40
,
80
,
160
,
320
]],
'44'
:
[[
64
],
[
44
,
88
],
[
44
,
88
,
176
],
[
44
,
88
,
176
,
352
]],
'48'
:
[[
64
],
[
48
,
96
],
[
48
,
96
,
192
],
[
48
,
96
,
192
,
384
]],
'60'
:
[[
64
],
[
60
,
120
],
[
60
,
120
,
240
],
[
60
,
120
,
240
,
480
]],
'64'
:
[[
64
],
[
64
,
128
],
[
64
,
128
,
256
],
[
64
,
128
,
256
,
512
]],
}
self
.
freeze_at
=
freeze_at
...
...
@@ -73,31 +96,38 @@ class HRNet(object):
def
net
(
self
,
input
):
width
=
self
.
width
channels_2
,
channels_3
,
channels_4
=
self
.
channels
[
width
]
num_modules_2
,
num_modules_3
,
num_modules_4
=
1
,
4
,
3
channels_1
,
channels_2
,
channels_3
,
channels_4
=
self
.
channels
[
str
(
width
)]
num_modules_1
,
num_modules_2
,
num_modules_3
,
num_modules_4
=
self
.
num_modules
[
str
(
width
)]
num_blocks_1
,
num_blocks_2
,
num_blocks_3
,
num_blocks_4
=
self
.
num_blocks
[
str
(
width
)]
x
=
self
.
conv_bn_layer
(
input
=
input
,
filter_size
=
3
,
num_filters
=
64
,
num_filters
=
channels_1
[
0
]
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_1'
)
x
=
self
.
conv_bn_layer
(
input
=
x
,
filter_size
=
3
,
num_filters
=
64
,
num_filters
=
channels_1
[
0
]
,
stride
=
2
,
if_act
=
True
,
name
=
'layer1_2'
)
la1
=
self
.
layer1
(
x
,
name
=
'layer2'
)
la1
=
self
.
layer1
(
x
,
n
um_blocks_1
,
channels_1
,
n
ame
=
'layer2'
)
tr1
=
self
.
transition_layer
([
la1
],
[
256
],
channels_2
,
name
=
'tr1'
)
st2
=
self
.
stage
(
tr1
,
num_modules_2
,
channels_2
,
name
=
'st2'
)
st2
=
self
.
stage
(
tr1
,
num_modules_2
,
num_blocks_2
,
channels_2
,
name
=
'st2'
)
tr2
=
self
.
transition_layer
(
st2
,
channels_2
,
channels_3
,
name
=
'tr2'
)
st3
=
self
.
stage
(
tr2
,
num_modules_3
,
channels_3
,
name
=
'st3'
)
st3
=
self
.
stage
(
tr2
,
num_modules_3
,
num_blocks_3
,
channels_3
,
name
=
'st3'
)
tr3
=
self
.
transition_layer
(
st3
,
channels_3
,
channels_4
,
name
=
'tr3'
)
st4
=
self
.
stage
(
tr3
,
num_modules_4
,
channels_4
,
name
=
'st4'
)
st4
=
self
.
stage
(
tr3
,
num_modules_4
,
num_blocks_4
,
channels_4
,
name
=
'st4'
)
# classification
if
self
.
num_classes
:
...
...
@@ -139,12 +169,12 @@ class HRNet(object):
self
.
end_points
=
st4
return
st4
[
-
1
]
def
layer1
(
self
,
input
,
name
=
None
):
def
layer1
(
self
,
input
,
n
um_blocks
,
channels
,
n
ame
=
None
):
conv
=
input
for
i
in
range
(
4
):
for
i
in
range
(
num_blocks
[
0
]
):
conv
=
self
.
bottleneck_block
(
conv
,
num_filters
=
64
,
num_filters
=
channels
[
0
]
,
downsample
=
True
if
i
==
0
else
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
conv
...
...
@@ -178,7 +208,7 @@ class HRNet(object):
out
=
[]
for
i
in
range
(
len
(
channels
)):
residual
=
x
[
i
]
for
j
in
range
(
block_num
):
for
j
in
range
(
block_num
[
i
]
):
residual
=
self
.
basic_block
(
residual
,
channels
[
i
],
...
...
@@ -240,10 +270,11 @@ class HRNet(object):
def
high_resolution_module
(
self
,
x
,
num_blocks
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
residual
=
self
.
branches
(
x
,
4
,
channels
,
name
=
name
)
residual
=
self
.
branches
(
x
,
num_blocks
,
channels
,
name
=
name
)
out
=
self
.
fuse_layers
(
residual
,
channels
,
...
...
@@ -254,6 +285,7 @@ class HRNet(object):
def
stage
(
self
,
x
,
num_modules
,
num_blocks
,
channels
,
multi_scale_output
=
True
,
name
=
None
):
...
...
@@ -262,12 +294,13 @@ class HRNet(object):
if
i
==
num_modules
-
1
and
multi_scale_output
==
False
:
out
=
self
.
high_resolution_module
(
out
,
num_blocks
,
channels
,
multi_scale_output
=
False
,
name
=
name
+
'_'
+
str
(
i
+
1
))
else
:
out
=
self
.
high_resolution_module
(
out
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
out
,
num_blocks
,
channels
,
name
=
name
+
'_'
+
str
(
i
+
1
))
return
out
...
...
paddlex/cv/nets/segmentation/hrnet.py
浏览文件 @
be77e22f
...
...
@@ -82,7 +82,8 @@ class HRNet(object):
st4
[
3
]
=
fluid
.
layers
.
resize_bilinear
(
st4
[
3
],
out_shape
=
shape
)
out
=
fluid
.
layers
.
concat
(
st4
,
axis
=
1
)
last_channels
=
sum
(
self
.
backbone
.
channels
[
self
.
backbone
.
width
][
-
1
])
last_channels
=
sum
(
self
.
backbone
.
channels
[
str
(
self
.
backbone
.
width
)][
-
1
])
out
=
self
.
_conv_bn_layer
(
input
=
out
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录