Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c71025eb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c71025eb
编写于
8月 26, 2021
作者:
X
XGZhang
提交者:
GitHub
8月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the bug of channel-wise quantization for ernie (#34948)
上级
0efda9d9
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
92 addition
and
33 deletion
+92
-33
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
+4
-0
paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
...erators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
+4
-0
paddle/fluid/operators/fake_dequantize_op.cc
paddle/fluid/operators/fake_dequantize_op.cc
+61
-20
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+10
-7
paddle/fluid/operators/fake_dequantize_op.h
paddle/fluid/operators/fake_dequantize_op.h
+7
-5
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+6
-1
未找到文件。
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
浏览文件 @
c71025eb
...
@@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
...
@@ -115,6 +115,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.
AddAttr
(
"quant_axis"
)
.
AddAttr
(
"quant_axis"
)
.
IsIntIn
({
0
,
1
})
.
IsIntIn
({
0
,
1
})
.
IsOptional
()
.
IsOptional
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsType
<
int
>
()
.
IsOptional
()
.
End
();
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
.
AddInput
(
"Input"
)
...
...
paddle/fluid/operators/compat/fake_channel_wise_dequantize_max_abs.pbtxt
浏览文件 @
c71025eb
...
@@ -17,4 +17,8 @@ def {
...
@@ -17,4 +17,8 @@ def {
name: "quant_axis"
name: "quant_axis"
type: INT
type: INT
}
}
attrs {
name: "x_num_col_dims"
type: INT
}
}
}
paddle/fluid/operators/fake_dequantize_op.cc
浏览文件 @
c71025eb
...
@@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -39,7 +39,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
const
int
x_num_col_dims
,
framework
::
Tensor
*
out
)
{
if
(
scale_num
==
1
)
{
if
(
scale_num
==
1
)
{
// Dequant op is before quantized op
// Dequant op is before quantized op
// Dequantize the weight of quantized op
// Dequantize the weight of quantized op
...
@@ -81,6 +81,32 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -81,6 +81,32 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
}
else
if
(
scale_num
==
2
)
{
}
else
if
(
scale_num
==
2
)
{
// Dequant op is after quantized op
// Dequant op is after quantized op
// Dequantize the output tensor of quantized op
// Dequantize the output tensor of quantized op
if
(
x_num_col_dims
>
1
)
{
auto
in_dims
=
in
->
dims
();
const
int64_t
channel
=
in_dims
[
x_num_col_dims
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int64_t
out_iter
=
1
;
for
(
int
i
=
0
;
i
<
x_num_col_dims
;
i
++
)
{
out_iter
*=
in_dims
[
i
];
}
int64_t
step_i
=
in
->
numel
()
/
out_iter
;
int64_t
step_j
=
in
->
numel
()
/
(
out_iter
*
channel
);
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
out_iter
;
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
channel
;
j
++
)
{
auto
*
cur_in
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out
=
out_data
+
i
*
step_i
+
j
*
step_j
;
T
s
=
scale_one
[
j
];
for
(
int64_t
k
=
0
;
k
<
step_j
;
k
++
)
{
*
cur_out
=
(
*
cur_in
)
*
s
*
scale_two
[
0
]
/
max_range
;
++
cur_in
;
++
cur_out
;
}
}
}
}
else
{
int
batch_size
=
in
->
dims
()[
0
];
int
batch_size
=
in
->
dims
()[
0
];
int
channel
=
in
->
dims
()[
1
];
int
channel
=
in
->
dims
()[
1
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
...
@@ -102,6 +128,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
...
@@ -102,6 +128,7 @@ struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
}
}
}
}
}
}
}
};
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
...
@@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
...
@@ -199,7 +226,16 @@ class FakeChannelWiseDequantizeMaxAbsOpMaker
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
});
});
AddAttr
<
int
>
(
"x_num_col_dims"
,
"The x_num_col_dims of mul. Only used for mul or matmul."
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
x_num_col_dims
)
{
PADDLE_ENFORCE_EQ
(
x_num_col_dims
==
0
,
false
,
platform
::
errors
::
InvalidArgument
(
"'x_num_col_dims' should be larger than 0, but "
"the received is %d"
,
x_num_col_dims
));
});
AddComment
(
R"DOC(
AddComment
(
R"DOC(
FakeChannelWiseDequantizeMaxAbsOp operator.
FakeChannelWiseDequantizeMaxAbsOp operator.
...
@@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
...
@@ -245,4 +281,9 @@ REGISTER_OP_VERSION(fake_channel_wise_dequantize_max_abs)
R"ROC(add new attributes [quant_axis] for applying per-channel "
R"ROC(add new attributes [quant_axis] for applying per-channel "
"dequantization to conv2d_tranpose and mul ops.)ROC"
,
"dequantization to conv2d_tranpose and mul ops.)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
().
NewAttr
(
paddle
::
framework
::
compatible
::
OpVersionDesc
().
NewAttr
(
"quant_axis"
,
"The axis for dequantization."
,
0
));
"quant_axis"
,
"The axis for dequantization."
,
0
))
.
AddCheckpoint
(
R"ROC(add new attributes [x_num_col_dims] for applying per-channel "
"dequantization to mul ops.)ROC"
,
paddle
::
framework
::
compatible
::
OpVersionDesc
().
NewAttr
(
"x_num_col_dims"
,
"The x_num_col_dims for dequantization."
,
1
));
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
c71025eb
...
@@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
...
@@ -77,9 +77,9 @@ __global__ void DequantizeOneScaleQuantAxis1(const T* in, const T* scale,
template
<
typename
T
>
template
<
typename
T
>
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
int
batch
_size
,
int
channel
,
T
*
out
)
{
int
iter
_size
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
(
batch
_size
*
channel
);
int
channel_size
=
num
/
(
iter
_size
*
channel
);
int
scale_index
=
blockIdx
.
x
%
channel
;
int
scale_index
=
blockIdx
.
x
%
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
...
@@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
...
@@ -93,7 +93,7 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
const
int
x_num_col_dims
,
framework
::
Tensor
*
out
)
{
auto
in_dims
=
in
->
dims
();
auto
in_dims
=
in
->
dims
();
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
...
@@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
...
@@ -116,14 +116,17 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
}
else
if
(
scale_num
==
2
)
{
}
else
if
(
scale_num
==
2
)
{
// Not need to consider quant_axis
// Not need to consider quant_axis
int
num
=
in
->
numel
();
int
num
=
in
->
numel
();
int
batch_size
=
in
->
dims
()[
0
];
int
iter_size
=
1
;
int
channel
=
in
->
dims
()[
1
];
for
(
int
i
=
0
;
i
<
x_num_col_dims
;
i
++
)
{
iter_size
*=
in
->
dims
()[
i
];
}
int
channel
=
in
->
dims
()[
x_num_col_dims
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int
block
=
1024
;
int
block
=
1024
;
int
grid
=
batch
_size
*
channel
;
int
grid
=
iter
_size
*
channel
;
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
batch
_size
,
channel
,
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
iter
_size
,
channel
,
out_data
);
out_data
);
}
}
}
}
...
...
paddle/fluid/operators/fake_dequantize_op.h
浏览文件 @
c71025eb
...
@@ -33,7 +33,8 @@ template <typename DeviceContext, typename T>
...
@@ -33,7 +33,8 @@ template <typename DeviceContext, typename T>
struct
ChannelDequantizeFunctor
{
struct
ChannelDequantizeFunctor
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
T
max_range
,
const
int
quant_axis
,
const
int
x_num_col_dims
,
framework
::
Tensor
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
...
@@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
...
@@ -64,6 +65,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
auto
quant_bits
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"quant_bits"
);
auto
quant_bits
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"quant_bits"
);
auto
quant_axis
=
ctx
.
Attr
<
int
>
(
"quant_axis"
);
auto
quant_axis
=
ctx
.
Attr
<
int
>
(
"quant_axis"
);
auto
x_num_col_dims
=
ctx
.
Attr
<
int
>
(
"x_num_col_dims"
);
int
max_range
=
1
;
int
max_range
=
1
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
...
@@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
...
@@ -80,11 +82,11 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
);
max_range
*=
(
std
::
pow
(
2
,
quant_bits
[
0
]
-
1
)
-
1
);
}
else
if
(
scale_num
==
2
)
{
}
else
if
(
scale_num
==
2
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
scales
[
0
]
->
numel
(),
in
->
dims
()[
1
],
scales
[
0
]
->
numel
(),
in
->
dims
()[
x_num_col_dims
],
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The number of first scale values must be the same with "
"The number of first scale values must be the same with "
"
second dimension value of Input(X) when the `Scales` has two
"
"
corresponding dimension value of Input(X) when the `Scales`
"
"elements, but %ld != %ld here."
,
"
has two
elements, but %ld != %ld here."
,
scales
[
0
]
->
numel
(),
in
->
dims
()[
1
]));
scales
[
0
]
->
numel
(),
in
->
dims
()[
1
]));
PADDLE_ENFORCE_EQ
(
scales
[
1
]
->
numel
(),
1
,
PADDLE_ENFORCE_EQ
(
scales
[
1
]
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
...
@@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
...
@@ -96,7 +98,7 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel<T> {
}
}
ChannelDequantizeFunctor
<
DeviceContext
,
T
>
()(
ChannelDequantizeFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scales
.
data
(),
scale_num
,
static_cast
<
T
>
(
max_range
),
dev_ctx
,
in
,
scales
.
data
(),
scale_num
,
static_cast
<
T
>
(
max_range
),
quant_axis
,
out
);
quant_axis
,
x_num_col_dims
,
out
);
}
}
};
};
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
c71025eb
...
@@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object):
...
@@ -1273,12 +1273,17 @@ class QuantizationFreezePass(object):
var_type
=
output_var_node
.
type
(),
var_type
=
output_var_node
.
type
(),
shape
=
output_var_node
.
shape
(),
shape
=
output_var_node
.
shape
(),
var_dtype
=
output_var_node
.
dtype
())
var_dtype
=
output_var_node
.
dtype
())
if
op_node
.
op
().
has_attr
(
"x_num_col_dims"
):
x_num_col_dims
=
op_node
.
op
().
attr
(
"x_num_col_dims"
)
else
:
x_num_col_dims
=
1
dequant_op_node
=
graph
.
create_op_node
(
dequant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
op_type
=
'fake_channel_wise_dequantize_max_abs'
,
attrs
=
{
attrs
=
{
'quant_bits'
:
[
self
.
_weight_bits
,
self
.
_activation_bits
],
'quant_bits'
:
[
self
.
_weight_bits
,
self
.
_activation_bits
],
'quant_axis'
:
quant_axis
,
'quant_axis'
:
quant_axis
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
,
'x_num_col_dims'
:
x_num_col_dims
},
},
inputs
=
{
inputs
=
{
'X'
:
output_var_node
,
'X'
:
output_var_node
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录