Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9fd90674
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9fd90674
编写于
4月 01, 2020
作者:
W
Wojciech Uss
提交者:
GitHub
4月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
handle conv2d activations in older QAT models (#23202)
上级
21d95be0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
148 addition
and
0 deletion
+148
-0
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
.../fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
+18
-0
python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py
...le/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py
+130
-0
未找到文件。
python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py
浏览文件 @
9fd90674
...
...
@@ -294,6 +294,23 @@ class Qat2Int8MkldnnPass(object):
tensor
=
self
.
_scope
.
find_var
(
name
).
get_tensor
()
tensor
.
set
(
array
,
self
.
_place
)
def
_update_activations
(
self
,
graph
):
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
()
in
self
.
_conv_ops
and
not
op
.
op
().
has_attr
(
"fuse_activation"
):
activation
=
""
if
op
.
op
().
has_attr
(
"fuse_relu"
)
and
op
.
op
().
attr
(
"fuse_relu"
):
activation
=
"relu"
elif
op
.
op
().
has_attr
(
"fuse_brelu"
)
and
op
.
op
().
attr
(
"fuse_brelu"
):
activation
=
"relu6"
alpha
=
6.0
if
op
.
op
().
has_attr
(
"fuse_brelu_threshold"
):
alpha
=
op
.
op
().
attr
(
"fuse_brelu_threshold"
)
op
.
set_attr
(
"fuse_alpha"
,
alpha
)
op
.
set_attr
(
"fuse_activation"
,
activation
)
return
graph
def
_remove_ctrl_vars
(
self
,
graph
):
remove_ctr_vars
=
set
()
for
node
in
graph
.
all_var_nodes
():
...
...
@@ -303,6 +320,7 @@ class Qat2Int8MkldnnPass(object):
return
graph
def
_optimize_fp32_graph
(
self
,
graph
):
graph
=
self
.
_update_activations
(
graph
)
graph
=
self
.
_remove_ctrl_vars
(
graph
)
graph
=
self
.
_apply_pass
(
graph
,
'mkldnn_placement_pass'
,
[
'mkldnn_enabled_op_types'
],
[
set
()])
...
...
python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py
0 → 100644
浏览文件 @
9fd90674
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid.contrib.slim.quantization
import
Qat2Int8MkldnnPass
class
TestQat2Int8MkldnnPass
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
scope
=
fluid
.
Scope
()
self
.
place
=
fluid
.
CPUPlace
()
self
.
dtype
=
np
.
float32
self
.
use_cudnn
=
False
self
.
use_mkldnn
=
True
self
.
data_format
=
"ANYLAYOUT"
self
.
pad
=
[
0
,
0
]
self
.
stride
=
[
1
,
1
]
self
.
dilations
=
[
1
,
1
]
self
.
groups
=
1
self
.
input_size
=
[
1
,
3
,
5
,
5
]
self
.
filter_size
=
[
16
,
3
,
3
,
3
]
self
.
filter_size2
=
[
1
,
16
,
2
,
2
]
self
.
conv_output_size
=
[
1
,
16
,
3
,
3
]
self
.
conv_output2_size
=
[
1
,
1
,
2
,
2
]
self
.
input
=
np
.
random
.
random
(
self
.
input_size
).
astype
(
self
.
dtype
)
self
.
filter
=
np
.
random
.
random
(
self
.
filter_size
).
astype
(
self
.
dtype
)
self
.
filter2
=
np
.
random
.
random
(
self
.
filter_size2
).
astype
(
self
.
dtype
)
self
.
conv_output
=
np
.
ndarray
(
self
.
conv_output_size
).
astype
(
self
.
dtype
)
self
.
conv_output2
=
np
.
ndarray
(
self
.
conv_output2_size
).
astype
(
self
.
dtype
)
self
.
quantized_ops
=
'conv2d'
self
.
variables
=
{
"input"
:
self
.
input
,
"filter"
:
self
.
filter
,
"filter2"
:
self
.
filter2
,
"conv_output"
:
self
.
conv_output
,
"conv_output2"
:
self
.
conv_output2
,
}
def
prepare_program
(
self
,
program
):
block
=
program
.
global_block
()
for
name
in
self
.
variables
:
block
.
create_var
(
name
=
name
,
dtype
=
"float32"
,
shape
=
self
.
variables
[
name
].
shape
)
conv2d_op1
=
block
.
append_op
(
type
=
"conv2d"
,
inputs
=
{
"Input"
:
block
.
var
(
'input'
),
'Filter'
:
block
.
var
(
'filter'
)
},
outputs
=
{
"Output"
:
block
.
var
(
'conv_output'
)},
attrs
=
{
'strides'
:
self
.
stride
,
'paddings'
:
self
.
pad
,
'groups'
:
self
.
groups
,
'dilations'
:
self
.
dilations
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'data_format'
:
self
.
data_format
,
'fuse_relu'
:
True
})
conv2d_op2
=
block
.
append_op
(
type
=
"conv2d"
,
inputs
=
{
"Input"
:
block
.
var
(
'conv_output'
),
'Filter'
:
block
.
var
(
'filter2'
)
},
outputs
=
{
"Output"
:
block
.
var
(
'conv_output2'
)},
attrs
=
{
'strides'
:
self
.
stride
,
'paddings'
:
self
.
pad
,
'groups'
:
self
.
groups
,
'dilations'
:
self
.
dilations
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'data_format'
:
self
.
data_format
,
'fuse_brelu'
:
True
})
def
remove_fuse_activation_attribute
(
self
,
graph
):
for
op
in
graph
.
all_op_nodes
():
op
.
op
().
remove_attr
(
"fuse_activation"
)
return
graph
def
check_graph_before_pass
(
self
,
graph
):
for
op
in
graph
.
all_op_nodes
():
self
.
assertFalse
(
op
.
op
().
has_attr
(
"fuse_activation"
))
def
check_graph_after_pass
(
self
,
graph
):
for
op
in
graph
.
all_op_nodes
():
self
.
assertTrue
(
op
.
op
().
has_attr
(
"fuse_activation"
))
if
op
.
op
().
has_attr
(
"fuse_relu"
)
and
op
.
op
().
attr
(
"fuse_relu"
):
self
.
assertTrue
(
op
.
op
().
attr
(
"fuse_activation"
)
==
"relu"
)
if
op
.
op
().
has_attr
(
"fuse_brelu"
)
and
op
.
op
().
attr
(
"fuse_brelu"
):
self
.
assertTrue
(
op
.
op
().
attr
(
"fuse_activation"
)
==
"relu6"
)
def
test_qat_update_activation
(
self
):
program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
program
):
self
.
prepare_program
(
program
)
graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
True
)
graph
=
self
.
remove_fuse_activation_attribute
(
graph
)
self
.
check_graph_before_pass
(
graph
)
qat2_int8_mkldnn_pass
=
Qat2Int8MkldnnPass
(
self
.
quantized_ops
,
_scope
=
self
.
scope
,
_place
=
self
.
place
,
_core
=
core
,
_debug
=
False
)
graph
=
qat2_int8_mkldnn_pass
.
_update_activations
(
graph
)
self
.
check_graph_after_pass
(
graph
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录