Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
75144f13
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
75144f13
编写于
6月 21, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
6月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update quantization round and clip calculation rules (#42695)
上级
ff7d2464
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
653 addition
and
253 deletion
+653
-253
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
+8
-0
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
...fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
+8
-0
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
...luid/framework/ir/delete_weight_dequant_linear_op_pass.cc
+8
-0
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
+8
-0
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+95
-32
paddle/fluid/operators/fake_quantize_op.cu.h
paddle/fluid/operators/fake_quantize_op.cu.h
+92
-44
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+69
-23
paddle/fluid/operators/quantize_linear_op.cc
paddle/fluid/operators/quantize_linear_op.cc
+14
-2
paddle/fluid/operators/quantize_linear_op.h
paddle/fluid/operators/quantize_linear_op.h
+5
-4
python/paddle/fluid/contrib/slim/quantization/adaround.py
python/paddle/fluid/contrib/slim/quantization/adaround.py
+11
-1
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+31
-15
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+76
-17
python/paddle/fluid/contrib/slim/quantization/utils.py
python/paddle/fluid/contrib/slim/quantization/utils.py
+22
-12
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
+1
-1
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
+1
-1
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
.../slim/tests/test_post_training_quantization_lstm_model.py
+8
-8
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
...ntrib/slim/tests/test_post_training_quantization_mnist.py
+49
-34
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
...slim/tests/test_post_training_quantization_mobilenetv1.py
+17
-16
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
...ib/slim/tests/test_post_training_quantization_resnet50.py
+4
-4
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+126
-39
未找到文件。
paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc
浏览文件 @
75144f13
...
...
@@ -45,6 +45,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsIntIn
({
8
,
16
})
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsIntIn
({
0
,
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"fake_channel_wise_quantize_dequantize_abs_max"
))
.
AddInput
(
"X"
)
...
...
@@ -61,6 +65,10 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() {
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsIntIn
({
0
,
1
})
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsIntIn
({
0
,
1
})
.
End
();
}
// Delete quant_dequant_op, then quantize and dequantize weight
...
...
paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc
浏览文件 @
75144f13
...
...
@@ -54,6 +54,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
...
...
@@ -74,6 +78,10 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
...
...
paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc
浏览文件 @
75144f13
...
...
@@ -52,6 +52,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"dequantize_linear"
))
.
AddInput
(
"X"
)
...
...
@@ -72,6 +76,10 @@ DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
.
End
()
.
AddAttr
(
"quant_axis"
)
.
IsType
<
int
>
()
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsType
<
int
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"conv2d"
))
.
AddInput
(
"Input"
)
...
...
paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.cc
浏览文件 @
75144f13
...
...
@@ -49,6 +49,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsIntIn
({
8
,
16
})
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsIntIn
({
0
,
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"fake_quantize_moving_average_abs_max"
))
.
AddInput
(
"X"
)
...
...
@@ -85,6 +89,10 @@ QuantDequantFusePass::QuantDequantFusePass() {
.
End
()
.
AddAttr
(
"bit_length"
)
.
IsIntIn
({
8
,
16
})
.
End
()
.
AddAttr
(
"round_type"
)
.
IsOptional
()
.
IsIntIn
({
0
,
1
})
.
End
();
AddOpCompat
(
OpCompat
(
"fake_dequantize_max_abs"
))
.
AddInput
(
"X"
)
...
...
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
75144f13
...
...
@@ -88,14 +88,14 @@ template <typename T>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
T
inv_s
=
inverse
(
s
);
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
();
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
}
};
...
...
@@ -105,16 +105,17 @@ template <typename T>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
T
inv_s
=
inverse
(
s
);
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
()
*
s
/
static_cast
<
T
>
(
bin_cnt
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
};
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
...
...
@@ -124,7 +125,7 @@ template <typename T>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
...
...
@@ -145,15 +146,10 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
}
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
T
inv_s
=
inverse
(
s
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
(
);
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
)
);
}
}
else
if
(
quant_axis
==
1
)
{
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
...
...
@@ -165,10 +161,9 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
trans
(
ctx
,
start
,
end
,
cur_out_data
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
for
(
int
k
=
0
;
k
<
step_j
;
k
++
)
{
cur_out_data
[
k
]
=
std
::
round
(
bin_cnt
*
inv_s
*
cur_out_data
[
k
]);
}
trans
(
ctx
,
start
,
end
,
cur_out_data
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
}
}
}
...
...
@@ -181,7 +176,7 @@ template <typename T>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
...
...
@@ -201,16 +196,13 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
}
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
T
inv_s
=
inverse
(
s
);
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
()
*
s
/
static_cast
<
T
>
(
bin_cnt
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
else
if
(
quant_axis
==
1
)
{
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
...
...
@@ -222,10 +214,11 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
trans
(
ctx
,
start
,
end
,
cur_out_data
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
trans
(
ctx
,
start
,
end
,
cur_out_data
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
for
(
int
k
=
0
;
k
<
step_j
;
k
++
)
{
cur_out_data
[
k
]
=
std
::
round
(
bin_cnt
*
inv_s
*
cur_out_data
[
k
])
*
s
/
static_cast
<
T
>
(
bin_cnt
);
cur_out_data
[
k
]
=
cur_out_data
[
k
]
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
}
...
...
@@ -334,6 +327,20 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddComment
(
R"DOC(
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
...
...
@@ -407,6 +414,20 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
...
...
@@ -480,6 +501,20 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddComment
(
R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
...
...
@@ -546,6 +581,20 @@ class FakeQuantizeRangeAbsMaxOpMaker
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
...
...
@@ -620,6 +669,20 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
...
...
paddle/fluid/operators/fake_quantize_op.cu.h
浏览文件 @
75144f13
...
...
@@ -214,7 +214,8 @@ template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -226,16 +227,24 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
template
<
typename
T
>
__global__
void
ClipAndQuantDequantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
...
...
@@ -248,10 +257,16 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
round
(
x
);
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
((
x
*
s
)
/
bin_cnt_t
);
}
}
...
...
@@ -260,7 +275,8 @@ template <typename T>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
...
...
@@ -270,7 +286,7 @@ struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
}
};
...
...
@@ -280,7 +296,8 @@ template <typename T>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
...
...
@@ -290,7 +307,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantDequantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
}
};
...
...
@@ -298,6 +315,7 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
int64_t
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
...
...
@@ -314,18 +332,25 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in_c
[
i
]);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out_c
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
// ChannelClipAndQuantKernel for quant_axis is N
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
...
...
@@ -334,10 +359,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
static_cast
<
ComputeDataType
>
(
scale
[(
i
/
quant_stride
)
%
nScale
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
...
...
@@ -345,7 +377,7 @@ template <typename T>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
...
...
@@ -363,7 +395,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
0
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
...
...
@@ -380,8 +412,8 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
ChannelClipAndQuantKernelQuantAxisN
<
T
><<<
grid_size
,
block_size
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
quant_axis
]
,
quant_stride
,
out_data
);
}
}
};
...
...
@@ -485,8 +517,8 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
...
...
@@ -498,18 +530,25 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
x
=
bin_cnt
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis1
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
inv_s
=
inverse
(
s
);
...
...
@@ -519,10 +558,17 @@ __global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
x
=
bin_cnt
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
x
=
roundWithTiesToEven
(
x
);
}
else
{
x
=
round
(
x
);
}
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
}
...
...
@@ -530,7 +576,7 @@ template <typename T>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
...
...
@@ -551,15 +597,17 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis0
<
T
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
0
],
out_data
);
}
else
if
(
quant_axis
==
1
)
{
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis1
<
T
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
}
};
...
...
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
75144f13
...
...
@@ -34,6 +34,46 @@ inline HOSTDEVICE T inverse(T s) {
return
s
<=
static_cast
<
T
>
(
1e-30
)
?
one
/
(
s
+
eps
)
:
one
/
s
;
}
template
<
typename
T
>
inline
HOSTDEVICE
T
roundWithTiesToEven
(
T
x
)
{
T
xLower
=
floor
(
x
);
T
xUpper
=
ceil
(
x
);
// x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
// even.
T
dLower
=
x
-
xLower
;
T
dUpper
=
xUpper
-
x
;
return
static_cast
<
T
>
(
(
dLower
==
dUpper
?
fmod
(
xLower
,
2.0
F
)
==
0.0
F
:
dLower
<
dUpper
)
?
xLower
:
xUpper
);
}
template
<
typename
T
>
class
QuantTensorFunctor
{
public:
explicit
QuantTensorFunctor
(
const
T
bin_cnt
,
const
int
round_type
,
const
T
inv_s
)
:
bin_cnt_
(
bin_cnt
),
round_type_
(
round_type
),
inv_s_
(
inv_s
)
{}
HOSTDEVICE
T
operator
()(
const
T
x
)
const
{
T
out
=
bin_cnt_
*
inv_s_
*
x
;
if
(
round_type_
==
0
)
{
out
=
roundWithTiesToEven
(
out
);
}
else
if
(
round_type_
==
1
)
{
out
=
std
::
round
(
out
);
}
T
max_bound
=
bin_cnt_
;
T
min_bound
=
-
bin_cnt_
-
static_cast
<
T
>
(
1
);
out
=
out
>
max_bound
?
max_bound
:
out
;
out
=
out
<
min_bound
?
min_bound
:
out
;
return
out
;
}
private:
T
bin_cnt_
;
int
round_type_
;
T
inv_s_
;
};
template
<
typename
DeviceContext
,
typename
T
>
struct
FindAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
);
...
...
@@ -43,14 +83,14 @@ template <typename DeviceContext, typename T>
struct
ClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
);
const
int
round_type
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
);
int
round_type
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -71,14 +111,15 @@ template <typename DeviceContext, typename T>
struct
ChannelClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -100,12 +141,13 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
const
T
*
in_data
=
in
->
data
<
T
>
();
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
}
virtual
~
FakeAbsMaxKernelBase
()
=
default
;
...
...
@@ -114,7 +156,7 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
=
0
;
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -122,9 +164,9 @@ class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
out
);
round_type
,
out
);
}
};
...
...
@@ -134,9 +176,9 @@ class FakeQuantizeDequantizeAbsMaxKernel
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
out
);
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
round_type
,
out
);
}
};
...
...
@@ -151,6 +193,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -162,7 +205,7 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
out_scale_data
);
}
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
}
};
...
...
@@ -179,6 +222,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
...
...
@@ -186,7 +230,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
out_scale_data
);
ChannelClipFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
}
};
...
...
@@ -202,13 +246,14 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
// testing
if
(
is_test
)
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
out
);
bin_cnt
,
round_type
,
out
);
return
;
}
...
...
@@ -228,7 +273,7 @@ class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
*
iter
,
window_size
,
out_scales
,
out_scale
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
bin_cnt
,
round_type
,
out
);
}
};
...
...
@@ -243,12 +288,13 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
// testing
if
(
is_test
)
{
RunClipFunctor
(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
round_type
,
out
);
return
;
}
...
...
@@ -273,7 +319,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
dev_ctx
,
*
in_accum
,
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
out_accum
,
out_scale
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
}
virtual
~
FakeMovingAverageAbsMaxKernelBase
()
=
default
;
...
...
@@ -282,7 +328,7 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
=
0
;
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
};
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -291,9 +337,9 @@ class FakeQuantizeMovingAverageAbsMaxKernel
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
out
);
round_type
,
out
);
}
};
...
...
@@ -303,9 +349,9 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
out
);
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
round_type
,
out
);
}
};
...
...
paddle/fluid/operators/quantize_linear_op.cc
浏览文件 @
75144f13
...
...
@@ -69,8 +69,6 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
}
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
double
>;
...
...
@@ -135,6 +133,20 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"the received is %d"
,
bit_length
));
});
AddAttr
<
int
>
(
"round_type"
,
"(int, default 0) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'round_type' should be between 0 and 1, but "
"the received is %d"
,
round_type
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
...
...
paddle/fluid/operators/quantize_linear_op.h
浏览文件 @
75144f13
...
...
@@ -45,6 +45,7 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Y"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
...
...
@@ -57,10 +58,10 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
out_s
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
bin_cnt
,
round_type
,
out
);
}
else
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
out
);
bin_cnt
,
round_type
,
out
);
}
}
else
{
if
(
!
is_test
)
{
...
...
@@ -69,10 +70,10 @@ class QuantizeLinearKernel : public framework::OpKernel<T> {
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
quant_axis
,
out_scale_data
);
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
}
else
{
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
}
}
}
...
...
python/paddle/fluid/contrib/slim/quantization/adaround.py
浏览文件 @
75144f13
...
...
@@ -20,7 +20,7 @@ import logging
import
paddle.fluid
as
fluid
from
....log_helper
import
get_logger
from
.utils
import
load_variable_data
,
set_variable_data
,
stable_sigmoid
,
quant_tensor
,
dequant_tensor
,
_channelwise_quant_axis1_ops
,
calculate_quant_cos_error
from
.utils
import
load_variable_data
,
set_variable_data
,
stable_sigmoid
,
quant_tensor
,
dequant_tensor
,
_channelwise_quant_axis1_ops
,
calculate_quant_cos_error
,
bias_correction_w
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
...
...
@@ -209,6 +209,7 @@ def run_adaround(data_loader,
scale_dict
,
num_iterations
=
1000
,
lr
=
0.001
,
bias_correction
=
False
,
fast_mode
=
True
):
fetch_op_name
=
fetch_list
[
0
].
name
final_weight_tensor_quant_dict
=
{}
...
...
@@ -307,6 +308,15 @@ def run_adaround(data_loader,
break
final_weight_tensor_quant_dict
[
weight_var_name
]
=
adaround
.
update_final_weights
()
if
bias_correction
:
final_weight_tensor_quant_dict
[
weight_var_name
]
=
bias_correction_w
(
weight_var_tensor
,
final_weight_tensor_quant_dict
[
weight_var_name
],
scale
,
adaround
.
quant_axis
,
weight_bits
=
adaround
.
weight_bits
)
del
adaround
# update adarounded calibrated weights
...
...
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
75144f13
...
...
@@ -121,7 +121,8 @@ class PostTrainingQuantization(object):
algo
=
"KL"
,
hist_percent
=
0.99999
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
round_type
=
'round'
,
weight_round_algo
=
'round'
,
round_type
=
'TiesToEven'
,
learning_rate
=
0.001
,
is_full_quantize
=
False
,
bias_correction
=
False
,
...
...
@@ -180,9 +181,14 @@ class PostTrainingQuantization(object):
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
round_type
(str, optional): The method of converting the quantized weights
weight_round_algo
(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
...
...
@@ -263,8 +269,10 @@ class PostTrainingQuantization(object):
self
.
_support_algo_type
=
[
'KL'
,
'hist'
,
'avg'
,
'mse'
,
'emd'
,
'abs_max'
,
'min_max'
]
assert
round_type
in
[
'
adaround'
,
'round
'
]
assert
round_type
in
[
'
TiesToEven'
,
'TiesAwayFromZero
'
]
self
.
_round_type
=
round_type
assert
weight_round_algo
in
[
'adaround'
,
'round'
]
self
.
_weight_round_algo
=
weight_round_algo
self
.
_learning_rate
=
learning_rate
self
.
_dynamic_quantize_op_type
=
[
'lstm'
]
self
.
_support_quantize_op_type
=
\
...
...
@@ -406,7 +414,7 @@ class PostTrainingQuantization(object):
if
self
.
_algo
in
[
"KL"
,
"hist"
]:
self
.
_calculate_kl_hist_threshold
()
if
self
.
_
round_type
==
'adaround'
:
if
self
.
_
weight_round_algo
==
'adaround'
:
self
.
_adaround_apply
()
self
.
_reset_activation_persistable
()
...
...
@@ -459,6 +467,7 @@ class PostTrainingQuantization(object):
self
.
_weight_op_pairs
,
scale_dict
,
num_iterations
=
self
.
_batch_nums
,
bias_correction
=
self
.
_bias_correction
,
lr
=
self
.
_learning_rate
)
def
save_quantized_model
(
self
,
...
...
@@ -642,6 +651,7 @@ class PostTrainingQuantization(object):
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
_logger
.
info
(
"MSE searching stage ..."
)
distribution
=
np
.
round
if
self
.
_round_type
==
'TiesToEven'
else
utils
.
round_c
for
var_name
in
self
.
_quantized_act_var_name
:
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
var_tensor
.
flatten
()
...
...
@@ -654,9 +664,9 @@ class PostTrainingQuantization(object):
scale
=
s
*
abs_max_value
s
+=
0.02
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
quant_
dequant_var
=
np
.
round
(
np
.
clip
(
var_tensor
,
0.0
,
scale
)
/
scale
*
bins
)
/
bins
*
scale
quant_
var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
-
bins
-
1
,
bins
)
quant_dequant_var
=
quant_var
/
bins
*
scale
mse_loss
=
((
var_tensor
-
quant_dequant_var
)
**
2
).
mean
()
if
mse_loss
<=
self
.
_best_calibration_loss
[
var_name
]:
self
.
_best_calibration_loss
[
var_name
]
=
mse_loss
...
...
@@ -681,6 +691,7 @@ class PostTrainingQuantization(object):
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
_logger
.
info
(
"EMD searching stage ..."
)
distribution
=
np
.
round
if
self
.
_round_type
==
'TiesToEven'
else
utils
.
round_c
for
var_name
in
self
.
_quantized_act_var_name
:
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
var_tensor
.
flatten
()
...
...
@@ -693,9 +704,9 @@ class PostTrainingQuantization(object):
scale
=
s
*
abs_max_value
s
+=
0.02
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
quant_
dequant_var
=
np
.
round
(
np
.
clip
(
var_tensor
,
0.0
,
scale
)
/
scale
*
bins
)
/
bins
*
scale
quant_
var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
-
bins
-
1
,
bins
)
quant_dequant_var
=
quant_var
/
bins
*
scale
emd_loss
=
np
.
abs
(
np
.
mean
(
var_tensor
)
-
np
.
mean
(
quant_dequant_var
))
+
np
.
abs
(
np
.
std
(
var_tensor
)
-
np
.
std
(
quant_dequant_var
))
...
...
@@ -907,7 +918,8 @@ class PostTrainingQuantization(object):
activation_bits
=
self
.
_activation_bits
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
quantizable_op_type
=
major_quantizable_op_types
)
quantizable_op_type
=
major_quantizable_op_types
,
round_type
=
self
.
_round_type
)
else
:
transform_pass
=
QuantizationTransformPassV2
(
scope
=
self
.
_scope
,
...
...
@@ -916,7 +928,8 @@ class PostTrainingQuantization(object):
activation_bits
=
self
.
_activation_bits
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
quantizable_op_type
=
major_quantizable_op_types
)
quantizable_op_type
=
major_quantizable_op_types
,
round_type
=
self
.
_round_type
)
for
sub_graph
in
graph
.
all_sub_graphs
():
# Insert fake_quant/fake_dequantize op must in test graph, so
...
...
@@ -933,13 +946,15 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass
=
AddQuantDequantPass
(
scope
=
self
.
_scope
,
place
=
self
.
_place
,
quantizable_op_type
=
minor_quantizable_op_types
)
quantizable_op_type
=
minor_quantizable_op_types
,
round_type
=
self
.
_round_type
)
else
:
add_quant_dequant_pass
=
AddQuantDequantPassV2
(
scope
=
self
.
_scope
,
place
=
self
.
_place
,
quantizable_op_type
=
minor_quantizable_op_types
,
is_full_quantized
=
self
.
_is_full_quantize
)
is_full_quantized
=
self
.
_is_full_quantize
,
round_type
=
self
.
_round_type
)
for
sub_graph
in
graph
.
all_sub_graphs
():
sub_graph
.
_for_test
=
True
...
...
@@ -964,6 +979,7 @@ class PostTrainingQuantization(object):
place
=
self
.
_place
,
bias_correction
=
self
.
_bias_correction
,
weight_bits
=
self
.
_weight_bits
,
weight_round_algo
=
self
.
_weight_round_algo
,
round_type
=
self
.
_round_type
,
activation_bits
=
self
.
_activation_bits
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
75144f13
...
...
@@ -119,6 +119,7 @@ class QuantizationTransformPass(object):
moving_rate
=
0.9
,
skip_pattern
=
[
'skip_quant'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
round_type
=
'TiesToEven'
,
weight_quantize_func
=
None
,
act_quantize_func
=
None
,
weight_preprocess_func
=
None
,
...
...
@@ -156,6 +157,10 @@ class QuantizationTransformPass(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
...
...
@@ -206,6 +211,7 @@ class QuantizationTransformPass(object):
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
...
...
@@ -459,10 +465,12 @@ class QuantizationTransformPass(object):
_init_var_node
(
scale_var_node
,
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
attrs
=
{
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
inputs
=
{
'X'
:
var_node
},
...
...
@@ -517,9 +525,11 @@ class QuantizationTransformPass(object):
inputs
[
'Iter'
]
=
self
.
_global_step
outputs
[
'OutScales'
]
=
scales_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
'window_size'
:
self
.
_window_size
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
}
...
...
@@ -590,8 +600,10 @@ class QuantizationTransformPass(object):
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
...
@@ -638,10 +650,12 @@ class QuantizationTransformPass(object):
_init_var_node
(
scale_var_node
,
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_quantize_abs_max'
,
attrs
=
{
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'quant_axis'
:
quant_axis
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
...
@@ -935,7 +949,8 @@ class QuantizationFreezePass(object):
bias_correction
=
False
,
weight_bits
=
8
,
activation_bits
=
8
,
round_type
=
'round'
,
weight_round_algo
=
'round'
,
round_type
=
'TiesToEven'
,
weight_quantize_type
=
'abs_max'
,
quantizable_op_type
=
None
):
"""
...
...
@@ -953,9 +968,14 @@ class QuantizationFreezePass(object):
https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
round_type(str, optional): The method of converting the quantized weights
value from float to int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the nearest whole number.
weight_round_algo(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
...
...
@@ -971,6 +991,7 @@ class QuantizationFreezePass(object):
self
.
_place
=
_get_paddle_place
(
place
)
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_weight_round_algo
=
weight_round_algo
self
.
_round_type
=
round_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_fake_quant_op_names
=
_fake_quant_op_list
...
...
@@ -1018,8 +1039,8 @@ class QuantizationFreezePass(object):
scale_v
=
scale_v
.
tolist
()
self
.
_quant_var_scale_map
[
input_arg_name
]
=
scale_v
# Quantize weight and restore
param_v
=
self
.
_load_var
(
input_arg_name
)
if
self
.
_round_type
==
'round'
:
if
self
.
_weight_round_algo
==
'round'
:
param_v
=
self
.
_load_var
(
input_arg_name
)
if
any
(
_check_grandchild_op_node
(
op_node
,
op
)
for
op
in
utils
.
_channelwise_quant_axis1_ops
):
...
...
@@ -1028,8 +1049,8 @@ class QuantizationFreezePass(object):
quant_axis
=
0
quantized_param_v
=
utils
.
quant_tensor
(
param_v
.
copy
(),
scale_v
,
quant_axis
,
self
.
_weight_bits
)
quantized_param_v
=
np
.
round
(
quantized_param_v
)
self
.
_weight_bits
,
self
.
_round_type
)
# Weight bias correction
if
self
.
_bias_correction
==
True
:
quantized_param_v
=
utils
.
bias_correction_w
(
param_v
,
...
...
@@ -1037,7 +1058,6 @@ class QuantizationFreezePass(object):
scale_v
,
quant_axis
,
weight_bits
=
self
.
_weight_bits
)
quantized_param_v
=
np
.
round
(
quantized_param_v
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
...
...
@@ -1580,7 +1600,8 @@ class AddQuantDequantPass(object):
quant_bits
=
8
,
skip_pattern
=
[
"skip_quant"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
is_full_quantized
=
False
):
is_full_quantized
=
False
,
round_type
=
'TiesToEven'
):
"""
Constructor.
...
...
@@ -1602,6 +1623,10 @@ class AddQuantDequantPass(object):
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
self
.
_scope
=
scope
self
.
_place
=
_get_paddle_place
(
place
)
...
...
@@ -1609,6 +1634,7 @@ class AddQuantDequantPass(object):
self
.
_quant_bits
=
quant_bits
self
.
_is_test
=
None
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
if
is_full_quantized
:
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
...
...
@@ -1743,8 +1769,10 @@ class AddQuantDequantPass(object):
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
...
@@ -1784,6 +1812,10 @@ class InsertQuantizeLinear(object):
Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
is_test(bool, optional): Whether quantization with training or not. Default is True.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
def
__init__
(
self
,
...
...
@@ -1792,13 +1824,15 @@ class InsertQuantizeLinear(object):
quant_bits
=
8
,
quant_axis
=-
1
,
channel_wise
=
False
,
is_test
=
True
):
is_test
=
True
,
round_type
=
'TiesToEven'
):
self
.
_place
=
place
self
.
_scope
=
scope
self
.
quant_bits
=
quant_bits
self
.
quant_axis
=
quant_axis
self
.
channel_wise
=
channel_wise
self
.
_is_test
=
is_test
self
.
_round_type
=
round_type
def
insert_quant_op
(
self
,
graph
,
var_node
):
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
...
...
@@ -1841,7 +1875,12 @@ class InsertQuantizeLinear(object):
if
zero_point_node
is
not
None
:
inputs
[
"ZeroPoint"
]
=
zero_point_node
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
,
"round_type"
:
round_type
}
outputs
=
{
"Y"
:
quant_var_node
}
if
not
self
.
_is_test
:
attrs
[
"is_test"
]
=
self
.
_is_test
...
...
@@ -1946,6 +1985,7 @@ class QuantizationTransformPassV2(object):
moving_rate
=
0.9
,
skip_pattern
=
[
'skip_quant'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
round_type
=
'TiesToEven'
,
weight_quantize_func
=
None
,
act_quantize_func
=
None
,
weight_preprocess_func
=
None
,
...
...
@@ -1981,6 +2021,10 @@ class QuantizationTransformPassV2(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
...
...
@@ -2030,6 +2074,7 @@ class QuantizationTransformPassV2(object):
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
...
...
@@ -2153,7 +2198,8 @@ class QuantizationTransformPassV2(object):
quant_bits
=
quant_bits
,
quant_axis
=
quant_axis
,
channel_wise
=
channel_wise
,
is_test
=
self
.
_is_test
)
is_test
=
self
.
_is_test
,
round_type
=
self
.
_round_type
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
var_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
...
...
@@ -2261,7 +2307,8 @@ class AddQuantDequantPassV2(object):
quant_bits
=
8
,
skip_pattern
=
[
"skip_quant"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
is_full_quantized
=
False
):
is_full_quantized
=
False
,
round_type
=
'TiesToEven'
):
"""
Args:
scope(paddle.Scope): The scope is used to initialize these new parameters.
...
...
@@ -2281,6 +2328,10 @@ class AddQuantDequantPassV2(object):
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
Examples:
.. code-block:: python
...
...
@@ -2303,6 +2354,7 @@ class AddQuantDequantPassV2(object):
self
.
_quant_bits
=
quant_bits
self
.
_is_test
=
None
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
if
is_full_quantized
:
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
...
...
@@ -2375,7 +2427,8 @@ class AddQuantDequantPassV2(object):
quant_bits
=
self
.
_quant_bits
,
quant_axis
=-
1
,
channel_wise
=
False
,
is_test
=
self
.
_is_test
)
is_test
=
self
.
_is_test
,
round_type
=
self
.
_round_type
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
in_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
...
...
@@ -2458,6 +2511,8 @@ class ReplaceFakeQuantDequantPass(object):
"quant_axis"
)
else
-
1
bit_length
=
op
.
op
().
attr
(
"bit_length"
)
if
op
.
op
().
has_attr
(
"bit_length"
)
else
8
round_type
=
op
.
op
().
attr
(
"round_type"
)
if
op
.
op
().
has_attr
(
"round_type"
)
else
0
zero_point_node
=
None
quanted_node
=
x_node
...
...
@@ -2479,7 +2534,8 @@ class ReplaceFakeQuantDequantPass(object):
quant_op_node
=
graph
.
create_op_node
(
op_type
=
"quantize_linear"
,
attrs
=
{
"quant_axis"
:
quant_axis
,
"bit_length"
:
bit_length
"bit_length"
:
bit_length
,
"round_type"
:
round_type
},
inputs
=
{
"X"
:
x_node
,
...
...
@@ -2598,8 +2654,11 @@ class QuantWeightPass(object):
param_v
=
self
.
_load_var
(
x_node
.
name
())
quant_axis
=
_op
.
op
().
attr
(
"quant_axis"
)
bits_length
=
_op
.
op
().
attr
(
"bit_length"
)
round_type
=
_op
.
op
().
attr
(
"round_type"
)
if
_op
.
op
().
has_attr
(
"round_type"
)
else
0
quantized_param_v
=
utils
.
quant_tensor
(
param_v
.
copy
(),
scale_v
,
quant_axis
,
bits_length
)
quant_axis
,
bits_length
,
round_type
)
if
self
.
_bias_correction
==
True
:
quantized_param_v
=
utils
.
bias_correction_w
(
param_v
,
...
...
python/paddle/fluid/contrib/slim/quantization/utils.py
浏览文件 @
75144f13
...
...
@@ -321,29 +321,39 @@ def set_variable_data(scope, place, var_name, np_value):
tensor
.
set
(
np_value
,
place
)
def
quant_tensor
(
x
,
scale
,
quant_axis
=
0
,
weight_bits
=
8
):
# symmetry quant
def
_clip
(
x
,
scale
):
x
[
x
>
scale
]
=
scale
x
[
x
<
-
scale
]
=
-
scale
return
x
def
round_c_single_element
(
val
):
dtype
=
type
(
val
)
if
val
>=
0
:
return
dtype
(
np
.
floor
(
val
+
0.5
))
return
dtype
(
np
.
ceil
(
val
-
0.5
))
# rounding to nearest ties away from zero
round_c
=
np
.
vectorize
(
round_c_single_element
)
def
quant_tensor
(
x
,
scale
,
quant_axis
=
0
,
weight_bits
=
8
,
round_type
=
'TiesToEven'
):
assert
quant_axis
in
[
0
,
1
],
'quant_axis should be 0 or 1 for now.'
distribution
=
np
.
round
if
round_type
==
'TiesToEven'
else
round_c
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
if
isinstance
(
scale
,
list
):
for
i
,
s
in
enumerate
(
scale
):
if
s
==
0.0
:
s
=
1e-8
if
quant_axis
==
0
:
x
[
i
]
=
_clip
(
x
[
i
],
s
)
x
[
i
]
=
x
[
i
]
/
s
*
bnt
x
[
i
]
=
distribution
(
x
[
i
]
/
s
*
bnt
)
x
[
i
]
=
np
.
clip
(
x
[
i
],
-
bnt
-
1
,
bnt
)
else
:
x
[:,
i
]
=
_clip
(
x
[:,
i
],
s
)
x
[:,
i
]
=
x
[:,
i
]
/
s
*
bnt
x
[:,
i
]
=
distribution
(
x
[:,
i
]
/
s
*
bnt
)
x
[:,
i
]
=
np
.
clip
(
x
[:,
i
],
-
bnt
-
1
,
bnt
)
else
:
scale
=
1e-8
if
scale
==
0.0
else
scale
x
=
_clip
(
x
,
scale
)
x
=
x
/
scale
*
bnt
x
=
distribution
(
x
/
scale
*
bnt
)
x
=
np
.
clip
(
x
,
-
bnt
-
1
,
bnt
)
return
x
...
...
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
浏览文件 @
75144f13
...
...
@@ -558,7 +558,7 @@ if(LINUX AND WITH_MKLDNN)
120
)
set_tests_properties
(
test_quant2_int8_ernie_mkldnn PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_quant_int8_googlenet_mkldnn PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT
12
0
)
set_tests_properties
(
test_quant2_int8_resnet50_mkldnn PROPERTIES TIMEOUT
20
0
)
set_tests_properties
(
test_quant2_int8_lstm_mkldnn PROPERTIES TIMEOUT 120
)
endif
()
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_ptq.py
浏览文件 @
75144f13
...
...
@@ -338,7 +338,7 @@ class TestImperativePTQKL(TestImperativePTQ):
self
.
batch_num
=
10
self
.
batch_size
=
10
self
.
eval_acc_top1
=
1.0
self
.
eval_acc_top1
=
0.98
conv2d_1_wt_thresholds
=
[
0.18116560578346252
,
0.17079241573810577
,
0.1702047884464264
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
浏览文件 @
75144f13
...
...
@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path
,
data_path
,
algo
=
"KL"
,
round_type
=
"round"
,
weight_round_algo
=
"round"
,
quantizable_op_type
=
[
"conv2d"
],
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
...
...
@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums
=
batch_nums
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
round_type
=
round_type
,
weight_round_algo
=
weight_round_algo
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
...
...
@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start post training quantization for {0} on {1} samples ..."
.
format
(
model_name
,
quant_iterations
))
self
.
generate_quantized_model
(
fp32_model_path
,
data_path
,
algo
,
round_type
,
quantizable_op_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
quant_iterations
,
onnx_format
)
...
...
@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"mul"
,
"lstm"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
infer_iterations
=
100
quant_iterations
=
10
self
.
run_test
(
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
infer_iterations
,
quant_iterations
)
...
...
@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"mul"
,
"lstm"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
浏览文件 @
75144f13
...
...
@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def
generate_quantized_model
(
self
,
model_path
,
algo
=
"KL"
,
round_type
=
"round"
,
weight_round_algo
=
"round"
,
quantizable_op_type
=
[
"conv2d"
],
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
...
...
@@ -116,7 +116,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size
=
10
,
batch_nums
=
10
,
onnx_format
=
False
,
skip_tensor_list
=
None
):
skip_tensor_list
=
None
,
bias_correction
=
False
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -129,9 +130,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums
=
batch_nums
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
round_type
=
round_type
,
weight_round_algo
=
weight_round_algo
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
bias_correction
=
bias_correction
,
onnx_format
=
onnx_format
,
skip_tensor_list
=
skip_tensor_list
,
is_use_cache_file
=
is_use_cache_file
)
...
...
@@ -143,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
@@ -152,6 +154,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_size
=
10
,
infer_iterations
=
10
,
quant_iterations
=
5
,
bias_correction
=
False
,
onnx_format
=
False
,
skip_tensor_list
=
None
):
...
...
@@ -166,11 +169,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model_name
,
quant_iterations
*
batch_size
))
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
batch_size
,
quant_iterations
,
onnx_format
,
skip_tensor_list
)
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
batch_size
,
quant_iterations
,
onnx_format
,
skip_tensor_list
,
bias_correction
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model_name
,
infer_iterations
*
batch_size
))
...
...
@@ -200,7 +204,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"KL"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -209,7 +213,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -222,7 +226,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"hist"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -231,7 +235,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -244,7 +248,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -253,7 +257,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -266,7 +270,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"emd"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -275,7 +279,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -288,7 +292,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -297,7 +301,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -310,7 +314,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"abs_max"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
is_full_quantize
=
True
is_use_cache_file
=
False
...
...
@@ -319,7 +323,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
10
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -332,7 +336,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"adaround"
weight_round_algo
=
"adaround"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -341,10 +345,21 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
bias_correction
=
True
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
,
bias_correction
=
bias_correction
)
class
TestPostTrainingKLAdaroundForMnist
(
TestPostTrainingQuantization
):
...
...
@@ -354,7 +369,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"KL"
round_type
=
"adaround"
weight_round_algo
=
"adaround"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -363,7 +378,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
...
...
@@ -376,7 +391,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -390,7 +405,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
@@ -410,7 +425,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
True
is_use_cache_file
=
False
...
...
@@ -424,7 +439,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
@@ -443,7 +458,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
...
...
@@ -457,7 +472,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url
,
data_md5
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
浏览文件 @
75144f13
...
...
@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path
,
quantizable_op_type
,
algo
=
"KL"
,
round_type
=
"round"
,
weight_round_algo
=
"round"
,
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
is_optimize_model
=
False
,
...
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir
=
model_path
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
round_type
=
round_type
,
weight_round_algo
=
weight_round_algo
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
...
...
@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def
run_test
(
self
,
model
,
algo
,
round_type
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
...
...
@@ -299,9 +299,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model
,
sample_iterations
*
batch_size
))
self
.
generate_quantized_model
(
model_cache_folder
+
"/model"
,
quantizable_op_type
,
algo
,
round_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
onnx_format
)
quantizable_op_type
,
algo
,
weight_round_algo
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
onnx_format
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model
,
infer_iterations
*
batch_size
))
...
...
@@ -329,7 +330,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_kl_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"KL"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
...
...
@@ -344,7 +345,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_optimize_model
=
True
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
...
...
@@ -354,7 +355,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_avg_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
...
...
@@ -368,7 +369,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_optimize_model
=
True
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
...
...
@@ -378,7 +379,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_hist_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"hist"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
...
...
@@ -392,7 +393,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_optimize_model
=
True
diff_threshold
=
0.03
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
...
...
@@ -402,7 +403,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_abs_max_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"abs_max"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
...
...
@@ -416,7 +417,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model
=
False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
...
...
@@ -426,7 +427,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_onnx_format_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"avg"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
...
...
@@ -443,7 +444,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
round_type
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
浏览文件 @
75144f13
...
...
@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
algo
=
"min_max"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
...
...
@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_optimize_model
=
False
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
...
...
@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
algo
=
"min_max"
round_type
=
"round"
weight_round_algo
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
...
...
@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
onnx_format
=
True
self
.
run_test
(
model
,
algo
,
round_type
,
weight_round_algo
,
data_urls
,
data_md5s
,
quantizable_op_type
,
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
75144f13
...
...
@@ -21,8 +21,6 @@ import math
from
op_test
import
OpTest
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def
round_c_single_element
(
val
):
dtype
=
type
(
val
)
if
val
>=
0
:
...
...
@@ -30,6 +28,7 @@ def round_c_single_element(val):
return
dtype
(
np
.
ceil
(
val
-
0.5
))
# rounding to nearest ties away from zero
round_c
=
np
.
vectorize
(
round_c_single_element
)
...
...
@@ -46,13 +45,25 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
self
.
op_type
=
'fake_quantize_abs_max'
self
.
attrs
=
{
'bit_length'
:
8
}
def
_fake_quantize_abs_max
(
self
,
dtype
,
input_shape
,
distribution
):
def
_fake_quantize_abs_max
(
self
,
dtype
,
input_shape
,
distribution
,
round_type
=
'TiesToEven'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
scale
=
np
.
max
(
np
.
abs
(
input_data
))
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
inv_scale
=
1.0
/
(
scale
+
1e-6
)
if
scale
<
1e-30
else
1.0
/
scale
output_data
=
round_c
(
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
else
:
round_out
=
round_c
(
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
inputs
=
{
'X'
:
input_data
}
self
.
outputs
=
{
'Out'
:
output_data
,
'OutScale'
:
scale
}
self
.
dtype
=
dtype
...
...
@@ -61,6 +72,11 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
def
test_fake_quantize_abs_max
(
self
):
self
.
_fake_quantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
)
def
test_fake_quantize_abs_max_round1
(
self
):
self
.
_fake_quantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
,
round_type
=
'TiesAwayFromZero'
)
def
test_fake_quantize_abs_max_float16
(
self
):
self
.
_fake_quantize_abs_max
(
np
.
float16
,
(
124
,
240
),
np
.
random
.
random
)
...
...
@@ -78,8 +94,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
self
.
op_type
=
'fake_channel_wise_quantize_abs_max'
self
.
attrs
=
{
'bit_length'
:
8
}
def
_fake_channel_wise_quantize_abs_max
(
self
,
dtype
,
input_shape
,
quant_axis
,
distribution
):
def
_fake_channel_wise_quantize_abs_max
(
self
,
dtype
,
input_shape
,
quant_axis
,
distribution
,
round_type
=
'TiesToEven'
):
assert
quant_axis
in
[
0
,
1
],
'quant_axis should be 0 or 1.'
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
...
...
@@ -87,8 +107,15 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
compute_axis
=
tuple
(
i
for
i
in
range
(
len
(
input_shape
))
if
i
!=
quant_axis
)
scale_broadcast
=
np
.
amax
(
input_data
,
axis
=
compute_axis
,
keepdims
=
True
)
output_data
=
round_c
(
bnt
*
input_data
.
astype
(
compute_type
)
/
scale_broadcast
)
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
scale_broadcast
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
else
:
round_out
=
round_c
(
input_data
.
astype
(
compute_type
)
/
scale_broadcast
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
if
quant_axis
==
1
:
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
(
1
,
)
+
compute_axis
)
...
...
@@ -102,16 +129,20 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
def
test_fake_channel_wise_quantize_abs_max
(
self
):
dtype_options
=
[
np
.
float32
,
np
.
float16
]
input_shape_quant_axis_options
=
[[(
20
,
15
,
6
,
6
),
0
],
[(
15
,
20
,
5
,
5
),
1
],
[(
30
,
15
),
0
],
[(
30
,
15
),
1
]]
for
dtype
,
input_shape_quant_axis
in
itertools
.
product
(
dtype_options
,
input_shape_quant_axis_options
):
[(
20
,
15
,
6
,
6
),
1
],
[(
30
,
30
),
0
],
[(
30
,
30
),
1
]]
round_type_options
=
[
'TiesToEven'
,
'TiesAwayFromZero'
]
for
dtype
,
input_shape_quant_axis
,
round_type
in
itertools
.
product
(
dtype_options
,
input_shape_quant_axis_options
,
round_type_options
):
input_shape
,
quant_axis
=
input_shape_quant_axis
with
self
.
subTest
(
dtype
=
dtype
,
input_shape
=
input_shape
,
quant_axis
=
quant_axis
):
quant_axis
=
quant_axis
,
round_type
=
round_type
):
self
.
_fake_channel_wise_quantize_abs_max
(
dtype
,
input_shape
,
quant_axis
,
np
.
random
.
random
)
dtype
,
input_shape
,
quant_axis
,
np
.
random
.
random
,
round_type
)
class
TestFakeQuantizeRangeAbsMaxOp
(
OpTest
):
...
...
@@ -124,7 +155,8 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
dtype
,
input_shape
,
distribution
,
is_test
=
False
):
is_test
=
False
,
round_type
=
'TiesToEven'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
...
...
@@ -133,11 +165,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
out_scale
[
0
]
=
np
.
max
(
np
.
abs
(
input_data
))
if
is_test
:
out_scale
[
0
]
=
in_scale
[
0
]
=
out_scale
[
0
]
-
1.0
clip_data
=
np
.
clip
(
input_data
,
-
in_scale
,
in_scale
)
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
else
:
clip_data
=
input_data
output_data
=
round_c
(
clip_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
round_out
=
round_c
(
input_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
inputs
=
{
'X'
:
input_data
,
'Iter'
:
np
.
zeros
(
1
).
astype
(
np
.
int64
),
...
...
@@ -153,15 +189,20 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
self
.
check_output
()
def
test_fake_quantize_range_abs_max
(
self
):
dtype_options
=
[
np
.
float
32
,
np
.
float16
]
dtype_options
=
[
np
.
float
16
,
np
.
float32
]
is_test_options
=
[
False
,
True
]
for
dtype
,
is_test
in
itertools
.
product
(
dtype_options
,
is_test_options
):
round_type_options
=
[
'TiesToEven'
,
'TiesAwayFromZero'
]
for
dtype
,
is_test
,
round_type
in
itertools
.
product
(
dtype_options
,
is_test_options
,
round_type_options
):
self
.
attrs
[
'bit_length'
]
=
8
if
is_test
else
5
with
self
.
subTest
(
dtype
=
dtype
,
is_test
=
is_test
):
with
self
.
subTest
(
dtype
=
dtype
,
is_test
=
is_test
,
round_type
=
round_type
):
self
.
_fake_quantize_range_abs_max
(
dtype
,
(
8
,
16
,
7
,
7
),
lambda
shape
:
(
np
.
random
.
random
(
shape
)
-
0.5
)
*
10
,
is_test
=
is_test
)
dtype
,
(
8
,
16
,
6
,
6
),
lambda
shape
:
(
np
.
random
.
random
(
shape
)
-
0.4
)
*
10
,
is_test
=
is_test
,
round_type
=
round_type
)
class
TestMovingAverageAbsMaxScaleOp
(
OpTest
):
...
...
@@ -208,7 +249,8 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
input_shape
,
distribution
,
dequantize
=
False
,
with_gradient
=
False
):
with_gradient
=
False
,
round_type
=
'TiesToEven'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
...
...
@@ -222,12 +264,20 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
np
.
abs
(
input_data
))
out_state
[
0
]
=
self
.
attrs
[
'moving_rate'
]
*
in_state
[
0
]
+
1.0
out_scale
=
out_accum
/
out_state
round_data
=
round_c
(
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
else
:
round_out
=
round_c
(
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
quant_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
if
dequantize
:
output_data
=
(
round
_data
*
out_scale
/
bnt
).
astype
(
dtype
)
output_data
=
(
quant
_data
*
out_scale
/
bnt
).
astype
(
dtype
)
self
.
op_type
=
'fake_quantize_dequantize_moving_average_abs_max'
else
:
output_data
=
round
_data
.
astype
(
dtype
)
output_data
=
quant
_data
.
astype
(
dtype
)
self
.
inputs
=
{
'X'
:
input_data
,
'InScale'
:
in_scale
,
...
...
@@ -256,6 +306,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float16
,
(
8
,
16
,
7
,
7
),
np
.
random
.
random
)
def
test_fake_quantize_moving_average_abs_max_round1
(
self
):
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float32
,
(
8
,
16
,
7
,
7
),
np
.
random
.
random
,
round_type
=
'TiesAwayFromZero'
)
def
test_fake_quantize_dequantize_moving_average_abs_max
(
self
):
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float32
,
(
8
,
16
,
7
,
7
),
np
.
random
.
random
,
...
...
@@ -269,12 +325,21 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
self
.
op_type
=
'fake_quantize_dequantize_abs_max'
self
.
attrs
=
{
'bit_length'
:
8
}
def
_fake_quantize_dequantize_abs_max
(
self
,
dtype
,
input_shape
,
distribution
):
def
_fake_quantize_dequantize_abs_max
(
self
,
dtype
,
input_shape
,
distribution
,
round_type
=
'TiesToEven'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
scale
=
np
.
max
(
np
.
abs
(
input_data
)).
astype
(
dtype
)
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
output_data
=
round_c
(
input_data
/
scale
*
bnt
)
*
scale
/
bnt
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
/
scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
else
:
round_out
=
round_c
(
input_data
/
scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale
/
bnt
self
.
inputs
=
{
'X'
:
input_data
}
self
.
outputs
=
{
'Out'
:
output_data
,
...
...
@@ -289,6 +354,11 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
self
.
_fake_quantize_dequantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
)
def
test_fake_quantize_dequantize_abs_max_round1
(
self
):
self
.
_fake_quantize_dequantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
,
round_type
=
'TiesAwayFromZero'
)
class
TestChannelWiseFakeQuantizeDequantizeAbsMaxOp
(
OpTest
):
...
...
@@ -296,9 +366,12 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
self
.
op_type
=
'fake_channel_wise_quantize_dequantize_abs_max'
self
.
attrs
=
{
'bit_length'
:
8
}
def
_fake_channel_wise_quantize_dequantize_abs_max
(
self
,
dtype
,
input_shape
,
def
_fake_channel_wise_quantize_dequantize_abs_max
(
self
,
dtype
,
input_shape
,
quant_axis
,
distribution
):
distribution
,
round_type
=
'TiesToEven'
):
assert
quant_axis
in
[
0
,
1
],
'quant_axis should be 0 or 1.'
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
...
...
@@ -307,8 +380,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
compute_axis
=
tuple
(
i
for
i
in
range
(
len
(
input_shape
))
if
i
!=
quant_axis
)
scale_broadcast
=
np
.
amax
(
input_data
,
axis
=
compute_axis
,
keepdims
=
True
)
output_data
=
round_c
(
bnt
*
output_data
/
scale_broadcast
)
*
scale_broadcast
/
bnt
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
bnt
*
output_data
/
scale_broadcast
)
self
.
attrs
[
'round_type'
]
=
0
else
:
round_out
=
round_c
(
bnt
*
output_data
/
scale_broadcast
)
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale_broadcast
/
bnt
if
quant_axis
==
1
:
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
(
1
,
)
+
compute_axis
)
...
...
@@ -325,10 +403,19 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
input_shape_quant_axis_options
=
[[(
3
,
4
,
64
,
64
),
0
],
[(
15
,
20
,
5
,
5
),
1
],
[(
30
,
15
),
0
],
[(
30
,
15
),
1
]]
for
input_shape
,
quant_axis
in
input_shape_quant_axis_options
:
with
self
.
subTest
(
input_shape
=
input_shape
,
quant_axis
=
quant_axis
):
round_type_options
=
[
'TiesToEven'
,
'TiesAwayFromZero'
]
for
input_shape_quant_axis
,
round_type
in
itertools
.
product
(
input_shape_quant_axis_options
,
round_type_options
):
input_shape
,
quant_axis
=
input_shape_quant_axis
with
self
.
subTest
(
input_shape
=
input_shape
,
quant_axis
=
quant_axis
,
round_type
=
round_type
):
self
.
_fake_channel_wise_quantize_dequantize_abs_max
(
np
.
float32
,
input_shape
,
quant_axis
,
np
.
random
.
random
)
np
.
float32
,
input_shape
,
quant_axis
,
np
.
random
.
random
,
round_type
=
round_type
)
def
quantize_max_abs
(
x
,
max_range
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录