Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleX
提交
3cb494f3
P
PaddleX
项目概览
PaddlePaddle
/
PaddleX
通知
138
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
43
列表
看板
标记
里程碑
合并请求
5
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleX
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
43
Issue
43
列表
看板
标记
里程碑
合并请求
5
合并请求
5
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cb494f3
编写于
7月 06, 2020
作者:
J
jiangjiajun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add scope for model
上级
1c1b6052
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
182 addition
and
197 deletion
+182
-197
docs/apis/deploy.md
docs/apis/deploy.md
+0
-38
docs/apis/index.rst
docs/apis/index.rst
+0
-1
paddlex/cv/models/base.py
paddlex/cv/models/base.py
+14
-14
paddlex/cv/models/classifier.py
paddlex/cv/models/classifier.py
+17
-13
paddlex/cv/models/deeplabv3p.py
paddlex/cv/models/deeplabv3p.py
+21
-19
paddlex/cv/models/faster_rcnn.py
paddlex/cv/models/faster_rcnn.py
+19
-16
paddlex/cv/models/load_model.py
paddlex/cv/models/load_model.py
+36
-32
paddlex/cv/models/mask_rcnn.py
paddlex/cv/models/mask_rcnn.py
+19
-16
paddlex/cv/models/slim/post_quantization.py
paddlex/cv/models/slim/post_quantization.py
+40
-35
paddlex/cv/models/yolo_v3.py
paddlex/cv/models/yolo_v3.py
+16
-13
未找到文件。
docs/apis/deploy.md
已删除
100755 → 0
浏览文件 @
1c1b6052
# 预测部署-paddlex.deploy
使用Paddle Inference进行高性能的Python预测部署。更多关于Paddle Inference信息请参考
[
Paddle Inference文档
](
https://paddle-inference.readthedocs.io/en/latest/#
)
## Predictor类
```
paddlex.deploy.Predictor(model_dir, use_gpu=False, gpu_id=0, use_mkl=False, use_trt=False, use_glog=False, memory_optimize=True)
```
> **参数**
> > * **model_dir**: 训练过程中保存的模型路径, 注意需要使用导出的inference模型
> > * **use_gpu**: 是否使用GPU进行预测
> > * **gpu_id**: 使用的GPU序列号
> > * **use_mkl**: 是否使用mkldnn加速库
> > * **use_trt**: 是否使用TensorRT预测引擎
> > * **use_glog**: 是否打印中间日志
> > * **memory_optimize**: 是否优化内存使用
> > ### 示例
> >
> > ```
> > import paddlex
> >
> > model = paddlex.deploy.Predictor(model_dir, use_gpu=True)
> > result = model.predict(image_file)
> > ```
### predict 接口
> ```
> predict(image, topk=1)
> ```
> **参数
*
**image(str|np.ndarray)**
: 待预测的图片路径或np.ndarray,若为后者需注意为BGR格式
*
**topk(int)**
: 图像分类时使用的参数,表示预测前topk个可能的分类
docs/apis/index.rst
浏览文件 @
3cb494f3
...
...
@@ -10,4 +10,3 @@ API接口说明
slim.md
load_model.md
visualize.md
deploy.md
paddlex/cv/models/base.py
浏览文件 @
3cb494f3
...
...
@@ -73,6 +73,7 @@ class BaseAPI:
self
.
status
=
'Normal'
# 已完成迭代轮数,为恢复训练时的起始轮数
self
.
completed_epochs
=
0
self
.
scope
=
fluid
.
global_scope
()
def
_get_single_card_bs
(
self
,
batch_size
):
if
batch_size
%
len
(
self
.
places
)
==
0
:
...
...
@@ -84,6 +85,10 @@ class BaseAPI:
'place'
]))
def
build_program
(
self
):
if
hasattr
(
paddlex
,
'model_built'
)
and
paddlex
.
model_built
:
logging
.
error
(
"Function model.train() only can be called once in your code."
)
paddlex
.
model_built
=
True
# 构建训练网络
self
.
train_inputs
,
self
.
train_outputs
=
self
.
build_net
(
mode
=
'train'
)
self
.
train_prog
=
fluid
.
default_main_program
()
...
...
@@ -155,7 +160,7 @@ class BaseAPI:
outputs
=
self
.
test_outputs
,
batch_size
=
batch_size
,
batch_nums
=
batch_num
,
scope
=
Non
e
,
scope
=
self
.
scop
e
,
algo
=
'KL'
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
is_full_quantize
=
False
,
...
...
@@ -244,8 +249,8 @@ class BaseAPI:
logging
.
info
(
"Load pretrain weights from {}."
.
format
(
pretrain_weights
),
use_color
=
True
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
paddlex
.
utils
.
utils
.
load_pretrain_weights
(
self
.
exe
,
self
.
train_prog
,
pretrain_weights
,
fuse_bn
)
# 进行裁剪
if
sensitivities_file
is
not
None
:
import
paddleslim
...
...
@@ -349,9 +354,7 @@ class BaseAPI:
logging
.
info
(
"Model saved in {}."
.
format
(
save_dir
))
def
export_inference_model
(
self
,
save_dir
):
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())
]
test_input_names
=
[
var
.
name
for
var
in
list
(
self
.
test_inputs
.
values
())]
test_outputs
=
list
(
self
.
test_outputs
.
values
())
if
self
.
__class__
.
__name__
==
'MaskRCNN'
:
from
paddlex.utils.save
import
save_mask_inference_model
...
...
@@ -388,8 +391,7 @@ class BaseAPI:
# 模型保存成功的标志
open
(
osp
.
join
(
save_dir
,
'.success'
),
'w'
).
close
()
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
logging
.
info
(
"Model for inference deploy saved in {}."
.
format
(
save_dir
))
def
train_loop
(
self
,
num_epochs
,
...
...
@@ -513,13 +515,11 @@ class BaseAPI:
eta
=
((
num_epochs
-
i
)
*
total_num_steps
-
step
-
1
)
*
avg_step_time
if
time_eval_one_epoch
is
not
None
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
time_eval_one_epoch
else
:
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eval_eta
=
(
total_eval_times
-
i
//
save_interval_epochs
)
*
total_num_steps_eval
*
avg_step_time
eta_str
=
seconds_to_hms
(
eta
+
eval_eta
)
logging
.
info
(
...
...
paddlex/cv/models/classifier.py
浏览文件 @
3cb494f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -227,9 +227,10 @@ class BaseClassifier(BaseAPI):
true_labels
=
list
()
pred_scores
=
list
()
if
not
hasattr
(
self
,
'parallel_test_prog'
):
self
.
parallel_test_prog
=
fluid
.
CompiledProgram
(
self
.
test_prog
).
with_data_parallel
(
share_vars_from
=
self
.
parallel_train_prog
)
with
fluid
.
scope_guard
(
self
.
scope
):
self
.
parallel_test_prog
=
fluid
.
CompiledProgram
(
self
.
test_prog
).
with_data_parallel
(
share_vars_from
=
self
.
parallel_train_prog
)
batch_size_each_gpu
=
self
.
_get_single_card_bs
(
batch_size
)
logging
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
eval_dataset
.
num_samples
,
total_steps
))
...
...
@@ -242,9 +243,11 @@ class BaseClassifier(BaseAPI):
num_pad_samples
=
batch_size
-
num_samples
pad_images
=
np
.
tile
(
images
[
0
:
1
],
(
num_pad_samples
,
1
,
1
,
1
))
images
=
np
.
concatenate
([
images
,
pad_images
])
outputs
=
self
.
exe
.
run
(
self
.
parallel_test_prog
,
feed
=
{
'image'
:
images
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()))
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
parallel_test_prog
,
feed
=
{
'image'
:
images
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()))
outputs
=
[
outputs
[
0
][:
num_samples
]]
true_labels
.
extend
(
labels
)
pred_scores
.
extend
(
outputs
[
0
].
tolist
())
...
...
@@ -286,10 +289,11 @@ class BaseClassifier(BaseAPI):
self
.
arrange_transforms
(
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
im
=
self
.
test_transforms
(
img_file
)
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
use_program_cache
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
use_program_cache
=
True
)
pred_label
=
np
.
argsort
(
result
[
0
][
0
])[::
-
1
][:
true_topk
]
res
=
[{
'category_id'
:
l
,
...
...
paddlex/cv/models/deeplabv3p.py
浏览文件 @
3cb494f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -317,19 +317,18 @@ class DeepLabv3p(BaseAPI):
tuple (metrics, eval_details):当return_details为True时,增加返回dict (eval_details),
包含关键字:'confusion_matrix',表示评估的混淆矩阵。
"""
self
.
arrange_transforms
(
transforms
=
eval_dataset
.
transforms
,
mode
=
'eval'
)
self
.
arrange_transforms
(
transforms
=
eval_dataset
.
transforms
,
mode
=
'eval'
)
total_steps
=
math
.
ceil
(
eval_dataset
.
num_samples
*
1.0
/
batch_size
)
conf_mat
=
ConfusionMatrix
(
self
.
num_classes
,
streaming
=
True
)
data_generator
=
eval_dataset
.
generator
(
batch_size
=
batch_size
,
drop_last
=
False
)
if
not
hasattr
(
self
,
'parallel_test_prog'
):
self
.
parallel_test_prog
=
fluid
.
CompiledProgram
(
self
.
test_prog
).
with_data_parallel
(
s
hare_vars_from
=
self
.
parallel_train_prog
)
logging
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
eval_dataset
.
num_samples
,
total_steps
))
with
fluid
.
scope_guard
(
self
.
scope
):
self
.
parallel_test_prog
=
fluid
.
CompiledProgram
(
s
elf
.
test_prog
).
with_data_parallel
(
share_vars_from
=
self
.
parallel_train_prog
)
logging
.
info
(
"Start to evaluating(total_samples={}, total_steps={})..."
.
format
(
eval_dataset
.
num_samples
,
total_steps
))
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
data_generator
()),
total
=
total_steps
):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
...
...
@@ -350,10 +349,12 @@ class DeepLabv3p(BaseAPI):
pad_images
=
np
.
tile
(
images
[
0
:
1
],
(
num_pad_samples
,
1
,
1
,
1
))
images
=
np
.
concatenate
([
images
,
pad_images
])
feed_data
=
{
'image'
:
images
}
outputs
=
self
.
exe
.
run
(
self
.
parallel_test_prog
,
feed
=
feed_data
,
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
parallel_test_prog
,
feed
=
feed_data
,
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
True
)
pred
=
outputs
[
0
]
if
num_samples
<
batch_size
:
pred
=
pred
[
0
:
num_samples
]
...
...
@@ -399,10 +400,11 @@ class DeepLabv3p(BaseAPI):
transforms
=
self
.
test_transforms
,
mode
=
'test'
)
im
,
im_info
=
self
.
test_transforms
(
im_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
use_program_cache
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
result
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
use_program_cache
=
True
)
pred
=
result
[
0
]
pred
=
np
.
squeeze
(
pred
).
astype
(
'uint8'
)
logit
=
result
[
1
]
...
...
paddlex/cv/models/faster_rcnn.py
浏览文件 @
3cb494f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -325,10 +325,12 @@ class FasterRCNN(BaseAPI):
'im_info'
:
im_infos
,
'im_shape'
:
im_shapes
,
}
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
'bbox'
:
(
np
.
array
(
outputs
[
0
]),
outputs
[
0
].
recursive_sequence_lengths
())
...
...
@@ -388,15 +390,16 @@ class FasterRCNN(BaseAPI):
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im_resize_info
=
np
.
expand_dims
(
im_resize_info
,
axis
=
0
)
im_shape
=
np
.
expand_dims
(
im_shape
,
axis
=
0
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_info'
:
im_resize_info
,
'im_shape'
:
im_shape
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_info'
:
im_resize_info
,
'im_shape'
:
im_shape
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
outputs
)
...
...
paddlex/cv/models/load_model.py
浏览文件 @
3cb494f3
...
...
@@ -24,6 +24,7 @@ import paddlex.utils.logging as logging
def
load_model
(
model_dir
,
fixed_input_shape
=
None
):
model_scope
=
fluid
.
Scope
()
if
not
osp
.
exists
(
osp
.
join
(
model_dir
,
"model.yml"
)):
raise
Exception
(
"There's not model.yml in {}"
.
format
(
model_dir
))
with
open
(
osp
.
join
(
model_dir
,
"model.yml"
))
as
f
:
...
...
@@ -51,38 +52,40 @@ def load_model(model_dir, fixed_input_shape=None):
format
(
fixed_input_shape
))
model
.
fixed_input_shape
=
fixed_input_shape
if
status
==
"Normal"
or
\
status
==
"Prune"
or
status
==
"fluid.save"
:
startup_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
.
test_inputs
,
model
.
test_outputs
=
model
.
build_net
(
mode
=
'test'
)
model
.
test_prog
=
model
.
test_prog
.
clone
(
for_test
=
True
)
model
.
exe
.
run
(
startup_prog
)
if
status
==
"Prune"
:
from
.slim.prune
import
update_program
model
.
test_prog
=
update_program
(
model
.
test_prog
,
model_dir
,
model
.
places
[
0
])
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
load_dict
=
pickle
.
load
(
f
)
fluid
.
io
.
set_program_state
(
model
.
test_prog
,
load_dict
)
elif
status
==
"Infer"
or
\
status
==
"Quant"
or
status
==
"fluid.save_inference_model"
:
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
model_dir
,
model
.
exe
,
params_filename
=
'__params__'
)
model
.
test_prog
=
prog
test_outputs_info
=
info
[
'_ModelInputsOutputs'
][
'test_outputs'
]
model
.
test_inputs
=
OrderedDict
()
model
.
test_outputs
=
OrderedDict
()
for
name
in
input_names
:
model
.
test_inputs
[
name
]
=
model
.
test_prog
.
global_block
().
var
(
name
)
for
i
,
out
in
enumerate
(
outputs
):
var_desc
=
test_outputs_info
[
i
]
model
.
test_outputs
[
var_desc
[
0
]]
=
out
with
fluid
.
scope_guard
(
model_scope
):
if
status
==
"Normal"
or
\
status
==
"Prune"
or
status
==
"fluid.save"
:
startup_prog
=
fluid
.
Program
()
model
.
test_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
model
.
test_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
model
.
test_inputs
,
model
.
test_outputs
=
model
.
build_net
(
mode
=
'test'
)
model
.
test_prog
=
model
.
test_prog
.
clone
(
for_test
=
True
)
model
.
exe
.
run
(
startup_prog
)
if
status
==
"Prune"
:
from
.slim.prune
import
update_program
model
.
test_prog
=
update_program
(
model
.
test_prog
,
model_dir
,
model
.
places
[
0
])
import
pickle
with
open
(
osp
.
join
(
model_dir
,
'model.pdparams'
),
'rb'
)
as
f
:
load_dict
=
pickle
.
load
(
f
)
fluid
.
io
.
set_program_state
(
model
.
test_prog
,
load_dict
)
elif
status
==
"Infer"
or
\
status
==
"Quant"
or
status
==
"fluid.save_inference_model"
:
[
prog
,
input_names
,
outputs
]
=
fluid
.
io
.
load_inference_model
(
model_dir
,
model
.
exe
,
params_filename
=
'__params__'
)
model
.
test_prog
=
prog
test_outputs_info
=
info
[
'_ModelInputsOutputs'
][
'test_outputs'
]
model
.
test_inputs
=
OrderedDict
()
model
.
test_outputs
=
OrderedDict
()
for
name
in
input_names
:
model
.
test_inputs
[
name
]
=
model
.
test_prog
.
global_block
().
var
(
name
)
for
i
,
out
in
enumerate
(
outputs
):
var_desc
=
test_outputs_info
[
i
]
model
.
test_outputs
[
var_desc
[
0
]]
=
out
if
'Transforms'
in
info
:
transforms_mode
=
info
.
get
(
'TransformsMode'
,
'RGB'
)
# 固定模型的输入shape
...
...
@@ -107,6 +110,7 @@ def load_model(model_dir, fixed_input_shape=None):
model
.
__dict__
[
k
]
=
v
logging
.
info
(
"Model[{}] loaded."
.
format
(
info
[
'Model'
]))
model
.
scope
=
model_scope
model
.
trainable
=
False
model
.
status
=
status
return
model
...
...
paddlex/cv/models/mask_rcnn.py
浏览文件 @
3cb494f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -286,10 +286,12 @@ class MaskRCNN(FasterRCNN):
'im_info'
:
im_infos
,
'im_shape'
:
im_shapes
,
}
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
'bbox'
:
(
np
.
array
(
outputs
[
0
]),
outputs
[
0
].
recursive_sequence_lengths
()),
...
...
@@ -356,15 +358,16 @@ class MaskRCNN(FasterRCNN):
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im_resize_info
=
np
.
expand_dims
(
im_resize_info
,
axis
=
0
)
im_shape
=
np
.
expand_dims
(
im_shape
,
axis
=
0
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_info'
:
im_resize_info
,
'im_shape'
:
im_shape
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_info'
:
im_resize_info
,
'im_shape'
:
im_shape
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
outputs
)
...
...
paddlex/cv/models/slim/post_quantization.py
浏览文件 @
3cb494f3
...
...
@@ -85,13 +85,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self
.
_support_quantize_op_type
=
\
list
(
set
(
QuantizationTransformPass
.
_supported_quantizable_op_type
+
AddQuantDequantPass
.
_supported_quantizable_op_type
))
# Check inputs
assert
executor
is
not
None
,
"The executor cannot be None."
assert
batch_size
>
0
,
"The batch_size should be greater than 0."
assert
algo
in
self
.
_support_algo_type
,
\
"The algo should be KL, abs_max or min_max."
self
.
_executor
=
executor
self
.
_dataset
=
dataset
self
.
_batch_size
=
batch_size
...
...
@@ -154,20 +154,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
logging
.
info
(
"Start to run batch!"
)
for
data
in
self
.
_data_loader
():
start
=
time
.
time
()
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
)
with
fluid
.
scope_guard
(
self
.
_scope
):
self
.
_executor
.
run
(
program
=
self
.
_program
,
feed
=
data
,
fetch_list
=
self
.
_fetch_list
,
return_numpy
=
False
)
if
self
.
_algo
==
"KL"
:
self
.
_sample_data
(
batch_id
)
else
:
self
.
_sample_threshold
()
end
=
time
.
time
()
logging
.
debug
(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'
.
format
(
str
(
batch_id
+
1
),
str
(
batch_ct
),
str
(
end
-
start
)))
logging
.
debug
(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'
.
format
(
str
(
batch_id
+
1
),
str
(
batch_ct
),
str
(
end
-
start
)))
batch_id
+=
1
if
self
.
_batch_nums
and
batch_id
>=
self
.
_batch_nums
:
break
...
...
@@ -194,15 +193,16 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
Returns:
None
'''
feed_vars_names
=
[
var
.
name
for
var
in
self
.
_feed_list
]
fluid
.
io
.
save_inference_model
(
dirname
=
save_model_path
,
feeded_var_names
=
feed_vars_names
,
target_vars
=
self
.
_fetch_list
,
executor
=
self
.
_executor
,
params_filename
=
'__params__'
,
main_program
=
self
.
_program
)
with
fluid
.
scope_guard
(
self
.
_scope
):
feed_vars_names
=
[
var
.
name
for
var
in
self
.
_feed_list
]
fluid
.
io
.
save_inference_model
(
dirname
=
save_model_path
,
feeded_var_names
=
feed_vars_names
,
target_vars
=
self
.
_fetch_list
,
executor
=
self
.
_executor
,
params_filename
=
'__params__'
,
main_program
=
self
.
_program
)
def
_load_model_data
(
self
):
'''
Set data loader.
...
...
@@ -212,7 +212,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self
.
_data_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
feed_vars
,
capacity
=
3
*
self
.
_batch_size
,
iterable
=
True
)
self
.
_data_loader
.
set_sample_list_generator
(
self
.
_dataset
.
generator
(
self
.
_batch_size
,
drop_last
=
True
),
self
.
_dataset
.
generator
(
self
.
_batch_size
,
drop_last
=
True
),
places
=
self
.
_place
)
def
_calculate_kl_threshold
(
self
):
...
...
@@ -235,10 +236,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
weight_threshold
.
append
(
abs_max_value
)
self
.
_quantized_var_kl_threshold
[
var_name
]
=
weight_threshold
end
=
time
.
time
()
logging
.
debug
(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_weight_var_name
)),
str
(
end
-
start
)))
logging
.
debug
(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_weight_var_name
)),
str
(
end
-
start
)))
ct
+=
1
ct
=
1
...
...
@@ -257,10 +260,12 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
self
.
_get_kl_scaling_factor
(
np
.
abs
(
sampling_data
))
end
=
time
.
time
()
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
ct
+=
1
else
:
for
var_name
in
self
.
_quantized_act_var_name
:
...
...
@@ -270,10 +275,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self
.
_quantized_var_kl_threshold
[
var_name
]
=
\
self
.
_get_kl_scaling_factor
(
np
.
abs
(
self
.
_sampling_data
[
var_name
]))
end
=
time
.
time
()
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
logging
.
debug
(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.
format
(
str
(
ct
),
str
(
len
(
self
.
_quantized_act_var_name
)),
str
(
end
-
start
)))
ct
+=
1
\ No newline at end of file
paddlex/cv/models/yolo_v3.py
浏览文件 @
3cb494f3
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -313,10 +313,12 @@ class YOLOv3(BaseAPI):
images
=
np
.
array
([
d
[
0
]
for
d
in
data
])
im_sizes
=
np
.
array
([
d
[
1
]
for
d
in
data
])
feed_data
=
{
'image'
:
images
,
'im_size'
:
im_sizes
}
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
[
feed_data
],
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
)
res
=
{
'bbox'
:
(
np
.
array
(
outputs
[
0
]),
outputs
[
0
].
recursive_sequence_lengths
())
...
...
@@ -366,12 +368,13 @@ class YOLOv3(BaseAPI):
im
,
im_size
=
self
.
test_transforms
(
img_file
)
im
=
np
.
expand_dims
(
im
,
axis
=
0
)
im_size
=
np
.
expand_dims
(
im_size
,
axis
=
0
)
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
with
fluid
.
scope_guard
(
self
.
scope
):
outputs
=
self
.
exe
.
run
(
self
.
test_prog
,
feed
=
{
'image'
:
im
,
'im_size'
:
im_size
},
fetch_list
=
list
(
self
.
test_outputs
.
values
()),
return_numpy
=
False
,
use_program_cache
=
True
)
res
=
{
k
:
(
np
.
array
(
v
),
v
.
recursive_sequence_lengths
())
for
k
,
v
in
zip
(
list
(
self
.
test_outputs
.
keys
()),
outputs
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录