Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5cda6b2b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5cda6b2b
编写于
9月 22, 2021
作者:
W
Wangzheee
提交者:
GitHub
9月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: delete_quant_dequant_filter_op_pass, delete_quant_dequant_op_pass (#35879)
上级
1238115e
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
87 addition
and
187 deletion
+87
-187
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
+24
-120
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
+51
-42
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+11
-22
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+1
-3
未找到文件。
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
浏览文件 @
5cda6b2b
...
@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
std
::
vector
<
float
>
weight_scale
;
std
::
vector
<
float
>
weight_scale
;
std
::
string
quant_dequant_op_out_name
=
quant_dequant_op_out
->
Var
()
->
Name
();
std
::
string
quant_dequant_op_out_name
=
quant_dequant_op_out
->
Var
()
->
Name
();
auto
*
any_op2_desc
=
any_op2
->
Op
();
auto
*
any_op2_desc
=
any_op2
->
Op
();
auto
var_map
=
any_op2_desc
->
Inputs
();
auto
var_map
=
any_op2_desc
->
Inputs
();
std
::
string
arg_name
=
""
;
std
::
string
arg_name
=
""
;
...
@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_GT
(
arg_name
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
arg_name
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"can not find the input %s."
,
"can not find the input %s."
,
quant_dequant_op_out_name
));
quant_dequant_op_out_name
));
any_op2_desc
->
SetAttr
(
"enable_int8"
,
true
);
//
any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
any_op2_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
// modify the any_op2's inputs
// modify the any_op2's inputs
any_op2_desc
->
Flush
();
auto
dequant_type
=
quant_dequant_op
->
Op
()
->
Type
();
auto
dequant_type
=
quant_dequant_op
->
Op
()
->
Type
();
auto
quantized_op_type
=
any_op2_desc
->
Type
();
// get weight tensor
// get weight tensor
auto
*
weight_tensor
=
auto
*
weight_tensor
=
scope
->
GetVar
(
quant_dequant_op_x
->
Name
())
->
GetMutable
<
LoDTensor
>
();
scope
->
GetVar
(
quant_dequant_op_x
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
w_dims
=
weight_tensor
->
dims
();
auto
w_dims
=
weight_tensor
->
dims
();
float
*
quantized_weight_data
=
float
*
quantized_weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// Get weight scale
// Get weight scale
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
auto
scales_name
=
quant_dequant_op
->
Op
()
->
Output
(
"OutScale"
);
int
quant_axis
=
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
// To Do @Wangzheee: use "OutScale" to quantdequant
/*auto scales_name = quant_dequant_op->Op()->Output("OutScale");
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
platform::errors::InvalidArgument(
platform::errors::InvalidArgument(
"Scales size in channel-wise quant dequantize op "
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d.",
"should be 1, got %d.",
scales_name.size()));
scales_name.size()));
const LoDTensor& channel_scale_tensor =
const LoDTensor& channel_scale_tensor =
scope
->
Get
Var
(
scales_name
[
0
])
->
Get
<
LoDTensor
>
();
scope->
Find
Var(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(
PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument(
platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU."));
"Channel scale tensor's place should be CPU."));
// compute the channel wise abs max of the weight tensor
// compute the channel wise abs max of the weight tensor
int
quant_axis
=
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
const float* channel_scale_data = channel_scale_tensor.data<float>();
platform
::
errors
::
InvalidArgument
(
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
"'quant_axis' should be 0 or 1, but "
weight_scale.push_back(channel_scale_data[i] );
"the received is %d"
,
}*/
quant_axis
));
// Implement channel_wise_quantize_dequantize_abs_max quantization
// algorithm
const
int64_t
channel
=
w_dims
[
quant_axis
];
const
int64_t
channel
=
w_dims
[
quant_axis
];
weight_scale
.
resize
(
channel
,
0
);
weight_scale
.
resize
(
channel
,
0
);
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
...
@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE
(
weight_scale
[
i
],
0
,
PADDLE_ENFORCE_NE
(
weight_scale
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Weight scale should be nonzero, but get zero."
));
"Weight scale should be nonzero, but get zero."
));
weight_scale
[
i
]
=
range
/
weight_scale
[
i
]
;
weight_scale
[
i
]
=
weight_scale
[
i
]
/
range
;
}
}
}
else
{
}
else
{
auto
scale_name
=
quant_dequant_op_outscale
->
Name
();
// Implement quantize_dequantize_abs_max quantization algorithm
// compute the abs max of the weight tensor
float
abs_max_weight
=
0.
;
float
abs_max_weight
=
0.
;
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
abs_max_weight
=
abs_max_weight
=
...
@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
...
@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE
(
abs_max_weight
,
0
,
PADDLE_ENFORCE_NE
(
abs_max_weight
,
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"Weight scale should be nonzero, but get zero"
));
"Weight scale should be nonzero, but get zero"
));
weight_scale
.
push_back
(
(
range
*
range
)
/
abs_max_weight
/
range
);
weight_scale
.
push_back
(
abs_max_weight
/
range
);
}
}
nodes2rm
.
insert
(
quant_dequant_op_outscale
);
nodes2rm
.
insert
(
quant_dequant_op_outscale
);
// perform quantize dequantize operations
// If quantized op is not channel wise, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if
(
dequant_type
==
"fake_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d."
,
quantized_op_type
,
weight_scale
.
size
()));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
0
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
0
];
}
}
else
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"fc"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
1
]),
platform
::
errors
::
InvalidArgument
(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"fc op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d"
||
quantized_op_type
==
"depthwise_conv2d"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
0
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
/
inner_size
],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
/
inner_size
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
/
inner_size
];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d_transpose"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
nodes2rm
.
insert
(
quant_dequant_op_out
);
nodes2rm
.
insert
(
quant_dequant_op_out
);
// link weight in quant_dequant_op_x to any_op2
// link weight in quant_dequant_op_x to any_op2
...
...
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
浏览文件 @
5cda6b2b
...
@@ -28,76 +28,85 @@ namespace ir {
...
@@ -28,76 +28,85 @@ namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(quant_dequant_op_out);
GET_IR_NODE(any_op2);
void
DeleteQuantDequantOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
DeleteQuantDequantOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_quantdequant_op_pattern"
;
const
std
::
string
pattern_name
=
"delete_quantdequant_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
std
::
string
quantdequant_types
=
"fake_quantize_dequantize_moving_average_abs_max"
;
auto
*
input_node
=
gpd
.
mutable_pattern
()
->
NewNode
(
"input_node"
)
->
assert_is_op_input
(
quantdequant_types
,
"X"
)
->
AsInput
();
patterns
::
DeleteQuantDequantOpPattern
pattern
(
gpd
.
mutable_pattern
(),
patterns
::
DeleteQuantDequantOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern_name
);
pattern
();
pattern
(
input_node
,
quantdequant_types
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
PADDLE_ENFORCE_EQ
(
subgraph
.
count
(
input_node
),
true
,
platform
::
errors
::
NotFound
(
"Input act node(%s) not found in QuantDequantFuse pass."
,
input_node
->
name
()));
Node
*
input
=
subgraph
.
at
(
input_node
);
GET_NODES
;
GET_NODES
;
IR_NODE_LINK_TO
(
any_op_out
,
any_op2
);
int
bit_length
=
std
::
string
any_op_out_name
=
any_op_out
->
Var
()
->
Name
(
);
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"bit_length"
)
);
std
::
string
quant_dequant_op_out_name
=
quant_dequant_op_out
->
Var
()
->
Name
(
);
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
// Get input scale from tensor
std
::
string
input_scale_var_name
=
std
::
string
input_scale_var_name
=
quant_dequant_op
->
Op
()
->
Input
(
"InScale"
).
front
();
quant_dequant_op
->
Op
()
->
Input
(
"InScale"
).
front
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteQuantDequantOpPass should not be null."
));
const
LoDTensor
&
input_scale_tensor
=
const
LoDTensor
&
input_scale_tensor
=
scope
->
GetVar
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
scope
->
FindVar
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
paddle
::
platform
::
is_cpu_place
(
input_scale_tensor
.
place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Input scale tensor's place should be CPU."
));
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
]
/
127.
;
float
input_scale
=
input_scale_data
[
0
]
/
range
;
auto
*
any_op2_desc
=
any_op2
->
Op
();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto
var_map
=
any_op2_desc
->
Inputs
();
std
::
string
arg_name
=
""
;
for
(
auto
&
name_m
:
var_map
)
{
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
quant_dequant_op_out_name
)
!=
name_m
.
second
.
end
())
{
arg_name
=
name_m
.
first
;
}
}
CHECK
(
arg_name
.
size
()
>
0
)
<<
"can not find the input "
<<
quant_dequant_op_out_name
;
any_op2_desc
->
SetAttr
(
"enable_int8"
,
true
);
any_op2_desc
->
SetAttr
(
arg_name
+
"_scale"
,
input_scale
);
// modify the any_op2's inputs
// Set input scale in attr, and relink nodes
for
(
auto
&
name_m
:
var_map
)
{
std
::
string
input_name
=
input
->
Var
()
->
Name
();
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
std
::
string
quant_dequant_output_name
=
quant_dequant_op_out
->
Var
()
->
Name
();
quant_dequant_op_out_name
)
!=
name_m
.
second
.
end
())
{
auto
outlinks
=
quant_dequant_op_out
->
outputs
;
std
::
vector
<
std
::
string
>
new_inputs
;
for
(
auto
*
quantized_node
:
outlinks
)
{
for
(
auto
&
i_n
:
name_m
.
second
)
{
auto
op_desc
=
quantized_node
->
Op
();
if
(
i_n
!=
quant_dequant_op_out_name
)
{
std
::
string
quantized_op_type
=
op_desc
->
Type
();
new_inputs
.
push_back
(
i_n
);
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
}
quantized_op_type
==
"matmul_v2"
)
{
}
op_desc
->
SetAttr
(
"X_scale"
,
input_scale
);
new_inputs
.
push_back
(
any_op_out_name
);
}
else
{
any_op2_desc
->
SetInput
(
name_m
.
first
,
new_inputs
);
op_desc
->
SetAttr
(
"Input_scale"
,
input_scale
);
any_op2_desc
->
Flush
();
}
}
op_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
op_desc
->
RenameInput
(
quant_dequant_output_name
,
input_name
);
op_desc
->
Flush
();
IR_NODE_LINK_TO
(
input
,
quantized_node
);
}
}
any_op2_desc
->
Flush
();
// Delete the unneeded nodes.
// Delete the unneeded nodes.
GraphSafeRemoveNodes
(
graph
,
GraphSafeRemoveNodes
(
graph
,
{
quant_dequant_op
,
quant_dequant_op_out
,
{
quant_dequant_op_inscale
,
quant_dequant_op
,
quant_dequant_op_inscale
,
quant_dequant_op_outscale
});
quant_dequant_op_outscale
,
quant_dequant_op_out
});
found_count
++
;
};
};
gpd
(
graph
,
handler
);
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
5cda6b2b
...
@@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
...
@@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out
->
LinksFrom
({
reshape2_op
});
reshape2_out
->
LinksFrom
({
reshape2_op
});
}
}
void
patterns
::
DeleteQuantDequantOpPattern
::
operator
()()
{
void
patterns
::
DeleteQuantDequantOpPattern
::
operator
()(
auto
any_op_out
=
PDNode
*
input_node
,
const
std
::
string
&
quantdequant_types
)
{
pattern
->
NewNode
(
any_op_out_repr
())
->
assert_is_op_input
(
"fake_quantize_dequantize_moving_average_abs_max"
,
"X"
)
->
AsInput
();
auto
quant_dequant_op_inscale
=
auto
quant_dequant_op_inscale
=
pattern
->
NewNode
(
quant_dequant_op_inscale_repr
())
pattern
->
NewNode
(
quant_dequant_op_inscale_repr
())
->
assert_is_op_input
(
->
assert_is_op_input
(
quantdequant_types
,
"InScale"
)
"fake_quantize_dequantize_moving_average_abs_max"
,
"InScale"
)
->
AsInput
();
->
AsInput
();
auto
quant_dequant_op
=
auto
quant_dequant_op
=
pattern
->
NewNode
(
quant_dequant_op_repr
())
pattern
->
NewNode
(
quant_dequant_op_repr
())
->
assert_is_op
(
quantdequant_types
);
->
assert_is_op
(
"fake_quantize_dequantize_moving_average_abs_max"
);
auto
quant_dequant_out
=
auto
quant_dequant_o
p_o
ut
=
pattern
->
NewNode
(
quant_dequant_op_out_repr
())
pattern
->
NewNode
(
quant_dequant_op_out_repr
())
->
assert_is_op_output
(
->
assert_is_op_output
(
quantdequant_types
,
"Out"
)
"fake_quantize_dequantize_moving_average_abs_max"
,
"Out"
)
->
AsOutput
();
->
AsIntermediate
();
auto
quant_dequant_op_outscale
=
auto
quant_dequant_op_outscale
=
pattern
->
NewNode
(
quant_dequant_op_outscale_repr
())
pattern
->
NewNode
(
quant_dequant_op_outscale_repr
())
->
assert_is_op_output
(
->
assert_is_op_output
(
quantdequant_types
,
"OutScale"
)
"fake_quantize_dequantize_moving_average_abs_max"
,
"OutScale"
)
->
AsOutput
();
->
AsOutput
();
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
quant_dequant_op
->
LinksFrom
({
any_op_out
,
quant_dequant_op_inscal
e
});
quant_dequant_op
->
LinksFrom
({
quant_dequant_op_inscale
,
input_nod
e
});
quant_dequant_op_outscale
->
LinksFrom
({
quant_dequant_op
});
quant_dequant_op_outscale
->
LinksFrom
({
quant_dequant_op
});
quant_dequant_out
->
LinksFrom
({
quant_dequant_op
});
quant_dequant_op_out
->
LinksFrom
({
quant_dequant_op
});
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
}
void
patterns
::
DeleteQuantDequantFilterOpPattern
::
operator
()()
{
void
patterns
::
DeleteQuantDequantFilterOpPattern
::
operator
()()
{
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
5cda6b2b
...
@@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
...
@@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
DeleteQuantDequantOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_op_pattern"
)
{}
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_op_pattern"
)
{}
void
operator
()();
void
operator
()(
PDNode
*
input_node
,
const
std
::
string
&
quantdequant_types
);
PATTERN_DECL_NODE
(
any_op_out
);
PATTERN_DECL_NODE
(
quant_dequant_op_inscale
);
PATTERN_DECL_NODE
(
quant_dequant_op_inscale
);
PATTERN_DECL_NODE
(
quant_dequant_op
);
PATTERN_DECL_NODE
(
quant_dequant_op
);
PATTERN_DECL_NODE
(
quant_dequant_op_outscale
);
PATTERN_DECL_NODE
(
quant_dequant_op_outscale
);
PATTERN_DECL_NODE
(
quant_dequant_op_out
);
PATTERN_DECL_NODE
(
quant_dequant_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
};
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录