Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
38faed7f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
38faed7f
编写于
1月 14, 2021
作者:
A
alncat
提交者:
GitHub
1月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added support for inference using quantization aware trained dygraph (#30288) (#30402)
上级
5d30d072
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
534 addition
and
10 deletion
+534
-10
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
+8
-0
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
+237
-0
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
.../fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
+37
-0
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
+2
-2
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+12
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+58
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+30
-0
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
+101
-3
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
+8
-0
paddle/fluid/framework/scope.cc
paddle/fluid/framework/scope.cc
+7
-0
paddle/fluid/framework/scope.h
paddle/fluid/framework/scope.h
+4
-0
paddle/fluid/inference/api/analysis_config.cc
paddle/fluid/inference/api/analysis_config.cc
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+5
-1
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+12
-1
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+9
-2
paddle/fluid/inference/tensorrt/op_teller.cc
paddle/fluid/inference/tensorrt/op_teller.cc
+2
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
38faed7f
...
...
@@ -85,6 +85,7 @@ pass_library(runtime_context_cache_pass base)
pass_library
(
quant_conv2d_dequant_fuse_pass inference
)
pass_library
(
shuffle_channel_detect_pass inference
)
pass_library
(
delete_quant_dequant_op_pass inference
)
pass_library
(
delete_quant_dequant_filter_op_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
pass_library
(
skip_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.cc
浏览文件 @
38faed7f
...
...
@@ -62,6 +62,14 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc
.
SetOutput
(
"Output"
,
{
output_name
});
new_op_desc
.
SetAttr
(
"is_test"
,
true
);
new_op_desc
.
SetAttr
(
"use_cudnn"
,
false
);
auto
*
elementwise_add_op_desc
=
elementwise_add_op
->
Op
();
auto
out_threshold_attr
=
elementwise_add_op_desc
->
GetNullableAttr
(
"out_threshold"
);
// set the out_threshold of the elementwise add op to be the out_threshold
// of the conv2d_fusion
if
(
out_threshold_attr
.
which
())
{
new_op_desc
.
SetAttr
(
"out_threshold"
,
out_threshold_attr
);
}
new_op_desc
.
Flush
();
// Create a new node for the fused op.
...
...
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
0 → 100644
浏览文件 @
38faed7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quant_dequant_op_x); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(any_op2);
// Delete quant_dequant_op, then quantize and dequantize weight
void
DeleteQuantDequantFilterOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_quantdequant_filter_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
// Create pattern
patterns
::
DeleteQuantDequantFilterOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
auto
*
scope
=
param_scope
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
std
::
vector
<
float
>
weight_scale
;
std
::
string
quant_dequant_op_out_name
=
quant_dequant_op_out
->
Var
()
->
Name
();
auto
*
any_op2_desc
=
any_op2
->
Op
();
auto
var_map
=
any_op2_desc
->
Inputs
();
std
::
string
arg_name
=
""
;
for
(
auto
&
name_m
:
var_map
)
{
if
(
std
::
find
(
name_m
.
second
.
begin
(),
name_m
.
second
.
end
(),
quant_dequant_op_out_name
)
!=
name_m
.
second
.
end
())
{
arg_name
=
name_m
.
first
;
break
;
}
}
PADDLE_ENFORCE_GT
(
arg_name
.
size
(),
0
,
platform
::
errors
::
InvalidArgument
(
"can not find the input %s."
,
quant_dequant_op_out_name
));
any_op2_desc
->
SetAttr
(
"enable_int8"
,
true
);
any_op2_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
// modify the any_op2's inputs
any_op2_desc
->
Flush
();
auto
dequant_type
=
quant_dequant_op
->
Op
()
->
Type
();
auto
quantized_op_type
=
any_op2_desc
->
Type
();
// Get weight scale
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
auto
scales_name
=
quant_dequant_op
->
Op
()
->
Output
(
"OutScale"
);
PADDLE_ENFORCE_EQ
(
scales_name
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d."
,
scales_name
.
size
()));
const
LoDTensor
&
channel_scale_tensor
=
scope
->
GetVar
(
scales_name
[
0
])
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
channel_scale_tensor
.
place
()),
platform
::
errors
::
InvalidArgument
(
"Channel scale tensor's place should be CPU."
));
const
float
*
channel_scale_data
=
channel_scale_tensor
.
data
<
float
>
();
for
(
int
i
=
0
;
i
<
channel_scale_tensor
.
numel
();
i
++
)
{
weight_scale
.
push_back
(
range
/
channel_scale_data
[
i
]);
}
}
else
{
auto
scale_name
=
quant_dequant_op_outscale
->
Name
();
const
LoDTensor
&
scale_tensor
=
scope
->
GetVar
(
scale_name
)
->
Get
<
LoDTensor
>
();
const
float
*
scale_data
=
scale_tensor
.
data
<
float
>
();
weight_scale
.
push_back
((
range
*
range
)
/
scale_data
[
0
]
/
range
);
}
nodes2rm
.
insert
(
quant_dequant_op_outscale
);
// perform quantize dequantize operations
auto
*
weight_tensor
=
scope
->
GetVar
(
quant_dequant_op_x
->
Name
())
->
GetMutable
<
LoDTensor
>
();
auto
w_dims
=
weight_tensor
->
dims
();
float
*
quantized_weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if
(
dequant_type
==
"fake_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d."
,
quantized_op_type
,
weight_scale
.
size
()));
PADDLE_ENFORCE_NE
(
weight_scale
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"Weight scale should be nonzero, but get zero"
));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
0
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
0
];
}
}
else
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"fc"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
1
]),
platform
::
errors
::
InvalidArgument
(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"fc op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d"
||
quantized_op_type
==
"depthwise_conv2d"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
0
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[
j
/
inner_size
],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d op weight scale should be nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[
j
/
inner_size
];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[
j
/
inner_size
];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
if
(
quantized_op_type
==
"conv2d_transpose"
)
{
if
(
dequant_type
==
"fake_channel_wise_quantize_dequantize_abs_max"
)
{
PADDLE_ENFORCE_EQ
(
weight_scale
.
size
(),
static_cast
<
size_t
>
(
w_dims
[
0
]),
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu."
,
static_cast
<
size_t
>
(
w_dims
[
1
]),
weight_scale
.
size
()));
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
for
(
int
j
=
0
;
j
<
weight_tensor
->
numel
();
j
++
)
{
// quantized
PADDLE_ENFORCE_NE
(
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]],
0
,
platform
::
errors
::
InvalidArgument
(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"
));
quantized_weight_data
[
j
]
=
quantized_weight_data
[
j
]
*
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
quantized_weight_data
[
j
]
=
std
::
round
(
quantized_weight_data
[
j
]);
// dequantized
quantized_weight_data
[
j
]
/=
weight_scale
[(
j
/
inner_size
)
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported quantized op type: %s"
,
quantized_op_type
));
}
nodes2rm
.
insert
(
quant_dequant_op_out
);
// link weight in quant_dequant_op_x to any_op2
any_op2_desc
->
RenameInput
(
quant_dequant_op_out
->
Var
()
->
Name
(),
quant_dequant_op_x
->
Var
()
->
Name
());
any_op2_desc
->
SetAttr
(
"weight_scale"
,
weight_scale
);
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
quant_dequant_op_x
,
any_op2
);
nodes2rm
.
insert
(
quant_dequant_op
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_quant_dequant_filter_op_pass
,
paddle
::
framework
::
ir
::
DeleteQuantDequantFilterOpPass
);
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h
0 → 100644
浏览文件 @
38faed7f
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
class
DeleteQuantDequantFilterOpPass
:
public
FusePassBase
{
public:
virtual
~
DeleteQuantDequantFilterOpPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
浏览文件 @
38faed7f
...
...
@@ -49,10 +49,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
std
::
string
input_scale_var_name
=
quant_dequant_op
->
Op
()
->
Input
(
"InScale"
).
front
();
const
LoDTensor
&
input_scale_tensor
=
scope
->
Find
Var
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
scope
->
Get
Var
(
input_scale_var_name
)
->
Get
<
LoDTensor
>
();
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
];
float
input_scale
=
input_scale_data
[
0
]
/
127.
;
auto
*
any_op2_desc
=
any_op2
->
Op
();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto
var_map
=
any_op2_desc
->
Inputs
();
...
...
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
38faed7f
...
...
@@ -149,6 +149,18 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
desc
.
SetAttr
(
"out_scale"
,
elementwise_desc
->
GetAttr
(
"out_scale"
));
}
auto
*
elementwise_add_op_desc
=
elementwise_add
->
Op
();
// if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc
auto
out_threshold_attr
=
elementwise_add_op_desc
->
GetNullableAttr
(
"out_threshold"
);
if
(
out_threshold_attr
.
which
())
{
VLOG
(
4
)
<<
"setting out_threshold: "
<<
BOOST_GET_CONST
(
float
,
out_threshold_attr
);
desc
.
SetAttr
(
"out_threshold"
,
out_threshold_attr
);
}
desc
.
Flush
();
auto
fc_node
=
g
->
CreateOpNode
(
&
desc
);
// OpDesc will be copied.
if
(
with_relu
)
{
GraphSafeRemoveNodes
(
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
38faed7f
...
...
@@ -1634,6 +1634,27 @@ PDNode *patterns::MatmulWithInputOps::operator()() {
return
matmul_out
;
}
PDNode
*
patterns
::
Flatten2Matmul
::
operator
()()
{
auto
flatten2_in_x
=
pattern
->
NewNode
(
flatten2_in_x_repr
())
->
assert_is_op_input
(
"flatten2"
,
"X"
)
->
AsInput
();
auto
flatten2_op
=
pattern
->
NewNode
(
flatten2_op_repr
())
->
assert_is_op
(
"flatten2"
);
auto
matmul_in_x
=
pattern
->
NewNode
(
matmul_in_x_repr
())
->
assert_is_op_output
(
"flatten2"
,
"Out"
)
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
matmul_in_y
=
pattern
->
NewNode
(
matmul_in_y_repr
())
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
matmul_op
=
pattern
->
NewNode
(
matmul_op_repr
())
->
assert_is_op
(
"matmul"
);
auto
matmul_out
=
pattern
->
NewNode
(
matmul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"matmul"
,
"Out"
);
flatten2_op
->
LinksFrom
({
flatten2_in_x
}).
LinksTo
({
matmul_in_x
});
matmul_op
->
LinksFrom
({
matmul_in_x
,
matmul_in_y
}).
LinksTo
({
matmul_out
});
return
matmul_out
;
}
PDNode
*
patterns
::
ConvResidual
::
operator
()(
bool
with_residual_data
)
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
...
...
@@ -2495,6 +2516,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
void
patterns
::
DeleteQuantDequantFilterOpPattern
::
operator
()()
{
auto
quant_dequant_op_x
=
pattern
->
NewNode
(
quant_dequant_op_x_repr
())
->
assert_is_ops_input
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"X"
)
->
AsInput
();
auto
quant_dequant_op
=
pattern
->
NewNode
(
quant_dequant_op_repr
())
->
assert_is_ops
({
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
});
auto
quant_dequant_out
=
pattern
->
NewNode
(
quant_dequant_op_out_repr
())
->
assert_is_ops_output
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"Out"
)
->
AsIntermediate
();
auto
quant_dequant_op_outscale
=
pattern
->
NewNode
(
quant_dequant_op_outscale_repr
())
->
assert_is_ops_output
(
{
"fake_channel_wise_quantize_dequantize_abs_max"
,
"fake_quantize_dequantize_abs_max"
},
"OutScale"
)
->
AsOutput
();
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
quant_dequant_op
->
LinksFrom
({
quant_dequant_op_x
});
quant_dequant_op_outscale
->
LinksFrom
({
quant_dequant_op
});
quant_dequant_out
->
LinksFrom
({
quant_dequant_op
});
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
bool
with_reshape_xshape
,
bool
with_transpose_xshape
)
{
auto
reshape_op
=
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
38faed7f
...
...
@@ -996,6 +996,21 @@ struct MatmulWithInputOps : public PatternBase {
PATTERN_DECL_NODE
(
matmul_out
);
};
// Flatten2 + Matmul
// Forward pass.
struct
Flatten2Matmul
:
public
PatternBase
{
Flatten2Matmul
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"flatten2_matmul"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
flatten2_in_x
);
PATTERN_DECL_NODE
(
flatten2_op
);
PATTERN_DECL_NODE
(
matmul_in_x
);
PATTERN_DECL_NODE
(
matmul_in_y
);
PATTERN_DECL_NODE
(
matmul_op
);
PATTERN_DECL_NODE
(
matmul_out
);
};
// Concat op
// Forward pass for concat.
// concat_out is a result of the operator.
...
...
@@ -1426,6 +1441,21 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE
(
any_op2
);
};
struct
DeleteQuantDequantFilterOpPattern
:
public
PatternBase
{
DeleteQuantDequantFilterOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quantdequant_filter_op_pattern"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
quant_dequant_op_x
);
PATTERN_DECL_NODE
(
quant_dequant_op
);
PATTERN_DECL_NODE
(
quant_dequant_op_outscale
);
PATTERN_DECL_NODE
(
quant_dequant_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
// Reshape + Transpose + Matmul
// named nodes:
// reshape_op, reshape_out, reshape_xshape,
...
...
paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc
浏览文件 @
38faed7f
...
...
@@ -71,7 +71,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -137,7 +141,11 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
squeeze2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -205,7 +213,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
1
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
reshape2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -219,6 +231,83 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis
(
found_count
);
}
void
Flatten2MatmulFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
InvalidArgument
(
"Graph cannot be nullptr."
));
std
::
string
name_scope
=
"flatten2_matmul_fuse_pass"
;
FusePassBase
::
Init
(
name_scope
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
Flatten2Matmul
fuse_pattern
(
gpd
.
mutable_pattern
(),
name_scope
);
fuse_pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
VLOG
(
4
)
<<
"fuse flatten2+matmul to mul"
;
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_in_x
,
flatten2_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
flatten2_op
,
flatten2_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_x
,
matmul_in_x
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_in_y
,
matmul_in_y
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_op
,
matmul_op
,
fuse_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
matmul_out
,
matmul_out
,
fuse_pattern
);
bool
pattern_found
=
true
;
size_t
flatten2_in_nums
=
flatten2_op
->
inputs
.
size
();
auto
flatten2_in_x_shape
=
flatten2_in_x
->
Var
()
->
GetShape
();
size_t
flatten2_in_x_rank
=
flatten2_in_x_shape
.
size
();
int
flatten2_axis
=
BOOST_GET_CONST
(
int
,
flatten2_op
->
Op
()
->
GetAttr
(
"axis"
));
// only convert matmul to mul when the flatten2 has a single input
// and the rank of input is 4 and the size of the output of matmul
// is 1.
pattern_found
=
pattern_found
&&
flatten2_in_nums
==
1
&&
flatten2_in_x_rank
==
4
&&
(
matmul_in_x
->
outputs
).
size
()
==
1
;
bool
transpose_X
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_X"
));
bool
transpose_Y
=
BOOST_GET_CONST
(
bool
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
float
alpha
=
BOOST_GET_CONST
(
float
,
matmul_op
->
Op
()
->
GetAttr
(
"alpha"
));
size_t
matmul_in_x_rank
=
(
matmul_in_x
->
Var
()
->
GetShape
()).
size
();
size_t
matmul_in_y_rank
=
(
matmul_in_y
->
Var
()
->
GetShape
()).
size
();
pattern_found
=
pattern_found
&&
!
transpose_X
&&
!
transpose_Y
&&
std
::
abs
(
alpha
-
1.0
)
<
1e-5
&&
matmul_in_x_rank
==
2
&&
matmul_in_y_rank
==
2
;
std
::
vector
<
Node
*>&
next_ops
=
matmul_out
->
outputs
;
// we further require the matmul op is followed by one elementwise
// add op.
pattern_found
=
pattern_found
&&
next_ops
.
size
()
==
1
&&
next_ops
[
0
]
->
Name
()
==
"elementwise_add"
;
if
(
pattern_found
)
{
OpDesc
desc
;
desc
.
SetType
(
"mul"
);
desc
.
SetInput
(
"X"
,
{
flatten2_in_x
->
Name
()});
desc
.
SetInput
(
"Y"
,
{
matmul_in_y
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
matmul_out
->
Name
()});
desc
.
SetAttr
(
"x_num_col_dims"
,
flatten2_axis
);
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
}
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
flatten2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
IR_NODE_LINK_TO
(
mul_node
,
matmul_out
);
GraphSafeRemoveNodes
(
graph
,
{
flatten2_op
,
matmul_in_x
,
matmul_op
});
++
found_count
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
...
...
@@ -247,3 +336,12 @@ REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.
LE
(
"matmul"
,
1
)
.
EQ
(
"reshape2"
,
0
)
.
EQ
(
"mul"
,
0
));
REGISTER_PASS
(
flatten2_matmul_fuse_pass
,
paddle
::
framework
::
ir
::
Flatten2MatmulFusePass
);
REGISTER_PASS_CAPABILITY
(
flatten2_matmul_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
LE
(
"matmul"
,
1
)
.
EQ
(
"flatten2"
,
0
)
.
EQ
(
"mul"
,
0
));
paddle/fluid/framework/ir/map_matmul_to_mul_pass.h
浏览文件 @
38faed7f
...
...
@@ -101,6 +101,14 @@ class Reshape2MatmulFusePass : public FusePassBase {
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
class
Flatten2MatmulFusePass
:
public
FusePassBase
{
public:
virtual
~
Flatten2MatmulFusePass
()
{}
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/scope.cc
浏览文件 @
38faed7f
...
...
@@ -83,6 +83,13 @@ Variable* Scope::FindVar(const std::string& name) const {
return
FindVarInternal
(
name
);
}
Variable
*
Scope
::
GetVar
(
const
std
::
string
&
name
)
const
{
auto
*
var
=
FindVar
(
name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
NotFound
(
"Cannot find %s in scope."
,
name
));
return
var
;
}
Variable
*
Scope
::
FindLocalVar
(
const
std
::
string
&
name
)
const
{
SCOPE_VARS_READER_LOCK
return
FindVarLocally
(
name
);
...
...
paddle/fluid/framework/scope.h
浏览文件 @
38faed7f
...
...
@@ -81,6 +81,10 @@ class Scope {
/// Caller doesn't own the returned Variable.
Variable
*
FindVar
(
const
std
::
string
&
name
)
const
;
// Get a variable in the scope or any of its ancestors. Enforce
/// the returned Variable is not nullptr
Variable
*
GetVar
(
const
std
::
string
&
name
)
const
;
/// Find a variable in the current scope.
/// Return nullptr if cannot find.
/// Caller doesn't own the returned Variable.
...
...
paddle/fluid/inference/api/analysis_config.cc
浏览文件 @
38faed7f
...
...
@@ -345,7 +345,7 @@ void AnalysisConfig::Update() {
pass_builder
()
->
ClearPasses
();
for
(
const
auto
&
pass
:
kTRTSubgraphPasses
)
{
if
(
tensorrt_precision_mode_
==
AnalysisConfig
::
Precision
::
kInt8
&&
(
pass
==
"conv_bn_fuse_pass"
||
pass
==
"fc_fuse_pass"
))
{
(
pass
==
"conv_bn_fuse_pass"
))
{
continue
;
}
pass_builder
()
->
AppendPass
(
pass
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
38faed7f
...
...
@@ -77,6 +77,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
...
...
@@ -86,15 +87,16 @@ const std::vector<std::string> kTRTSubgraphPasses({
"conv_bn_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
"conv_elementwise_add_act_fuse_pass"
,
//
"conv_elementwise_add2_act_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
#endif //
"transpose_flatten_concat_fuse_pass"
,
});
...
...
@@ -118,6 +120,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"multihead_matmul_fuse_pass_v2"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"fc_elementwise_layernorm_fuse_pass"
,
//
...
...
@@ -172,6 +175,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"seq_concat_fc_fuse_pass"
,
//
"squeeze2_matmul_fuse_pass"
,
//
"reshape2_matmul_fuse_pass"
,
//
"flatten2_matmul_fuse_pass"
,
//
"map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"repeated_fc_relu_fuse_pass"
,
//
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
38faed7f
...
...
@@ -105,8 +105,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
TensorRTEngine
::
Weight
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
size_t
>
(
Y_t
->
numel
())};
float
*
bias_data
=
nullptr
;
size_t
bias_size
=
0
;
if
(
op_desc
.
Type
()
==
"conv2d_fusion"
)
{
auto
*
bias_tensor
=
scope
.
GetVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
bias_tensor_data
=
bias_tensor
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
bias_tensor_data
,
false
);
bias_size
=
static_cast
<
size_t
>
(
bias_tensor_data
->
numel
());
}
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
nullptr
,
0
};
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
bias_data
),
bias_size
};
auto
*
layer
=
fadd_layer
(
const_cast
<
nvinfer1
::
ITensor
*>
(
X
),
n_output
,
n_input
,
nv_ksize
,
weight
,
bias
);
PADDLE_ENFORCE_NOT_NULL
(
layer
,
...
...
@@ -184,4 +194,5 @@ class Deconv2dOpConverter : public OpConverter {
}
// namespace paddle
REGISTER_TRT_OP_CONVERTER
(
conv2d
,
Conv2dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
conv2d_fusion
,
Conv2dOpConverter
);
REGISTER_TRT_OP_CONVERTER
(
conv2d_transpose
,
Deconv2dOpConverter
);
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
38faed7f
...
...
@@ -67,10 +67,11 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided.
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
float
in_scale
=
0.
;
if
(
enable_int8
)
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
i_name
+
"_scale"
));
float
in_scale
=
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
i_name
+
"_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
...
...
@@ -131,7 +132,7 @@ class FcOpConverter : public OpConverter {
float
*
bias_data
=
nullptr
;
int
bias_num
=
0
;
if
(
with_bias
)
{
auto
*
b_v
=
scope
.
Find
Var
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_v
=
scope
.
Get
Var
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_t
=
b_v
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
b_t
,
false
);
...
...
@@ -183,6 +184,9 @@ class FcOpConverter : public OpConverter {
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
X
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
reshape_itensor
=
reshape_layer
->
getOutput
(
0
);
if
(
enable_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
}
else
{
PADDLE_ENFORCE_NE
(
input_dims
,
1
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -200,6 +204,9 @@ class FcOpConverter : public OpConverter {
auto
*
reshape_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
X
);
reshape_layer
->
setReshapeDimensions
(
reshape_dim
);
reshape_itensor
=
reshape_layer
->
getOutput
(
0
);
if
(
enable_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
}
regist_fc
(
reshape_itensor
,
n_output
,
weight
,
bias
);
}
...
...
paddle/fluid/inference/tensorrt/op_teller.cc
浏览文件 @
38faed7f
...
...
@@ -58,6 +58,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// use this set for no calib int8.
std
::
unordered_set
<
std
::
string
>
int8_teller_set
{
"mul"
,
"conv2d"
,
"conv2d_fusion"
,
"pool2d"
,
"relu"
,
"depthwise_conv2d"
,
...
...
@@ -76,6 +77,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"mul"
,
"matmul"
,
"conv2d"
,
"conv2d_fusion"
,
"pool2d"
,
"relu"
,
"softmax"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录