Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
6e54823f
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
285
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6e54823f
编写于
5月 13, 2020
作者:
W
wuyefeilin
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #243 from wuyefeilin/humanseg
add image_shape argparse
上级
1d009734
4968e6c2
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
66 addition
and
15 deletion
+66
-15
contrib/HumanSeg/README.md
contrib/HumanSeg/README.md
+15
-5
contrib/HumanSeg/infer.py
contrib/HumanSeg/infer.py
+8
-1
contrib/HumanSeg/quant_offline.py
contrib/HumanSeg/quant_offline.py
+8
-1
contrib/HumanSeg/quant_online.py
contrib/HumanSeg/quant_online.py
+9
-2
contrib/HumanSeg/train.py
contrib/HumanSeg/train.py
+9
-2
contrib/HumanSeg/val.py
contrib/HumanSeg/val.py
+8
-1
contrib/HumanSeg/video_infer.py
contrib/HumanSeg/video_infer.py
+9
-3
未找到文件。
contrib/HumanSeg/README.md
浏览文件 @
6e54823f
...
@@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \
...
@@ -65,7 +65,8 @@ python train.py --model_type HumanSegMobile \
--pretrained_weights
pretrained_weights/humanseg_mobile
\
--pretrained_weights
pretrained_weights/humanseg_mobile
\
--batch_size
8
\
--batch_size
8
\
--learning_rate
0.001
\
--learning_rate
0.001
\
--num_epochs
10
--num_epochs
10
\
--image_shape
192 192
```
```
其中参数含义如下:
其中参数含义如下:
*
`--model_type`
: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
*
`--model_type`
: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
...
@@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \
...
@@ -77,6 +78,7 @@ python train.py --model_type HumanSegMobile \
*
`--batch_size`
: 批大小
*
`--batch_size`
: 批大小
*
`--learning_rate`
: 初始学习率
*
`--learning_rate`
: 初始学习率
*
`--num_epochs`
: 训练轮数
*
`--num_epochs`
: 训练轮数
*
`--image_shape`
: 网络输入图像大小(w, h)
更多命令行帮助可运行下述命令进行查看:
更多命令行帮助可运行下述命令进行查看:
```
bash
```
bash
...
@@ -90,24 +92,28 @@ python train.py --help
...
@@ -90,24 +92,28 @@ python train.py --help
```
bash
```
bash
python val.py
--model_dir
output/best_model
\
python val.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--data_dir
data/mini_supervisely
\
--val_list
data/mini_supervisely/val.txt
--val_list
data/mini_supervisely/val.txt
\
--image_shape
192 192
```
```
其中参数含义如下:
其中参数含义如下:
*
`--model_dir`
: 模型路径
*
`--model_dir`
: 模型路径
*
`--data_dir`
: 数据集路径
*
`--data_dir`
: 数据集路径
*
`--val_list`
: 验证集列表路径
*
`--val_list`
: 验证集列表路径
*
`--image_shape`
: 网络输入图像大小(w, h)
## 预测
## 预测
使用下述命令进行预测
使用下述命令进行预测
```
bash
```
bash
python infer.py
--model_dir
output/best_model
\
python infer.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--data_dir
data/mini_supervisely
\
--test_list
data/mini_supervisely/test.txt
--test_list
data/mini_supervisely/test.txt
\
--image_shape
192 192
```
```
其中参数含义如下:
其中参数含义如下:
*
`--model_dir`
: 模型路径
*
`--model_dir`
: 模型路径
*
`--data_dir`
: 数据集路径
*
`--data_dir`
: 数据集路径
*
`--test_list`
: 测试集列表路径
*
`--test_list`
: 测试集列表路径
*
`--image_shape`
: 网络输入图像大小(w, h)
## 模型导出
## 模型导出
```
bash
```
bash
...
@@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \
...
@@ -124,13 +130,15 @@ python export.py --model_dir output/best_model \
python quant_offline.py
--model_dir
output/best_model
\
python quant_offline.py
--model_dir
output/best_model
\
--data_dir
data/mini_supervisely
\
--data_dir
data/mini_supervisely
\
--quant_list
data/mini_supervisely/val.txt
\
--quant_list
data/mini_supervisely/val.txt
\
--save_dir
output/quant_offline
--save_dir
output/quant_offline
\
--image_shape
192 192
```
```
其中参数含义如下:
其中参数含义如下:
*
`--model_dir`
: 待量化模型路径
*
`--model_dir`
: 待量化模型路径
*
`--data_dir`
: 数据集路径
*
`--data_dir`
: 数据集路径
*
`--quant_list`
: 量化数据集列表路径,一般直接选择训练集或验证集
*
`--quant_list`
: 量化数据集列表路径,一般直接选择训练集或验证集
*
`--save_dir`
: 量化模型保存路径
*
`--save_dir`
: 量化模型保存路径
*
`--image_shape`
: 网络输入图像大小(w, h)
## 在线量化
## 在线量化
利用float训练模型进行在线量化。
利用float训练模型进行在线量化。
...
@@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \
...
@@ -143,7 +151,8 @@ python quant_online.py --model_type HumanSegMobile \
--pretrained_weights
output/best_model
\
--pretrained_weights
output/best_model
\
--batch_size
2
\
--batch_size
2
\
--learning_rate
0.001
\
--learning_rate
0.001
\
--num_epochs
2
--num_epochs
2
\
--image_shape
192 192
```
```
其中参数含义如下:
其中参数含义如下:
*
`--model_type`
: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
*
`--model_type`
: 模型类型,可选项为:HumanSegServer、HumanSegMobile和HumanSegLite
...
@@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \
...
@@ -155,3 +164,4 @@ python quant_online.py --model_type HumanSegMobile \
*
`--batch_size`
: 批大小
*
`--batch_size`
: 批大小
*
`--learning_rate`
: 初始学习率
*
`--learning_rate`
: 初始学习率
*
`--num_epochs`
: 训练轮数
*
`--num_epochs`
: 训练轮数
*
`--image_shape`
: 网络输入图像大小(w, h)
contrib/HumanSeg/infer.py
浏览文件 @
6e54823f
...
@@ -34,6 +34,13 @@ def parse_args():
...
@@ -34,6 +34,13 @@ def parse_args():
help
=
'The directory for saving the inference results'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
type
=
str
,
default
=
'./output/result'
)
default
=
'./output/result'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -45,7 +52,7 @@ def mkdir(path):
...
@@ -45,7 +52,7 @@ def mkdir(path):
def
infer
(
args
):
def
infer
(
args
):
test_transforms
=
transforms
.
Compose
(
test_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
192
,
192
)
),
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
transforms
.
Normalize
()])
model
=
models
.
load_model
(
args
.
model_dir
)
model
=
models
.
load_model
(
args
.
model_dir
)
added_saveed_path
=
osp
.
join
(
args
.
save_dir
,
'added'
)
added_saveed_path
=
osp
.
join
(
args
.
save_dir
,
'added'
)
...
...
contrib/HumanSeg/quant_offline.py
浏览文件 @
6e54823f
...
@@ -42,12 +42,19 @@ def parse_args():
...
@@ -42,12 +42,19 @@ def parse_args():
help
=
'The directory for saving the quant model'
,
help
=
'The directory for saving the quant model'
,
type
=
str
,
type
=
str
,
default
=
'./output/quant_offline'
)
default
=
'./output/quant_offline'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
evaluate
(
args
):
def
evaluate
(
args
):
eval_transforms
=
transforms
.
Compose
(
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
192
,
192
)
),
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
transforms
.
Normalize
()])
eval_dataset
=
Dataset
(
eval_dataset
=
Dataset
(
...
...
contrib/HumanSeg/quant_online.py
浏览文件 @
6e54823f
...
@@ -73,6 +73,13 @@ def parse_args():
...
@@ -73,6 +73,13 @@ def parse_args():
help
=
'The interval epochs for save a model snapshot'
,
help
=
'The interval epochs for save a model snapshot'
,
type
=
int
,
type
=
int
,
default
=
1
)
default
=
1
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -80,12 +87,12 @@ def parse_args():
...
@@ -80,12 +87,12 @@ def parse_args():
def
train
(
args
):
def
train
(
args
):
train_transforms
=
transforms
.
Compose
([
train_transforms
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Resize
(
(
192
,
192
)
),
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()
transforms
.
Normalize
()
])
])
eval_transforms
=
transforms
.
Compose
(
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
192
,
192
)
),
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
transforms
.
Normalize
()])
train_dataset
=
Dataset
(
train_dataset
=
Dataset
(
...
...
contrib/HumanSeg/train.py
浏览文件 @
6e54823f
...
@@ -43,6 +43,13 @@ def parse_args():
...
@@ -43,6 +43,13 @@ def parse_args():
help
=
'Number of classes'
,
help
=
'Number of classes'
,
type
=
int
,
type
=
int
,
default
=
2
)
default
=
2
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
parser
.
add_argument
(
parser
.
add_argument
(
'--num_epochs'
,
'--num_epochs'
,
dest
=
'num_epochs'
,
dest
=
'num_epochs'
,
...
@@ -91,13 +98,13 @@ def parse_args():
...
@@ -91,13 +98,13 @@ def parse_args():
def
train
(
args
):
def
train
(
args
):
train_transforms
=
transforms
.
Compose
([
train_transforms
=
transforms
.
Compose
([
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
Resize
((
192
,
192
)),
transforms
.
Normalize
()
transforms
.
Normalize
()
])
])
eval_transforms
=
transforms
.
Compose
(
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
192
,
192
)
),
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
transforms
.
Normalize
()])
train_dataset
=
Dataset
(
train_dataset
=
Dataset
(
...
...
contrib/HumanSeg/val.py
浏览文件 @
6e54823f
...
@@ -29,12 +29,19 @@ def parse_args():
...
@@ -29,12 +29,19 @@ def parse_args():
help
=
'Mini batch size'
,
help
=
'Mini batch size'
,
type
=
int
,
type
=
int
,
default
=
128
)
default
=
128
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
evaluate
(
args
):
def
evaluate
(
args
):
eval_transforms
=
transforms
.
Compose
(
eval_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
(
(
192
,
192
)
),
[
transforms
.
Resize
(
args
.
image_shape
),
transforms
.
Normalize
()])
transforms
.
Normalize
()])
eval_dataset
=
Dataset
(
eval_dataset
=
Dataset
(
...
...
contrib/HumanSeg/video_infer.py
浏览文件 @
6e54823f
...
@@ -29,6 +29,13 @@ def parse_args():
...
@@ -29,6 +29,13 @@ def parse_args():
help
=
'The directory for saving the inference results'
,
help
=
'The directory for saving the inference results'
,
type
=
str
,
type
=
str
,
default
=
'./output'
)
default
=
'./output'
)
parser
.
add_argument
(
"--image_shape"
,
dest
=
"image_shape"
,
help
=
"The image shape for net inputs."
,
nargs
=
2
,
default
=
[
192
,
192
],
type
=
int
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -60,9 +67,8 @@ def recover(img, im_info):
...
@@ -60,9 +67,8 @@ def recover(img, im_info):
def
video_infer
(
args
):
def
video_infer
(
args
):
resize_h
=
args
.
image_shape
[
1
]
resize_h
=
192
resize_w
=
args
.
image_shape
[
0
]
resize_w
=
192
test_transforms
=
transforms
.
Compose
(
test_transforms
=
transforms
.
Compose
(
[
transforms
.
Resize
((
resize_w
,
resize_h
)),
[
transforms
.
Resize
((
resize_w
,
resize_h
)),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录