Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
1b58ce14
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1b58ce14
编写于
4月 02, 2022
作者:
W
Wangzheee
提交者:
GitHub
4月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Paddle inference] support new quant_model (#41049)
* paddle inference support new quant_model
上级
4a09da02
变更
40
隐藏空白更改
内联
并排
Showing
40 changed file
with
1146 addition
and
285 deletion
+1146
-285
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/add_support_int8_pass.cc
paddle/fluid/framework/ir/add_support_int8_pass.cc
+52
-9
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
+148
-0
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h
.../fluid/framework/ir/delete_quant_dequant_linear_op_pass.h
+35
-0
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
+2
-8
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
...luid/framework/ir/delete_weight_dequant_linear_op_pass.cc
+415
-0
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
...fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
+35
-0
paddle/fluid/framework/ir/fc_fuse_pass.cc
paddle/fluid/framework/ir/fc_fuse_pass.cc
+20
-9
paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc
+6
-13
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+84
-17
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+34
-2
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
+19
-32
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
+4
-7
paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc
paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc
+88
-13
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+9
-7
paddle/fluid/inference/tensorrt/convert/activation_op.cc
paddle/fluid/inference/tensorrt/convert/activation_op.cc
+0
-6
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
+3
-10
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
+2
-9
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
...le/fluid/inference/tensorrt/convert/deformable_conv_op.cc
+1
-2
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
+1
-19
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
...fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/fc_op.cc
paddle/fluid/inference/tensorrt/convert/fc_op.cc
+33
-27
paddle/fluid/inference/tensorrt/convert/group_norm_op.cc
paddle/fluid/inference/tensorrt/convert/group_norm_op.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc
paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/matmul_op.cc
paddle/fluid/inference/tensorrt/convert/matmul_op.cc
+3
-1
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
...e/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
+20
-26
paddle/fluid/inference/tensorrt/convert/op_converter.h
paddle/fluid/inference/tensorrt/convert/op_converter.h
+57
-31
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
+3
-4
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
+3
-2
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
...inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
.../fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
+2
-2
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
+1
-1
paddle/fluid/inference/tensorrt/engine.cc
paddle/fluid/inference/tensorrt/engine.cc
+1
-3
paddle/fluid/inference/tensorrt/engine.h
paddle/fluid/inference/tensorrt/engine.h
+1
-2
paddle/fluid/operators/compat/dequantize_linear.pbtxt
paddle/fluid/operators/compat/dequantize_linear.pbtxt
+25
-0
paddle/fluid/operators/compat/mul.pbtxt
paddle/fluid/operators/compat/mul.pbtxt
+1
-9
paddle/fluid/operators/compat/quantize_linear.pbtxt
paddle/fluid/operators/compat/quantize_linear.pbtxt
+25
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py
...ittests/ir/inference/test_trt_convert_multihead_matmul.py
+3
-6
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
1b58ce14
...
...
@@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library
(
shuffle_channel_detect_pass inference
)
pass_library
(
delete_quant_dequant_op_pass inference
)
pass_library
(
delete_quant_dequant_filter_op_pass inference
)
pass_library
(
delete_weight_dequant_linear_op_pass inference
)
pass_library
(
delete_quant_dequant_linear_op_pass inference
)
pass_library
(
delete_dropout_op_pass inference
)
pass_library
(
simplify_with_basic_ops_pass base
)
pass_library
(
fc_elementwise_layernorm_fuse_pass base
)
...
...
paddle/fluid/framework/ir/add_support_int8_pass.cc
浏览文件 @
1b58ce14
// Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -19,11 +19,7 @@ namespace framework {
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(prev_op); \
GET_IR_NODE(prev_out); \
GET_IR_NODE(quant_op); \
GET_IR_NODE(quant_out);
#define GET_NODES GET_IR_NODE(quant_op);
void
AddSupportInt8Pass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"add_support_int8"
;
...
...
@@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
if
(
prev_op
->
Op
()
->
HasAttr
(
"out_threshold"
)
&&
quant_op
->
Op
()
->
HasAttr
(
"out_threshold"
))
{
quant_op
->
Op
()
->
SetAttr
(
"support_int8"
,
true
);
bool
inscale_flag
=
false
;
bool
outscale_flag
=
false
;
auto
*
quanted_op_desc
=
quant_op
->
Op
();
// If inputs'tensors have the inputs_scale, then save it's index in
// input_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one input
for
(
size_t
i
=
0
;
i
<
quanted_op_desc
->
InputNames
().
size
();
i
++
)
{
if
(
quanted_op_desc
->
Input
(
quanted_op_desc
->
InputNames
()[
i
]).
size
()
>
0
&&
quanted_op_desc
->
HasAttr
(
"Input_scale_"
+
quanted_op_desc
->
Input
(
quanted_op_desc
->
InputNames
()[
i
])[
0
]))
{
inscale_flag
=
true
;
quanted_op_desc
->
SetAttr
(
quanted_op_desc
->
InputNames
()[
i
],
quanted_op_desc
->
GetAttr
(
"Input_scale_"
+
quanted_op_desc
->
Input
(
quanted_op_desc
->
InputNames
()[
i
])[
0
]));
}
}
// If outputs'tensors have the outputs_scale, then save it's index in
// output_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one output
for
(
auto
out_node
:
quant_op
->
outputs
)
{
for
(
auto
out_op_node
:
out_node
->
outputs
)
{
for
(
auto
name
:
out_op_node
->
Op
()
->
InputNames
())
{
for
(
auto
input_name
:
out_op_node
->
Op
()
->
Input
(
name
))
{
if
(
out_op_node
->
Op
()
->
HasAttr
(
"Input_scale_"
+
input_name
))
{
for
(
size_t
i
=
0
;
i
<
quanted_op_desc
->
OutputNames
().
size
();
i
++
)
{
if
(
quanted_op_desc
->
Output
(
quanted_op_desc
->
OutputNames
()[
i
])
.
size
()
>
0
&&
input_name
==
quanted_op_desc
->
Output
(
quanted_op_desc
->
OutputNames
()[
i
])[
0
])
{
outscale_flag
=
true
;
quanted_op_desc
->
SetAttr
(
quanted_op_desc
->
OutputNames
()[
i
],
out_op_node
->
Op
()
->
GetAttr
(
"Input_scale_"
+
input_name
));
}
}
}
}
}
}
}
quanted_op_desc
->
SetAttr
(
"support_int8"
,
inscale_flag
&&
outscale_flag
);
quanted_op_desc
->
Flush
();
found_count
++
;
};
gpd
(
graph
,
handler
);
...
...
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
0 → 100644
浏览文件 @
1b58ce14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quantize_linear_op_x); \
GET_IR_NODE(quantize_linear_op_scale); \
GET_IR_NODE(quantize_linear_op); \
GET_IR_NODE(quantize_linear_op_out); \
GET_IR_NODE(dequantize_linear_op); \
GET_IR_NODE(dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteQuantDequantLinearOpPass
::
DeleteQuantDequantLinearOpPass
()
{
AddOpCompat
(
OpCompat
(
"quantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
void
DeleteQuantDequantLinearOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_quantdequant_linear_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteQuantDequantLinearOpPass should not be null."
));
// Create pattern
patterns
::
DeleteQuantDequantLinearOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_quant_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
BOOST_GET_CONST
(
int
,
quantize_linear_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
// Get input scale from tensor
const
LoDTensor
&
input_scale_tensor
=
scope
->
GetVar
(
quantize_linear_op_scale
->
Name
())
->
Get
<
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
paddle
::
platform
::
is_cpu_place
(
input_scale_tensor
.
place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Input scale tensor's place should be CPU."
));
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
]
/
range
;
auto
*
any_op2_desc
=
any_op2
->
Op
();
any_op2_desc
->
SetAttr
(
"Input_scale_"
+
quantize_linear_op_x
->
Var
()
->
Name
(),
input_scale
);
nodes2rm
.
insert
(
quantize_linear_op_scale
);
nodes2rm
.
insert
(
quantize_linear_op
);
nodes2rm
.
insert
(
quantize_linear_op_out
);
nodes2rm
.
insert
(
dequantize_linear_op
);
nodes2rm
.
insert
(
dequantize_linear_op_out
);
// link x to any_op2
any_op2_desc
->
RenameInput
(
dequantize_linear_op_out
->
Var
()
->
Name
(),
quantize_linear_op_x
->
Var
()
->
Name
());
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
quantize_linear_op_x
,
any_op2
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_quant_dequant_linear_op_pass
,
paddle
::
framework
::
ir
::
DeleteQuantDequantLinearOpPass
);
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h
0 → 100644
浏览文件 @
1b58ce14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteQuantDequantLinearOpPass
:
public
FusePassBase
{
public:
DeleteQuantDequantLinearOpPass
();
virtual
~
DeleteQuantDequantLinearOpPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
GET_NODES
;
int
bit_length
=
BOOST_GET_CONST
(
int
,
quant_dequant_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
// Get input scale from tensor
std
::
string
input_scale_var_name
=
...
...
@@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
platform
::
errors
::
InvalidArgument
(
"Input scale tensor's place should be CPU."
));
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
input_scale
=
input_scale_data
[
0
]
/
range
;
float
input_scale
=
input_scale_data
[
0
];
// Set input scale in attr, and relink nodes
std
::
string
input_name
=
input
->
Var
()
->
Name
();
...
...
@@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
for
(
auto
*
quantized_node
:
outlinks
)
{
auto
op_desc
=
quantized_node
->
Op
();
std
::
string
quantized_op_type
=
op_desc
->
Type
();
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"matmul_v2"
)
{
op_desc
->
SetAttr
(
"X_scale"
,
input_scale
);
}
else
{
op_desc
->
SetAttr
(
"Input_scale"
,
input_scale
);
}
op_desc
->
SetAttr
(
"Input_scale"
,
input_scale
);
op_desc
->
SetAttr
(
"bit_length"
,
bit_length
);
op_desc
->
RenameInput
(
quant_dequant_output_name
,
input_name
);
op_desc
->
Flush
();
...
...
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
0 → 100644
浏览文件 @
1b58ce14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightQuantDequantLinearOpPass
::
DeleteWeightQuantDequantLinearOpPass
()
{
AddOpCompat
(
OpCompat
(
"quantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"depthwise_conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"relu"
,
""
})
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"output_padding"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"output_size"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
// Delete dequantize_linear_op, then dequantize weight
void
DeleteWeightQuantDequantLinearOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_weight_quantdequant_linear_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteWeightQuantDequantLinearOpPass should not be null."
));
// Create pattern
patterns
::
DeleteWeightQuantDequantLinearOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
BOOST_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
any_op2_desc
=
any_op2
->
Op
();
// get weight tensor
auto
*
weight_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_x
->
Name
())
->
GetMutable
<
LoDTensor
>
();
int8_t
*
quantized_weight_data
=
weight_tensor
->
mutable_data
<
int8_t
>
(
platform
::
CPUPlace
());
auto
w_dims
=
weight_tensor
->
dims
();
// Get weight scale
std
::
vector
<
float
>
weight_scale
;
auto
*
weight_scale_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_scale
->
Name
())
->
GetMutable
<
LoDTensor
>
();
float
*
weight_scale_data
=
weight_scale_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]
/
range
);
}
// dequant weight
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_tensor
->
numel
());
int
quant_axis
=
BOOST_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
if
(
quant_axis
==
-
1
)
{
// per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
1
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."
));
// float(weight) * scale
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
0
];
}
}
else
if
(
quant_axis
==
0
)
{
// per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
/
inner_size
];
}
}
else
if
(
quant_axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
if
(
w_dims
.
size
()
==
4
)
{
// conv2d_transpose
std
::
string
quantized_op_type
=
any_op2
->
Op
()
->
Type
();
PADDLE_ENFORCE_EQ
(
quantized_op_type
,
"conv2d_transpose"
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[(
i
/
inner_size
)
%
w_dims
[
1
]];
}
}
else
if
(
w_dims
.
size
()
==
2
)
{
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"quant_axis should be -1 or 0 or 1, please check your model "
"OP'attribute "
));
}
weight_tensor
->
clear
();
// clear int weight
weight_tensor
->
Resize
(
phi
::
make_ddim
(
phi
::
vectorize
(
w_dims
)));
float
*
new_quantized_weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_quantized_weight_data
,
weight_data_tmp
.
data
(),
weight_tensor
->
numel
()
*
sizeof
(
float
));
nodes2rm
.
insert
(
weight_dequantize_linear_op_scale
);
nodes2rm
.
insert
(
weight_dequantize_linear_op
);
nodes2rm
.
insert
(
weight_dequantize_linear_op_out
);
// relink weight to any_op2
any_op2_desc
->
RenameInput
(
weight_dequantize_linear_op_out
->
Var
()
->
Name
(),
weight_dequantize_linear_op_x
->
Var
()
->
Name
());
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
weight_dequantize_linear_op_x
,
any_op2
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_weight_dequant_linear_op_pass
,
paddle
::
framework
::
ir
::
DeleteWeightQuantDequantLinearOpPass
);
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
0 → 100644
浏览文件 @
1b58ce14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteWeightQuantDequantLinearOpPass
:
public
FusePassBase
{
public:
DeleteWeightQuantDequantLinearOpPass
();
virtual
~
DeleteWeightQuantDequantLinearOpPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/fc_fuse_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
// will add "input_scale"
, "weight_scale"
which are extracted from
// will add "input_scale" which are extracted from
// fake_quant op and fake_dequant op to mul op, and then delete the
// fake_quant op and fake_dequant op in the graph. If the mul op has the
// scale info, we should add those to the fused fc.
auto
*
mul_op_desc
=
mul
->
Op
();
auto
*
elementwise_add_op_desc
=
elementwise_add
->
Op
();
if
(
mul_op_desc
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
mul_op_desc
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"Input_scale"
,
mul_op_desc
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
mul_op_desc
->
GetAttr
(
"weight_scale"
));
if
(
mul_op_desc
->
HasAttr
(
"out_scale"
))
desc
.
SetAttr
(
"out_scale"
,
mul_op_desc
->
GetAttr
(
"out_scale"
));
auto
elementwise_desc
=
elementwise_add
->
Op
();
if
(
elementwise_desc
->
HasAttr
(
"out_scale"
))
desc
.
SetAttr
(
"out_scale"
,
elementwise_desc
->
GetAttr
(
"out_scale"
));
}
auto
*
elementwise_add_op_desc
=
elementwise_add
->
Op
();
if
(
mul_op_desc
->
HasAttr
(
"Input_scale"
))
{
desc
.
SetAttr
(
"Input_scale"
,
mul_op_desc
->
GetAttr
(
"Input_scale"
));
}
bool
inscale_flag
=
false
;
bool
outscale_flag
=
false
;
if
(
mul_op_desc
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
mul_op_desc
->
GetAttr
(
"X"
));
inscale_flag
=
true
;
}
if
(
elementwise_add_op_desc
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
elementwise_add_op_desc
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag
&&
outscale_flag
);
// if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc
auto
out_threshold_attr
=
...
...
paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
@@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
@@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
}
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
@@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
@@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
@@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
1b58ce14
...
...
@@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
any_op2
->
LinksFrom
({
quant_dequant_out
});
}
void
patterns
::
DeleteWeightQuantDequantLinearOpPattern
::
operator
()()
{
auto
weight_dequantize_linear_op_x
=
pattern
->
NewNode
(
weight_dequantize_linear_op_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"dequantize_linear"
,
"X"
)
->
assert_is_persistable_var
();
auto
weight_dequantize_linear_op_scale
=
pattern
->
NewNode
(
weight_dequantize_linear_op_scale_repr
())
->
AsInput
()
->
assert_is_op_input
(
"dequantize_linear"
,
"Scale"
)
->
assert_is_persistable_var
();
auto
weight_dequantize_linear_op
=
pattern
->
NewNode
(
weight_dequantize_linear_op_repr
())
->
assert_is_op
(
"dequantize_linear"
);
auto
weight_dequantize_linear_op_out
=
pattern
->
NewNode
(
weight_dequantize_linear_op_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"dequantize_linear"
,
"Y"
);
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
weight_dequantize_linear_op
->
LinksFrom
(
{
weight_dequantize_linear_op_x
,
weight_dequantize_linear_op_scale
})
.
LinksTo
({
weight_dequantize_linear_op_out
});
any_op2
->
LinksFrom
({
weight_dequantize_linear_op_out
});
}
void
patterns
::
DeleteQuantDequantLinearOpPattern
::
operator
()()
{
auto
quantize_linear_op_x
=
pattern
->
NewNode
(
quantize_linear_op_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"quantize_linear"
,
"X"
);
auto
quantize_linear_op_scale
=
pattern
->
NewNode
(
quantize_linear_op_scale_repr
())
->
AsInput
()
->
assert_is_op_input
(
"quantize_linear"
,
"Scale"
)
->
assert_is_persistable_var
();
auto
quantize_linear_op
=
pattern
->
NewNode
(
quantize_linear_op_repr
())
->
assert_is_op
(
"quantize_linear"
);
auto
quantize_linear_op_out
=
pattern
->
NewNode
(
quantize_linear_op_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"quantize_linear"
,
"Y"
)
->
assert_is_op_input
(
"dequantize_linear"
,
"X"
)
->
assert_var_not_persistable
();
// Can not add this node. Todo: Wangzheee
/*
auto dequantize_linear_op_scale =
pattern->NewNode(dequantize_linear_op_scale_repr())
->assert_is_op_input("dequantize_linear", "Scale")
->AsIntermediate();
*/
auto
dequantize_linear_op
=
pattern
->
NewNode
(
dequantize_linear_op_repr
())
->
assert_is_op
(
"dequantize_linear"
);
auto
dequantize_linear_op_out
=
pattern
->
NewNode
(
dequantize_linear_op_out_repr
())
->
AsIntermediate
()
->
assert_is_op_output
(
"dequantize_linear"
,
"Y"
);
auto
any_op2
=
pattern
->
NewNode
(
any_op2_repr
())
->
assert_is_op
()
->
AsOutput
();
quantize_linear_op
->
LinksFrom
({
quantize_linear_op_x
,
quantize_linear_op_scale
})
.
LinksTo
({
quantize_linear_op_out
});
dequantize_linear_op
->
LinksFrom
({
quantize_linear_op_out
})
.
LinksTo
({
dequantize_linear_op_out
});
any_op2
->
LinksFrom
({
dequantize_linear_op_out
});
}
PDNode
*
patterns
::
ReshapeTransposeMatmulPattern
::
operator
()(
const
std
::
string
&
op_name
,
bool
with_reshape_xshape
,
bool
with_transpose_xshape
)
{
...
...
@@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() {
return
shift_out
;
}
// Add support int8 flag
// Add support int8 flag
and out_threshold
PDNode
*
patterns
::
AddSupportInt8
::
operator
()()
{
auto
prev_op
=
pattern
->
NewNode
(
prev_op_repr
())
->
assert_is_op
()
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
Op
()
->
HasAttr
(
"out_threshold"
)
?
true
:
false
;
});
auto
prev_out
=
pattern
->
NewNode
(
prev_out_repr
())
->
assert_is_var
();
auto
quant_op
=
pattern
->
NewNode
(
quant_op_repr
())
->
assert_is_op
()
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
Op
()
->
HasAttr
(
"out_threshold"
)
?
true
:
false
;
});
auto
quant_op
=
pattern
->
NewNode
(
quant_op_repr
())
->
assert_is_op
();
auto
quant_out
=
pattern
->
NewNode
(
quant_out_repr
())
->
assert_is_var
()
->
AsOutput
();
prev_op
->
LinksTo
({
prev_out
});
prev_out
->
LinksTo
({
quant_op
});
pattern
->
NewNode
(
quant_out_repr
())
->
assert_is_var
()
->
assert_more
([
&
](
Node
*
node
)
{
return
node
->
outputs
.
size
()
>
0
;
})
->
AsOutput
();
quant_op
->
LinksTo
({
quant_out
});
return
quant_out
;
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
1b58ce14
...
...
@@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
PATTERN_DECL_NODE
(
any_op2
);
};
struct
DeleteWeightQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteWeightQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_weight_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
weight_dequantize_linear_op_x
);
PATTERN_DECL_NODE
(
weight_dequantize_linear_op_scale
);
PATTERN_DECL_NODE
(
weight_dequantize_linear_op
);
PATTERN_DECL_NODE
(
weight_dequantize_linear_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
struct
DeleteQuantDequantLinearOpPattern
:
public
PatternBase
{
DeleteQuantDequantLinearOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_quant_dequant_linear_op_pattern"
)
{}
void
operator
()();
PATTERN_DECL_NODE
(
quantize_linear_op_x
);
PATTERN_DECL_NODE
(
quantize_linear_op_scale
);
PATTERN_DECL_NODE
(
quantize_linear_op
);
PATTERN_DECL_NODE
(
quantize_linear_op_out
);
PATTERN_DECL_NODE
(
dequantize_linear_op
);
// PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node.
// Todo: Wangzheee
PATTERN_DECL_NODE
(
dequantize_linear_op_out
);
PATTERN_DECL_NODE
(
any_op2
);
};
// Reshape + Transpose + Matmul
// named nodes:
// reshape_op, reshape_out, reshape_xshape,
...
...
@@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase {
:
PatternBase
(
pattern
,
name_scope
,
"Add_support_int8"
)
{}
PDNode
*
operator
()();
PATTERN_DECL_NODE
(
prev_op
);
PATTERN_DECL_NODE
(
prev_out
);
PATTERN_DECL_NODE
(
quant_op
);
PATTERN_DECL_NODE
(
quant_out
);
};
...
...
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_op_desc
.
SetAttr
(
"head_number"
,
head_number
);
auto
*
mul0_op_desc
=
mul0
->
Op
();
auto
*
mul1_op_desc
=
mul1
->
Op
();
auto
*
mul2_op_desc
=
mul2
->
Op
();
if
(
mul0_op_desc
->
HasAttr
(
"enable_int8"
))
{
multihead_op_desc
.
SetAttr
(
"enable_int8"
,
mul0_op_desc
->
GetAttr
(
"enable_int8"
));
// all mul op has same input.
// all mul op has same input.
if
(
multihead_op_desc
.
HasAttr
(
"Input_scale"
))
{
multihead_op_desc
.
SetAttr
(
"Input_scale"
,
mul0_op_desc
->
GetAttr
(
"X_scale"
));
auto
weight_scale0
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
mul0_op_desc
->
GetAttr
(
"weight_scale"
));
auto
weight_scale1
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
mul1_op_desc
->
GetAttr
(
"weight_scale"
));
auto
weight_scale2
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
mul2_op_desc
->
GetAttr
(
"weight_scale"
));
auto
weight_max
=
std
::
max
(
weight_scale0
,
weight_scale1
);
weight_max
=
std
::
max
(
weight_max
,
weight_scale2
);
multihead_op_desc
.
SetAttr
(
"weight_scale"
,
weight_max
);
auto
*
add0_op_desc
=
eltadd0
->
Op
();
auto
*
add1_op_desc
=
eltadd1
->
Op
();
auto
*
add2_op_desc
=
eltadd2
->
Op
();
if
(
add0_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
out_scale0
=
BOOST_GET_CONST
(
float
,
add0_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale1
=
BOOST_GET_CONST
(
float
,
add1_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale2
=
BOOST_GET_CONST
(
float
,
add2_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale_max
=
std
::
max
(
out_scale0
,
out_scale1
);
out_scale_max
=
std
::
max
(
out_scale_max
,
out_scale2
);
multihead_op_desc
.
SetAttr
(
"fc_out_threshold"
,
out_scale_max
);
}
mul0_op_desc
->
GetAttr
(
"Input_scale"
));
}
auto
*
add0_op_desc
=
eltadd0
->
Op
();
auto
*
add1_op_desc
=
eltadd1
->
Op
();
auto
*
add2_op_desc
=
eltadd2
->
Op
();
if
(
add0_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
out_scale0
=
BOOST_GET_CONST
(
float
,
add0_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale1
=
BOOST_GET_CONST
(
float
,
add1_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale2
=
BOOST_GET_CONST
(
float
,
add2_op_desc
->
GetAttr
(
"out_threshold"
));
auto
out_scale_max
=
std
::
max
(
out_scale0
,
out_scale1
);
out_scale_max
=
std
::
max
(
out_scale_max
,
out_scale2
);
multihead_op_desc
.
SetAttr
(
"fc_out_threshold"
,
out_scale_max
);
}
auto
*
softmax_qk_op_desc
=
softmax_qk
->
Op
();
auto
*
matmul_qk_op_desc
=
matmul_qk
->
Op
();
if
(
matmul_qk_op_desc
->
HasAttr
(
"
X
_scale"
))
{
if
(
matmul_qk_op_desc
->
HasAttr
(
"
Input
_scale"
))
{
multihead_op_desc
.
SetAttr
(
"qkv2context_plugin_int8"
,
true
);
if
(
softmax_qk_op_desc
->
HasAttr
(
"out_threshold"
))
{
auto
qkv_plugin_scale
=
BOOST_GET_CONST
(
...
...
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
Node
*
output_scale
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"output_scale_node"
));
Node
*
output_act
=
subgraph
.
at
(
pattern
.
GetPDNode
(
"output_act_node"
));
int
bit_length
=
BOOST_GET_CONST
(
int
,
quant
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
// Get input scale from tensor
std
::
string
input_scale_var_name
=
quant
->
Op
()
->
Input
(
"InScale"
).
front
();
...
...
@@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
"Input scale tensor's place should be CPU."
));
const
float
*
input_scale_data
=
input_scale_tensor
.
data
<
float
>
();
float
in_scale
=
input_scale_data
[
0
];
float
scale_value
=
in_scale
/
range
;
float
scale_value
=
in_scale
;
// Set input scale in attr, and relink nodes
std
::
string
input_act_name
=
input_act
->
Var
()
->
Name
();
...
...
@@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type
==
"conv2d_fusion"
||
quantized_op_type
==
"depthwise_conv2d"
||
quantized_op_type
==
"fc"
||
quantized_op_type
==
"conv2d_transpose"
)
{
quantized_op_type
==
"conv2d_transpose"
||
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"matmul_v2"
)
{
op_desc
->
SetAttr
(
"Input_scale"
,
scale_value
);
}
else
if
(
quantized_op_type
==
"mul"
||
quantized_op_type
==
"matmul"
||
quantized_op_type
==
"matmul_v2"
)
{
op_desc
->
SetAttr
(
"X_scale"
,
scale_value
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported quantized op type %s."
,
quantized_op_type
));
...
...
@@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc
.
SetInput
(
"X"
,
{
new_input
});
new_op_desc
.
SetOutput
(
"Out"
,
{
new_output
});
}
new_op_desc
.
SetAttr
(
"weight_scale"
,
weight_scale
);
new_op_desc
.
Flush
();
auto
*
new_op
=
graph
->
CreateOpNode
(
&
new_op_desc
);
IR_NODE_LINK_TO
(
quantized_op_input_node
,
new_op
);
...
...
paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc
浏览文件 @
1b58ce14
...
...
@@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"transpose_Y"
,
matmul_op
->
Op
()
->
GetAttr
(
"transpose_Y"
));
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag
=
false
;
bool
outscale_flag
=
false
;
if
(
matmul_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
matmul_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag
=
true
;
}
if
(
matmul_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag
&&
outscale_flag
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"transpose_Y"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"trans_y"
));
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag
=
false
;
bool
outscale_flag
=
false
;
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag
=
true
;
}
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag
&&
outscale_flag
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_v2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_v2_in_y
,
mul_node
);
...
...
@@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
}
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag
=
false
;
bool
outscale_flag
=
false
;
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag
=
true
;
}
if
(
matmul_v2_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_v2_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag
&&
outscale_flag
);
auto
matmul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
matmul_v2_in_x
,
matmul_node
);
IR_NODE_LINK_TO
(
matmul_v2_in_y
,
matmul_node
);
...
...
@@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag_x
=
false
;
bool
outscale_flag
=
false
;
if
(
squeeze2_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
squeeze2_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag_x
=
true
;
}
if
(
matmul_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag_x
&&
outscale_flag
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
squeeze2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
@@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag_x
=
false
;
bool
outscale_flag
=
false
;
if
(
reshape2_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
reshape2_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag_x
=
true
;
}
if
(
matmul_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag_x
&&
outscale_flag
);
if
(
!
IsCompat
(
desc
))
{
LOG
(
WARNING
)
<<
"TrtReshape2MatmulFusePass in out mul op compat failed."
;
...
...
@@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc
.
SetAttr
(
"y_num_col_dims"
,
1
);
if
(
matmul_op
->
Op
()
->
HasAttr
(
"enable_int8"
))
{
desc
.
SetAttr
(
"enable_int8"
,
matmul_op
->
Op
()
->
GetAttr
(
"enable_int8"
));
desc
.
SetAttr
(
"X_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"X_scale"
));
desc
.
SetAttr
(
"weight_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"weight_scale"
));
desc
.
SetAttr
(
"Input_scale"
,
matmul_op
->
Op
()
->
GetAttr
(
"Input_scale"
));
desc
.
SetAttr
(
"out_threshold"
,
matmul_op
->
Op
()
->
GetAttr
(
"out_threshold"
));
}
bool
inscale_flag_x
=
false
;
bool
outscale_flag
=
false
;
if
(
flatten2_op
->
Op
()
->
HasAttr
(
"X"
))
{
desc
.
SetAttr
(
"X"
,
flatten2_op
->
Op
()
->
GetAttr
(
"X"
));
inscale_flag_x
=
true
;
}
if
(
matmul_op
->
Op
()
->
HasAttr
(
"Out"
))
{
desc
.
SetAttr
(
"Out"
,
matmul_op
->
Op
()
->
GetAttr
(
"Out"
));
outscale_flag
=
true
;
}
desc
.
SetAttr
(
"support_int8"
,
inscale_flag_x
&&
outscale_flag
);
auto
mul_node
=
g
->
CreateOpNode
(
&
desc
);
IR_NODE_LINK_TO
(
flatten2_in_x
,
mul_node
);
IR_NODE_LINK_TO
(
matmul_in_y
,
mul_node
);
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
1b58ce14
...
...
@@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const
std
::
vector
<
std
::
string
>
kTRTSubgraphPasses
({
"adaptive_pool2d_convert_global_pass"
,
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
"delete_weight_dequant_linear_op_pass"
,
//
"delete_quant_dequant_linear_op_pass"
,
//
"add_support_int8_pass"
,
//
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
"embedding_eltwise_layernorm_fuse_pass"
,
//
...
...
@@ -98,9 +101,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass"
,
//
"fc_fuse_pass"
,
//
"conv_elementwise_add_fuse_pass"
,
//
"add_support_int8_pass"
,
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
"tensorrt_subgraph_pass"
,
//
"conv_bn_fuse_pass"
,
//
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
...
...
paddle/fluid/inference/tensorrt/convert/activation_op.cc
浏览文件 @
1b58ce14
...
...
@@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter {
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
RreplenishLayerAndOutput
(
layer
,
op_type_
,
{
output_name
},
test_mode
);
if
(
op_desc
.
HasAttr
(
"out_scale"
))
{
#if IS_TRT_VERSION_GE(5130)
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_scale"
));
engine_
->
SetTensorDynamicRange
(
layer
->
getOutput
(
0
),
out_scale
);
#endif
}
}
protected:
...
...
paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
浏览文件 @
1b58ce14
...
...
@@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter {
auto
*
scale_v
=
scope
.
FindVar
(
scale_name
);
auto
*
scale_t
=
scale_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
scale_ptr
=
engine_
->
GetWeightCPUData
(
scale_name
,
scale_t
,
false
);
float
*
scale_ptr
=
engine_
->
GetWeightCPUData
(
scale_name
,
scale_t
);
auto
*
bias_v
=
scope
.
FindVar
(
bias_name
);
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
bias_ptr
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
,
false
);
float
*
bias_ptr
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
...
...
paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
浏览文件 @
1b58ce14
...
...
@@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if
(
enable_int8
)
{
#if IS_TRT_VERSION_GE(5000)
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
true
,
weight_scale
);
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
engine
->
SetTensorDynamicRange
(
X
,
in_scale
);
#endif
}
else
{
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
false
);
}
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
4UL
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
auto
*
bias_tensor
=
scope
.
GetVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
bias_tensor_data
=
bias_tensor
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
bias_tensor_data
,
false
);
bias_tensor_data
);
bias_size
=
static_cast
<
size_t
>
(
bias_tensor_data
->
numel
());
}
...
...
paddle/fluid/inference/tensorrt/convert/conv3d_op.cc
浏览文件 @
1b58ce14
...
...
@@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
if
(
enable_int8
)
{
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
true
,
weight_scale
);
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
engine
->
SetTensorDynamicRange
(
X
,
in_scale
);
}
else
{
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
,
false
);
}
weight_data
=
engine
->
GetWeightCPUData
(
op_desc
.
Input
(
"Filter"
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
5UL
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/inference/tensorrt/convert/deformable_conv_op.cc
浏览文件 @
1b58ce14
...
...
@@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter {
auto
*
filter_var
=
scope
.
FindVar
(
filter_name
);
auto
*
filter_tensor
=
filter_var
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
filter_data
=
engine_
->
GetWeightCPUData
(
filter_name
,
filter_tensor
,
false
);
float
*
filter_data
=
engine_
->
GetWeightCPUData
(
filter_name
,
filter_tensor
);
const
int
c_o
=
filter_tensor
->
dims
()[
0
];
const
int
c_i
=
filter_tensor
->
dims
()[
1
];
...
...
paddle/fluid/inference/tensorrt/convert/elementwise_op.cc
浏览文件 @
1b58ce14
...
...
@@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto
*
Y_t
=
Y_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
weight_data
=
nullptr
;
auto
output_name
=
op_desc
.
Output
(
"Out"
)[
0
];
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Y"
).
front
(),
Y_t
,
false
);
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Y"
).
front
(),
Y_t
);
nvinfer1
::
Dims
dims_x
=
X
->
getDimensions
();
auto
regist_eltwise_weight
=
[
&
](
nvinfer1
::
ScaleMode
scale_mode
)
{
...
...
@@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter {
RreplenishLayerAndOutput
(
layer
,
"elementwise_"
+
op_type_
,
{
output_name
},
test_mode
);
}
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
"X_scale"
));
float
x_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X_scale"
));
engine_
->
SetTensorDynamicRange
(
X
,
x_scale
);
#endif
}
};
if
(
engine_
->
with_dynamic_shape
())
{
...
...
@@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
auto
common_func
=
[
&
](
nvinfer1
::
ILayer
*
layer
)
{
RreplenishLayerAndOutput
(
layer
,
"elementwise"
,
{
output_name
},
test_mode
);
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
"X_scale"
));
CHECK
(
op_desc
.
HasAttr
(
"Y_scale"
));
float
x_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X_scale"
));
float
y_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Y_scale"
));
engine_
->
SetTensorDynamicRange
(
X
,
x_scale
);
engine_
->
SetTensorDynamicRange
(
Y
,
y_scale
);
#endif
}
};
if
(
dims_x
.
nbDims
==
dims_y
.
nbDims
)
{
...
...
paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc
浏览文件 @
1b58ce14
...
...
@@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
,
false
);
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
...
...
paddle/fluid/inference/tensorrt/convert/fc_op.cc
浏览文件 @
1b58ce14
...
...
@@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided.
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
float
in_scale
=
0.
;
if
(
enable_int8
)
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
i_name
+
"_scale"
));
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
i_name
+
"_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
,
true
,
weight_scale
);
bool
support_int8
=
false
;
if
(
op_desc
.
HasAttr
(
"support_int8"
))
{
support_int8
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"support_int8"
));
}
float
in_scale
=
0
;
if
(
enable_int8
||
support_int8
)
{
if
(
enable_int8
)
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
}
else
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X"
));
}
engine_
->
SetTensorDynamicRange
(
X
,
in_scale
);
#endif
}
else
{
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
,
false
);
}
weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
w_name
).
front
(),
Y_t
);
PADDLE_ENFORCE_EQ
(
Y_t
->
dims
().
size
(),
2UL
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter {
auto
regist_fc
=
[
&
](
nvinfer1
::
ITensor
*
inputs
,
int
n_output
,
TensorRTEngine
::
Weight
&
weight
,
TensorRTEngine
::
Weight
&
bias
)
{
if
(
enable_int8
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv layer
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
float
out_scale
=
0
;
if
(
enable_int8
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
}
else
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
inputs
,
n_output
,
...
...
@@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter {
if
(
with_bias
)
{
auto
*
b_v
=
scope
.
GetVar
(
op_desc
.
Input
(
"Bias"
).
front
());
auto
*
b_t
=
b_v
->
GetMutable
<
framework
::
LoDTensor
>
();
bias_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
b_t
,
false
);
bias_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Bias"
).
front
(),
b_t
);
bias_num
=
b_t
->
numel
();
}
TensorRTEngine
::
Weight
bias
{
nvinfer1
::
DataType
::
kFLOAT
,
...
...
@@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter {
// not add Shuffle layer in ernie's multihead.
if
(
engine_
->
use_oss
()
&&
engine_
->
with_ernie
()
&&
x_dim
.
nbDims
==
4
&&
x_dim
.
d
[
3
]
==
1
&&
x_num_col_dims
==
2
)
{
if
(
enable_int8
)
{
if
(
enable_int8
||
support_int8
)
{
// add conv1x1 layer
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
auto
*
fc_layer_int8
=
...
...
@@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter {
op_desc
.
HasAttr
(
"out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in fc layers in int8 mode"
));
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
float
out_scale
=
0
;
if
(
enable_int8
)
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
}
else
{
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Out"
));
}
engine_
->
SetTensorDynamicRange
(
fc_layer_int8
->
getOutput
(
0
),
out_scale
);
nvinfer1
::
IActivationLayer
*
relu_layer_int8
=
TRT_ENGINE_ADD_LAYER
(
...
...
@@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter {
auto
*
reshape_before_fc_layer
=
reshape_before_fc
(
X
,
x_dim
,
x_num_col_dims
,
output_name
);
auto
*
reshape_itensor
=
reshape_before_fc_layer
->
getOutput
(
0
);
if
(
enable_int8
)
{
if
(
enable_int8
||
support_int8
)
{
engine_
->
SetTensorDynamicRange
(
reshape_itensor
,
in_scale
);
}
regist_fc
(
reshape_itensor
,
n_output
,
weight
,
bias
);
...
...
paddle/fluid/inference/tensorrt/convert/group_norm_op.cc
浏览文件 @
1b58ce14
...
...
@@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter {
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
,
false
);
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
...
...
paddle/fluid/inference/tensorrt/convert/leaky_relu_op.cc
浏览文件 @
1b58ce14
...
...
@@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter {
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
if
(
enable_int8
)
{
CHECK
(
op_desc
.
HasAttr
(
"
X
_scale"
));
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"
X
_scale"
));
CHECK
(
op_desc
.
HasAttr
(
"
Input
_scale"
));
float
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"
Input
_scale"
));
engine_
->
SetTensorDynamicRange
(
input
,
in_scale
);
}
#else
...
...
paddle/fluid/inference/tensorrt/convert/matmul_op.cc
浏览文件 @
1b58ce14
...
...
@@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter {
:
nvinfer1
::
MatrixOperation
::
kNONE
;
if
(
op_desc
.
HasAttr
(
"support_int8"
)
&&
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
)
{
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"support_int8"
))
&&
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
&&
platform
::
GetGPUComputeCapability
(
0
)
>=
75
)
{
if
(
engine_
->
with_dynamic_shape
())
{
VLOG
(
3
)
<<
"Convert a fluid matmul_op_int8_dynamic to TensorRT "
"MatmulPluginLayer"
;
...
...
paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
浏览文件 @
1b58ce14
...
...
@@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
float
*
weight_data
=
nullptr
;
bool
enable_int8
=
op_desc
.
HasAttr
(
"enable_int8"
);
bool
qkv2context_plugin_int8
=
op_desc
.
HasAttr
(
"qkv2context_plugin_int8"
);
float
in_scale
=
0.
;
if
(
enable_int8
)
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
))
*
127
;
auto
weight_scale
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op_desc
.
GetAttr
(
"weight_scale"
));
weight_data
=
engine_
->
GetWeightCPUData
(
weight_name
,
weight_t
,
true
,
weight_scale
);
if
(
op_desc
.
HasAttr
(
"Input_scale"
))
{
in_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
engine_
->
SetTensorDynamicRange
(
input
,
in_scale
);
}
else
{
weight_data
=
engine_
->
GetWeightCPUData
(
weight_name
,
weight_t
,
false
);
}
weight_data
=
engine_
->
GetWeightCPUData
(
weight_name
,
weight_t
);
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
,
false
);
float
*
bias_data
=
engine_
->
GetWeightCPUData
(
bias_name
,
bias_t
);
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_t
->
numel
());
memcpy
(
weight_data_tmp
.
data
(),
weight_data
,
...
...
@@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
if
(
engine_
->
with_dynamic_shape
())
{
if
(
engine_
->
use_oss
())
{
if
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kFloat32
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use use_oss must be int8 or half, not float32."
));
}
nvinfer1
::
Weights
weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
weight_data
),
static_cast
<
int32_t
>
(
weight_t
->
numel
())};
...
...
@@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast
<
int32_t
>
(
bias_t
->
numel
())};
if
(
engine_
->
with_interleaved
())
{
VLOG
(
4
)
<<
"fused multihead_matmul op: use_oss and with_interleaved"
;
if
(
!
enable_int8
)
{
if
(
!
op_desc
.
HasAttr
(
"Input_scale"
)
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"use with_interleaved must be int8."
));
}
...
...
@@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
float
dp_probs
=
1.0
/
127.0
;
if
(
enable_int8
)
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
)
)
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
input
,
n
,
nv_ksize
,
weight
,
bias
);
...
...
@@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
weight
,
bias
);
}
if
(
enable_int8
)
{
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
)
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
"must have out threshold in multihead layers "
...
...
@@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto
creator
=
GetPluginRegistry
()
->
getPluginCreator
(
"CustomQKVToContextPluginDynamic"
,
"2"
);
assert
(
creator
!=
nullptr
);
int
type
=
static_cast
<
int
>
((
engine_
->
WithFp16
()
==
1
)
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
);
if
(
enable_int8
)
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kHALF
);
if
(
qkv2context_plugin_int8
)
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kINT8
);
}
int
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kHALF
);
if
(
qkv2context_plugin_int8
&&
(
engine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kInt8
))
{
type
=
static_cast
<
int
>
(
nvinfer1
::
DataType
::
kINT8
);
}
bool
has_mask
=
true
;
int
var_seqlen
=
1
;
...
...
@@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_dim
.
d
[
4
]
=
1
;
auto
*
reshape_before_fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Shuffle
,
*
input
);
if
(
enable_int8
)
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
)
)
{
engine_
->
SetTensorDynamicRange
(
reshape_before_fc_layer
->
getOutput
(
0
),
in_scale
);
}
...
...
@@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
// add layer fc
nvinfer1
::
ILayer
*
fc_layer
=
nullptr
;
if
(
enable_int8
)
{
if
(
op_desc
.
HasAttr
(
"Input_scale"
)
)
{
nvinfer1
::
DimsHW
nv_ksize
(
1
,
1
);
fc_layer
=
TRT_ENGINE_ADD_LAYER
(
engine_
,
Convolution
,
*
reshape_before_fc_layer
->
getOutput
(
0
),
n
,
...
...
@@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n
,
weight
.
get
(),
bias
.
get
());
}
if
(
enable_int8
)
{
if
(
op_desc
.
HasAttr
(
"fc_out_threshold"
)
)
{
PADDLE_ENFORCE_EQ
(
op_desc
.
HasAttr
(
"fc_out_threshold"
),
true
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool
with_fp16
=
engine_
->
WithFp16
()
&&
!
engine_
->
disable_trt_plugin_fp16
();
if
(
en
able_i
nt8
)
{
with_fp16
=
1
;
if
(
en
gine_
->
precision
()
==
AnalysisConfig
::
Precision
::
kI
nt8
)
{
with_fp16
=
true
;
}
plugin
::
DynamicPluginTensorRT
*
plugin
=
new
plugin
::
QkvToContextPluginDynamic
(
hidden_in
,
head_number
,
...
...
paddle/fluid/inference/tensorrt/convert/op_converter.h
浏览文件 @
1b58ce14
...
...
@@ -145,42 +145,68 @@ class OpConverter {
(
*
it
)(
op
,
scope
,
test_mode
);
size_t
output_num
=
op_desc
.
OutputNames
().
size
();
if
(
output_num
==
1
)
{
// The number of output is 1
if
(
op_desc
.
HasAttr
(
"out_threshold"
))
{
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
std
::
string
output_name
=
""
;
if
(
op_desc
.
HasOutput
(
"Output"
))
{
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
}
else
if
(
op_desc
.
HasOutput
(
"Out"
))
{
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
}
else
if
(
op_desc
.
HasOutput
(
"Y"
))
{
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Op %s has out threshold but doesn't "
"have an output named
\"
Output
\"
, "
"
\"
Out
\"
or
\"
Y
\"
."
,
op_desc
.
Type
()));
}
// only one out settensordynamicRange
if
(
op_desc
.
HasAttr
(
"out_threshold"
))
{
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_threshold"
));
std
::
string
output_name
=
""
;
if
(
op_desc
.
HasOutput
(
"Output"
))
{
output_name
=
op_desc
.
Output
(
"Output"
).
front
();
}
else
if
(
op_desc
.
HasOutput
(
"Out"
))
{
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
}
else
if
(
op_desc
.
HasOutput
(
"Y"
))
{
output_name
=
op_desc
.
Output
(
"Y"
).
front
();
}
else
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Op %s has out threshold but doesn't "
"have an output named
\"
Output
\"
, "
"
\"
Out
\"
or
\"
Y
\"
."
,
op_desc
.
Type
()));
}
auto
*
output_itensor
=
engine
->
GetITensor
(
output_name
);
engine
->
SetTensorDynamicRange
(
output_itensor
,
out_scale
);
VLOG
(
1
)
<<
"Set out scale = "
<<
out_scale
<<
" for tensor "
<<
output_name
<<
"."
;
}
// outs settensordynamicRange
for
(
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
if
(
op_desc
.
HasAttr
(
"out_"
+
std
::
to_string
(
i
)
+
"_threshold"
))
{
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_"
+
std
::
to_string
(
i
)
+
"_threshold"
));
std
::
string
output_name
=
op_desc
.
Output
(
op_desc
.
OutputNames
()[
i
]).
front
();
auto
*
output_itensor
=
engine
->
GetITensor
(
output_name
);
engine
->
SetTensorDynamicRange
(
output_itensor
,
out_scale
);
VLOG
(
1
)
<<
"Set out scale = "
<<
out_scale
<<
" for tensor "
<<
output_name
<<
"."
;
}
}
else
if
(
output_num
>
1
)
{
// The number of outputs greater than 1
for
(
size_t
i
=
0
;
i
<
output_num
;
++
i
)
{
if
(
op_desc
.
HasAttr
(
"out_"
+
std
::
to_string
(
i
)
+
"_threshold"
))
{
float
out_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"out_"
+
std
::
to_string
(
i
)
+
"_threshold"
));
std
::
string
output_name
=
op_desc
.
Output
(
op_desc
.
OutputNames
()[
i
]).
front
();
auto
*
output_itensor
=
engine
->
GetITensor
(
output_name
);
engine
->
SetTensorDynamicRange
(
output_itensor
,
out_scale
);
VLOG
(
1
)
<<
"Set out scale = "
<<
out_scale
<<
" for tensor "
<<
output_name
<<
"."
;
}
}
// quant_dequant_linear support for paddle trt
std
::
vector
<
std
::
string
>
inputs_name
=
op_desc
.
InputNames
();
std
::
vector
<
std
::
string
>
outputs_name
=
op_desc
.
OutputNames
();
for
(
size_t
i
=
0
;
i
<
inputs_name
.
size
();
i
++
)
{
if
(
op_desc
.
HasAttr
(
inputs_name
[
i
]))
{
std
::
string
input_tensor_name
=
op_desc
.
Input
(
inputs_name
[
i
])[
0
];
auto
*
input_itensor
=
engine
->
GetITensor
(
input_tensor_name
);
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
inputs_name
[
i
]));
engine
->
SetTensorDynamicRange
(
input_itensor
,
input_scale
);
VLOG
(
1
)
<<
"Set input tensor scale = "
<<
input_scale
<<
" for tensor: "
<<
input_tensor_name
<<
"."
;
}
}
for
(
size_t
i
=
0
;
i
<
outputs_name
.
size
();
i
++
)
{
if
(
op_desc
.
HasAttr
(
outputs_name
[
i
]))
{
std
::
string
output_tensor_name
=
op_desc
.
Output
(
outputs_name
[
i
])[
0
];
auto
*
output_itensor
=
engine
->
GetITensor
(
output_tensor_name
);
float
output_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
outputs_name
[
i
]));
engine
->
SetTensorDynamicRange
(
output_itensor
,
output_scale
);
VLOG
(
1
)
<<
"Set output tensor scale = "
<<
output_scale
<<
" for tensor: "
<<
output_tensor_name
<<
"."
;
}
}
}
...
...
paddle/fluid/inference/tensorrt/convert/pool2d_op.cc
浏览文件 @
1b58ce14
...
...
@@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter {
}
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
#if IS_TRT_VERSION_GE(5000)
CHECK
(
op_desc
.
HasAttr
(
"X_scale"
));
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X
_scale"
));
CHECK
(
op_desc
.
HasAttr
(
"Input_scale"
));
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input
_scale"
));
engine_
->
SetTensorDynamicRange
(
input1
,
input_scale
);
#endif
}
std
::
vector
<
int
>
real_paddings
=
paddings
;
...
...
paddle/fluid/inference/tensorrt/convert/pool3d_op.cc
浏览文件 @
1b58ce14
...
...
@@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter {
nvinfer1
::
Dims3
nv_paddings
(
paddings
[
0
],
paddings
[
1
],
paddings
[
2
]);
nvinfer1
::
ILayer
*
layer
=
nullptr
;
if
(
op_desc
.
HasAttr
(
"enable_int8"
))
{
CHECK
(
op_desc
.
HasAttr
(
"X_scale"
));
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"X_scale"
));
CHECK
(
op_desc
.
HasAttr
(
"Input_scale"
));
float
input_scale
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"Input_scale"
));
engine_
->
SetTensorDynamicRange
(
input1
,
input_scale
);
}
...
...
paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc
浏览文件 @
1b58ce14
...
...
@@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
,
false
);
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
...
...
paddle/fluid/inference/tensorrt/convert/preln_skip_layernorm.cc
浏览文件 @
1b58ce14
...
...
@@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
,
false
);
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
...
...
paddle/fluid/inference/tensorrt/convert/prelu_op.cc
浏览文件 @
1b58ce14
...
...
@@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter {
layer
=
engine_
->
AddDynamicPlugin
(
&
input
,
input_num
,
plugin
);
}
else
{
#if IS_TRT_VERSION_GE(7000)
float
*
alpha_weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Alpha"
)[
0
],
alpha_tensor
,
false
);
float
*
alpha_weight_data
=
engine_
->
GetWeightCPUData
(
op_desc
.
Input
(
"Alpha"
)[
0
],
alpha_tensor
);
TensorRTEngine
::
Weight
alpha_weight
{
nvinfer1
::
DataType
::
kFLOAT
,
static_cast
<
void
*>
(
alpha_weight_data
),
static_cast
<
size_t
>
(
alpha_tensor
->
numel
())};
...
...
paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc
浏览文件 @
1b58ce14
...
...
@@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter {
auto
*
temp_tensor
=
temp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
(
*
dims
)
=
temp_tensor
->
dims
();
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
,
false
);
auto
*
temp_data
=
engine_
->
GetWeightCPUData
(
var_name
,
temp_tensor
);
return
temp_data
;
};
...
...
paddle/fluid/inference/tensorrt/engine.cc
浏览文件 @
1b58ce14
...
...
@@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
}
float
*
TensorRTEngine
::
GetWeightCPUData
(
const
std
::
string
&
name
,
framework
::
Tensor
*
weight_tensor
,
bool
enable_int8
,
const
std
::
vector
<
float
>
&
scale
)
{
framework
::
Tensor
*
weight_tensor
)
{
static
int
name_suffix_counter
=
0
;
std
::
string
name_suffix
=
std
::
to_string
(
name_suffix_counter
);
std
::
string
splitter
=
"__"
;
...
...
paddle/fluid/inference/tensorrt/engine.h
浏览文件 @
1b58ce14
...
...
@@ -389,8 +389,7 @@ class TensorRTEngine {
}
float
*
GetWeightCPUData
(
const
std
::
string
&
name
,
framework
::
Tensor
*
weight_tensor
,
bool
enable_int8
,
const
std
::
vector
<
float
>&
scale
=
{});
framework
::
Tensor
*
weight_tensor
);
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
...
...
paddle/fluid/operators/compat/dequantize_linear.pbtxt
0 → 100644
浏览文件 @
1b58ce14
type: "dequantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
paddle/fluid/operators/compat/mul.pbtxt
浏览文件 @
1b58ce14
...
...
@@ -60,15 +60,7 @@ extra {
type: BOOLEAN
}
attrs {
name: "X_scale"
type: FLOAT
}
attrs {
name: "weight_scale"
type: FLOAT
}
attrs {
name: "out_scale"
name: "Input_scale"
type: FLOAT
}
attrs {
...
...
paddle/fluid/operators/compat/quantize_linear.pbtxt
0 → 100644
浏览文件 @
1b58ce14
type: "quantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_multihead_matmul.py
浏览文件 @
1b58ce14
...
...
@@ -491,8 +491,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims"
:
2
,
"y_num_col_dims"
:
1
,
"enable_int8"
:
True
,
"X_scale"
:
1.0
,
"weight_scale"
:
[
1.0
],
"Input_scale"
:
1.0
,
},
{
"axis"
:
2
,
"out_threshold"
:
1.0
,
...
...
@@ -504,8 +503,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims"
:
2
,
"y_num_col_dims"
:
1
,
"enable_int8"
:
True
,
"X_scale"
:
1.0
,
"weight_scale"
:
[
1.0
],
"Input_scale"
:
1.0
,
},
{
"axis"
:
2
,
"out_threshold"
:
1.0
,
...
...
@@ -517,8 +515,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims"
:
2
,
"y_num_col_dims"
:
1
,
"enable_int8"
:
True
,
"X_scale"
:
1.0
,
"weight_scale"
:
[
1.0
],
"Input_scale"
:
1.0
,
},
{
"axis"
:
2
,
"out_threshold"
:
1.0
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录