Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSeg
提交
d3ef4d2e
P
PaddleSeg
项目概览
PaddlePaddle
/
PaddleSeg
通知
288
Star
8
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSeg
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d3ef4d2e
编写于
5月 25, 2020
作者:
C
chenguowei01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix quant online program
上级
1a8bd484
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
72 addition
and
59 deletion
+72
-59
contrib/HumanSeg/models/humanseg.py
contrib/HumanSeg/models/humanseg.py
+62
-54
contrib/HumanSeg/models/load_model.py
contrib/HumanSeg/models/load_model.py
+10
-5
未找到文件。
contrib/HumanSeg/models/humanseg.py
浏览文件 @
d3ef4d2e
...
@@ -242,30 +242,11 @@ class SegModel(object):
...
@@ -242,30 +242,11 @@ class SegModel(object):
if
self
.
status
==
'Normal'
:
if
self
.
status
==
'Normal'
:
fluid
.
save
(
self
.
train_prog
,
osp
.
join
(
save_dir
,
'model'
))
fluid
.
save
(
self
.
train_prog
,
osp
.
join
(
save_dir
,
'model'
))
model_info
[
'status'
]
=
'Normal'
elif
self
.
status
==
'Quant'
:
elif
self
.
status
==
'Quant'
:
float_prog
,
_
=
slim
.
quant
.
convert
(
fluid
.
save
(
self
.
test_prog
,
osp
.
join
(
save_dir
,
'model'
))
self
.
test_prog
,
self
.
exe
.
place
,
save_int8
=
True
)
model_info
[
'status'
]
=
'QuantOnline'
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
fluid
.
io
.
save_inference_model
(
dirname
=
save_dir
,
executor
=
self
.
exe
,
params_filename
=
'__params__'
,
feeded_var_names
=
test_input_names
,
target_vars
=
test_outputs
,
main_program
=
float_prog
)
model_info
[
'_ModelInputsOutputs'
]
=
dict
()
model_info
[
'_ModelInputsOutputs'
][
'test_inputs'
]
=
[
[
k
,
v
.
name
]
for
k
,
v
in
self
.
test_inputs
.
items
()
]
model_info
[
'_ModelInputsOutputs'
][
'test_outputs'
]
=
[
[
k
,
v
.
name
]
for
k
,
v
in
self
.
test_outputs
.
items
()
]
model_info
[
'status'
]
=
self
.
status
with
open
(
with
open
(
osp
.
join
(
save_dir
,
'model.yml'
),
encoding
=
'utf-8'
,
osp
.
join
(
save_dir
,
'model.yml'
),
encoding
=
'utf-8'
,
mode
=
'w'
)
as
f
:
mode
=
'w'
)
as
f
:
...
@@ -307,40 +288,57 @@ class SegModel(object):
...
@@ -307,40 +288,57 @@ class SegModel(object):
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
export_quant_model
(
self
,
def
export_quant_model
(
self
,
dataset
,
dataset
=
None
,
save_dir
,
save_dir
=
None
,
batch_size
=
1
,
batch_size
=
1
,
batch_nums
=
10
,
batch_nums
=
10
,
cache_dir
=
"./.temp"
):
cache_dir
=
"./.temp"
,
self
.
arrange_transform
(
transforms
=
dataset
.
transforms
,
mode
=
'quant'
)
quant_type
=
"offline"
):
dataset
.
num_samples
=
batch_size
*
batch_nums
if
quant_type
==
"offline"
:
try
:
self
.
arrange_transform
(
transforms
=
dataset
.
transforms
,
mode
=
'quant'
)
from
utils
import
HumanSegPostTrainingQuantization
dataset
.
num_samples
=
batch_size
*
batch_nums
except
:
try
:
raise
Exception
(
from
utils
import
HumanSegPostTrainingQuantization
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
except
:
)
raise
Exception
(
is_use_cache_file
=
True
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0"
if
cache_dir
is
None
:
)
is_use_cache_file
=
False
is_use_cache_file
=
True
post_training_quantization
=
HumanSegPostTrainingQuantization
(
if
cache_dir
is
None
:
executor
=
self
.
exe
,
is_use_cache_file
=
False
dataset
=
dataset
,
post_training_quantization
=
HumanSegPostTrainingQuantization
(
program
=
self
.
test_prog
,
executor
=
self
.
exe
,
inputs
=
self
.
test_inputs
,
dataset
=
dataset
,
outputs
=
self
.
test_outputs
,
program
=
self
.
test_prog
,
batch_size
=
batch_size
,
inputs
=
self
.
test_inputs
,
batch_nums
=
batch_nums
,
outputs
=
self
.
test_outputs
,
scope
=
None
,
batch_size
=
batch_size
,
algo
=
'KL'
,
batch_nums
=
batch_nums
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
scope
=
None
,
is_full_quantize
=
False
,
algo
=
'KL'
,
is_use_cache_file
=
is_use_cache_file
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
cache_dir
=
cache_dir
)
is_full_quantize
=
False
,
post_training_quantization
.
quantize
()
is_use_cache_file
=
is_use_cache_file
,
post_training_quantization
.
save_quantized_model
(
save_dir
)
cache_dir
=
cache_dir
)
if
cache_dir
is
not
None
:
post_training_quantization
.
quantize
()
os
.
system
(
'rm -r'
+
cache_dir
)
post_training_quantization
.
save_quantized_model
(
save_dir
)
if
cache_dir
is
not
None
:
os
.
system
(
'rm -r '
+
cache_dir
)
else
:
float_prog
,
_
=
slim
.
quant
.
convert
(
self
.
test_prog
,
self
.
exe
.
place
,
save_int8
=
True
)
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
fluid
.
io
.
save_inference_model
(
dirname
=
save_dir
,
executor
=
self
.
exe
,
params_filename
=
'__params__'
,
feeded_var_names
=
test_input_names
,
target_vars
=
test_outputs
,
main_program
=
float_prog
)
model_info
=
self
.
get_model_info
()
model_info
=
self
.
get_model_info
()
model_info
[
'status'
]
=
'Quant'
model_info
[
'status'
]
=
'Quant'
...
@@ -592,6 +590,16 @@ class SegModel(object):
...
@@ -592,6 +590,16 @@ class SegModel(object):
'Current evaluated best model in eval_dataset is epoch_{}, miou={}'
'Current evaluated best model in eval_dataset is epoch_{}, miou={}'
.
format
(
best_model_epoch
,
best_miou
))
.
format
(
best_model_epoch
,
best_miou
))
if
quant
:
if
osp
.
exists
(
osp
.
join
(
save_dir
,
"best_model"
)):
fluid
.
load
(
program
=
self
.
test_prog
,
model_path
=
osp
.
join
(
save_dir
,
"best_model"
),
executor
=
self
.
exe
)
self
.
export_quant_model
(
save_dir
=
osp
.
join
(
save_dir
,
"best_model_export"
),
quant_type
=
"online"
)
def
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
):
def
evaluate
(
self
,
eval_dataset
,
batch_size
=
1
,
epoch_id
=
None
):
"""评估。
"""评估。
...
...
contrib/HumanSeg/models/load_model.py
浏览文件 @
d3ef4d2e
...
@@ -33,7 +33,7 @@ def load_model(model_dir):
...
@@ -33,7 +33,7 @@ def load_model(model_dir):
raise
Exception
(
"There's no attribute {} in models"
.
format
(
raise
Exception
(
"There's no attribute {} in models"
.
format
(
info
[
'Model'
]))
info
[
'Model'
]))
model
=
getattr
(
models
,
info
[
'Model'
])(
**
info
[
'_init_params'
])
model
=
getattr
(
models
,
info
[
'Model'
])(
**
info
[
'_init_params'
])
if
status
==
"Normal"
:
if
status
in
[
"Normal"
,
"QuantOnline"
]
:
startup_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
...
@@ -41,11 +41,16 @@ def load_model(model_dir):
...
@@ -41,11 +41,16 @@ def load_model(model_dir):
model
.
test_inputs
,
model
.
test_outputs
=
model
.
build_net
(
model
.
test_inputs
,
model
.
test_outputs
=
model
.
build_net
(
mode
=
'test'
)
mode
=
'test'
)
model
.
test_prog
=
model
.
test_prog
.
clone
(
for_test
=
True
)
model
.
test_prog
=
model
.
test_prog
.
clone
(
for_test
=
True
)
if
status
==
"QuantOnline"
:
print
(
'test quant online'
)
import
paddleslim
as
slim
model
.
test_prog
=
slim
.
quant
.
quant_aware
(
model
.
test_prog
,
model
.
exe
.
place
,
for_test
=
True
)
model
.
exe
.
run
(
startup_prog
)
model
.
exe
.
run
(
startup_prog
)
import
pickle
fluid
.
load
(
model
.
test_prog
,
osp
.
join
(
model_dir
,
'model'
))
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
if
status
==
"QuantOnline"
:
load_dict
=
pickle
.
load
(
f
)
model
.
test_prog
=
slim
.
quant
.
convert
(
model
.
test_prog
,
fluid
.
io
.
set_program_state
(
model
.
test_prog
,
load_dict
)
model
.
exe
.
place
)
elif
status
in
[
'Infer'
,
'Quant'
]:
elif
status
in
[
'Infer'
,
'Quant'
]:
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录