Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
04f20b83
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
04f20b83
编写于
4月 21, 2022
作者:
B
baoachun
提交者:
GitHub
4月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add mkldnn int8 pass [step1] (#41579) (#42045)
上级
d24a402e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
751 addition
and
0 deletion
+751
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h
paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h
+77
-0
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
...le/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
+582
-0
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h
+91
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
04f20b83
...
...
@@ -140,6 +140,7 @@ if(WITH_MKLDNN)
pass_library
(
batch_norm_act_fuse_pass inference DIR mkldnn
)
pass_library
(
multi_gru_fuse_pass inference DIR mkldnn
)
pass_library
(
multi_gru_seq_fuse_pass inference DIR mkldnn
)
pass_library
(
quant_dequant_mkldnn_pass inference DIR mkldnn
)
endif
()
if
(
WITH_IPU
)
...
...
paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h
0 → 100644
浏览文件 @
04f20b83
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/graph_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
void
SaveInfoInTheFirstOp
(
ir
::
Graph
*
graph
,
const
std
::
string
&
flag
,
const
std
::
string
&
key_suffix
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>&
info_map
)
{
VLOG
(
3
)
<<
"save variables in the first op's attr"
;
const
std
::
string
suffix
=
"_"
+
key_suffix
+
"_"
+
flag
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
()
||
op_node
->
Op
()
->
Type
()
==
"feed"
||
op_node
->
Op
()
->
Type
()
==
"fetch"
)
continue
;
op_node
->
Op
()
->
SetAttr
(
flag
,
true
);
for
(
auto
iter
=
info_map
.
begin
();
iter
!=
info_map
.
end
();
++
iter
)
{
op_node
->
Op
()
->
SetAttr
(
iter
->
first
+
suffix
,
iter
->
second
);
}
break
;
}
}
static
void
GetInfoFromTheFirstOp
(
ir
::
Graph
*
graph
,
const
std
::
string
&
flag
,
const
std
::
string
&
key_suffix
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
info_map
)
{
VLOG
(
3
)
<<
"get variables from the first op's attr"
;
const
std
::
string
suffix
=
"_"
+
key_suffix
+
"_"
+
flag
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
()
||
op_node
->
Op
()
->
Type
()
==
"feed"
||
op_node
->
Op
()
->
Type
()
==
"fetch"
)
continue
;
auto
*
op_desc
=
op_node
->
Op
();
if
(
op_desc
->
GetAttrIfExists
<
bool
>
(
flag
))
{
op_desc
->
RemoveAttr
(
flag
);
std
::
vector
<
std
::
string
>
attr_names
=
op_desc
->
AttrNames
();
for
(
auto
fake_name
:
attr_names
)
{
size_t
pos
=
fake_name
.
find
(
suffix
);
if
(
pos
!=
std
::
string
::
npos
)
{
std
::
string
name
=
fake_name
.
substr
(
0
,
pos
);
auto
scales_vector
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
->
GetAttr
(
fake_name
));
info_map
->
insert
(
std
::
make_pair
(
name
,
scales_vector
));
op_desc
->
RemoveAttr
(
fake_name
);
}
}
break
;
}
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.cc
0 → 100644
浏览文件 @
04f20b83
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
void
QuantDequantMkldnnPass
::
MarkSkipQuantizedOps
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
skip_ops
)
const
{
VLOG
(
3
)
<<
"mark skip quantized ops"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
skip_ops
.
count
(
op_node
->
Name
()))
{
auto
*
op_desc
=
op_node
->
Op
();
if
(
!
op_desc
->
HasAttr
(
"quantization_type"
))
{
bool
is_quantized_op
=
true
;
for
(
auto
*
node_input
:
op_node
->
inputs
)
{
for
(
auto
*
node_input_input
:
node_input
->
inputs
)
{
if
(
!
node_input_input
->
IsOp
())
continue
;
if
(
node_input_input
->
Name
().
find
(
"quantize_dequantize"
)
==
std
::
string
::
npos
)
{
is_quantized_op
=
false
;
break
;
}
}
if
(
!
is_quantized_op
)
break
;
}
if
(
!
is_quantized_op
)
{
op_node
->
Op
()
->
SetAttr
(
"skip_quant"
,
1
);
}
}
}
}
}
void
QuantDequantMkldnnPass
::
MarkSkipQuantizedPool2d
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"mark avg pool2d as skip quantized op"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
op_node
->
Name
()
==
"pool2d"
)
{
auto
*
op_desc
=
op_node
->
Op
();
auto
pool_type
=
BOOST_GET_CONST
(
std
::
string
,
op_desc
->
GetAttr
(
"pooling_type"
));
if
(
pool_type
==
"avg"
)
{
op_node
->
Op
()
->
SetAttr
(
"skip_quant"
,
1
);
}
}
}
}
void
QuantDequantMkldnnPass
::
CollectInfoFromFake
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_set
<
std
::
string
>&
fake_dequantize_types
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
weight_thresholds
)
const
{
VLOG
(
3
)
<<
"gather weight_thresholds from fake dequantized ops"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
fake_dequantize_types
.
count
(
op_node
->
Name
()))
{
auto
*
op_desc
=
op_node
->
Op
();
auto
x_var_name
=
op_desc
->
Input
(
"X"
)[
0
];
if
(
op_desc
->
HasAttr
(
"max_range"
))
{
const
float
max_range
=
BOOST_GET_CONST
(
float
,
op_desc
->
GetAttr
(
"max_range"
));
std
::
vector
<
float
>
thresholds
=
{
127
*
127
/
max_range
};
weight_thresholds
->
insert
(
std
::
make_pair
(
x_var_name
,
thresholds
));
}
else
{
auto
scale_name
=
op_desc
->
Input
(
"Scales"
)[
0
];
auto
*
var
=
scope
->
FindVar
(
scale_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"The Scales variable [%s] of dequantize op is not found."
,
var
));
auto
*
scale_tensor
=
var
->
GetMutable
<
LoDTensor
>
();
auto
*
scale_data
=
scale_tensor
->
data
<
float
>
();
std
::
vector
<
float
>
thresholds
{};
for
(
int
i
=
0
;
i
<
scale_tensor
->
numel
();
i
++
)
{
thresholds
.
push_back
(
scale_data
[
i
]);
}
weight_thresholds
->
insert
(
std
::
make_pair
(
x_var_name
,
thresholds
));
}
}
}
}
void
QuantDequantMkldnnPass
::
CollectInputScalesFromFake
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_types
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
var_quant_scales
)
const
{
VLOG
(
3
)
<<
"gather input scales from fake quantized ops"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
op_node
->
Name
()
==
"fake_quantize_dequantize_moving_average_abs_max"
||
fake_quantize_types
.
count
(
op_node
->
Name
()))
{
auto
*
op_desc
=
op_node
->
Op
();
const
int
bit_length
=
BOOST_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"bit_length"
));
PADDLE_ENFORCE_EQ
(
bit_length
,
8
,
platform
::
errors
::
InvalidArgument
(
"Unsupported number quantization "
"bits: %d, only 8 is supported now."
,
bit_length
));
auto
x_var_name
=
op_desc
->
Input
(
"X"
)[
0
];
auto
scale_name
=
op_desc
->
Input
(
"InScale"
)[
0
];
auto
out_var_name
=
op_desc
->
Output
(
"Out"
)[
0
];
auto
*
var
=
scope
->
FindVar
(
scale_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"The InScale variable [%s] of quantize op is not found."
,
var
));
auto
*
scale_tensor
=
var
->
GetMutable
<
LoDTensor
>
();
auto
*
scale_data
=
scale_tensor
->
data
<
float
>
();
float
scale
=
1.0
/
scale_data
[
0
];
if
(
std
::
isinf
(
scale
)
||
std
::
isnan
(
scale
))
{
scale
=
0.0
;
}
if
(
!
var_quant_scales
->
count
(
x_var_name
))
{
std
::
vector
<
float
>
scale_v
=
{
scale
};
var_quant_scales
->
insert
(
std
::
make_pair
(
x_var_name
,
scale_v
));
}
if
(
!
var_quant_scales
->
count
(
out_var_name
))
{
std
::
vector
<
float
>
scale_v
=
{
scale
};
var_quant_scales
->
insert
(
std
::
make_pair
(
out_var_name
,
scale_v
));
}
}
}
}
void
QuantDequantMkldnnPass
::
CollectOutputScalesFromAttr
(
ir
::
Graph
*
graph
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
var_quant_scales
)
const
{
VLOG
(
3
)
<<
"gather output scales from op's attr"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
auto
*
op_desc
=
op_node
->
Op
();
if
(
op_desc
->
HasAttr
(
"out_threshold"
))
{
const
float
attr_scale
=
BOOST_GET_CONST
(
float
,
op_desc
->
GetAttr
(
"out_threshold"
));
if
(
attr_scale
==
0.0
)
continue
;
float
scale
=
1.0
/
attr_scale
;
std
::
vector
<
float
>
scale_v
=
{
scale
};
auto
var_name_map
=
op_desc
->
Outputs
();
for
(
auto
iter
=
var_name_map
.
begin
();
iter
!=
var_name_map
.
end
();
++
iter
)
{
for
(
auto
var_name
:
iter
->
second
)
{
var_quant_scales
->
insert
(
std
::
make_pair
(
var_name
,
scale_v
));
}
}
}
}
}
void
QuantDequantMkldnnPass
::
CollectFakeQuantizeOps
(
ir
::
Graph
*
graph
,
Node
*
op_node
,
std
::
unordered_set
<
const
Node
*>*
nodes2rm
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
auto
x_var_name
=
op_desc
->
Input
(
"X"
)[
0
];
auto
in_scale_name
=
op_desc
->
Input
(
"InScale"
)[
0
];
auto
out_var_name
=
op_desc
->
Output
(
"Out"
)[
0
];
auto
out_scale_name
=
op_desc
->
Output
(
"OutScale"
)[
0
];
Node
*
fake_quant_in
=
nullptr
;
Node
*
fake_quant_in_scale
=
nullptr
;
for
(
auto
*
node_input
:
op_node
->
inputs
)
{
if
(
node_input
->
Name
()
==
x_var_name
)
{
fake_quant_in
=
node_input
;
break
;
}
else
if
(
node_input
->
Name
()
==
in_scale_name
)
{
fake_quant_in_scale
=
node_input
;
break
;
}
}
Node
*
fake_quant_out
=
nullptr
;
Node
*
fake_quant_out_scale
=
nullptr
;
for
(
auto
*
node_output
:
op_node
->
outputs
)
{
if
(
node_output
->
Name
()
==
out_var_name
)
{
fake_quant_out
=
node_output
;
break
;
}
else
if
(
node_output
->
Name
()
==
out_scale_name
)
{
fake_quant_out_scale
=
node_output
;
break
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
fake_quant_in
,
platform
::
errors
::
NotFound
(
"The input var [%s] of quantize op is not found."
,
x_var_name
));
PADDLE_ENFORCE_NOT_NULL
(
fake_quant_out
,
platform
::
errors
::
NotFound
(
"The output var [%s] of quantize op is not found."
,
out_var_name
));
std
::
string
input_act_name
=
fake_quant_in
->
Var
()
->
Name
();
std
::
string
output_act_name
=
fake_quant_out
->
Var
()
->
Name
();
auto
outlinks
=
fake_quant_out
->
outputs
;
for
(
auto
*
next_node
:
outlinks
)
{
if
(
!
next_node
->
IsOp
())
continue
;
next_node
->
Op
()
->
RenameInput
(
output_act_name
,
input_act_name
);
IR_NODE_LINK_TO
(
fake_quant_in
,
next_node
);
}
nodes2rm
->
insert
(
op_node
);
nodes2rm
->
insert
(
fake_quant_in_scale
);
nodes2rm
->
insert
(
fake_quant_out
);
nodes2rm
->
insert
(
fake_quant_out_scale
);
}
void
QuantDequantMkldnnPass
::
CollectFakeDequantizeOps
(
ir
::
Graph
*
graph
,
Node
*
op_node
,
std
::
unordered_set
<
const
Node
*>*
nodes2rm
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
auto
x_var_name
=
op_desc
->
Input
(
"X"
)[
0
];
auto
out_var_name
=
op_desc
->
Output
(
"Out"
)[
0
];
Node
*
fake_dequant_in
=
nullptr
;
for
(
auto
*
node_input
:
op_node
->
inputs
)
{
if
(
node_input
->
Name
()
==
x_var_name
)
{
fake_dequant_in
=
node_input
;
break
;
}
}
Node
*
fake_dequant_out
=
nullptr
;
for
(
auto
*
node_output
:
op_node
->
outputs
)
{
if
(
node_output
->
Name
()
==
out_var_name
)
{
fake_dequant_out
=
node_output
;
break
;
}
}
PADDLE_ENFORCE_NOT_NULL
(
fake_dequant_in
,
platform
::
errors
::
NotFound
(
"The input var [%s] of dequantize op is not found."
,
x_var_name
));
PADDLE_ENFORCE_NOT_NULL
(
fake_dequant_out
,
platform
::
errors
::
NotFound
(
"The output var [%s] of dequantize op is not found."
,
out_var_name
));
std
::
string
input_act_name
=
fake_dequant_in
->
Var
()
->
Name
();
std
::
string
output_act_name
=
fake_dequant_out
->
Var
()
->
Name
();
auto
outlinks
=
fake_dequant_out
->
outputs
;
for
(
auto
*
next_node
:
outlinks
)
{
next_node
->
Op
()
->
RenameInput
(
output_act_name
,
input_act_name
);
IR_NODE_LINK_TO
(
fake_dequant_in
,
next_node
);
}
nodes2rm
->
insert
(
op_node
);
nodes2rm
->
insert
(
fake_dequant_out
);
}
void
QuantDequantMkldnnPass
::
RemoveFakeOps
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_types
,
const
std
::
unordered_set
<
std
::
string
>&
fake_dequantize_types
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_dequantize_types
)
const
{
VLOG
(
3
)
<<
"remove fake quantize and dequantize ops"
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
fake_quantize_types
.
count
(
op_node
->
Name
()))
{
CollectFakeQuantizeOps
(
graph
,
op_node
,
&
nodes2rm
);
}
else
if
(
fake_dequantize_types
.
count
(
op_node
->
Name
()))
{
CollectFakeDequantizeOps
(
graph
,
op_node
,
&
nodes2rm
);
}
else
if
(
fake_quantize_dequantize_types
.
count
(
op_node
->
Name
()))
{
CollectFakeDequantizeOps
(
graph
,
op_node
,
&
nodes2rm
);
}
}
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
void
QuantDequantMkldnnPass
::
TransposeWeight
(
Tensor
*
input
)
const
{
const
auto
in_dims
=
input
->
dims
();
std
::
vector
<
int
>
out_dim_v
;
std
::
vector
<
int
>
axis
;
for
(
int
i
=
in_dims
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
axis
.
push_back
(
i
);
out_dim_v
.
push_back
(
in_dims
[
i
]);
}
const
auto
out_dims
=
phi
::
make_ddim
(
out_dim_v
);
const
int
rank
=
axis
.
size
();
auto
in_stride
=
phi
::
stride
(
in_dims
);
auto
out_stride
=
phi
::
stride
(
out_dims
);
const
int
count
=
input
->
numel
();
Tensor
trans_tensor
;
trans_tensor
.
Resize
(
out_dims
);
float
*
trans_data
=
trans_tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
float
*
in_data
=
input
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int64_t
out_idx
=
0
;
out_idx
<
count
;
++
out_idx
)
{
int64_t
in_idx
=
0
;
int64_t
tmp_idx
=
out_idx
;
for
(
int
i
=
0
;
i
<
rank
;
++
i
)
{
const
int64_t
coordinate
=
tmp_idx
/
out_stride
[
i
];
tmp_idx
-=
coordinate
*
out_stride
[
i
];
in_idx
+=
coordinate
*
in_stride
[
axis
[
i
]];
}
trans_data
[
out_idx
]
=
in_data
[
in_idx
];
}
input
->
Resize
(
out_dims
);
for
(
int
i
=
0
;
i
<
input
->
numel
();
i
++
)
{
in_data
[
i
]
=
trans_data
[
i
];
}
}
bool
QuantDequantMkldnnPass
::
IsInt8Weight
(
Node
*
op_node
,
Scope
*
scope
,
const
std
::
string
&
weight_name
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
auto
var_name
=
op_desc
->
Input
(
weight_name
)[
0
];
auto
*
var
=
scope
->
FindVar
(
var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"The input persistable [%s] var of [%s] op is not found."
,
var_name
,
op_desc
->
Type
()));
auto
*
weight_tensor
=
var
->
GetMutable
<
LoDTensor
>
();
auto
*
weight_data
=
weight_tensor
->
data
<
float
>
();
bool
is_int8
=
true
;
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
if
(
weight_data
[
i
]
-
static_cast
<
int
>
(
weight_data
[
i
])
!=
0
)
{
is_int8
=
false
;
break
;
}
}
return
is_int8
;
}
void
QuantDequantMkldnnPass
::
DequantizeOpWeights
(
Node
*
op_node
,
Scope
*
scope
,
const
std
::
string
&
weight_name
,
const
std
::
string
&
output_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>&
weight_thresholds
)
const
{
auto
*
op_desc
=
op_node
->
Op
();
std
::
string
weight_var_name
=
op_desc
->
Input
(
weight_name
)[
0
];
std
::
string
output_var_name
=
op_desc
->
Output
(
output_name
)[
0
];
std
::
vector
<
float
>
scales
;
auto
iter
=
weight_thresholds
.
find
(
output_var_name
);
if
(
iter
!=
weight_thresholds
.
end
())
{
scales
=
iter
->
second
;
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Could not find threshold information for [%s] var, please check if "
"the model is correct."
,
output_var_name
));
}
auto
*
var
=
scope
->
FindVar
(
weight_var_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"The input persistable [%s] var of [%s] op is not found."
,
weight_var_name
,
op_desc
->
Type
()));
auto
*
weight_tensor
=
var
->
GetMutable
<
LoDTensor
>
();
const
auto
weight_dims
=
weight_tensor
->
dims
();
const
int
size
=
scales
.
size
();
if
(
size
==
1
||
size
==
weight_dims
[
0
])
{
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data
[
i
]
/=
127
;
}
TransposeWeight
(
weight_tensor
);
if
(
size
==
1
)
{
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data
[
i
]
*=
scales
[
0
];
}
}
else
{
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data
[
i
]
*=
scales
[
i
%
size
];
}
}
TransposeWeight
(
weight_tensor
);
}
else
if
(
weight_dims
.
size
()
>
1
&&
size
==
weight_dims
[
1
])
{
auto
*
weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data
[
i
]
/=
127
;
}
int
step_n
=
1
;
for
(
int
i
=
1
;
i
<
weight_dims
.
size
();
i
++
)
{
step_n
*=
weight_dims
[
i
];
}
int
step_c
=
step_n
/
size
;
for
(
int
i
=
0
;
i
<
weight_dims
[
0
];
i
++
)
{
int
begin_n
=
i
*
step_n
;
for
(
int
j
=
begin_n
;
j
<
begin_n
+
step_n
;
j
++
)
{
for
(
int
k
=
0
;
k
<
size
;
k
++
)
{
int
begin_c
=
k
*
step_c
;
for
(
int
m
=
begin_c
;
m
<
begin_c
+
step_c
;
m
++
)
{
weight_data
[
m
]
*=
scales
[
k
];
}
}
}
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The size of weight scales vector (%d) does not "
"match the dimensions (%d) of the weights tensor %s."
,
size
,
weight_tensor
->
dims
().
size
(),
weight_var_name
));
}
weight_tensor
->
Resize
(
weight_dims
);
}
void
QuantDequantMkldnnPass
::
DequantizeWeights
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>&
weight_thresholds
)
const
{
VLOG
(
3
)
<<
"dequantize weight for ops which has weight"
;
if
(
weight_thresholds
.
empty
())
{
VLOG
(
3
)
<<
"No need to dequantize weights because weight_thresholds is empty."
;
return
;
}
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
op_node
->
Name
()
==
"conv2d"
||
op_node
->
Name
()
==
"depthwise_conv2d"
)
{
if
(
IsInt8Weight
(
op_node
,
scope
,
"Filter"
))
{
DequantizeOpWeights
(
op_node
,
scope
,
"Filter"
,
"Output"
,
weight_thresholds
);
}
}
else
if
(
op_node
->
Name
()
==
"mul"
||
op_node
->
Name
()
==
"matmul"
||
op_node
->
Name
()
==
"matmul_v2"
)
{
if
(
IsInt8Weight
(
op_node
,
scope
,
"Y"
))
{
DequantizeOpWeights
(
op_node
,
scope
,
"Y"
,
"Out"
,
weight_thresholds
);
}
}
}
}
void
QuantDequantMkldnnPass
::
UpdateActivations
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"update conv2d or depthwise_conv2d fused activation"
;
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
!
op_node
->
IsOp
())
continue
;
if
(
op_node
->
Name
()
==
"conv2d"
||
op_node
->
Name
()
==
"depthwise_conv2d"
)
{
auto
*
op_desc
=
op_node
->
Op
();
if
(
!
op_desc
->
HasAttr
(
"fuse_activation"
))
{
std
::
string
activation
;
if
(
op_desc
->
GetAttrIfExists
<
bool
>
(
"fuse_relu"
))
{
activation
=
"relu"
;
}
else
if
(
op_desc
->
GetAttrIfExists
<
bool
>
(
"fuse_brelu"
))
{
activation
=
"relu6"
;
float
alpha
=
6.0
;
if
(
op_desc
->
HasAttr
(
"fuse_brelu_threshold"
))
{
alpha
=
BOOST_GET_CONST
(
float
,
op_desc
->
GetAttr
(
"fuse_brelu_threshold"
));
}
op_node
->
Op
()
->
SetAttr
(
"fuse_alpha"
,
alpha
);
}
op_node
->
Op
()
->
SetAttr
(
"fuse_activation"
,
activation
);
}
}
}
}
void
QuantDequantMkldnnPass
::
RemoveCtrlVars
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"remove control flow variable"
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
for
(
auto
*
op_node
:
ir
::
TopologyVarientSort
(
*
graph
,
static_cast
<
ir
::
SortKind
>
(
0
)))
{
if
(
op_node
->
IsCtrlVar
())
{
nodes2rm
.
insert
(
op_node
);
}
}
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
}
void
QuantDequantMkldnnPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
3
)
<<
"Convert paddle slim quantized model to mkldnn quantized model."
;
const
std
::
string
pattern_name
=
"quant_dequant_mkldnn_pass"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
const
std
::
unordered_set
<
std
::
string
>
skip_ops
=
{
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
"matmul"
,
"matmul_v2"
};
const
std
::
unordered_set
<
std
::
string
>
fake_quantize_types
=
{
"fake_quantize_moving_average_abs_max"
,
"fake_quantize_range_abs_max"
};
const
std
::
unordered_set
<
std
::
string
>
fake_dequantize_types
=
{
"fake_dequantize_max_abs"
,
"fake_channel_wise_dequantize_max_abs"
};
const
std
::
unordered_set
<
std
::
string
>
fake_quantize_dequantize_types
=
{
"fake_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_moving_average_abs_max"
,
"fake_channel_wise_quantize_dequantize_abs_max"
};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>
weight_thresholds
{};
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>
var_quant_scales
{};
auto
*
scope
=
param_scope
();
MarkSkipQuantizedOps
(
graph
,
skip_ops
);
MarkSkipQuantizedPool2d
(
graph
);
CollectInfoFromFake
(
graph
,
scope
,
fake_dequantize_types
,
&
weight_thresholds
);
CollectInputScalesFromFake
(
graph
,
scope
,
fake_quantize_types
,
&
var_quant_scales
);
CollectOutputScalesFromAttr
(
graph
,
&
var_quant_scales
);
RemoveFakeOps
(
graph
,
fake_quantize_types
,
fake_dequantize_types
,
fake_quantize_dequantize_types
);
DequantizeWeights
(
graph
,
scope
,
weight_thresholds
);
UpdateActivations
(
graph
);
RemoveCtrlVars
(
graph
);
// save var_quant_scales in the first op's attr
// for compute_propagate_scales_mkldnn_pass
SaveInfoInTheFirstOp
(
graph
,
"has_quant_info"
,
"var_quant_scales"
,
var_quant_scales
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
quant_dequant_mkldnn_pass
,
paddle
::
framework
::
ir
::
QuantDequantMkldnnPass
);
REGISTER_PASS_CAPABILITY
(
quant_dequant_mkldnn_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"conv2d"
,
1
)
.
EQ
(
"fc"
,
0
)
.
LE
(
"conv2d_transpose"
,
2
)
.
EQ
(
"fake_quantize_abs_max"
,
0
)
.
EQ
(
"fake_quantize_range_abs_max"
,
0
)
.
EQ
(
"fake_quantize_moving_average_abs_max"
,
0
)
.
LE
(
"fake_channel_wise_quantize_abs_max"
,
1
)
.
EQ
(
"fake_dequantize_max_abs"
,
0
));
paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_pass.h
0 → 100644
浏览文件 @
04f20b83
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
QuantDequantMkldnnPass
:
public
FusePassBase
{
public:
QuantDequantMkldnnPass
()
=
default
;
virtual
~
QuantDequantMkldnnPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
MarkSkipQuantizedOps
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
skip_ops
)
const
;
void
MarkSkipQuantizedPool2d
(
ir
::
Graph
*
graph
)
const
;
void
CollectInfoFromFake
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_set
<
std
::
string
>&
fake_dequantize_types
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
weight_thresholds
)
const
;
void
CollectInputScalesFromFake
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_types
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
var_quant_scales
)
const
;
void
CollectOutputScalesFromAttr
(
ir
::
Graph
*
graph
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>*
var_quant_scales
)
const
;
void
CollectFakeQuantizeOps
(
ir
::
Graph
*
graph
,
Node
*
op_node
,
std
::
unordered_set
<
const
Node
*>*
nodes2rm
)
const
;
void
CollectFakeDequantizeOps
(
ir
::
Graph
*
graph
,
Node
*
op_node
,
std
::
unordered_set
<
const
Node
*>*
nodes2rm
)
const
;
void
RemoveFakeOps
(
ir
::
Graph
*
graph
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_types
,
const
std
::
unordered_set
<
std
::
string
>&
fake_dequantize_types
,
const
std
::
unordered_set
<
std
::
string
>&
fake_quantize_dequantize_types
)
const
;
bool
IsInt8Weight
(
Node
*
op_node
,
Scope
*
scope
,
const
std
::
string
&
weight_name
)
const
;
void
TransposeWeight
(
Tensor
*
input
)
const
;
void
DequantizeOpWeights
(
Node
*
op_node
,
Scope
*
scope
,
const
std
::
string
&
weight_name
,
const
std
::
string
&
output_name
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>&
weight_thresholds
)
const
;
void
DequantizeWeights
(
ir
::
Graph
*
graph
,
Scope
*
scope
,
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
float
>>&
weight_thresholds
)
const
;
void
UpdateActivations
(
ir
::
Graph
*
graph
)
const
;
void
RemoveCtrlVars
(
ir
::
Graph
*
graph
)
const
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录