Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
38faed7f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
38faed7f
编写于
1月 14, 2021
作者:
A
alncat
提交者:
GitHub
1月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added support for inference using quantization aware trained dygraph (#30288) (#30402)
上级
5d30d072
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
534 addition
and
10 deletion
+534
-10
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
+8
-0
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
+237
-0
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
.../fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
+37
-0
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
+2
-2
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+12
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+58
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+30
-0
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
+101
-3
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
+8
-0
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+7
-0
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+4
-0
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+5
-1
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+12
-1
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+9
-2
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
38faed7f
...
...
@@ -85,6 +85,7 @@ pass_library(runtime_context_cache_pass base)
pass_library
(
quant_conv2d_dequant_fuse_pass inference
)
pass_library
(
shuffle_channel_detect_pass inference
)
pass_library
(
delete_quant_dequant_op_pass inference
)
pass_library
(
delete_quant_dequant_filter_op_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
skip_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
浏览文件 @
38faed7f
...
...
@@ -62,6 +62,14 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc
.
SetOutput
(
"Output"
,
{
output_name
});
new_op_desc
.
SetAttr
(
"is_test"
,
true
);
new_op_desc
.
SetAttr
(
"use_cudnn"
,
false
);
auto
*
elementwise_add_op_desc
=
elementwise_add_op
->
Op
();
auto
out_threshold_attr
=
elementwise_add_op_desc
->
GetNullableAttr
(
"out_threshold"
);
// set the out_threshold of the elementwise add op to be the out_threshold
// of the conv2d_fusion
if
(
out_threshold_attr
.
which
())
{
new_op_desc
.
SetAttr
(
"out_threshold"
,
out_threshold_attr
);
}
new_op_desc
.
Flush
();
// Create a new node for the fused op.
...
...
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
0 → 100644
浏览文件 @
38faed7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quant_dequant_op_x); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(any_op2);
// Delete quant_dequant_op, then quantize and dequantize weight
void
DeleteQuantDequantFilterOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_quantdequant_filter_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
// Create pattern
patterns
::
DeleteQuantDequantFilterOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
auto
*
scope
=
param_scope
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
std
::
vector
<
float
>
weight_scale
;
std
::
string
quant_dequant_op_out_name
=
quant_dequant_op_out
->
Var
()
->
Name
();
auto
*
any_op2_desc
=
any_op2
->
Op
();
auto
var_map
=
any_op2_desc
->
Inputs
();
std
::
string
arg_name
=
""
;
for
(
auto
&
name_m
:
var_map
)
{
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
quant_dequant_op_out_name
)
!=
name_m
.
second
.
end
())
{
arg_name
=
name_m
.
first
;
break
;
}
}
PADDLE_ENFORCE_GT
(
arg_name
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"can not find the input %s."
,
quant_dequant_op_out_name
));
any_op2_desc
->
SetAttr
(
"enable_int8"
,
true
);
any_op2_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
// modify the any_op2's inputs
any_op2_desc
->
Flush
();
auto
dequant_type
=
quant_dequant_op
->
Op
()
->
Type
();
auto
quantized_op_type
=
any_op2_desc
->
Type
();
// Get weight scale
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
auto
scales_name
=
quant_dequant_op
->
Op
()
->
Output
(
"OutScale"
);
PADDLE_ENFORCE_EQ
(
scales_name
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d."
,
scales_name
.
size
()));
const
LoDTensor
&
channel_scale_tensor
=
scope
->
GetVar
(
scales_name
[
0
])
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
channel_scale_tensor
.
place
()),
platform
::
errors
::
InvalidArgument
(
"Channel scale tensor's place should be CPU."
));
const
float
*
channel_scale_data
=
channel_scale_tensor
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
channel_scale_tensor
.
numel
();
i
++
)
{
weight_scale
.
push_back
(
range
/
channel_scale_data
[
i
]);
}
}
else
{
auto
scale_name
=
quant_dequant_op_outscale
->
Name
();
const
LoDTensor
&
scale_tensor
=
scope
->
GetVar
(
scale_name
)
->
Get
<
LoDTensor
>
();
const
float
*
scale_data
=
scale_tensor
.
data
<
float
>
();
weight_scale
.
push_back
((
range
*
range
)
/
scale_data
[
0
]
/
range
);
}
nodes2rm
.
insert
(
quant_dequant_op_outscale
);
// perform quantize dequantize operations
auto
*
weight_tensor
=
scope
->
GetVar
(
quant_dequant_op_x
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
w_dims
=
weight_tensor
->
dims
();
float
*
quantized_weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if
(
dequant_type
==
"fake_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d."
,
quantized_op_type
,
weight_scale
.
size
()));
PADDLE_ENFORCE_NE
(
weight_scale
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"Weight scale should be nonzero, but get zero"
));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
0
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
0
];
}
}
else
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"fc"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
1
]),
platform
::
errors
::
InvalidArgument
(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"fc op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d"
||
quantized_op_type
==
"depthwise_conv2d"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
0
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
/
inner_size
],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
/
inner_size
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
/
inner_size
];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d_transpose"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
nodes2rm
.
insert
(
quant_dequant_op_out
);
// link weight in quant_dequant_op_x to any_op2
any_op2_desc
->
RenameInput
(
quant_dequant_op_out
->
Var
()
->
Name
(),
quant_dequant_op_x
->
Var
()
->
Name
());
any_op2_desc
->
SetAttr
(
"weight_scale"
,
weight_scale
);
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
quant_dequant_op_x
,
any_op2
);
nodes2rm
.
insert
(
quant_dequant_op
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_quant_dequant_filter_op_pass
,
paddle
::
framework
::
ir
::
DeleteQuantDequantFilterOpPass
);
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
0 → 100644
浏览文件 @
38faed7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
class
DeleteQuantDequantFilterOpPass
:
public
FusePassBase
{
public:
virtual
~
DeleteQuantDequantFilterOpPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
浏览文件 @
38faed7f
...
...
@@ -49,10 +49,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
std
::
string
input_scale_var_name
=
quant_dequant_op
->
Op
()
->
Input
(
"InScale"
).
front
();
const
LoDTensor
&
input_scale_tensor
=
scope
->
Find
Var
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
scope
->
Get
Var
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
];
float
input_scale
=
input_scale_data
[
0
]
/
127.
;
auto
*
any_op2_desc
=
any_op2
->
Op
();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto
var_map
=
any_op2_desc
->
Inputs
();
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
38faed7f
...
...
@@ -149,6 +149,18 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
desc
.
SetAttr
(
"out_scale"
,
elementwise_desc
->
GetAttr
(
"out_scale"
));
}
auto
*
elementwise_add_op_desc
=
elementwise_add
->
Op
();
// if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc
auto
out_threshold_attr
=
elementwise_add_op_desc
->
GetNullableAttr
(
"out_threshold"
);
if
(
out_threshold_attr
.
which
())
{
VLOG
(
4
)
<<
"setting out_threshold: "
<<
BOOST_GET_CONST
(
float
,
out_threshold_attr
);
desc
.
SetAttr
(
"out_threshold"
,
out_threshold_attr
);
}
desc
.
Flush
();
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
if
(
with_relu
)
{
GraphSafeRemoveNodes
(
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
38faed7f
...
...
@@ -1634,6 +1634,27 @@ PDNode *patterns::MatmulWithInputOps::operator()() {
return
matmul_out
;
}
PDNode
*
patterns
::
Flatten2Matmul
::
operator
()()
{
auto
flatten2_in_x
=
pattern
->
NewNode
(
flatten2_in_x_repr
())
->
assert_is_op_input
(
"flatten2"
,
"X"
)
->
AsInput
();
auto
flatten2_op
=
pattern
->
NewNode
(
flatten2_op_repr
())
->
assert_is_op
(
"flatten2"
);
auto
matmul_in_x
=
pattern
->
NewNode
(
matmul_in_x_repr
())
->
assert_is_op_output
(
"flatten2"
,
"Out"
)
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
matmul_in_y
=
pattern
->
NewNode
(
matmul_in_y_repr
())
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
flatten2_op
->
LinksFrom
({
flatten2_in_x
}).
LinksTo
({
matmul_in_x
});
matmul_op
->
LinksFrom
({
matmul_in_x
,
matmul_in_y
}).
LinksTo
({
matmul_out
});
return
matmul_out
;
}
PDNode
*
patterns
::
ConvResidual
::
operator
()(
bool
with_residual_data
)
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
...
...
@@ -2495,6 +2516,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
void
patterns
::
DeleteQuantDequantFilterOpPattern
::
operator
()()
{
auto
quant_dequant_op_x
=
pattern
->
NewNode
(
quant_dequant_op_x_repr
())
->
assert_is_ops_input
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"X"
)
->
AsInput
();
auto
quant_dequant_op
=
pattern
->
NewNode
(
quant_dequant_op_repr
())
->
assert_is_ops
({
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
});
auto
quant_dequant_out
=
pattern
->
NewNode
(
quant_dequant_op_out_repr
())
->
assert_is_ops_output
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"Out"
)
->
AsIntermediate
();
auto
quant_dequant_op_outscale
=
pattern
->
NewNode
(
quant_dequant_op_outscale_repr
())
->
assert_is_ops_output
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"OutScale"
)
->
AsOutput
();
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
quant_dequant_op
->
LinksFrom
({
quant_dequant_op_x
});
quant_dequant_op_outscale
->
LinksFrom
({
quant_dequant_op
});
quant_dequant_out
->
LinksFrom
({
quant_dequant_op
});
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
bool
with_reshape_xshape
,
bool
with_transpose_xshape
)
{
auto
reshape_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
38faed7f
...
...
@@ -996,6 +996,21 @@ struct MatmulWithInputOps : public PatternBase {
PATTERN_DECL_NODE
(
matmul_out
);
};
// Flatten2 + Matmul
// Forward pass.
struct
Flatten2Matmul
:
public
PatternBase
{
Flatten2Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"flatten2_matmul"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
flatten2_in_x
);
PATTERN_DECL_NODE
(
flatten2_op
);
PATTERN_DECL_NODE
(
matmul_in_x
);
PATTERN_DECL_NODE
(
matmul_in_y
);
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
};
// Concat op
// Forward pass for concat.
// concat_out is a result of the operator.
...
...
@@ -1426,6 +1441,21 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE
(
any_op2
);
};
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
DeleteQuantDequantFilterOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_filter_op_pattern"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
quant_dequant_op_x
);
PATTERN_DECL_NODE
(
quant_dequant_op
);
PATTERN_DECL_NODE
(
quant_dequant_op_outscale
);
PATTERN_DECL_NODE
(
quant_dequant_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
// Reshape + Transpose + Matmul
// named nodes:
// reshape_op, reshape_out, reshape_xshape,
...
...
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
浏览文件 @
38faed7f
...
...
@@ -71,7 +71,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -137,7 +141,11 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
squeeze2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -205,7 +213,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
reshape2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -219,6 +231,83 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
found_count
);
}
void
Flatten2MatmulFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"flatten2_matmul_fuse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Flatten2Matmul
fuse_pattern
(
gpd
.
mutable_pattern
(),
name_scope
);
fuse_pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"fuse flatten2+matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_in_x
,
flatten2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_op
,
flatten2_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
fuse_pattern
);
bool
pattern_found
=
true
;
size_t
flatten2_in_nums
=
flatten2_op
->
inputs
.
size
();
auto
flatten2_in_x_shape
=
flatten2_in_x
->
Var
()
->
GetShape
();
size_t
flatten2_in_x_rank
=
flatten2_in_x_shape
.
size
();
int
flatten2_axis
=
BOOST_GET_CONST
(
int
,
flatten2_op
->
Op
()
->
GetAttr
(
"axis"
));
// only convert matmul to mul when the flatten2 has a single input
// and the rank of input is 4 and the size of the output of matmul
// is 1.
pattern_found
=
pattern_found
&&
flatten2_in_nums
==
1
&&
flatten2_in_x_rank
==
4
&&
(
matmul_in_x
->
outputs
).
size
()
==
1
;
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
float
alpha
=
BOOST_GET_CONST
(
float
,
matmul_op
->
Op
()
->
GetAttr
(
"alpha"
));
size_t
matmul_in_x_rank
=
(
matmul_in_x
->
Var
()
->
GetShape
()).
size
();
size_t
matmul_in_y_rank
=
(
matmul_in_y
->
Var
()
->
GetShape
()).
size
();
pattern_found
=
pattern_found
&&
!
transpose_X
&&
!
transpose_Y
&&
std
::
abs
(
alpha
-
1.0
)
<
1e-5
&&
matmul_in_x_rank
==
2
&&
matmul_in_y_rank
==
2
;
std
::
vector
<
Node
*>&
next_ops
=
matmul_out
->
outputs
;
// we further require the matmul op is followed by one elementwise
// add op.
pattern_found
=
pattern_found
&&
next_ops
.
size
()
==
1
&&
next_ops
[
0
]
->
Name
()
==
"elementwise_add"
;
if
(
pattern_found
)
{
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
flatten2_in_x
->
Name
()});
desc
.
SetInput
(
"Y"
,
{
matmul_in_y
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
flatten2_axis
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
flatten2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
flatten2_op
,
matmul_in_x
,
matmul_op
});
++
found_count
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
@@ -247,3 +336,12 @@ REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"mul"
,
0
));
REGISTER_PASS
(
flatten2_matmul_fuse_pass
,
paddle
::
framework
::
ir
::
Flatten2MatmulFusePass
);
REGISTER_PASS_CAPABILITY
(
flatten2_matmul_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"matmul"
,
1
)
.
EQ
(
"flatten2"
,
0
)
.
EQ
(
"mul"
,
0
));
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
浏览文件 @
38faed7f
...
...
@@ -101,6 +101,14 @@ class Reshape2MatmulFusePass : public FusePassBase {
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
class
Flatten2MatmulFusePass
:
public
FusePassBase
{
public:
virtual
~
Flatten2MatmulFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/scope.cc
浏览文件 @
38faed7f
...
...
@@ -83,6 +83,13 @@ Variable* Scope::FindVar(const std::string& name) const {
return
FindVarInternal
(
name
);
}
Variable
*
Scope
::
GetVar
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Cannot find %s in scope."
,
name
));
return
var
;
}
Variable
*
Scope
::
FindLocalVar
(
const
std
::
string
&
name
)
const
{
SCOPE_VARS_READER_LOCK
return
FindVarLocally
(
name
);
...
...
paddle/fluid/framework/scope.h
浏览文件 @
38faed7f
...
...
@@ -81,6 +81,10 @@ class Scope {
/// Caller doesn't own the returned Variable.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
// Get a variable in the scope or any of its ancestors. Enforce
/// the returned Variable is not nullptr
Variable
*
GetVar
(
const
std
::
string
&
name
)
const
;
/// Find a variable in the current scope.
/// Return nullptr if cannot find.
/// Caller doesn't own the returned Variable.
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
38faed7f
...
...
@@ -345,7 +345,7 @@ void AnalysisConfig::Update() {
pass_builder
()
->
ClearPasses
();
for
(
const
auto
&
pass
:
kTRTSubgraphPasses
)
{
if
(
tensorrt_precision_mode_
==
AnalysisConfig
::
Precision
::
kInt8
&&
(
pass
==
"conv_bn_fuse_pass"
||
pass
==
"fc_fuse_pass"
))
{
(
pass
==
"conv_bn_fuse_pass"
))
{
continue
;
}
pass_builder
()
->
AppendPass
(
pass
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
38faed7f
...
...
@@ -77,6 +77,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
...
...
@@ -86,15 +87,16 @@ const std::vector<std::string> kTRTSubgraphPasses({
"conv_bn_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
"conv_elementwise_add_act_fuse_pass"
,
//
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
#endif //
"transpose_flatten_concat_fuse_pass"
,
});
...
...
@@ -118,6 +120,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"multihead_matmul_fuse_pass_v2"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
...
...
@@ -172,6 +175,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"seq_concat_fc_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"repeated_fc_relu_fuse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
38faed7f
...
...
@@ -105,8 +105,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
Y_t
->
numel
())};
float
*
bias_data
=
nullptr
;
size_t
bias_size
=
0
;
if
(
op_desc
.
Type
()
==
"conv2d_fusion"
)
{
auto
*
bias_tensor
=
scope
.
GetVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
bias_tensor_data
=
bias_tensor
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
bias_tensor_data
,
false
);
bias_size
=
static_cast
<
size_t
>
(
bias_tensor_data
->
numel
());
}
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
bias_size
};
auto
*
layer
=
fadd_layer
(
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
n_output
,
n_input
,
nv_ksize
,
weight
,
bias
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
...
...
@@ -184,4 +194,5 @@ class Deconv2dOpConverter : public OpConverter {
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
conv2d
,
Conv2dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
conv2d_fusion
,
Conv2dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
conv2d_transpose
,
Deconv2dOpConverter
);
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
38faed7f
...
...
@@ -67,10 +67,11 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided.
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
float
in_scale
=
0.
;
if
(
enable_int8
)
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
i_name
+
"_scale"
));
float
in_scale
=
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
i_name
+
"_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
...
...
@@ -131,7 +132,7 @@ class FcOpConverter : public OpConverter {
float
*
bias_data
=
nullptr
;
int
bias_num
=
0
;
if
(
with_bias
)
{
auto
*
b_v
=
scope
.
Find
Var
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_v
=
scope
.
Get
Var
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_t
=
b_v
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
b_t
,
false
);
...
...
@@ -183,6 +184,9 @@ class FcOpConverter : public OpConverter {
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
X
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
reshape_itensor
=
reshape_layer
->
getOutput
(
0
);
if
(
enable_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
}
else
{
PADDLE_ENFORCE_NE
(
input_dims
,
1
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -200,6 +204,9 @@ class FcOpConverter : public OpConverter {
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
X
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
reshape_itensor
=
reshape_layer
->
getOutput
(
0
);
if
(
enable_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
}
regist_fc
(
reshape_itensor
,
n_output
,
weight
,
bias
);
}
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
38faed7f
...
...
@@ -58,6 +58,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// use this set for no calib int8.
std
::
unordered_set
<
std
::
string
>
int8_teller_set
{
"mul"
,
"conv2d"
,
"conv2d_fusion"
,
"pool2d"
,
"relu"
,
"depthwise_conv2d"
,
...
...
@@ -76,6 +77,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"mul"
,
"matmul"
,
"conv2d"
,
"conv2d_fusion"
,
"pool2d"
,
"relu"
,
"softmax"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录