Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
95332bef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95332bef
编写于
12月 08, 2022
作者:
R
RichardWooSJTU
提交者:
GitHub
12月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rewrite delete_weight_dequant_linear_op_encoder/decoder pass (#48650)
* rewrite delete_weight_deqquant_linear_op_encoder/decoder pass
上级
a14ae84b
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
442 addition
and
893 deletion
+442
-893
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-2
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc
...mework/ir/delete_weight_dequant_linear_op_decoder_pass.cc
+0
-373
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h
...amework/ir/delete_weight_dequant_linear_op_encoder_pass.h
+0
-34
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
...luid/framework/ir/delete_weight_dequant_linear_op_pass.cc
+142
-408
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
...fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
+16
-19
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc
...amework/ir/delete_weight_dequant_linear_op_pass_tester.cc
+141
-0
paddle/fluid/framework/ir/pass.cc
paddle/fluid/framework/ir/pass.cc
+2
-2
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+17
-0
paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc
.../framework/ir/trt_delete_weight_dequant_linear_op_pass.cc
+103
-38
paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h
...d/framework/ir/trt_delete_weight_dequant_linear_op_pass.h
+4
-3
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+12
-14
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
95332bef
...
...
@@ -95,9 +95,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library
(
shuffle_channel_detect_pass inference
)
pass_library
(
delete_quant_dequant_op_pass inference
)
pass_library
(
delete_quant_dequant_filter_op_pass inference
)
pass_library
(
trt_delete_weight_dequant_linear_op_pass inference
)
pass_library
(
delete_weight_dequant_linear_op_pass inference
)
pass_library
(
delete_weight_dequant_linear_op_encoder_pass inference
)
pass_library
(
delete_weight_dequant_linear_op_decoder_pass inference
)
pass_library
(
delete_quant_dequant_linear_op_pass inference
)
pass_library
(
delete_dropout_op_pass inference
)
pass_library
(
delete_c_identity_op_pass inference
)
...
...
@@ -359,6 +358,10 @@ cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
DEPS delete_dropout_op_pass
)
cc_test
(
test_delete_dequant_weight_linear_op_pass
SRCS delete_weight_dequant_linear_op_pass_tester.cc
DEPS delete_weight_dequant_linear_op_pass
)
if
(
WITH_GPU OR WITH_ROCM
)
cc_test
(
test_embedding_eltwise_layernorm_fuse_pass
...
...
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc
已删除
100644 → 0
浏览文件 @
a14ae84b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpDecoderPass
::
DeleteWeightDequantLinearOpDecoderPass
()
{
AddOpCompat
(
OpCompat
(
"quantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"depthwise_conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"relu"
,
""
})
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"output_padding"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"output_size"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
// Delete dequantize_linear_op, then dequantize weight
void
DeleteWeightDequantLinearOpDecoderPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_weight_dequant_linear_op_decoder_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteWeightDequantLinearOpDecoderPass "
"should not be null."
));
// Create pattern
patterns
::
DeleteWeightDequantLinearOpDecoderPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
int
found_count
=
0
;
bool
is_int8
=
false
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8
=
true
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
auto
*
any_op2_desc
=
any_op2
->
Op
();
// Get weight scale
std
::
vector
<
float
>
weight_scale
;
auto
*
weight_scale_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_scale
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT32
)
{
float
*
weight_scale_data
=
weight_scale_tensor
->
data
<
float
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]);
}
}
else
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
phi
::
dtype
::
float16
*
weight_scale_data
=
weight_scale_tensor
->
data
<
phi
::
dtype
::
float16
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
static_cast
<
float
>
(
weight_scale_data
[
i
]));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"%d is not supported."
,
weight_scale_tensor
->
dtype
()));
}
int
quant_axis
=
PADDLE_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
if
(
quant_axis
==
-
1
)
{
// per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
1
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."
));
// Add attr to anyop 2
any_op2_desc
->
SetAttr
(
"weight_scale"
,
weight_scale
[
0
]);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"
));
}
nodes2rm
.
insert
(
weight_dequantize_linear_op_scale
);
nodes2rm
.
insert
(
weight_dequantize_linear_op
);
nodes2rm
.
insert
(
weight_dequantize_linear_op_out
);
// relink weight to any_op2
any_op2_desc
->
RenameInput
(
weight_dequantize_linear_op_out
->
Var
()
->
Name
(),
weight_dequantize_linear_op_x
->
Var
()
->
Name
());
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
weight_dequantize_linear_op_x
,
any_op2
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
if
(
is_int8
)
{
auto
&
enable_int8
=
graph
->
Get
<
bool
>
(
"enable_int8"
);
enable_int8
=
true
;
}
AddStatis
(
found_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_weight_dequant_linear_op_decoder_pass
,
paddle
::
framework
::
ir
::
DeleteWeightDequantLinearOpDecoderPass
);
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h
已删除
100644 → 0
浏览文件 @
a14ae84b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteWeightDequantLinearOpEncoderPass
:
public
FusePassBase
{
public:
DeleteWeightDequantLinearOpEncoderPass
();
virtual
~
DeleteWeightDequantLinearOpEncoderPass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
浏览文件 @
95332bef
/
/
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
// limitations under the License.
/
*
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "glog/logging.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightQuantDequantLinearOpPass
::
DeleteWeightQuantDequantLinearOpPass
()
{
AddOpCompat
(
OpCompat
(
"quantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Scale"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"ZeroPoint"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"depthwise_conv2d"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ResidualData"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul_v2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"trans_x"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"trans_y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumGE
(
0.99
f
)
.
IsNumLE
(
1.01
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
.
IsBoolEQ
(
false
)
.
End
();
AddOpCompat
(
OpCompat
(
"fc"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"W"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"in_num_col_dims"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"activation_type"
)
.
IsStringIn
({
"relu"
,
""
})
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d_transpose"
))
.
AddInput
(
"Input"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Filter"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Bias"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Output"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"output_padding"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"output_size"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"groups"
)
.
IsNumGE
(
1
)
.
End
()
.
AddAttr
(
"dilations"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"strides"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"paddings"
)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
()
.
AddAttr
(
"padding_algorithm"
)
.
IsOptional
()
.
IsStringIn
({
"EXPLICIT"
,
"SAME"
,
"VALID"
})
.
End
()
.
AddAttr
(
"data_format"
)
.
IsStringIn
({
"NCHW"
,
"NHWC"
,
"AnyLayout"
})
.
End
();
}
// Delete dequantize_linear_op, then dequantize weight
void
DeleteWeightQuantDequantLinearOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_weight_quantdequant_linear_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteWeightQuantDequantLinearOpPass should not be null."
));
// Create pattern
patterns
::
DeleteWeightQuantDequantLinearOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
int
found_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
PADDLE_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
any_op2_desc
=
any_op2
->
Op
();
// get weight tensor
auto
*
weight_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_x
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int8_t
*
quantized_weight_data
=
weight_tensor
->
mutable_data
<
int8_t
>
(
platform
::
CPUPlace
());
auto
w_dims
=
weight_tensor
->
dims
();
// Get weight scale
std
::
vector
<
float
>
weight_scale
;
auto
*
weight_scale_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_scale
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
float
*
weight_scale_data
=
weight_scale_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]
/
range
);
}
// dequant weight
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_tensor
->
numel
());
int
quant_axis
=
PADDLE_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
if
(
quant_axis
==
-
1
)
{
// per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
class
Graph
;
void
DeleteWeightDequantLinearOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
std
::
unordered_set
<
std
::
string
>
op_list
=
{
"matmul_v2"
,
"matmul"
,
"mul"
,
"fc"
,
"depthwise_conv2d"
,
"conv2d"
,
"conv2d_transpose"
};
PADDLE_ENFORCE_EQ
(
graph
->
Has
(
kParamScopeAttr
),
true
,
platform
::
errors
::
InvalidArgument
(
"Graph must have kParamScopeAttr attribute."
));
auto
&
scope
=
graph
->
Get
<
framework
::
Scope
>
(
kParamScopeAttr
);
bool
is_int8
=
false
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
;
for
(
const
Node
*
n
:
graph
->
Nodes
())
{
if
(
n
->
IsOp
())
{
auto
*
op
=
n
->
Op
();
if
(
op
->
Type
()
==
"dequantize_linear"
)
{
Node
*
weight_var_node
,
*
dequantized_weight_var_node
,
*
scale_var_node
,
*
calcu_op_node
,
*
while_op_node
;
// 1. Judge whether for dequant weight and find
// weight_var_node/scale_var_node
for
(
auto
*
input_node
:
n
->
inputs
)
{
if
(
input_node
->
IsVar
()
&&
input_node
->
Var
()
->
Persistable
())
{
is_int8
=
true
;
if
(
input_node
->
Var
()
->
Name
()
==
op
->
Input
(
"X"
)[
0
])
{
weight_var_node
=
input_node
;
}
else
if
(
input_node
->
Var
()
->
Name
()
==
op
->
Input
(
"Scale"
)[
0
])
{
scale_var_node
=
input_node
;
}
}
else
{
return
;
}
}
// 2. Find next_op_node
// For while op: delete its input which is related to dequantized
// For calculation op: set weight scale as their attributes
for
(
auto
*
output_node
:
n
->
outputs
)
{
if
(
output_node
->
IsVar
()
&&
output_node
->
Var
()
->
Name
()
==
op
->
Output
(
"Y"
)[
0
])
{
dequantized_weight_var_node
=
output_node
;
for
(
auto
*
next_op_node
:
output_node
->
outputs
)
{
if
(
next_op_node
->
IsOp
())
{
if
(
next_op_node
->
Op
()
->
Type
()
==
"while"
)
{
while_op_node
=
next_op_node
;
auto
while_op_desc
=
while_op_node
->
Op
();
auto
while_Xs
=
while_op_desc
->
Input
(
"X"
);
while_Xs
.
erase
(
std
::
remove
(
std
::
begin
(
while_Xs
),
std
::
end
(
while_Xs
),
output_node
->
Var
()
->
Name
()),
std
::
end
(
while_Xs
));
while_op_node
->
Op
()
->
SetInput
(
"X"
,
while_Xs
);
}
else
if
(
op_list
.
count
(
next_op_node
->
Op
()
->
Type
())
!=
0
)
{
calcu_op_node
=
next_op_node
;
auto
*
calcu_op_desc
=
calcu_op_node
->
Op
();
std
::
vector
<
float
>
weight_scale
;
auto
*
weight_scale_tensor
=
scope
.
GetVar
(
scale_var_node
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT32
)
{
float
*
weight_scale_data
=
weight_scale_tensor
->
data
<
float
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]);
}
}
else
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
phi
::
dtype
::
float16
*
weight_scale_data
=
weight_scale_tensor
->
data
<
phi
::
dtype
::
float16
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
static_cast
<
float
>
(
weight_scale_data
[
i
]));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"The dtype of quantization scale must be FP32/16, "
"but received %d, which is not supported."
,
weight_scale_tensor
->
dtype
()));
}
int
quant_axis
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"quant_axis"
));
if
(
quant_axis
==
-
1
)
{
// per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
1
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."
));
// float(weight) * scale
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
0
];
}
}
else
if
(
quant_axis
==
0
)
{
// per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
/
inner_size
];
}
}
else
if
(
quant_axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
if
(
w_dims
.
size
()
==
4
)
{
// conv2d_transpose
std
::
string
quantized_op_type
=
any_op2
->
Op
()
->
Type
();
PADDLE_ENFORCE_EQ
(
quantized_op_type
,
"conv2d_transpose"
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[(
i
/
inner_size
)
%
w_dims
[
1
]];
"When quant_axis == -1, it means using per_layer "
"dequantization. In this situation, the number of "
"weight_scale should be 1, but received %d."
,
weight_scale_nums
));
calcu_op_desc
->
SetAttr
(
"weight_scale"
,
weight_scale
[
0
]);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Delete Weight Dequant Linear Op Pass is not supported "
"for "
"per-channel quantization"
));
}
calcu_op_desc
->
RenameInput
(
dequantized_weight_var_node
->
Var
()
->
Name
(),
weight_var_node
->
Var
()
->
Name
());
}
}
}
}
}
}
else
if
(
w_dims
.
size
()
==
2
)
{
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
%
w_dims
[
1
]];
// 3. Delete dequant op
IR_NODE_LINK_TO
(
weight_var_node
,
calcu_op_node
);
std
::
vector
<
const
Node
*>
nodes2rm_local
{
dequantized_weight_var_node
,
scale_var_node
,
n
};
for
(
auto
*
node2rm
:
nodes2rm_local
)
{
if
(
node2rm
)
{
nodes2rm
.
insert
(
node2rm
);
}
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"quant_axis should be -1 or 0 or 1, please check your model "
"OP'attribute "
));
}
weight_tensor
->
clear
();
// clear int weight
weight_tensor
->
Resize
(
phi
::
make_ddim
(
phi
::
vectorize
(
w_dims
)));
float
*
new_quantized_weight_data
=
weight_tensor
->
mutable_data
<
float
>
(
platform
::
CPUPlace
());
memcpy
(
new_quantized_weight_data
,
weight_data_tmp
.
data
(),
weight_tensor
->
numel
()
*
sizeof
(
float
));
nodes2rm
.
insert
(
weight_dequantize_linear_op_scale
);
nodes2rm
.
insert
(
weight_dequantize_linear_op
);
nodes2rm
.
insert
(
weight_dequantize_linear_op_out
);
}
// relink weight to any_op2
any_op2_desc
->
RenameInput
(
weight_dequantize_linear_op_out
->
Var
()
->
Name
(),
weight_dequantize_linear_op_x
->
Var
()
->
Name
());
any_op2_desc
->
Flush
();
IR_NODE_LINK_TO
(
weight_dequantize_linear_op_x
,
any_op2
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
found_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_count
);
GraphSafeRemoveNodes
(
graph
,
nodes2rm
);
graph
->
Set
(
"enable_int8"
,
new
bool
(
is_int8
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_weight_dequant_linear_op_pass
,
paddle
::
framework
::
ir
::
DeleteWeight
Quant
DequantLinearOpPass
);
paddle
::
framework
::
ir
::
DeleteWeightDequantLinearOpPass
);
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h
浏览文件 @
95332bef
/
/
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
//
Licensed under the Apache License, Version 2.0 (the "License");
//
you may not use this file except in compliance with the License.
//
You may obtain a copy of the License at
//
//
http://www.apache.org/licenses/LICENSE-2.0
//
//
Unless required by applicable law or agreed to in writing, software
//
distributed under the License is distributed on an "AS IS" BASIS,
//
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//
See the License for the specific language governing permissions and
// limitations under the License.
/
*
Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/
fuse_pass_base
.h"
#include "paddle/fluid/framework/ir/
pass
.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
DeleteWeightQuantDequantLinearOpPass
:
public
FusePassBase
{
public:
DeleteWeightQuantDequantLinearOpPass
();
virtual
~
DeleteWeightQuantDequantLinearOpPass
()
{}
class
Graph
;
class
DeleteWeightDequantLinearOpPass
:
public
Pass
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
...
...
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc
0 → 100644
浏览文件 @
95332bef
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
template
<
typename
T
>
void
AddVarToScope
(
Scope
*
param_scope
,
const
std
::
string
&
name
,
const
DDim
&
dims
)
{
auto
*
tensor
=
param_scope
->
Var
(
name
)
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
dims
);
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
dev_ctx
->
HostAlloc
<
T
>
(
tensor
,
tensor
->
numel
()
*
sizeof
(
T
));
}
template
<
typename
T
>
Scope
*
CreateParamScope
()
{
auto
param_scope
=
new
Scope
();
AddVarToScope
<
T
>
(
param_scope
,
"scale"
,
{
1
});
return
param_scope
;
}
TEST
(
DeleteWeightDequantLinearOpPass
,
basic
)
{
// inputs operator output
// --------------------------------------------------------------------
// (weight, scale) dequantize_linear -> dequantized_weight
// (x, dequantized_weight) matmul/fc/conv -> matmul_out
// (dequantized_weight) while -> [optional]
Layers
layers
;
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
768
});
auto
*
weight
=
layers
.
data
(
"weight"
,
{
768
,
768
},
true
);
auto
*
scale
=
layers
.
data
(
"scale"
,
{
1
},
true
);
auto
*
zero_point
=
layers
.
data
(
"zero_point"
,
{
1
},
true
);
auto
*
dequantized_weight
=
layers
.
dequantize_linear
(
weight
,
scale
,
zero_point
);
layers
.
matmul_v2
(
x
,
dequantized_weight
);
layers
.
while_loop
({
dequantized_weight
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
<
float
>
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"delete_weight_dequant_linear_op_pass"
);
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_dequant_nodes_after
=
GetNumOpNodes
(
graph
,
"dequantize_linear"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
3
,
platform
::
errors
::
InvalidArgument
(
"After pass, the number of nodes should be reduced by 3, but the "
"number before pass is %d, after pass is %d."
,
num_nodes_before
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_dequant_nodes_after
,
0
,
platform
::
errors
::
InvalidArgument
(
"After pass, the number of nodes of type "
"'dequantize_linear' should be 1, not %d."
,
num_dequant_nodes_after
));
}
TEST
(
DeleteWeightDequantLinearOpPass
,
basic_fp16
)
{
// inputs operator output
// --------------------------------------------------------------------
// (weight, scale) dequantize_linear -> dequantized_weight
// (x, dequantized_weight) matmul/fc/conv -> matmul_out
// (dequantized_weight) while -> [optional]
Layers
layers
;
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
128
,
768
});
auto
*
weight
=
layers
.
data
(
"weight"
,
{
768
,
768
},
true
);
auto
*
scale
=
layers
.
data
(
"scale"
,
{
1
},
true
);
auto
*
zero_point
=
layers
.
data
(
"zero_point"
,
{
1
},
true
);
auto
*
dequantized_weight
=
layers
.
dequantize_linear
(
weight
,
scale
,
zero_point
);
layers
.
matmul_v2
(
x
,
dequantized_weight
);
layers
.
while_loop
({
dequantized_weight
});
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
<
phi
::
dtype
::
float16
>
());
auto
pass
=
PassRegistry
::
Instance
().
Get
(
"delete_weight_dequant_linear_op_pass"
);
int
num_nodes_before
=
graph
->
Nodes
().
size
();
VLOG
(
3
)
<<
DebugString
(
graph
);
graph
.
reset
(
pass
->
Apply
(
graph
.
release
()));
int
num_nodes_after
=
graph
->
Nodes
().
size
();
int
num_dequant_nodes_after
=
GetNumOpNodes
(
graph
,
"dequantize_linear"
);
VLOG
(
3
)
<<
DebugString
(
graph
);
PADDLE_ENFORCE_EQ
(
num_nodes_before
,
num_nodes_after
+
3
,
platform
::
errors
::
InvalidArgument
(
"After pass, the number of nodes should be reduced by 3, but the "
"number before pass is %d, after pass is %d."
,
num_nodes_before
,
num_nodes_after
));
PADDLE_ENFORCE_EQ
(
num_dequant_nodes_after
,
0
,
platform
::
errors
::
InvalidArgument
(
"After pass, the number of nodes of type "
"'dequantize_linear' should be 1, not %d."
,
num_dequant_nodes_after
));
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
delete_weight_dequant_linear_op_pass
);
paddle/fluid/framework/ir/pass.cc
浏览文件 @
95332bef
...
...
@@ -48,8 +48,8 @@ static const std::vector<std::string> support_subgraph_passes = {
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass"
,
"fuse_multi_transformer_layer_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_
encoder_
pass"
,
"delete_weight_dequant_linear_op_decoder_pass"
};
"delete_weight_dequant_linear_op_pass"
,
};
Graph
*
Pass
::
Apply
(
Graph
*
graph
)
const
{
VLOG
(
10
)
<<
"start to apply pass "
<<
Type
()
<<
" to graph"
;
...
...
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
95332bef
...
...
@@ -641,6 +641,23 @@ struct Layers {
return
out
;
}
VarDesc
*
dequantize_linear
(
VarDesc
*
x
,
VarDesc
*
scale
,
VarDesc
*
zero_point
,
int
bit_length
=
8
,
int
quant_axis
=
-
1
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"dequantize_linear"
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetInput
(
"Scale"
,
{
scale
->
Name
()});
op
->
SetInput
(
"ZeroPoint"
,
{
zero_point
->
Name
()});
op
->
SetAttr
(
"bit_length"
,
bit_length
);
op
->
SetAttr
(
"quant_axis"
,
quant_axis
);
op
->
SetOutput
(
"Y"
,
{
out
->
Name
()});
return
out
;
}
void
backward
(
std
::
vector
<
VarDesc
*>
targets
)
{
// This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program.
...
...
paddle/fluid/framework/ir/
delete_weight_dequant_linear_op_encoder
_pass.cc
→
paddle/fluid/framework/ir/
trt_delete_weight_dequant_linear_op
_pass.cc
浏览文件 @
95332bef
...
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/
delete_weight_dequant_linear_op_encoder
_pass.h"
#include "paddle/fluid/framework/ir/
trt_delete_weight_dequant_linear_op
_pass.h"
#include <algorithm>
#include <memory>
...
...
@@ -32,8 +32,8 @@ namespace ir {
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpEncoder
Pass
::
DeleteWeightDequantLinearOpEncoder
Pass
()
{
TrtDeleteWeightQuantDequantLinearOp
Pass
::
TrtDeleteWeightQuantDequantLinearOp
Pass
()
{
AddOpCompat
(
OpCompat
(
"quantize_linear"
))
.
AddInput
(
"X"
)
.
IsTensor
()
...
...
@@ -270,64 +270,69 @@ DeleteWeightDequantLinearOpEncoderPass::
.
End
();
}
// Delete dequantize_linear_op, then dequantize weight
void
DeleteWeightDequantLinearOpEncoderPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
void
TrtDeleteWeightQuantDequantLinearOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_weight_
dequant_linear_op_encoder
_pattern"
;
"delete_weight_
quantdequant_linear_op
_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
GraphPatternDetector
gpd
;
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in DeleteWeightDequantLinearOpEncoderPass "
"should not be null."
));
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
InvalidArgument
(
"Scope in TrtDeleteWeightQuantDequantLinearOpPass should not be "
"null."
));
// Create pattern
patterns
::
DeleteWeight
DequantLinearOpEncoder
Pattern
pattern
(
patterns
::
DeleteWeight
QuantDequantLinearOp
Pattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
int
found_count
=
0
;
bool
is_int8
=
false
;
// Device context
auto
*
dev_ctx
=
static_cast
<
phi
::
CPUContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
platform
::
CPUPlace
()));
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
LOG(WARNING) << "
trt_
delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8
=
true
;
std
::
unordered_set
<
const
Node
*>
nodes2rm
=
{};
int
bit_length
=
PADDLE_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"bit_length"
));
int
range
=
((
1
<<
(
bit_length
-
1
))
-
1
);
auto
*
any_op2_desc
=
any_op2
->
Op
();
// get weight tensor
auto
*
weight_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_x
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
int8_t
*
quantized_weight_data
=
weight_tensor
->
data
<
int8_t
>
();
auto
w_dims
=
weight_tensor
->
dims
();
// Get weight scale
std
::
vector
<
float
>
weight_scale
;
auto
*
weight_scale_tensor
=
scope
->
GetVar
(
weight_dequantize_linear_op_scale
->
Name
())
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
float
*
weight_scale_data
=
weight_scale_tensor
->
data
<
float
>
();
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT32
)
{
float
*
weight_scale_data
=
weight_scale_tensor
->
data
<
float
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]);
}
}
else
if
(
weight_scale_tensor
->
dtype
()
==
paddle
::
experimental
::
DataType
::
FLOAT16
)
{
phi
::
dtype
::
float16
*
weight_scale_data
=
weight_scale_tensor
->
data
<
phi
::
dtype
::
float16
>
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
static_cast
<
float
>
(
weight_scale_data
[
i
]));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"%d is not supported."
,
weight_scale_tensor
->
dtype
()));
auto
weight_scale_nums
=
weight_scale_tensor
->
numel
();
for
(
int
i
=
0
;
i
<
weight_scale_nums
;
i
++
)
{
weight_scale
.
push_back
(
weight_scale_data
[
i
]
/
range
);
}
// dequant weight
std
::
vector
<
float
>
weight_data_tmp
;
weight_data_tmp
.
reserve
(
weight_tensor
->
numel
());
int
quant_axis
=
PADDLE_GET_CONST
(
int
,
weight_dequantize_linear_op
->
Op
()
->
GetAttr
(
"quant_axis"
));
if
(
quant_axis
==
-
1
)
{
// per_layer quant_dequant: all OP
...
...
@@ -337,13 +342,74 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."
));
// Add attr to anyop 2
any_op2_desc
->
SetAttr
(
"weight_scale"
,
weight_scale
[
0
]);
// float(weight) * scale
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
0
];
}
}
else
if
(
quant_axis
==
0
)
{
// per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
1
]
*
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
/
inner_size
];
}
}
else
if
(
quant_axis
==
1
)
{
PADDLE_ENFORCE_EQ
(
weight_scale_nums
,
w_dims
[
quant_axis
],
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."
));
if
(
w_dims
.
size
()
==
4
)
{
// conv2d_transpose
std
::
string
quantized_op_type
=
any_op2
->
Op
()
->
Type
();
PADDLE_ENFORCE_EQ
(
quantized_op_type
,
"conv2d_transpose"
,
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."
));
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
int
inner_size
=
w_dims
[
2
]
*
w_dims
[
3
];
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[(
i
/
inner_size
)
%
w_dims
[
1
]];
}
}
else
if
(
w_dims
.
size
()
==
2
)
{
for
(
int
i
=
0
;
i
<
weight_tensor
->
numel
();
i
++
)
{
weight_data_tmp
[
i
]
=
static_cast
<
float
>
(
quantized_weight_data
[
i
])
*
weight_scale
[
i
%
w_dims
[
1
]];
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "
));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"
Delete Weight Dequant Linear Op Encoder Pass is not supported for
"
"
per-channel quantization
"
));
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"
quant_axis should be -1 or 0 or 1, please check your model
"
"
OP'attribute
"
));
}
weight_tensor
->
clear
();
// clear int weight
weight_tensor
->
Resize
(
phi
::
make_ddim
(
phi
::
vectorize
(
w_dims
)));
float
*
new_quantized_weight_data
=
dev_ctx
->
HostAlloc
<
float
>
(
weight_tensor
,
weight_tensor
->
numel
()
*
sizeof
(
float
));
memcpy
(
new_quantized_weight_data
,
weight_data_tmp
.
data
(),
weight_tensor
->
numel
()
*
sizeof
(
float
));
nodes2rm
.
insert
(
weight_dequantize_linear_op_scale
);
nodes2rm
.
insert
(
weight_dequantize_linear_op
);
...
...
@@ -358,7 +424,6 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
found_count
++
;
};
gpd
(
graph
,
handler
);
graph
->
Set
(
"enable_int8"
,
new
bool
(
is_int8
));
AddStatis
(
found_count
);
}
...
...
@@ -366,5 +431,5 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
delete_weight_dequant_linear_op_encoder
_pass
,
paddle
::
framework
::
ir
::
DeleteWeightDequantLinearOpEncoder
Pass
);
REGISTER_PASS
(
trt_delete_weight_dequant_linear_op
_pass
,
paddle
::
framework
::
ir
::
TrtDeleteWeightQuantDequantLinearOp
Pass
);
paddle/fluid/framework/ir/
delete_weight_dequant_linear_op_decoder
_pass.h
→
paddle/fluid/framework/ir/
trt_delete_weight_dequant_linear_op
_pass.h
浏览文件 @
95332bef
...
...
@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
...
...
@@ -20,10 +21,10 @@ namespace paddle {
namespace
framework
{
namespace
ir
{
class
DeleteWeightDequantLinearOpDecoder
Pass
:
public
FusePassBase
{
class
TrtDeleteWeightQuantDequantLinearOp
Pass
:
public
FusePassBase
{
public:
DeleteWeightDequantLinearOpDecoder
Pass
();
virtual
~
DeleteWeightDequantLinearOpDecoder
Pass
()
{}
TrtDeleteWeightQuantDequantLinearOp
Pass
();
virtual
~
TrtDeleteWeightQuantDequantLinearOp
Pass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
95332bef
...
...
@@ -84,16 +84,16 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) {
void
PaddlePassBuilder
::
ClearPasses
()
{
passes_
.
clear
();
}
const
std
::
vector
<
std
::
string
>
kTRTSubgraphPasses
({
"adaptive_pool2d_convert_global_pass"
,
//
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_fill_constant_op_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
"delete_weight_dequant_linear_op_pass"
,
//
"delete_quant_dequant_linear_op_pass"
,
//
"identity_scale_op_clean_pass"
,
//
"add_support_int8_pass"
,
//
"adaptive_pool2d_convert_global_pass"
,
//
"shuffle_channel_detect_pass"
,
//
"quant_conv2d_dequant_fuse_pass"
,
//
"delete_fill_constant_op_pass"
,
//
"delete_quant_dequant_op_pass"
,
//
"delete_quant_dequant_filter_op_pass"
,
//
"
trt_
delete_weight_dequant_linear_op_pass"
,
//
"delete_quant_dequant_linear_op_pass"
,
//
"identity_scale_op_clean_pass"
,
//
"add_support_int8_pass"
,
//
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass"
,
//
"trt_embedding_eltwise_layernorm_fuse_pass"
,
//
...
...
@@ -161,8 +161,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
const
std
::
vector
<
std
::
string
>
kGpuLowerPrecisionPasses
{
"simplify_with_basic_ops_pass"
,
"delete_quant_dequant_linear_op_pass"
,
"delete_weight_dequant_linear_op_encoder_pass"
,
"delete_weight_dequant_linear_op_decoder_pass"
,
"delete_weight_dequant_linear_op_pass"
,
"map_depthwise_conv_to_conv_pass"
,
"conv_bn_fuse_pass"
,
"conv_eltwiseadd_bn_fuse_pass"
,
...
...
@@ -210,8 +209,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"is_test_pass"
,
//
"simplify_with_basic_ops_pass"
,
//
"delete_quant_dequant_linear_op_pass"
,
//
"delete_weight_dequant_linear_op_encoder_pass"
,
//
"delete_weight_dequant_linear_op_decoder_pass"
,
//
"delete_weight_dequant_linear_op_pass"
,
//
"map_depthwise_conv_to_conv_pass"
,
//
"conv_bn_fuse_pass"
,
//
"conv_eltwiseadd_bn_fuse_pass"
,
//
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录