Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ee44bcdd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ee44bcdd
编写于
7月 13, 2020
作者:
Z
Zhen Wang
提交者:
GitHub
7月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add more unit tests for imperative qat. test=develop (#25486)
上级
fd961b0d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
166 addition
and
0 deletion
+166
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
+166
-0
未找到文件。
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
浏览文件 @
ee44bcdd
...
...
@@ -254,6 +254,172 @@ class TestImperativeQat(unittest.TestCase):
np
.
allclose
(
after_save
,
before_save
.
numpy
()),
msg
=
'Failed to save the inference quantized model.'
)
def
test_qat_acc
(
self
):
def
_build_static_lenet
(
main
,
startup
,
is_test
=
False
,
seed
=
1000
):
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
main
.
random_seed
=
seed
startup
.
random_seed
=
seed
img
=
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
label
=
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
)
prediction
=
StaticLenet
(
img
)
if
not
is_test
:
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
else
:
avg_loss
=
prediction
return
img
,
label
,
avg_loss
reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
32
,
drop_last
=
True
)
weight_quantize_type
=
'abs_max'
activation_quant_type
=
'moving_average_abs_max'
param_init_map
=
{}
seed
=
1000
lr
=
0.1
# imperative train
_logger
.
info
(
"--------------------------dynamic graph qat--------------------------"
)
imperative_qat
=
ImperativeQuantAware
(
weight_quantize_type
=
weight_quantize_type
,
activation_quantize_type
=
activation_quant_type
)
with
fluid
.
dygraph
.
guard
():
np
.
random
.
seed
(
seed
)
fluid
.
default_main_program
().
random_seed
=
seed
fluid
.
default_startup_program
().
random_seed
=
seed
lenet
=
ImperativeLenet
()
fixed_state
=
{}
for
name
,
param
in
lenet
.
named_parameters
():
p_shape
=
param
.
numpy
().
shape
p_value
=
param
.
numpy
()
if
name
.
endswith
(
"bias"
):
value
=
np
.
zeros_like
(
p_value
).
astype
(
'float32'
)
else
:
value
=
np
.
random
.
normal
(
loc
=
0.0
,
scale
=
0.01
,
size
=
np
.
product
(
p_shape
)).
reshape
(
p_shape
).
astype
(
'float32'
)
fixed_state
[
name
]
=
value
param_init_map
[
param
.
name
]
=
value
lenet
.
set_dict
(
fixed_state
)
imperative_qat
.
quantize
(
lenet
)
adam
=
AdamOptimizer
(
learning_rate
=
lr
,
parameter_list
=
lenet
.
parameters
())
dynamic_loss_rec
=
[]
lenet
.
train
()
for
batch_id
,
data
in
enumerate
(
reader
()):
x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
fluid
.
dygraph
.
to_variable
(
x_data
)
label
=
fluid
.
dygraph
.
to_variable
(
y_data
)
out
=
lenet
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
lenet
.
clear_gradients
()
dynamic_loss_rec
.
append
(
avg_loss
.
numpy
()[
0
])
if
batch_id
%
100
==
0
:
_logger
.
info
(
'{}: {}'
.
format
(
'loss'
,
avg_loss
.
numpy
()))
imperative_qat
.
save_quantized_model
(
dirname
=
"./dynamic_mnist"
,
model
=
lenet
,
input_shape
=
[(
1
,
28
,
28
)],
input_dtype
=
[
'float32'
],
feed
=
[
0
],
fetch
=
[
0
])
# static graph train
_logger
.
info
(
"--------------------------static graph qat--------------------------"
)
static_loss_rec
=
[]
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
main
=
fluid
.
Program
()
infer
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
static_img
,
static_label
,
static_loss
=
_build_static_lenet
(
main
,
startup
,
False
,
seed
)
infer_img
,
_
,
infer_pre
=
_build_static_lenet
(
infer
,
startup
,
True
,
seed
)
with
fluid
.
unique_name
.
guard
():
with
fluid
.
program_guard
(
main
,
startup
):
opt
=
AdamOptimizer
(
learning_rate
=
lr
)
opt
.
minimize
(
static_loss
)
scope
=
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup
)
for
param
in
main
.
all_parameters
():
param_tensor
=
scope
.
var
(
param
.
name
).
get_tensor
()
param_tensor
.
set
(
param_init_map
[
param
.
name
],
place
)
main_graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
infer_graph
=
IrGraph
(
core
.
Graph
(
infer
.
desc
),
for_test
=
True
)
transform_pass
=
QuantizationTransformPass
(
scope
=
scope
,
place
=
place
,
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quantize_type
,
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
])
transform_pass
.
apply
(
main_graph
)
transform_pass
.
apply
(
infer_graph
)
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
.
fuse_all_reduce_ops
=
False
binary
=
fluid
.
CompiledProgram
(
main_graph
.
graph
).
with_data_parallel
(
loss_name
=
static_loss
.
name
,
build_strategy
=
build_strategy
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
static_img
,
static_label
],
place
=
place
)
with
fluid
.
scope_guard
(
scope
):
for
batch_id
,
data
in
enumerate
(
reader
()):
loss_v
,
=
exe
.
run
(
binary
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
static_loss
])
static_loss_rec
.
append
(
loss_v
[
0
])
if
batch_id
%
100
==
0
:
_logger
.
info
(
'{}: {}'
.
format
(
'loss'
,
loss_v
))
save_program
=
infer_graph
.
to_program
()
with
fluid
.
scope_guard
(
scope
):
fluid
.
io
.
save_inference_model
(
"./static_mnist"
,
[
infer_img
.
name
],
[
infer_pre
],
exe
,
save_program
)
rtol
=
1e-05
atol
=
1e-08
for
i
,
(
loss_d
,
loss_s
)
in
enumerate
(
zip
(
dynamic_loss_rec
,
static_loss_rec
)):
diff
=
np
.
abs
(
loss_d
-
loss_s
)
if
diff
>
(
atol
+
rtol
*
np
.
abs
(
loss_s
)):
_logger
.
info
(
"diff({}) at {}, dynamic loss = {}, static loss = {}"
.
format
(
diff
,
i
,
loss_d
,
loss_s
))
break
self
.
assertTrue
(
np
.
allclose
(
np
.
array
(
dynamic_loss_rec
),
np
.
array
(
static_loss_rec
),
rtol
=
rtol
,
atol
=
atol
,
equal_nan
=
True
),
msg
=
'Failed to do the imperative qat.'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录