Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
491b87b4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
491b87b4
编写于
6月 24, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix quantization clip and round Attribute (#43764)
上级
2739bd73
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
997 addition
and
726 deletion
+997
-726
paddle/fluid/operators/fake_quantize_op.cc
paddle/fluid/operators/fake_quantize_op.cc
+358
-192
paddle/fluid/operators/fake_quantize_op.cu.h
paddle/fluid/operators/fake_quantize_op.cu.h
+268
-168
paddle/fluid/operators/fake_quantize_op.h
paddle/fluid/operators/fake_quantize_op.h
+181
-136
paddle/fluid/operators/quantize_linear_op.cc
paddle/fluid/operators/quantize_linear_op.cc
+39
-26
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+24
-28
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+18
-72
python/paddle/fluid/contrib/slim/quantization/utils.py
python/paddle/fluid/contrib/slim/quantization/utils.py
+24
-22
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
.../slim/tests/test_post_training_quantization_lstm_model.py
+8
-8
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
...ntrib/slim/tests/test_post_training_quantization_mnist.py
+29
-30
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
...slim/tests/test_post_training_quantization_mobilenetv1.py
+16
-17
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
...ib/slim/tests/test_post_training_quantization_resnet50.py
+4
-4
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+28
-23
未找到文件。
paddle/fluid/operators/fake_quantize_op.cc
浏览文件 @
491b87b4
...
@@ -33,8 +33,10 @@ struct Compare {
...
@@ -33,8 +33,10 @@ struct Compare {
template
<
typename
T
>
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
T
*
in
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
int
num
,
T
*
out
)
{
const
T
*
in
,
const
int
num
,
T
*
out
)
{
*
out
=
std
::
abs
(
*
(
std
::
max_element
(
in
+
0
,
in
+
num
,
Compare
<
T
>
())));
*
out
=
std
::
abs
(
*
(
std
::
max_element
(
in
+
0
,
in
+
num
,
Compare
<
T
>
())));
}
}
};
};
...
@@ -43,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
...
@@ -43,24 +45,26 @@ template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
FindChannelAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_tensor
,
const
int
quant_axis
,
const
framework
::
Tensor
&
in_tensor
,
T
*
out_abs_max
)
{
const
int
quant_axis
,
T
*
out_abs_max
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
auto
*
in_data
=
in_tensor
.
data
<
T
>
();
auto
*
in_data
=
in_tensor
.
data
<
T
>
();
auto
in_dims
=
in_tensor
.
dims
();
auto
in_dims
=
in_tensor
.
dims
();
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
int64_t
channel
=
in_dims
[
quant_axis
];
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
const
int64_t
channel_size
=
in_tensor
.
numel
()
/
channel
;
const
int64_t
channel_size
=
in_tensor
.
numel
()
/
channel
;
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
out_abs_max
[
i
]
=
out_abs_max
[
i
]
=
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
}
}
...
@@ -72,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
...
@@ -72,8 +76,8 @@ struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
const
int64_t
step_j
=
in_tensor
.
numel
()
/
(
in_dims
[
0
]
*
in_dims
[
1
]);
const
int64_t
step_j
=
in_tensor
.
numel
()
/
(
in_dims
[
0
]
*
in_dims
[
1
]);
for
(
int64_t
i
=
0
;
i
<
in_dims
[
0
];
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
in_dims
[
0
];
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
for
(
int64_t
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
T
abs_max
=
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
T
abs_max
=
std
::
abs
(
*
(
std
::
max_element
(
start
,
end
,
Compare
<
T
>
())));
out_abs_max
[
j
]
=
std
::
max
(
out_abs_max
[
j
],
abs_max
);
out_abs_max
[
j
]
=
std
::
max
(
out_abs_max
[
j
],
abs_max
);
}
}
...
@@ -86,16 +90,30 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
...
@@ -86,16 +90,30 @@ template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
ClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
T
s
=
scale
.
data
<
T
>
()[
0
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
if
(
round_type
==
0
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
trans
(
ctx
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
}
else
{
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
();
}
}
}
};
};
...
@@ -103,19 +121,34 @@ template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
...
@@ -103,19 +121,34 @@ template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
T
s
=
scale
.
data
<
T
>
()[
0
];
T
s
=
scale
.
data
<
T
>
()[
0
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
if
(
round_type
==
0
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
trans
(
ctx
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
in
.
data
<
T
>
(),
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
in
.
data
<
T
>
()
+
in
.
numel
(),
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
else
{
trans
(
ctx
,
in
.
data
<
T
>
(),
in
.
data
<
T
>
()
+
in
.
numel
(),
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
()
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
}
};
};
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
...
@@ -123,20 +156,24 @@ template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
...
@@ -123,20 +156,24 @@ template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
template
<
typename
T
>
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
in_dims
=
in
.
dims
();
auto
in_dims
=
in
.
dims
();
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
int64_t
channel
=
in_dims
[
quant_axis
];
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
...
@@ -144,12 +181,31 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
...
@@ -144,12 +181,31 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
const
int64_t
channel_size
=
in
.
numel
()
/
channel
;
const
int64_t
channel_size
=
in
.
numel
()
/
channel
;
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
trans
(
if
(
round_type
==
0
)
{
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
trans
(
ctx
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
start
,
end
,
out_data
+
i
*
channel_size
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
}
else
{
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
}
}
if
(
round_type
==
1
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
T
inv_s
=
inverse
(
s
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
();
}
}
}
}
else
if
(
quant_axis
==
1
)
{
}
else
if
(
quant_axis
==
1
)
{
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
...
@@ -158,12 +214,21 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
...
@@ -158,12 +214,21 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
for
(
int
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
for
(
int
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
T
s
=
scale_data
[
j
];
T
s
=
scale_data
[
j
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
trans
(
ctx
,
start
,
end
,
cur_out_data
,
if
(
round_type
==
0
)
{
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
trans
(
ctx
,
inv_s
));
start
,
end
,
cur_out_data
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
}
else
{
trans
(
ctx
,
start
,
end
,
cur_out_data
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
for
(
int
k
=
0
;
k
<
step_j
;
k
++
)
{
cur_out_data
[
k
]
=
std
::
round
(
bin_cnt
*
inv_s
*
cur_out_data
[
k
]);
}
}
}
}
}
}
}
}
...
@@ -174,19 +239,23 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
...
@@ -174,19 +239,23 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float
>;
float
>;
template
<
typename
T
>
template
<
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
scale_data
=
scale
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
in_data
=
in
.
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
in_dims
=
in
.
dims
();
auto
in_dims
=
in
.
dims
();
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
int64_t
channel
=
in_dims
[
quant_axis
];
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
platform
::
Transform
<
platform
::
CPUDeviceContext
>
trans
;
...
@@ -194,15 +263,35 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
...
@@ -194,15 +263,35 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
const
int64_t
channel_size
=
in
.
numel
()
/
channel
;
const
int64_t
channel_size
=
in
.
numel
()
/
channel
;
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
T
s
=
scale_data
[
i
];
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
start
=
in_data
+
i
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
auto
*
end
=
in_data
+
(
i
+
1
)
*
channel_size
;
T
inv_s
=
inverse
(
s
);
if
(
round_type
==
0
)
{
trans
(
T
inv_s
=
inverse
(
s
);
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
trans
(
ctx
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
inv_s
));
start
,
end
,
out_data
+
i
*
channel_size
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
}
else
{
trans
(
ctx
,
start
,
end
,
out_data
+
i
*
channel_size
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
}
}
for
(
int
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_data
[
i
];
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
if
(
round_type
==
0
)
{
out_e
.
device
(
*
ctx
.
eigen_device
())
=
out_e
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
else
{
T
inv_s
=
inverse
(
s
);
out_e
.
device
(
*
ctx
.
eigen_device
())
=
(
bin_cnt
*
inv_s
*
out_e
).
round
()
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
}
}
else
if
(
quant_axis
==
1
)
{
}
else
if
(
quant_axis
==
1
)
{
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
const
int64_t
step_i
=
in
.
numel
()
/
in_dims
[
0
];
...
@@ -211,14 +300,25 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
...
@@ -211,14 +300,25 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
for
(
int
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
for
(
int
j
=
0
;
j
<
in_dims
[
1
];
j
++
)
{
T
s
=
scale_data
[
j
];
T
s
=
scale_data
[
j
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
start
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
end
=
in_data
+
i
*
step_i
+
(
j
+
1
)
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out_data
=
out_data
+
i
*
step_i
+
j
*
step_j
;
trans
(
ctx
,
start
,
end
,
cur_out_data
,
if
(
round_type
==
0
)
{
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
round_type
,
trans
(
ctx
,
inv_s
));
start
,
end
,
cur_out_data
,
QuantTensorFunctor
<
T
>
(
static_cast
<
T
>
(
bin_cnt
),
inv_s
));
}
else
{
trans
(
ctx
,
start
,
end
,
cur_out_data
,
phi
::
ClipFunctor
<
T
>
(
-
s
,
s
));
}
for
(
int
k
=
0
;
k
<
step_j
;
k
++
)
{
for
(
int
k
=
0
;
k
<
step_j
;
k
++
)
{
cur_out_data
[
k
]
=
cur_out_data
[
k
]
*
s
/
static_cast
<
T
>
(
bin_cnt
);
if
(
round_type
==
0
)
{
cur_out_data
[
k
]
=
cur_out_data
[
k
]
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
else
{
cur_out_data
[
k
]
=
std
::
round
(
bin_cnt
*
inv_s
*
cur_out_data
[
k
])
*
s
/
static_cast
<
T
>
(
bin_cnt
);
}
}
}
}
}
}
}
...
@@ -230,12 +330,14 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
...
@@ -230,12 +330,14 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CPUDeviceContext,
float
>;
float
>;
template
<
typename
T
>
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
FindRangeAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
const
framework
::
Tensor
&
iter
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
const
int
window_size
,
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int64_t
it
=
iter
.
data
<
int64_t
>
()[
0
];
int64_t
it
=
iter
.
data
<
int64_t
>
()[
0
];
int
idx
=
it
%
window_size
;
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
removed
=
scale_arr
[
idx
];
...
@@ -247,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
...
@@ -247,8 +349,8 @@ struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
max
=
cur
;
max
=
cur
;
}
else
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
}
else
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
int
size
=
(
it
>
window_size
)
?
window_size
:
it
;
int
size
=
(
it
>
window_size
)
?
window_size
:
it
;
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
ctx
,
scale_arr
,
size
,
FindAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
()(
&
max
);
ctx
,
scale_arr
,
size
,
&
max
);
}
}
out_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
max
;
out_scale
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())[
0
]
=
max
;
}
}
...
@@ -258,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
...
@@ -258,11 +360,14 @@ template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_state
,
const
T
*
cur_scale
,
const
framework
::
Tensor
&
in_state
,
const
float
rate
,
framework
::
Tensor
*
out_state
,
const
T
*
cur_scale
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
const
float
rate
,
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
T
accum
=
in_accum
.
data
<
T
>
()[
0
];
T
accum
=
in_accum
.
data
<
T
>
()[
0
];
T
state
=
in_state
.
data
<
T
>
()[
0
];
T
state
=
in_state
.
data
<
T
>
()[
0
];
T
scale
=
cur_scale
[
0
];
T
scale
=
cur_scale
[
0
];
...
@@ -282,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
...
@@ -282,18 +387,22 @@ template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
class
FakeQuantOrWithDequantAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
class
FakeQuantOrWithDequantAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeQuantOrWithDequantAbsMaxOp
(
const
std
::
string
&
type
,
FakeQuantOrWithDequantAbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
"FakeQuantOrWithDequantAbsMaxOp"
);
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantOrWithDequantAbsMaxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeQuantOrWithDequantAbsMaxOp"
);
"FakeQuantOrWithDequantAbsMaxOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeQuantOrWithDequantAbsMaxOp"
);
"FakeQuantOrWithDequantAbsMaxOp"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
...
@@ -302,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -302,7 +411,7 @@ class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
ctx
.
device_context
());
...
@@ -320,8 +429,9 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
...
@@ -320,8 +429,9 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
AddOutput
(
"OutScale"
,
"(Tensor) Current scale"
);
AddOutput
(
"OutScale"
,
"(Tensor) Current scale"
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -329,18 +439,22 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
...
@@ -329,18 +439,22 @@ class FakeQuantOrWithDequantAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
AddAttr
<
int
>
(
"round_type"
,
"round_type"
,
"(int, default
0
) The round type of fp32 to int."
"(int, default
1
) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
...
@@ -363,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -363,12 +477,16 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
"FakeChannelWiseQuantizeAbsMax"
);
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeChannelWiseQuantizeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeChannelWiseQuantizeAbsMax"
);
"FakeChannelWiseQuantizeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeChannelWiseQuantizeAbsMax"
);
"FakeChannelWiseQuantizeAbsMax"
);
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
...
@@ -378,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -378,7 +496,7 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
...
@@ -398,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
...
@@ -398,8 +516,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis."
)
"and mul, the quant_axis is equal to the cout axis."
)
.
SetDefault
(
0
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -407,8 +526,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
...
@@ -407,8 +526,9 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -416,18 +536,22 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
...
@@ -416,18 +536,22 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
AddAttr
<
int
>
(
"round_type"
,
"round_type"
,
"(int, default
0
) The round type of fp32 to int."
"(int, default
1
) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddAttr
<
bool
>
(
"is_test"
,
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
...
@@ -450,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
...
@@ -450,12 +574,18 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
"FakeChannelWiseQuantizeDequantizeAbsMax"
);
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
...
@@ -465,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
...
@@ -465,7 +595,7 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOp
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
...
@@ -485,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
...
@@ -485,8 +615,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
"For conv2d, depthwise_conv2d, conv2d_transpose "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis."
)
"and mul, the quant_axis is equal to the cout axis."
)
.
SetDefault
(
0
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -494,8 +625,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
...
@@ -494,8 +625,9 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -503,18 +635,22 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
...
@@ -503,18 +635,22 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
AddAttr
<
int
>
(
"round_type"
,
"round_type"
,
"(int, default
0
) The round type of fp32 to int."
"(int, default
1
) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddComment
(
R"DOC(
AddComment
(
R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.
In detail, each channel of the input X has a scale value.
...
@@ -530,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
...
@@ -530,17 +666,19 @@ $$0 \leq c \lt \ the\ channel\ number\ of\ X$$
class
FakeQuantizeRangeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
class
FakeQuantizeRangeAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeQuantizeRangeAbsMaxOp
(
const
std
::
string
&
type
,
FakeQuantizeRangeAbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantizeRangeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantizeRangeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
"FakeQuantizeRangeAbsMax"
);
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeQuantizeRangeAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeQuantizeRangeAbsMax"
);
"FakeQuantizeRangeAbsMax"
);
if
(
ctx
->
HasOutput
(
"OutScales"
))
{
if
(
ctx
->
HasOutput
(
"OutScales"
))
{
int
window_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"window_size"
);
int
window_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"window_size"
);
...
@@ -553,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
...
@@ -553,7 +691,7 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
ctx
.
device_context
());
...
@@ -574,8 +712,9 @@ class FakeQuantizeRangeAbsMaxOpMaker
...
@@ -574,8 +712,9 @@ class FakeQuantizeRangeAbsMaxOpMaker
.
SetDefault
(
10000
);
.
SetDefault
(
10000
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -583,18 +722,22 @@ class FakeQuantizeRangeAbsMaxOpMaker
...
@@ -583,18 +722,22 @@ class FakeQuantizeRangeAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
AddAttr
<
int
>
(
"round_type"
,
"round_type"
,
"(int, default
0
) The round type of fp32 to int."
"(int, default
1
) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddAttr
<
bool
>
(
"is_test"
,
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
...
@@ -614,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
...
@@ -614,17 +757,24 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
:
public
framework
::
OperatorWithKernel
{
:
public
framework
::
OperatorWithKernel
{
public:
public:
FakeQuantOrWithDequantMovingAverageAbsMaxOp
(
FakeQuantOrWithDequantMovingAverageAbsMaxOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
AttributeMap
&
attrs
)
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
:
OperatorWithKernel
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
"FakeQuantOrWithDequantMovingAverageAbsMax"
);
if
(
ctx
->
HasOutput
(
"OutState"
))
{
if
(
ctx
->
HasOutput
(
"OutState"
))
{
ctx
->
SetOutputDim
(
"OutState"
,
{
1
});
ctx
->
SetOutputDim
(
"OutState"
,
{
1
});
...
@@ -639,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
...
@@ -639,7 +789,7 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
ctx
.
device_context
());
...
@@ -662,8 +812,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
...
@@ -662,8 +812,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
.
SetDefault
(
0.9
);
.
SetDefault
(
0.9
);
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8), quantization bit number."
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -671,18 +822,22 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
...
@@ -671,18 +822,22 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
});
});
AddAttr
<
int
>
(
AddAttr
<
int
>
(
"round_type"
,
"round_type"
,
"(int, default
0
) The round type of fp32 to int."
"(int, default
1
) The round type of fp32 to int."
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
1
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddAttr
<
bool
>
(
"is_test"
,
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
...
@@ -709,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
...
@@ -709,10 +864,12 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
OP_INOUT_CHECK
(
"MovingAverageAbsMaxScale"
);
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"MovingAverageAbsMaxScale"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"OutScale"
),
"Output"
,
"OutScale"
,
"MovingAverageAbsMaxScale"
);
"MovingAverageAbsMaxScale"
);
if
(
ctx
->
HasOutput
(
"OutState"
))
{
if
(
ctx
->
HasOutput
(
"OutState"
))
{
...
@@ -730,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
...
@@ -730,7 +887,7 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
...
@@ -770,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
...
@@ -770,19 +927,23 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
out_grad_name
),
"Input"
,
out_grad_name
,
OP_INOUT_CHECK
(
ctx
->
HasInput
(
out_grad_name
),
"Input"
,
out_grad_name
,
"StrightThroughEstimatorGradOp"
);
"StrightThroughEstimatorGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
x_grad_name
),
"Output"
,
x_grad_name
,
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
x_grad_name
),
"Output"
,
x_grad_name
,
"StrightThroughEstimatorGradOp"
);
"StrightThroughEstimatorGradOp"
);
ctx
->
SetOutputDim
(
x_grad_name
,
ctx
->
GetInputDim
(
out_grad_name
));
ctx
->
SetOutputDim
(
x_grad_name
,
ctx
->
GetInputDim
(
out_grad_name
));
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
));
ctx
,
framework
::
GradVarName
(
"Out"
));
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
...
@@ -810,7 +971,8 @@ namespace ops = paddle::operators;
...
@@ -810,7 +971,8 @@ namespace ops = paddle::operators;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_abs_max
,
ops
::
FakeQuantOrWithDequantAbsMaxOp
,
fake_quantize_abs_max
,
ops
::
FakeQuantOrWithDequantAbsMaxOp
,
ops
::
FakeQuantOrWithDequantAbsMaxOpMaker
,
ops
::
FakeQuantOrWithDequantAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
...
@@ -818,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
...
@@ -818,7 +980,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantOrWithDequantAbsMaxOp
,
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantOrWithDequantAbsMaxOp
,
ops
::
FakeQuantOrWithDequantAbsMaxOpMaker
,
ops
::
FakeQuantOrWithDequantAbsMaxOpMaker
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
StrightThroughEstimatorMaker
<
paddle
::
imperative
::
OpBase
>
);
...
@@ -826,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
...
@@ -826,7 +989,8 @@ REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxOp
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
ops
::
FakeQuantizeRangeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
...
@@ -853,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL(
...
@@ -853,7 +1017,8 @@ REGISTER_OP_CPU_KERNEL(
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxOp
,
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxOp
,
ops
::
FakeChannelWiseQuantizeAbsMaxOpMaker
,
ops
::
FakeChannelWiseQuantizeAbsMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
...
@@ -861,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
...
@@ -861,7 +1026,8 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CPU
,
float
>
);
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleOp
,
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleOp
,
ops
::
MovingAverageAbsMaxScaleOpMaker
,
ops
::
MovingAverageAbsMaxScaleOpMaker
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
StrightThroughEstimatorMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
StrightThroughEstimatorMaker
<
paddle
::
imperative
::
OpBase
>
);
...
...
paddle/fluid/operators/fake_quantize_op.cu.h
浏览文件 @
491b87b4
...
@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> {
...
@@ -36,12 +36,12 @@ struct QuantizeDataType<paddle::platform::float16> {
};
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
char
*
shared_max_data_tmp
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
if
(
gridDim
.
x
>
1
)
{
if
(
gridDim
.
x
>
1
)
{
T
local_max_data
=
T
(
0
);
T
local_max_data
=
T
(
0
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
@@ -73,14 +73,16 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
...
@@ -73,14 +73,16 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
template
<
typename
T
>
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
int
num
,
T
*
out
)
{
const
T
*
in
,
const
int
num
,
T
*
out
)
{
int
block
=
1024
;
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
framework
::
Tensor
max
;
framework
::
Tensor
max
;
T
*
max_data
=
max
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
grid
}),
ctx
.
GetPlace
());
T
*
max_data
=
max
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
grid
}),
ctx
.
GetPlace
());
FindAbsMaxKernel
<
T
>
FindAbsMaxKernel
<
T
>
<<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
<<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
FindAbsMaxKernel
<
T
>
FindAbsMaxKernel
<
T
>
...
@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
...
@@ -93,13 +95,15 @@ template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
paddle
::
platform
::
float16
>;
paddle
::
platform
::
float16
>;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis0
(
const
T
*
in
,
const
int
n
,
__global__
void
FindChannelAbsMaxKernelQuantAxis0
(
const
T
*
in
,
const
int
c
,
T
*
out
)
{
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
char
*
shared_max_data_tmp
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
T
local_max_data
=
T
(
0
);
T
local_max_data
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
static_cast
<
T
>
(
T
tmp
=
static_cast
<
T
>
(
...
@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
...
@@ -122,17 +126,16 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis1
(
const
T
*
in
,
const
int
n
,
__global__
void
FindChannelAbsMaxKernelQuantAxis1
(
const
int
cin
,
const
int
cout
,
const
T
*
in
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
*
out
)
{
extern
__shared__
char
*
shared_max_data_tmp
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
int
cout_wh_size
=
n
/
cin
;
int
cout_wh_size
=
n
/
cin
;
int
wh_size
=
n
/
(
cin
*
cout
);
int
wh_size
=
n
/
(
cin
*
cout
);
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
int
bid
=
blockIdx
.
x
;
const
T
*
in_current
=
in
+
tid
*
cout_wh_size
+
bid
*
wh_size
;
const
T
*
in_current
=
in
+
tid
*
cout_wh_size
+
bid
*
wh_size
;
T
local_max_data
=
T
(
0
);
T
local_max_data
=
T
(
0
);
for
(
int
i
=
0
;
i
<
wh_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
wh_size
;
i
++
)
{
T
tmp
=
static_cast
<
T
>
(
T
tmp
=
static_cast
<
T
>
(
...
@@ -162,24 +165,26 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
...
@@ -162,24 +165,26 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
template
<
typename
T
>
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_tensor
,
const
int
quant_axis
,
const
framework
::
Tensor
&
in_tensor
,
T
*
out_abs_max
)
{
const
int
quant_axis
,
T
*
out_abs_max
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
const
int
num
=
in_tensor
.
numel
();
const
int
num
=
in_tensor
.
numel
();
auto
in_dims
=
in_tensor
.
dims
();
auto
in_dims
=
in_tensor
.
dims
();
const
T
*
in_data
=
in_tensor
.
data
<
T
>
();
const
T
*
in_data
=
in_tensor
.
data
<
T
>
();
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
int
cout
=
in_dims
[
0
];
int
cout
=
in_dims
[
0
];
int
grid
=
cout
;
int
grid
=
cout
;
int
block
=
1024
;
int
block
=
1024
;
FindChannelAbsMaxKernelQuantAxis0
<
T
>
FindChannelAbsMaxKernelQuantAxis0
<
T
>
<<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
cout
,
<<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
out_abs_max
);
in_data
,
num
,
cout
,
out_abs_max
);
}
else
if
(
quant_axis
==
1
)
{
}
else
if
(
quant_axis
==
1
)
{
int
cin
=
in_dims
[
0
];
int
cin
=
in_dims
[
0
];
int
cout
=
in_dims
[
1
];
int
cout
=
in_dims
[
1
];
...
@@ -213,9 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
...
@@ -213,9 +218,12 @@ struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
T
*
scale
,
const
int
n
,
T
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
@@ -227,25 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
...
@@ -227,25 +235,30 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
else
{
}
else
{
x
=
round
(
x
);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ClipAndQuantDequantKernel
(
const
T
*
in
,
const
T
*
scale
,
__global__
void
ClipAndQuantDequantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
const
int
round_type
,
T
*
out
)
{
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
...
@@ -257,33 +270,39 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
...
@@ -257,33 +270,39 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
((
x
*
s
)
/
bin_cnt_t
);
}
else
{
}
else
{
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
round
(
x
);
x
=
round
(
x
);
out
[
i
]
=
static_cast
<
T
>
((
x
*
s
)
/
bin_cnt_t
);
}
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
((
x
*
s
)
/
bin_cnt_t
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
num
=
in
.
numel
();
int
block
=
1024
;
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
...
@@ -294,17 +313,19 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
...
@@ -294,17 +313,19 @@ template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
num
=
in
.
numel
();
int
block
=
1024
;
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantDequantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
ClipAndQuantDequantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
out_data
);
...
@@ -313,16 +334,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
...
@@ -313,16 +334,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
// ChannelClipAndQuantKernel for quant_axis is 0
// ChannelClipAndQuantKernel for quant_axis is 0
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
round_type
,
const
int64_t
n
,
const
int64_t
n
,
const
int
c
,
T
*
out
)
{
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int64_t
channel_size
=
n
/
c
;
int64_t
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
...
@@ -332,25 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
...
@@ -332,25 +355,33 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in_c
[
i
]);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in_c
[
i
]);
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
static_cast
<
T
>
(
x
);
}
else
{
}
else
{
x
=
round
(
x
);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out_c
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
}
}
// ChannelClipAndQuantKernel for quant_axis is N
// ChannelClipAndQuantKernel for quant_axis is N
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
const
T
*
in
,
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
T
*
scale
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
...
@@ -359,37 +390,44 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
...
@@ -359,37 +390,44 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
static_cast
<
ComputeDataType
>
(
scale
[(
i
/
quant_stride
)
%
nScale
]);
static_cast
<
ComputeDataType
>
(
scale
[(
i
/
quant_stride
)
%
nScale
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
x
=
bin_cnt_t
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
else
{
}
else
{
x
=
round
(
x
);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
ComputeDataType
max_bound
=
bin_cnt_t
;
ComputeDataType
min_bound
=
-
bin_cnt_t
-
static_cast
<
ComputeDataType
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out
[
i
]
=
static_cast
<
T
>
(
x
);
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
int64_t
num
=
in
.
numel
();
int64_t
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
grid
=
in_dims
[
0
];
...
@@ -411,9 +449,15 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
...
@@ -411,9 +449,15 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
const
int64_t
grid_size
=
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
ChannelClipAndQuantKernelQuantAxisN
<
T
><<<
grid_size
,
block_size
>>>
(
ChannelClipAndQuantKernelQuantAxisN
<
T
>
in_data
,
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
quant_axis
],
<<<
grid_size
,
block_size
>>>
(
in_data
,
quant_stride
,
out_data
);
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
}
}
};
};
...
@@ -422,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
...
@@ -422,12 +466,14 @@ template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
float
>;
float
>;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
const
T
*
last_scale
,
const
int64_t
*
iter
,
const
int64_t
*
iter
,
const
int
window_size
,
T
*
scale_arr
,
const
int
window_size
,
T
*
out_scale
,
int
*
need_find_max
,
T
*
scale_arr
,
int
*
out_size
)
{
T
*
out_scale
,
int
*
need_find_max
,
int
*
out_size
)
{
int
it
=
iter
[
0
];
int
it
=
iter
[
0
];
int
idx
=
it
%
window_size
;
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
removed
=
scale_arr
[
idx
];
...
@@ -446,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
...
@@ -446,45 +492,63 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
template
<
typename
T
>
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
const
framework
::
Tensor
&
iter
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
framework
::
Tensor
need_find_max
,
out_size
;
framework
::
Tensor
need_find_max
,
out_size
;
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
({
1
},
gpu_place
);
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
({
1
},
gpu_place
);
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
({
1
},
gpu_place
);
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
({
1
},
gpu_place
);
FindRangeAbsMaxAndFillArray
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
FindRangeAbsMaxAndFillArray
<
T
>
cur_scale
.
data
<
T
>
(),
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
<<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
cur_scale
.
data
<
T
>
(),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
int
g_find_max
;
int
g_find_max
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
g_find_max
,
gpu_place
,
find_max
,
memory
::
Copy
(
platform
::
CPUPlace
(),
sizeof
(
int
),
ctx
.
stream
());
&
g_find_max
,
gpu_place
,
find_max
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
if
(
g_find_max
)
{
if
(
g_find_max
)
{
int
len
;
int
len
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
len
,
gpu_place
,
out_size_data
,
memory
::
Copy
(
platform
::
CPUPlace
(),
sizeof
(
int
),
ctx
.
stream
());
&
len
,
gpu_place
,
out_size_data
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
ctx
.
Wait
();
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
,
scale_arr
,
len
,
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
out_scale_data
);
ctx
,
scale_arr
,
len
,
out_scale_data
);
}
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
FindMovingAverageAbsMaxKernel
(
const
T
*
in_state
,
__global__
void
FindMovingAverageAbsMaxKernel
(
const
T
*
in_state
,
const
T
*
in_accum
,
const
T
*
in_accum
,
const
T
*
cur_scale
,
const
T
rate
,
const
T
*
cur_scale
,
T
*
out_state
,
T
*
out_accum
,
const
T
rate
,
T
*
out_scale
)
{
T
*
out_state
,
T
*
out_accum
,
T
*
out_scale
)
{
T
state
=
rate
*
(
*
in_state
)
+
T
(
1.0
f
);
T
state
=
rate
*
(
*
in_state
)
+
T
(
1.0
f
);
T
accum
=
rate
*
(
*
in_accum
)
+
(
*
cur_scale
);
T
accum
=
rate
*
(
*
in_accum
)
+
(
*
cur_scale
);
*
out_state
=
state
;
*
out_state
=
state
;
...
@@ -496,92 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
...
@@ -496,92 +560,119 @@ template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
template
<
typename
T
>
template
<
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_state
,
const
T
*
cur_scale
,
const
framework
::
Tensor
&
in_state
,
const
float
rate
,
framework
::
Tensor
*
out_state
,
const
T
*
cur_scale
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
const
float
rate
,
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
rate_t
=
static_cast
<
T
>
(
rate
);
T
rate_t
=
static_cast
<
T
>
(
rate
);
T
*
out_state_data
=
out_state
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_state_data
=
out_state
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_accum_data
=
out_accum
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_accum_data
=
out_accum
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
FindMovingAverageAbsMaxKernel
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
FindMovingAverageAbsMaxKernel
<
T
>
in_state
.
data
<
T
>
(),
in_accum
.
data
<
T
>
(),
cur_scale
,
rate_t
,
<<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
in_state
.
data
<
T
>
(),
out_state_data
,
out_accum_data
,
out_scale_data
);
in_accum
.
data
<
T
>
(),
cur_scale
,
rate_t
,
out_state_data
,
out_accum_data
,
out_scale_data
);
}
}
};
};
// ChannelClipAndQuantDequantKernel for quant_axis is 0
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis0
(
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
T
*
scale
,
const
int
n
,
const
int
c
,
T
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
x
=
in_c
[
i
];
x
=
bin_cnt
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
else
{
}
else
{
x
=
round
(
x
);
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
}
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis1
(
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis1
(
const
T
*
in
,
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
round_type
,
const
T
*
scale
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
inv_s
=
inverse
(
s
);
T
inv_s
=
inverse
(
s
);
int
wh_size
=
n
/
(
cin
*
cout
);
int
wh_size
=
n
/
(
cin
*
cout
);
const
T
*
in_c
=
in
+
blockIdx
.
x
*
wh_size
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
wh_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
wh_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
wh_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
x
=
in_c
[
i
];
x
=
bin_cnt
*
inv_s
*
x
;
if
(
round_type
==
0
)
{
if
(
round_type
==
0
)
{
x
=
bin_cnt
*
inv_s
*
x
;
x
=
roundWithTiesToEven
(
x
);
x
=
roundWithTiesToEven
(
x
);
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
else
{
}
else
{
x
=
round
(
x
);
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
T
max_bound
=
bin_cnt
;
T
min_bound
=
-
bin_cnt
-
static_cast
<
T
>
(
1
);
x
=
x
>
max_bound
?
max_bound
:
x
;
x
=
x
<
min_bound
?
min_bound
:
x
;
out_c
[
i
]
=
(
x
*
s
)
/
bin_cnt
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
framework
::
Tensor
&
in
,
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
)
{
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
quant_axis
));
quant_axis
));
...
@@ -589,25 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
...
@@ -589,25 +680,34 @@ struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
int
num
=
in
.
numel
();
int
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis0
<
T
>
ChannelClipAndQuantDequantKernelQuantAxis0
<
T
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
round_type
,
num
,
in_dims
[
0
],
scale_data
,
bin_cnt
,
round_type
,
num
,
in_dims
[
0
],
out_data
);
out_data
);
}
else
if
(
quant_axis
==
1
)
{
}
else
if
(
quant_axis
==
1
)
{
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis1
<
T
>
ChannelClipAndQuantDequantKernelQuantAxis1
<
T
>
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
<<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
round_type
,
num
,
in_dims
[
0
],
scale_data
,
in_dims
[
1
],
out_data
);
bin_cnt
,
round_type
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/fake_quantize_op.h
浏览文件 @
491b87b4
...
@@ -51,16 +51,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) {
...
@@ -51,16 +51,11 @@ inline HOSTDEVICE T roundWithTiesToEven(T x) {
template
<
typename
T
>
template
<
typename
T
>
class
QuantTensorFunctor
{
class
QuantTensorFunctor
{
public:
public:
explicit
QuantTensorFunctor
(
const
T
bin_cnt
,
const
int
round_type
,
explicit
QuantTensorFunctor
(
const
T
bin_cnt
,
const
T
inv_s
)
const
T
inv_s
)
:
bin_cnt_
(
bin_cnt
),
inv_s_
(
inv_s
)
{}
:
bin_cnt_
(
bin_cnt
),
round_type_
(
round_type
),
inv_s_
(
inv_s
)
{}
HOSTDEVICE
T
operator
()(
const
T
x
)
const
{
HOSTDEVICE
T
operator
()(
const
T
x
)
const
{
T
out
=
bin_cnt_
*
inv_s_
*
x
;
T
out
=
bin_cnt_
*
inv_s_
*
x
;
if
(
round_type_
==
0
)
{
out
=
roundWithTiesToEven
(
out
);
out
=
roundWithTiesToEven
(
out
);
}
else
if
(
round_type_
==
1
)
{
out
=
std
::
round
(
out
);
}
T
max_bound
=
bin_cnt_
;
T
max_bound
=
bin_cnt_
;
T
min_bound
=
-
bin_cnt_
-
static_cast
<
T
>
(
1
);
T
min_bound
=
-
bin_cnt_
-
static_cast
<
T
>
(
1
);
out
=
out
>
max_bound
?
max_bound
:
out
;
out
=
out
>
max_bound
?
max_bound
:
out
;
...
@@ -70,82 +65,101 @@ class QuantTensorFunctor {
...
@@ -70,82 +65,101 @@ class QuantTensorFunctor {
private:
private:
T
bin_cnt_
;
T
bin_cnt_
;
int
round_type_
;
T
inv_s_
;
T
inv_s_
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
FindAbsMaxFunctor
{
struct
FindAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
);
void
operator
()(
const
DeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
ClipAndFakeQuantFunctor
{
struct
ClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
const
int
round_type
,
framework
::
Tensor
*
out
);
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
round_type
,
framework
::
Tensor
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
{
struct
ClipAndFakeQuantDequantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
framework
::
Tensor
*
out
);
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
int
round_type
,
framework
::
Tensor
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
FindRangeAbsMaxFunctor
{
struct
FindRangeAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
const
framework
::
Tensor
&
last_scale
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
FindChannelAbsMaxFunctor
{
struct
FindChannelAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_tensor
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
int
quant_axis
,
T
*
out_abs_max
);
const
framework
::
Tensor
&
in_tensor
,
const
int
quant_axis
,
T
*
out_abs_max
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
{
struct
ChannelClipAndFakeQuantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
const
int
round_type
,
const
int
quant_axis
,
const
framework
::
Tensor
&
scale
,
framework
::
Tensor
*
out
);
const
int
bin_cnt
,
const
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
{
struct
ChannelClipFakeQuantDequantFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
int
round_type
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
{
struct
FindMovingAverageAbsMaxFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_state
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
in_state
,
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
const
framework
::
Tensor
&
cur_scale
,
framework
::
Tensor
*
out_scale
);
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
class
FakeAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
in_data
=
in
->
data
<
T
>
();
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in_data
,
in
->
numel
(),
out_s
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
}
}
...
@@ -153,20 +167,25 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
...
@@ -153,20 +167,25 @@ class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
virtual
~
FakeAbsMaxKernelBase
()
=
default
;
virtual
~
FakeAbsMaxKernelBase
()
=
default
;
protected:
protected:
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
scale
,
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
int
bin_cnt
,
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeAbsMaxKernel
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
class
FakeQuantizeAbsMaxKernel
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
const
framework
::
Tensor
&
scale
,
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
int
bin_cnt
,
round_type
,
out
);
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
round_type
,
out
);
}
}
};
};
...
@@ -174,9 +193,12 @@ template <typename DeviceContext, typename T>
...
@@ -174,9 +193,12 @@ template <typename DeviceContext, typename T>
class
FakeQuantizeDequantizeAbsMaxKernel
class
FakeQuantizeDequantizeAbsMaxKernel
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
:
public
FakeAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
const
framework
::
Tensor
&
scale
,
int
bin_cnt
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
scale
,
bin_cnt
,
round_type
,
out
);
dev_ctx
,
in
,
scale
,
bin_cnt
,
round_type
,
out
);
}
}
...
@@ -185,11 +207,11 @@ class FakeQuantizeDequantizeAbsMaxKernel
...
@@ -185,11 +207,11 @@ class FakeQuantizeDequantizeAbsMaxKernel
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeChannelWiseQuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FakeChannelWiseQuantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
...
@@ -198,11 +220,11 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
...
@@ -198,11 +220,11 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
if
(
!
is_test
)
{
if
(
!
is_test
)
{
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
quant_axis
,
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
out_scale_data
);
dev_ctx
,
*
in
,
quant_axis
,
out_scale_data
);
}
}
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
...
@@ -213,12 +235,12 @@ template <typename DeviceContext, typename T>
...
@@ -213,12 +235,12 @@ template <typename DeviceContext, typename T>
class
FakeChannelWiseQuantizeDequantizeAbsMaxKernel
class
FakeChannelWiseQuantizeDequantizeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
...
@@ -226,8 +248,8 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
...
@@ -226,8 +248,8 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
quant_axis
,
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
out_scale_data
);
dev_ctx
,
*
in
,
quant_axis
,
out_scale_data
);
ChannelClipFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
ChannelClipFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
quant_axis
,
out
);
...
@@ -237,60 +259,64 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
...
@@ -237,60 +259,64 @@ class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeRangeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FakeQuantizeRangeAbsMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
// testing
// testing
if
(
is_test
)
{
if
(
is_test
)
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
bin_cnt
,
round_type
,
out
);
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
round_type
,
out
);
return
;
return
;
}
}
// training
// training
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scales
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
out_scales
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScales"
);
auto
*
iter
=
context
.
Input
<
framework
::
Tensor
>
(
"Iter"
);
auto
*
iter
=
context
.
Input
<
framework
::
Tensor
>
(
"Iter"
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
int
window_size
=
context
.
Attr
<
int
>
(
"window_size"
);
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
Tensor
cur_scale
;
framework
::
Tensor
cur_scale
;
T
*
cur_scale_data
=
cur_scale
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
T
*
cur_scale_data
=
cur_scale
.
mutable_data
<
T
>
({
1
},
context
.
GetPlace
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
cur_scale_data
);
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
FindRangeAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
cur_scale
,
*
in_scale
,
FindRangeAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
iter
,
window_size
,
out_scales
,
cur_scale
,
*
in_scale
,
*
iter
,
window_size
,
out_scales
,
out_scale
);
out_scale
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
bin_cnt
,
round_type
,
out
);
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeMovingAverageAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
class
FakeMovingAverageAbsMaxKernelBase
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"InScale"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
round_type
=
context
.
Attr
<
int
>
(
"round_type"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
// testing
// testing
if
(
is_test
)
{
if
(
is_test
)
{
...
@@ -299,25 +325,30 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
...
@@ -299,25 +325,30 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
}
}
// training
// training
auto
*
in_accum
=
context
.
Input
<
framework
::
Tensor
>
(
"InAccum"
);
auto
*
in_accum
=
context
.
Input
<
framework
::
Tensor
>
(
"InAccum"
);
auto
*
in_state
=
context
.
Input
<
framework
::
Tensor
>
(
"InState"
);
auto
*
in_state
=
context
.
Input
<
framework
::
Tensor
>
(
"InState"
);
auto
cur_scale
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
));
auto
cur_scale
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
));
T
*
cur_scale_data
=
static_cast
<
T
*>
(
cur_scale
->
ptr
());
T
*
cur_scale_data
=
static_cast
<
T
*>
(
cur_scale
->
ptr
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
cur_scale_data
);
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
auto
*
out_state
=
context
.
Output
<
framework
::
Tensor
>
(
"OutState"
);
auto
*
out_state
=
context
.
Output
<
framework
::
Tensor
>
(
"OutState"
);
auto
*
out_accum
=
context
.
Output
<
framework
::
Tensor
>
(
"OutAccum"
);
auto
*
out_accum
=
context
.
Output
<
framework
::
Tensor
>
(
"OutAccum"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
out_state
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_state
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_accum
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_accum
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
moving_rate
=
context
.
Attr
<
float
>
(
"moving_rate"
);
float
moving_rate
=
context
.
Attr
<
float
>
(
"moving_rate"
);
FindMovingAverageAbsMaxFunctor
<
DeviceContext
,
T
>
()(
FindMovingAverageAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
dev_ctx
,
*
in_accum
,
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
*
in_accum
,
out_accum
,
out_scale
);
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
out_accum
,
out_scale
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
RunClipFunctor
(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
round_type
,
out
);
}
}
...
@@ -325,21 +356,26 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
...
@@ -325,21 +356,26 @@ class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
virtual
~
FakeMovingAverageAbsMaxKernelBase
()
=
default
;
virtual
~
FakeMovingAverageAbsMaxKernelBase
()
=
default
;
protected:
protected:
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
virtual
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in_scale
,
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
int
bin_cnt
,
int
round_type
,
framework
::
Tensor
*
out
)
const
=
0
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
FakeQuantizeMovingAverageAbsMaxKernel
class
FakeQuantizeMovingAverageAbsMaxKernel
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
const
framework
::
Tensor
&
in_scale
,
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
int
bin_cnt
,
round_type
,
out
);
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
round_type
,
out
);
}
}
};
};
...
@@ -347,9 +383,12 @@ template <typename DeviceContext, typename T>
...
@@ -347,9 +383,12 @@ template <typename DeviceContext, typename T>
class
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
class
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
:
public
FakeMovingAverageAbsMaxKernelBase
<
DeviceContext
,
T
>
{
protected:
protected:
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
void
RunClipFunctor
(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
const
framework
::
Tensor
&
in
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
const
framework
::
Tensor
&
in_scale
,
int
bin_cnt
,
int
round_type
,
framework
::
Tensor
*
out
)
const
override
{
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
ClipAndFakeQuantDequantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
round_type
,
out
);
dev_ctx
,
in
,
in_scale
,
bin_cnt
,
round_type
,
out
);
}
}
...
@@ -358,12 +397,12 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
...
@@ -358,12 +397,12 @@ class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
MovingAverageAbsMaxScaleKernel
:
public
framework
::
OpKernel
<
T
>
{
class
MovingAverageAbsMaxScaleKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
if
(
context
.
HasOutput
(
"Out"
))
{
if
(
context
.
HasOutput
(
"Out"
))
{
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
framework
::
TensorCopy
(
*
in
,
context
.
GetPlace
(),
dev_ctx
,
out
);
framework
::
TensorCopy
(
*
in
,
context
.
GetPlace
(),
dev_ctx
,
out
);
}
}
...
@@ -375,40 +414,46 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
...
@@ -375,40 +414,46 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
}
}
// training
// training
auto
*
in_accum
=
context
.
Input
<
framework
::
Tensor
>
(
"InAccum"
);
auto
*
in_accum
=
context
.
Input
<
framework
::
Tensor
>
(
"InAccum"
);
auto
*
in_state
=
context
.
Input
<
framework
::
Tensor
>
(
"InState"
);
auto
*
in_state
=
context
.
Input
<
framework
::
Tensor
>
(
"InState"
);
auto
cur_scale
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
));
auto
cur_scale
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
));
T
*
cur_scale_data
=
static_cast
<
T
*>
(
cur_scale
->
ptr
());
T
*
cur_scale_data
=
static_cast
<
T
*>
(
cur_scale
->
ptr
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
cur_scale_data
);
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
cur_scale_data
);
auto
*
out_state
=
context
.
Output
<
framework
::
Tensor
>
(
"OutState"
);
auto
*
out_state
=
context
.
Output
<
framework
::
Tensor
>
(
"OutState"
);
auto
*
out_accum
=
context
.
Output
<
framework
::
Tensor
>
(
"OutAccum"
);
auto
*
out_accum
=
context
.
Output
<
framework
::
Tensor
>
(
"OutAccum"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
out_state
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_state
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_accum
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_accum
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
moving_rate
=
context
.
Attr
<
float
>
(
"moving_rate"
);
float
moving_rate
=
context
.
Attr
<
float
>
(
"moving_rate"
);
FindMovingAverageAbsMaxFunctor
<
DeviceContext
,
T
>
()(
FindMovingAverageAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
dev_ctx
,
*
in_accum
,
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
*
in_accum
,
out_accum
,
out_scale
);
*
in_state
,
cur_scale_data
,
moving_rate
,
out_state
,
out_accum
,
out_scale
);
}
}
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
StrightThroughEstimatorGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
StrightThroughEstimatorGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d_out
=
auto
*
d_out
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
x_grad_name
);
auto
*
d_x
=
context
.
Output
<
framework
::
LoDTensor
>
(
x_grad_name
);
PADDLE_ENFORCE_NOT_NULL
(
d_x
,
platform
::
errors
::
PreconditionNotMet
(
PADDLE_ENFORCE_NOT_NULL
(
d_x
,
"StrightThroughEstimatorGradKernel "
platform
::
errors
::
PreconditionNotMet
(
"doesn't have the output named %s."
,
"StrightThroughEstimatorGradKernel "
x_grad_name
));
"doesn't have the output named %s."
,
x_grad_name
));
// Initialize dx as same as d_out
// Initialize dx as same as d_out
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/fluid/operators/quantize_linear_op.cc
浏览文件 @
491b87b4
...
@@ -26,14 +26,17 @@ namespace operators {
...
@@ -26,14 +26,17 @@ namespace operators {
template
<
typename
T
>
template
<
typename
T
>
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
T
>
{
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
const
framework
::
Tensor
*
in
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
const
framework
::
Tensor
*
scale
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// Dequant op is before quantized op
// Dequant op is before quantized op
// Dequantize the weight of quantized op
// Dequantize the weight of quantized op
auto
in_dims
=
in
->
dims
();
auto
in_dims
=
in
->
dims
();
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
if
(
quant_axis
==
0
)
{
if
(
quant_axis
==
0
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_factor
[
i
];
T
s
=
scale_factor
[
i
];
...
@@ -41,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
...
@@ -41,7 +44,7 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
in_e
*
s
/
max_range
;
out_e
.
device
(
dev
)
=
in_e
*
s
/
max_range
;
}
}
}
else
if
(
quant_axis
==
1
)
{
}
else
if
(
quant_axis
==
1
)
{
...
@@ -51,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
...
@@ -51,12 +54,12 @@ struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, T> {
}
}
int64_t
step_i
=
in
->
numel
()
/
out_iter
;
int64_t
step_i
=
in
->
numel
()
/
out_iter
;
int64_t
step_j
=
in
->
numel
()
/
(
out_iter
*
channel
);
int64_t
step_j
=
in
->
numel
()
/
(
out_iter
*
channel
);
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
out_iter
;
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
out_iter
;
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
channel
;
j
++
)
{
for
(
int64_t
j
=
0
;
j
<
channel
;
j
++
)
{
auto
*
cur_in
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_in
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out
=
out_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out
=
out_data
+
i
*
step_i
+
j
*
step_j
;
T
s
=
scale_factor
[
j
];
T
s
=
scale_factor
[
j
];
for
(
int64_t
k
=
0
;
k
<
step_j
;
k
++
)
{
for
(
int64_t
k
=
0
;
k
<
step_j
;
k
++
)
{
*
cur_out
=
(
*
cur_in
)
*
s
/
max_range
;
*
cur_out
=
(
*
cur_in
)
*
s
/
max_range
;
...
@@ -75,11 +78,11 @@ template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
...
@@ -75,11 +78,11 @@ template struct ChannelDequantizeFunctorV2<platform::CPUDeviceContext, double>;
class
QuantizeLinearOp
:
public
framework
::
OperatorWithKernel
{
class
QuantizeLinearOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ZeroPoint"
),
"Input"
,
"ZeroPoint"
,
OP_INOUT_CHECK
(
"QuantizeLinear"
);
ctx
->
HasInput
(
"ZeroPoint"
),
"Input"
,
"ZeroPoint"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"QuantizeLinear"
);
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
...
@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
...
@@ -95,7 +98,7 @@ class QuantizeLinearOp : public framework::OperatorWithKernel {
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
...
@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -116,9 +119,10 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"For conv2d, depthwise_conv2d, conv2d_transpose "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis."
)
"and mul, the quant_axis is equal to the cout axis."
)
.
SetDefault
(
0
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
||
quant_axis
==
-
1
,
true
,
quant_axis
==
0
||
quant_axis
==
1
||
quant_axis
==
-
1
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -126,8 +130,9 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -126,8 +130,9 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
});
});
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
"the received is %d"
,
...
@@ -140,13 +145,17 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -140,13 +145,17 @@ class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
"round(2.5)=3"
)
"round(2.5)=3"
)
.
SetDefault
(
0
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
round_type
)
{
.
AddCustomChecker
([](
const
int
&
round_type
)
{
PADDLE_ENFORCE_EQ
(
round_type
>=
0
&&
round_type
<=
1
,
true
,
PADDLE_ENFORCE_EQ
(
platform
::
errors
::
InvalidArgument
(
round_type
==
0
||
round_type
==
1
,
"'round_type' should be between 0 and 1, but "
true
,
"the received is %d"
,
platform
::
errors
::
InvalidArgument
(
round_type
));
"'round_type' should be 0 or 1, 0 rounding to "
});
"nearest ties to even and 1 is rounding to nearest "
"ties away from zero.but the received is %d"
,
round_type
));
})
.
AsExtra
();
AddAttr
<
bool
>
(
"is_test"
,
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
"for training. Some layers may run faster when this is true."
)
...
@@ -170,14 +179,18 @@ namespace ops = paddle::operators;
...
@@ -170,14 +179,18 @@ namespace ops = paddle::operators;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
quantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
quantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
quantize_linear
,
ops
::
QuantizeLinearKernel
<
CPU
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
quantize_linear
,
ops
::
QuantizeLinearKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
dequantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
dequantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
...
...
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
491b87b4
...
@@ -121,8 +121,7 @@ class PostTrainingQuantization(object):
...
@@ -121,8 +121,7 @@ class PostTrainingQuantization(object):
algo
=
"KL"
,
algo
=
"KL"
,
hist_percent
=
0.99999
,
hist_percent
=
0.99999
,
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
],
weight_round_algo
=
'round'
,
round_type
=
'round'
,
round_type
=
'TiesToEven'
,
learning_rate
=
0.001
,
learning_rate
=
0.001
,
is_full_quantize
=
False
,
is_full_quantize
=
False
,
bias_correction
=
False
,
bias_correction
=
False
,
...
@@ -181,14 +180,10 @@ class PostTrainingQuantization(object):
...
@@ -181,14 +180,10 @@ class PostTrainingQuantization(object):
quantizable_op_type(list[str], optional): List the type of ops
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
"mul"].
weight_round_algo
(str, optional): The method of converting the quantized weights
round_type
(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
learning_rate(float, optional): The learning rate of adaround method.
learning_rate(float, optional): The learning rate of adaround method.
is_full_quantized(bool, optional): If set is_full_quantized as True,
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
apply quantization to all supported quantizable op type. If set
...
@@ -269,10 +264,8 @@ class PostTrainingQuantization(object):
...
@@ -269,10 +264,8 @@ class PostTrainingQuantization(object):
self
.
_support_algo_type
=
[
self
.
_support_algo_type
=
[
'KL'
,
'hist'
,
'avg'
,
'mse'
,
'emd'
,
'abs_max'
,
'min_max'
'KL'
,
'hist'
,
'avg'
,
'mse'
,
'emd'
,
'abs_max'
,
'min_max'
]
]
assert
round_type
in
[
'
TiesToEven'
,
'TiesAwayFromZero
'
]
assert
round_type
in
[
'
adaround'
,
'round
'
]
self
.
_round_type
=
round_type
self
.
_round_type
=
round_type
assert
weight_round_algo
in
[
'adaround'
,
'round'
]
self
.
_weight_round_algo
=
weight_round_algo
self
.
_learning_rate
=
learning_rate
self
.
_learning_rate
=
learning_rate
self
.
_dynamic_quantize_op_type
=
[
'lstm'
]
self
.
_dynamic_quantize_op_type
=
[
'lstm'
]
self
.
_support_quantize_op_type
=
\
self
.
_support_quantize_op_type
=
\
...
@@ -414,7 +407,7 @@ class PostTrainingQuantization(object):
...
@@ -414,7 +407,7 @@ class PostTrainingQuantization(object):
if
self
.
_algo
in
[
"KL"
,
"hist"
]:
if
self
.
_algo
in
[
"KL"
,
"hist"
]:
self
.
_calculate_kl_hist_threshold
()
self
.
_calculate_kl_hist_threshold
()
if
self
.
_
weight_round_algo
==
'adaround'
:
if
self
.
_
round_type
==
'adaround'
:
self
.
_adaround_apply
()
self
.
_adaround_apply
()
self
.
_reset_activation_persistable
()
self
.
_reset_activation_persistable
()
...
@@ -651,7 +644,6 @@ class PostTrainingQuantization(object):
...
@@ -651,7 +644,6 @@ class PostTrainingQuantization(object):
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
_logger
.
info
(
"MSE searching stage ..."
)
_logger
.
info
(
"MSE searching stage ..."
)
distribution
=
np
.
round
if
self
.
_round_type
==
'TiesToEven'
else
utils
.
round_c
for
var_name
in
self
.
_quantized_act_var_name
:
for
var_name
in
self
.
_quantized_act_var_name
:
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
var_tensor
.
flatten
()
var_tensor
=
var_tensor
.
flatten
()
...
@@ -664,9 +656,14 @@ class PostTrainingQuantization(object):
...
@@ -664,9 +656,14 @@ class PostTrainingQuantization(object):
scale
=
s
*
abs_max_value
scale
=
s
*
abs_max_value
s
+=
0.02
s
+=
0.02
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
quant_var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
if
self
.
_onnx_format
:
-
bins
-
1
,
bins
)
quant_var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
quant_dequant_var
=
quant_var
/
bins
*
scale
-
bins
-
1
,
bins
)
quant_dequant_var
=
quant_var
/
bins
*
scale
else
:
quant_dequant_var
=
np
.
round
(
np
.
clip
(
var_tensor
,
0.0
,
scale
)
/
scale
*
bins
)
/
bins
*
scale
mse_loss
=
((
var_tensor
-
quant_dequant_var
)
**
2
).
mean
()
mse_loss
=
((
var_tensor
-
quant_dequant_var
)
**
2
).
mean
()
if
mse_loss
<=
self
.
_best_calibration_loss
[
var_name
]:
if
mse_loss
<=
self
.
_best_calibration_loss
[
var_name
]:
self
.
_best_calibration_loss
[
var_name
]
=
mse_loss
self
.
_best_calibration_loss
[
var_name
]
=
mse_loss
...
@@ -691,7 +688,6 @@ class PostTrainingQuantization(object):
...
@@ -691,7 +688,6 @@ class PostTrainingQuantization(object):
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
float
(
np
.
max
(
np
.
abs
(
var_tensor
[
i
]))))
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
self
.
_quantized_threshold
[
var_name
]
=
abs_max_value
_logger
.
info
(
"EMD searching stage ..."
)
_logger
.
info
(
"EMD searching stage ..."
)
distribution
=
np
.
round
if
self
.
_round_type
==
'TiesToEven'
else
utils
.
round_c
for
var_name
in
self
.
_quantized_act_var_name
:
for
var_name
in
self
.
_quantized_act_var_name
:
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
utils
.
load_variable_data
(
self
.
_scope
,
var_name
)
var_tensor
=
var_tensor
.
flatten
()
var_tensor
=
var_tensor
.
flatten
()
...
@@ -704,9 +700,14 @@ class PostTrainingQuantization(object):
...
@@ -704,9 +700,14 @@ class PostTrainingQuantization(object):
scale
=
s
*
abs_max_value
scale
=
s
*
abs_max_value
s
+=
0.02
s
+=
0.02
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
bins
=
2
**
(
self
.
_activation_bits
-
1
)
-
1
quant_var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
if
self
.
_onnx_format
:
-
bins
-
1
,
bins
)
quant_var
=
np
.
clip
(
distribution
(
var_tensor
/
scale
*
bins
),
quant_dequant_var
=
quant_var
/
bins
*
scale
-
bins
-
1
,
bins
)
quant_dequant_var
=
quant_var
/
bins
*
scale
else
:
quant_dequant_var
=
np
.
round
(
np
.
clip
(
var_tensor
,
0.0
,
scale
)
/
scale
*
bins
)
/
bins
*
scale
emd_loss
=
np
.
abs
(
emd_loss
=
np
.
abs
(
np
.
mean
(
var_tensor
)
-
np
.
mean
(
quant_dequant_var
))
+
np
.
abs
(
np
.
mean
(
var_tensor
)
-
np
.
mean
(
quant_dequant_var
))
+
np
.
abs
(
np
.
std
(
var_tensor
)
-
np
.
std
(
quant_dequant_var
))
np
.
std
(
var_tensor
)
-
np
.
std
(
quant_dequant_var
))
...
@@ -918,8 +919,7 @@ class PostTrainingQuantization(object):
...
@@ -918,8 +919,7 @@ class PostTrainingQuantization(object):
activation_bits
=
self
.
_activation_bits
,
activation_bits
=
self
.
_activation_bits
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
quantizable_op_type
=
major_quantizable_op_types
,
quantizable_op_type
=
major_quantizable_op_types
)
round_type
=
self
.
_round_type
)
else
:
else
:
transform_pass
=
QuantizationTransformPassV2
(
transform_pass
=
QuantizationTransformPassV2
(
scope
=
self
.
_scope
,
scope
=
self
.
_scope
,
...
@@ -928,8 +928,7 @@ class PostTrainingQuantization(object):
...
@@ -928,8 +928,7 @@ class PostTrainingQuantization(object):
activation_bits
=
self
.
_activation_bits
,
activation_bits
=
self
.
_activation_bits
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
activation_quantize_type
=
self
.
_activation_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
quantizable_op_type
=
major_quantizable_op_types
,
quantizable_op_type
=
major_quantizable_op_types
)
round_type
=
self
.
_round_type
)
for
sub_graph
in
graph
.
all_sub_graphs
():
for
sub_graph
in
graph
.
all_sub_graphs
():
# Insert fake_quant/fake_dequantize op must in test graph, so
# Insert fake_quant/fake_dequantize op must in test graph, so
...
@@ -946,15 +945,13 @@ class PostTrainingQuantization(object):
...
@@ -946,15 +945,13 @@ class PostTrainingQuantization(object):
add_quant_dequant_pass
=
AddQuantDequantPass
(
add_quant_dequant_pass
=
AddQuantDequantPass
(
scope
=
self
.
_scope
,
scope
=
self
.
_scope
,
place
=
self
.
_place
,
place
=
self
.
_place
,
quantizable_op_type
=
minor_quantizable_op_types
,
quantizable_op_type
=
minor_quantizable_op_types
)
round_type
=
self
.
_round_type
)
else
:
else
:
add_quant_dequant_pass
=
AddQuantDequantPassV2
(
add_quant_dequant_pass
=
AddQuantDequantPassV2
(
scope
=
self
.
_scope
,
scope
=
self
.
_scope
,
place
=
self
.
_place
,
place
=
self
.
_place
,
quantizable_op_type
=
minor_quantizable_op_types
,
quantizable_op_type
=
minor_quantizable_op_types
,
is_full_quantized
=
self
.
_is_full_quantize
,
is_full_quantized
=
self
.
_is_full_quantize
)
round_type
=
self
.
_round_type
)
for
sub_graph
in
graph
.
all_sub_graphs
():
for
sub_graph
in
graph
.
all_sub_graphs
():
sub_graph
.
_for_test
=
True
sub_graph
.
_for_test
=
True
...
@@ -979,7 +976,6 @@ class PostTrainingQuantization(object):
...
@@ -979,7 +976,6 @@ class PostTrainingQuantization(object):
place
=
self
.
_place
,
place
=
self
.
_place
,
bias_correction
=
self
.
_bias_correction
,
bias_correction
=
self
.
_bias_correction
,
weight_bits
=
self
.
_weight_bits
,
weight_bits
=
self
.
_weight_bits
,
weight_round_algo
=
self
.
_weight_round_algo
,
round_type
=
self
.
_round_type
,
round_type
=
self
.
_round_type
,
activation_bits
=
self
.
_activation_bits
,
activation_bits
=
self
.
_activation_bits
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
weight_quantize_type
=
self
.
_weight_quantize_type
,
...
...
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
491b87b4
...
@@ -119,7 +119,6 @@ class QuantizationTransformPass(object):
...
@@ -119,7 +119,6 @@ class QuantizationTransformPass(object):
moving_rate
=
0.9
,
moving_rate
=
0.9
,
skip_pattern
=
[
'skip_quant'
],
skip_pattern
=
[
'skip_quant'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
round_type
=
'TiesToEven'
,
weight_quantize_func
=
None
,
weight_quantize_func
=
None
,
act_quantize_func
=
None
,
act_quantize_func
=
None
,
weight_preprocess_func
=
None
,
weight_preprocess_func
=
None
,
...
@@ -157,10 +156,6 @@ class QuantizationTransformPass(object):
...
@@ -157,10 +156,6 @@ class QuantizationTransformPass(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
In this function, user should both define quantization function and
...
@@ -211,7 +206,6 @@ class QuantizationTransformPass(object):
...
@@ -211,7 +206,6 @@ class QuantizationTransformPass(object):
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_activation_bits
=
activation_bits
self
.
_skip_pattern
=
skip_pattern
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
...
@@ -465,12 +459,10 @@ class QuantizationTransformPass(object):
...
@@ -465,12 +459,10 @@ class QuantizationTransformPass(object):
_init_var_node
(
scale_var_node
,
_init_var_node
(
scale_var_node
,
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
self
.
_scope
,
self
.
_place
)
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_quantize_abs_max'
,
op_type
=
'fake_quantize_abs_max'
,
attrs
=
{
attrs
=
{
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
},
},
inputs
=
{
'X'
:
var_node
},
inputs
=
{
'X'
:
var_node
},
...
@@ -525,11 +517,9 @@ class QuantizationTransformPass(object):
...
@@ -525,11 +517,9 @@ class QuantizationTransformPass(object):
inputs
[
'Iter'
]
=
self
.
_global_step
inputs
[
'Iter'
]
=
self
.
_global_step
outputs
[
'OutScales'
]
=
scales_node
outputs
[
'OutScales'
]
=
scales_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
attrs
=
{
'window_size'
:
self
.
_window_size
,
'window_size'
:
self
.
_window_size
,
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
}
}
...
@@ -600,10 +590,8 @@ class QuantizationTransformPass(object):
...
@@ -600,10 +590,8 @@ class QuantizationTransformPass(object):
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
outs
[
'OutAccum'
]
=
accum_out_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
attrs
=
{
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'moving_rate'
:
self
.
_moving_rate
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
@@ -650,12 +638,10 @@ class QuantizationTransformPass(object):
...
@@ -650,12 +638,10 @@ class QuantizationTransformPass(object):
_init_var_node
(
scale_var_node
,
_init_var_node
(
scale_var_node
,
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
np
.
zeros
(
scale_var_node
.
shape
(),
dtype
=
data_type
),
self
.
_scope
,
self
.
_place
)
self
.
_scope
,
self
.
_place
)
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
quant_op_node
=
graph
.
create_op_node
(
quant_op_node
=
graph
.
create_op_node
(
op_type
=
'fake_channel_wise_quantize_abs_max'
,
op_type
=
'fake_channel_wise_quantize_abs_max'
,
attrs
=
{
attrs
=
{
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'quant_axis'
:
quant_axis
,
'quant_axis'
:
quant_axis
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
@@ -949,8 +935,7 @@ class QuantizationFreezePass(object):
...
@@ -949,8 +935,7 @@ class QuantizationFreezePass(object):
bias_correction
=
False
,
bias_correction
=
False
,
weight_bits
=
8
,
weight_bits
=
8
,
activation_bits
=
8
,
activation_bits
=
8
,
weight_round_algo
=
'round'
,
round_type
=
'round'
,
round_type
=
'TiesToEven'
,
weight_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
quantizable_op_type
=
None
):
quantizable_op_type
=
None
):
"""
"""
...
@@ -968,14 +953,10 @@ class QuantizationFreezePass(object):
...
@@ -968,14 +953,10 @@ class QuantizationFreezePass(object):
https://arxiv.org/abs/1810.05723.
https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights.
weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
activation_bits(int): quantization bit number for activation.
weight_round_algo
(str, optional): The method of converting the quantized weights
round_type
(str, optional): The method of converting the quantized weights
value float->int. Currently supports ['round', 'adaround'] methods.
value float->int. Currently supports ['round', 'adaround'] methods.
Default is `round`, which is rounding nearest to the integer.
Default is `round`, which is rounding nearest to the integer.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
'adaround' is refer to https://arxiv.org/abs/2004.10568.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
since weights are fixed once the model is well trained.
...
@@ -991,7 +972,6 @@ class QuantizationFreezePass(object):
...
@@ -991,7 +972,6 @@ class QuantizationFreezePass(object):
self
.
_place
=
_get_paddle_place
(
place
)
self
.
_place
=
_get_paddle_place
(
place
)
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_activation_bits
=
activation_bits
self
.
_weight_round_algo
=
weight_round_algo
self
.
_round_type
=
round_type
self
.
_round_type
=
round_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_fake_quant_op_names
=
_fake_quant_op_list
self
.
_fake_quant_op_names
=
_fake_quant_op_list
...
@@ -1039,7 +1019,7 @@ class QuantizationFreezePass(object):
...
@@ -1039,7 +1019,7 @@ class QuantizationFreezePass(object):
scale_v
=
scale_v
.
tolist
()
scale_v
=
scale_v
.
tolist
()
self
.
_quant_var_scale_map
[
input_arg_name
]
=
scale_v
self
.
_quant_var_scale_map
[
input_arg_name
]
=
scale_v
# Quantize weight and restore
# Quantize weight and restore
if
self
.
_
weight_round_algo
==
'round'
:
if
self
.
_
round_type
==
'round'
:
param_v
=
self
.
_load_var
(
input_arg_name
)
param_v
=
self
.
_load_var
(
input_arg_name
)
if
any
(
if
any
(
_check_grandchild_op_node
(
op_node
,
op
)
_check_grandchild_op_node
(
op_node
,
op
)
...
@@ -1049,7 +1029,8 @@ class QuantizationFreezePass(object):
...
@@ -1049,7 +1029,8 @@ class QuantizationFreezePass(object):
quant_axis
=
0
quant_axis
=
0
quantized_param_v
=
utils
.
quant_tensor
(
quantized_param_v
=
utils
.
quant_tensor
(
param_v
.
copy
(),
scale_v
,
quant_axis
,
param_v
.
copy
(),
scale_v
,
quant_axis
,
self
.
_weight_bits
,
self
.
_round_type
)
self
.
_weight_bits
)
quantized_param_v
=
np
.
round
(
quantized_param_v
)
# Weight bias correction
# Weight bias correction
if
self
.
_bias_correction
==
True
:
if
self
.
_bias_correction
==
True
:
quantized_param_v
=
utils
.
bias_correction_w
(
quantized_param_v
=
utils
.
bias_correction_w
(
...
@@ -1058,6 +1039,7 @@ class QuantizationFreezePass(object):
...
@@ -1058,6 +1039,7 @@ class QuantizationFreezePass(object):
scale_v
,
scale_v
,
quant_axis
,
quant_axis
,
weight_bits
=
self
.
_weight_bits
)
weight_bits
=
self
.
_weight_bits
)
quantized_param_v
=
np
.
round
(
quantized_param_v
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
self
.
_restore_var
(
input_arg_name
,
quantized_param_v
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
self
.
_remove_fake_quant_and_dequant_op
(
graph
,
op_node
)
...
@@ -1600,8 +1582,7 @@ class AddQuantDequantPass(object):
...
@@ -1600,8 +1582,7 @@ class AddQuantDequantPass(object):
quant_bits
=
8
,
quant_bits
=
8
,
skip_pattern
=
[
"skip_quant"
],
skip_pattern
=
[
"skip_quant"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
is_full_quantized
=
False
,
is_full_quantized
=
False
):
round_type
=
'TiesToEven'
):
"""
"""
Constructor.
Constructor.
...
@@ -1623,10 +1604,6 @@ class AddQuantDequantPass(object):
...
@@ -1623,10 +1604,6 @@ class AddQuantDequantPass(object):
quantization to all supported quantizable op type. If set is_full_quantized
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
as False, only apply quantization to the op type according to the input
quantizable_op_type.
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
"""
self
.
_scope
=
scope
self
.
_scope
=
scope
self
.
_place
=
_get_paddle_place
(
place
)
self
.
_place
=
_get_paddle_place
(
place
)
...
@@ -1634,7 +1611,6 @@ class AddQuantDequantPass(object):
...
@@ -1634,7 +1611,6 @@ class AddQuantDequantPass(object):
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
self
.
_is_test
=
None
self
.
_is_test
=
None
self
.
_skip_pattern
=
skip_pattern
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
if
is_full_quantized
:
if
is_full_quantized
:
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
...
@@ -1769,10 +1745,8 @@ class AddQuantDequantPass(object):
...
@@ -1769,10 +1745,8 @@ class AddQuantDequantPass(object):
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutState'
]
=
state_out_node
outs
[
'OutAccum'
]
=
accum_out_node
outs
[
'OutAccum'
]
=
accum_out_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
attrs
=
{
'bit_length'
:
quant_bits
,
'bit_length'
:
quant_bits
,
'round_type'
:
round_type
,
'moving_rate'
:
self
.
_moving_rate
,
'moving_rate'
:
self
.
_moving_rate
,
'is_test'
:
self
.
_is_test
,
'is_test'
:
self
.
_is_test
,
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
'op_role'
:
core
.
op_proto_and_checker_maker
.
OpRole
.
Forward
...
@@ -1812,10 +1786,6 @@ class InsertQuantizeLinear(object):
...
@@ -1812,10 +1786,6 @@ class InsertQuantizeLinear(object):
Default is -1.
Default is -1.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
channel_wise(bool, optional): Whether quantization with per channel or not. Default is False.
is_test(bool, optional): Whether quantization with training or not. Default is True.
is_test(bool, optional): Whether quantization with training or not. Default is True.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -1824,15 +1794,13 @@ class InsertQuantizeLinear(object):
...
@@ -1824,15 +1794,13 @@ class InsertQuantizeLinear(object):
quant_bits
=
8
,
quant_bits
=
8
,
quant_axis
=-
1
,
quant_axis
=-
1
,
channel_wise
=
False
,
channel_wise
=
False
,
is_test
=
True
,
is_test
=
True
):
round_type
=
'TiesToEven'
):
self
.
_place
=
place
self
.
_place
=
place
self
.
_scope
=
scope
self
.
_scope
=
scope
self
.
quant_bits
=
quant_bits
self
.
quant_bits
=
quant_bits
self
.
quant_axis
=
quant_axis
self
.
quant_axis
=
quant_axis
self
.
channel_wise
=
channel_wise
self
.
channel_wise
=
channel_wise
self
.
_is_test
=
is_test
self
.
_is_test
=
is_test
self
.
_round_type
=
round_type
def
insert_quant_op
(
self
,
graph
,
var_node
):
def
insert_quant_op
(
self
,
graph
,
var_node
):
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
assert
var_node
.
is_var
(),
'{} is not a var'
.
format
(
var_node
.
name
())
...
@@ -1875,12 +1843,7 @@ class InsertQuantizeLinear(object):
...
@@ -1875,12 +1843,7 @@ class InsertQuantizeLinear(object):
if
zero_point_node
is
not
None
:
if
zero_point_node
is
not
None
:
inputs
[
"ZeroPoint"
]
=
zero_point_node
inputs
[
"ZeroPoint"
]
=
zero_point_node
round_type
=
0
if
self
.
_round_type
==
'TiesToEven'
else
1
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
}
attrs
=
{
"quant_axis"
:
self
.
quant_axis
,
"bit_length"
:
self
.
quant_bits
,
"round_type"
:
round_type
}
outputs
=
{
"Y"
:
quant_var_node
}
outputs
=
{
"Y"
:
quant_var_node
}
if
not
self
.
_is_test
:
if
not
self
.
_is_test
:
attrs
[
"is_test"
]
=
self
.
_is_test
attrs
[
"is_test"
]
=
self
.
_is_test
...
@@ -1985,7 +1948,6 @@ class QuantizationTransformPassV2(object):
...
@@ -1985,7 +1948,6 @@ class QuantizationTransformPassV2(object):
moving_rate
=
0.9
,
moving_rate
=
0.9
,
skip_pattern
=
[
'skip_quant'
],
skip_pattern
=
[
'skip_quant'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
],
round_type
=
'TiesToEven'
,
weight_quantize_func
=
None
,
weight_quantize_func
=
None
,
act_quantize_func
=
None
,
act_quantize_func
=
None
,
weight_preprocess_func
=
None
,
weight_preprocess_func
=
None
,
...
@@ -2021,10 +1983,6 @@ class QuantizationTransformPassV2(object):
...
@@ -2021,10 +1983,6 @@ class QuantizationTransformPassV2(object):
quantizable_op_type(list[str]): List the type of ops that will be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
weight_quantize_func(function): Function that defines how to quantize weight.
weight_quantize_func(function): Function that defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
Using this can quickly test if user's quantization method works or not.
In this function, user should both define quantization function and
In this function, user should both define quantization function and
...
@@ -2074,7 +2032,6 @@ class QuantizationTransformPassV2(object):
...
@@ -2074,7 +2032,6 @@ class QuantizationTransformPassV2(object):
self
.
_weight_bits
=
weight_bits
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_activation_bits
=
activation_bits
self
.
_skip_pattern
=
skip_pattern
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_weight_quantize_func
=
weight_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_act_quantize_func
=
act_quantize_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
self
.
_weight_preprocess_func
=
weight_preprocess_func
...
@@ -2198,8 +2155,7 @@ class QuantizationTransformPassV2(object):
...
@@ -2198,8 +2155,7 @@ class QuantizationTransformPassV2(object):
quant_bits
=
quant_bits
,
quant_bits
=
quant_bits
,
quant_axis
=
quant_axis
,
quant_axis
=
quant_axis
,
channel_wise
=
channel_wise
,
channel_wise
=
channel_wise
,
is_test
=
self
.
_is_test
,
is_test
=
self
.
_is_test
)
round_type
=
self
.
_round_type
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
var_node
)
graph
,
var_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
...
@@ -2307,8 +2263,7 @@ class AddQuantDequantPassV2(object):
...
@@ -2307,8 +2263,7 @@ class AddQuantDequantPassV2(object):
quant_bits
=
8
,
quant_bits
=
8
,
skip_pattern
=
[
"skip_quant"
],
skip_pattern
=
[
"skip_quant"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
quantizable_op_type
=
[
"elementwise_add"
,
"pool2d"
],
is_full_quantized
=
False
,
is_full_quantized
=
False
):
round_type
=
'TiesToEven'
):
"""
"""
Args:
Args:
scope(paddle.Scope): The scope is used to initialize these new parameters.
scope(paddle.Scope): The scope is used to initialize these new parameters.
...
@@ -2328,10 +2283,6 @@ class AddQuantDequantPassV2(object):
...
@@ -2328,10 +2283,6 @@ class AddQuantDequantPassV2(object):
quantization to all supported quantizable op type. If set is_full_quantized
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
as False, only apply quantization to the op type according to the input
quantizable_op_type.
quantizable_op_type.
round_type(str, optional): The method of converting the tensor value float->int.
Currently supports ['TiesToEven', 'TiesAwayFromZero'] methods.
Default is `TiesToEven`, which is rounding to nearest ties to even.
'TiesAwayFromZero' is rounding to nearest ties away from zero.
Examples:
Examples:
.. code-block:: python
.. code-block:: python
...
@@ -2354,7 +2305,6 @@ class AddQuantDequantPassV2(object):
...
@@ -2354,7 +2305,6 @@ class AddQuantDequantPassV2(object):
self
.
_quant_bits
=
quant_bits
self
.
_quant_bits
=
quant_bits
self
.
_is_test
=
None
self
.
_is_test
=
None
self
.
_skip_pattern
=
skip_pattern
self
.
_skip_pattern
=
skip_pattern
self
.
_round_type
=
round_type
if
is_full_quantized
:
if
is_full_quantized
:
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
self
.
_quantizable_op_type
=
utils
.
_act_supported_quantizable_op_type
...
@@ -2427,8 +2377,7 @@ class AddQuantDequantPassV2(object):
...
@@ -2427,8 +2377,7 @@ class AddQuantDequantPassV2(object):
quant_bits
=
self
.
_quant_bits
,
quant_bits
=
self
.
_quant_bits
,
quant_axis
=-
1
,
quant_axis
=-
1
,
channel_wise
=
False
,
channel_wise
=
False
,
is_test
=
self
.
_is_test
,
is_test
=
self
.
_is_test
)
round_type
=
self
.
_round_type
)
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
quant_var_node
,
scale_var_node
=
insert_quant_pass
.
insert_quant_op
(
graph
,
in_node
)
graph
,
in_node
)
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
dequant_var_node
=
insert_quant_pass
.
insert_dequant_op
(
...
@@ -2511,8 +2460,6 @@ class ReplaceFakeQuantDequantPass(object):
...
@@ -2511,8 +2460,6 @@ class ReplaceFakeQuantDequantPass(object):
"quant_axis"
)
else
-
1
"quant_axis"
)
else
-
1
bit_length
=
op
.
op
().
attr
(
"bit_length"
)
if
op
.
op
().
has_attr
(
bit_length
=
op
.
op
().
attr
(
"bit_length"
)
if
op
.
op
().
has_attr
(
"bit_length"
)
else
8
"bit_length"
)
else
8
round_type
=
op
.
op
().
attr
(
"round_type"
)
if
op
.
op
().
has_attr
(
"round_type"
)
else
0
zero_point_node
=
None
zero_point_node
=
None
quanted_node
=
x_node
quanted_node
=
x_node
...
@@ -2534,8 +2481,7 @@ class ReplaceFakeQuantDequantPass(object):
...
@@ -2534,8 +2481,7 @@ class ReplaceFakeQuantDequantPass(object):
quant_op_node
=
graph
.
create_op_node
(
op_type
=
"quantize_linear"
,
quant_op_node
=
graph
.
create_op_node
(
op_type
=
"quantize_linear"
,
attrs
=
{
attrs
=
{
"quant_axis"
:
quant_axis
,
"quant_axis"
:
quant_axis
,
"bit_length"
:
bit_length
,
"bit_length"
:
bit_length
"round_type"
:
round_type
},
},
inputs
=
{
inputs
=
{
"X"
:
x_node
,
"X"
:
x_node
,
...
@@ -2654,11 +2600,11 @@ class QuantWeightPass(object):
...
@@ -2654,11 +2600,11 @@ class QuantWeightPass(object):
param_v
=
self
.
_load_var
(
x_node
.
name
())
param_v
=
self
.
_load_var
(
x_node
.
name
())
quant_axis
=
_op
.
op
().
attr
(
"quant_axis"
)
quant_axis
=
_op
.
op
().
attr
(
"quant_axis"
)
bits_length
=
_op
.
op
().
attr
(
"bit_length"
)
bits_length
=
_op
.
op
().
attr
(
"bit_length"
)
round_type
=
_op
.
op
().
attr
(
"round_type"
)
if
_op
.
op
().
has_attr
(
quantized_param_v
=
utils
.
quant_tensor
(
param_v
.
copy
(),
"round_type"
)
else
0
scale_v
,
quantized_param_v
=
utils
.
quant_tensor
(
param_v
.
copy
(),
scale_v
,
quant_axis
,
quant_axis
,
bits_length
,
bits_length
,
round_typ
e
)
onnx_format
=
Tru
e
)
if
self
.
_bias_correction
==
True
:
if
self
.
_bias_correction
==
True
:
quantized_param_v
=
utils
.
bias_correction_w
(
quantized_param_v
=
utils
.
bias_correction_w
(
param_v
,
param_v
,
...
...
python/paddle/fluid/contrib/slim/quantization/utils.py
浏览文件 @
491b87b4
...
@@ -321,39 +321,41 @@ def set_variable_data(scope, place, var_name, np_value):
...
@@ -321,39 +321,41 @@ def set_variable_data(scope, place, var_name, np_value):
tensor
.
set
(
np_value
,
place
)
tensor
.
set
(
np_value
,
place
)
def
round_c_single_element
(
val
):
def
quant_tensor
(
x
,
scale
,
quant_axis
=
0
,
weight_bits
=
8
,
onnx_format
=
False
):
dtype
=
type
(
val
)
# symmetry quant
if
val
>=
0
:
def
_clip
(
x
,
scale
):
return
dtype
(
np
.
floor
(
val
+
0.5
))
x
[
x
>
scale
]
=
scale
return
dtype
(
np
.
ceil
(
val
-
0.5
))
x
[
x
<
-
scale
]
=
-
scale
return
x
# rounding to nearest ties away from zero
round_c
=
np
.
vectorize
(
round_c_single_element
)
def
quant_tensor
(
x
,
scale
,
quant_axis
=
0
,
weight_bits
=
8
,
round_type
=
'TiesToEven'
):
assert
quant_axis
in
[
0
,
1
],
'quant_axis should be 0 or 1 for now.'
assert
quant_axis
in
[
0
,
1
],
'quant_axis should be 0 or 1 for now.'
distribution
=
np
.
round
if
round_type
==
'TiesToEven'
else
round_c
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
if
isinstance
(
scale
,
list
):
if
isinstance
(
scale
,
list
):
for
i
,
s
in
enumerate
(
scale
):
for
i
,
s
in
enumerate
(
scale
):
if
s
==
0.0
:
if
s
==
0.0
:
s
=
1e-8
s
=
1e-8
if
quant_axis
==
0
:
if
quant_axis
==
0
:
x
[
i
]
=
distribution
(
x
[
i
]
/
s
*
bnt
)
if
onnx_format
:
x
[
i
]
=
np
.
clip
(
x
[
i
],
-
bnt
-
1
,
bnt
)
x
[
i
]
=
np
.
round
(
x
[
i
]
/
s
*
bnt
)
x
[
i
]
=
np
.
clip
(
x
[
i
],
-
bnt
-
1
,
bnt
)
else
:
x
[
i
]
=
_clip
(
x
[
i
],
s
)
x
[
i
]
=
x
[
i
]
/
s
*
bnt
else
:
else
:
x
[:,
i
]
=
distribution
(
x
[:,
i
]
/
s
*
bnt
)
if
onnx_format
:
x
[:,
i
]
=
np
.
clip
(
x
[:,
i
],
-
bnt
-
1
,
bnt
)
x
[:,
i
]
=
np
.
round
(
x
[:,
i
]
/
s
*
bnt
)
x
[:,
i
]
=
np
.
clip
(
x
[:,
i
],
-
bnt
-
1
,
bnt
)
else
:
x
[:,
i
]
=
_clip
(
x
[:,
i
],
s
)
x
[:,
i
]
=
x
[:,
i
]
/
s
*
bnt
else
:
else
:
scale
=
1e-8
if
scale
==
0.0
else
scale
scale
=
1e-8
if
scale
==
0.0
else
scale
x
=
distribution
(
x
/
scale
*
bnt
)
if
onnx_format
:
x
=
np
.
clip
(
x
,
-
bnt
-
1
,
bnt
)
x
=
np
.
round
(
x
/
scale
*
bnt
)
x
=
np
.
clip
(
x
,
-
bnt
-
1
,
bnt
)
else
:
x
=
_clip
(
x
,
scale
)
x
=
x
/
scale
*
bnt
return
x
return
x
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
浏览文件 @
491b87b4
...
@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -165,7 +165,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path
,
model_path
,
data_path
,
data_path
,
algo
=
"KL"
,
algo
=
"KL"
,
weight_round_algo
=
"round"
,
round_type
=
"round"
,
quantizable_op_type
=
[
"conv2d"
],
quantizable_op_type
=
[
"conv2d"
],
is_full_quantize
=
False
,
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
is_use_cache_file
=
False
,
...
@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -185,7 +185,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums
=
batch_nums
,
batch_nums
=
batch_nums
,
algo
=
algo
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
quantizable_op_type
=
quantizable_op_type
,
weight_round_algo
=
weight_round_algo
,
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
onnx_format
=
onnx_format
,
...
@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -201,7 +201,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -224,7 +224,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start post training quantization for {0} on {1} samples ..."
.
print
(
"Start post training quantization for {0} on {1} samples ..."
.
format
(
model_name
,
quant_iterations
))
format
(
model_name
,
quant_iterations
))
self
.
generate_quantized_model
(
fp32_model_path
,
data_path
,
algo
,
self
.
generate_quantized_model
(
fp32_model_path
,
data_path
,
algo
,
weight_round_algo
,
quantizable_op_type
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
quant_iterations
,
is_optimize_model
,
quant_iterations
,
onnx_format
)
onnx_format
)
...
@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
...
@@ -255,7 +255,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"mul"
,
"lstm"
]
quantizable_op_type
=
[
"mul"
,
"lstm"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
...
@@ -264,7 +264,7 @@ class TestPostTrainingAvgForLSTM(TestPostTrainingQuantization):
infer_iterations
=
100
infer_iterations
=
100
quant_iterations
=
10
quant_iterations
=
10
self
.
run_test
(
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
self
.
run_test
(
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
quantizable_op_type
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
infer_iterations
,
quant_iterations
)
diff_threshold
,
infer_iterations
,
quant_iterations
)
...
@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
...
@@ -279,7 +279,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"mul"
,
"lstm"
]
quantizable_op_type
=
[
"mul"
,
"lstm"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
...
@@ -295,7 +295,7 @@ class TestPostTrainingAvgForLSTMONNXFormat(TestPostTrainingQuantization):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
浏览文件 @
491b87b4
...
@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -108,7 +108,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def
generate_quantized_model
(
self
,
def
generate_quantized_model
(
self
,
model_path
,
model_path
,
algo
=
"KL"
,
algo
=
"KL"
,
weight_round_algo
=
"round"
,
round_type
=
"round"
,
quantizable_op_type
=
[
"conv2d"
],
quantizable_op_type
=
[
"conv2d"
],
is_full_quantize
=
False
,
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
is_use_cache_file
=
False
,
...
@@ -130,7 +130,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -130,7 +130,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
batch_nums
=
batch_nums
,
batch_nums
=
batch_nums
,
algo
=
algo
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
quantizable_op_type
=
quantizable_op_type
,
weight_round_algo
=
weight_round_algo
,
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
optimize_model
=
is_optimize_model
,
bias_correction
=
bias_correction
,
bias_correction
=
bias_correction
,
...
@@ -145,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -145,7 +145,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
@@ -169,11 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -169,11 +169,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model_name
,
quant_iterations
*
batch_size
))
format
(
model_name
,
quant_iterations
*
batch_size
))
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
round_type
,
weight_round_algo
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
is_optimize_model
,
is_optimize_model
,
batch_size
,
batch_size
,
quant_iterations
,
onnx_format
,
quant_iterations
,
onnx_format
,
skip_tensor_list
,
bias_correction
)
skip_tensor_list
,
bias_correction
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
...
@@ -204,7 +203,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
...
@@ -204,7 +203,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"KL"
algo
=
"KL"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -213,7 +212,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
...
@@ -213,7 +212,7 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -226,7 +225,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
...
@@ -226,7 +225,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"hist"
algo
=
"hist"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -235,7 +234,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
...
@@ -235,7 +234,7 @@ class TestPostTraininghistForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -248,7 +247,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
...
@@ -248,7 +247,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
algo
=
"mse"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -257,7 +256,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
...
@@ -257,7 +256,7 @@ class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -270,7 +269,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
...
@@ -270,7 +269,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"emd"
algo
=
"emd"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -279,7 +278,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
...
@@ -279,7 +278,7 @@ class TestPostTrainingemdForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -292,7 +291,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
...
@@ -292,7 +291,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -301,7 +300,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
...
@@ -301,7 +300,7 @@ class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -314,7 +313,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
...
@@ -314,7 +313,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"abs_max"
algo
=
"abs_max"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
is_full_quantize
=
True
is_full_quantize
=
True
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -323,7 +322,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
...
@@ -323,7 +322,7 @@ class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
10
quant_iterations
=
10
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -336,7 +335,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
...
@@ -336,7 +335,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
algo
=
"mse"
weight_round_algo
=
"adaround"
round_type
=
"adaround"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -350,7 +349,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
...
@@ -350,7 +349,7 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
@@ -369,7 +368,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
...
@@ -369,7 +368,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"KL"
algo
=
"KL"
weight_round_algo
=
"adaround"
round_type
=
"adaround"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -378,7 +377,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
...
@@ -378,7 +377,7 @@ class TestPostTrainingKLAdaroundForMnist(TestPostTrainingQuantization):
batch_size
=
10
batch_size
=
10
infer_iterations
=
50
infer_iterations
=
50
quant_iterations
=
5
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
weight_round_algo
,
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
)
infer_iterations
,
quant_iterations
)
...
@@ -391,7 +390,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
...
@@ -391,7 +390,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
algo
=
"mse"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -405,7 +404,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
...
@@ -405,7 +404,7 @@ class TestPostTrainingmseForMnistONNXFormat(TestPostTrainingQuantization):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
@@ -425,7 +424,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
...
@@ -425,7 +424,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
algo
=
"mse"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
True
is_full_quantize
=
True
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -439,7 +438,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
...
@@ -439,7 +438,7 @@ class TestPostTrainingmseForMnistONNXFormatFullQuant(
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
@@ -458,7 +457,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
...
@@ -458,7 +457,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_full_quantize
=
False
is_use_cache_file
=
False
is_use_cache_file
=
False
...
@@ -472,7 +471,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
...
@@ -472,7 +471,7 @@ class TestPostTrainingavgForMnistSkipOP(TestPostTrainingQuantization):
data_url
,
data_url
,
data_md5
,
data_md5
,
algo
,
algo
,
weight_round_algo
,
round_type
,
quantizable_op_type
,
quantizable_op_type
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
浏览文件 @
491b87b4
...
@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -242,7 +242,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path
,
model_path
,
quantizable_op_type
,
quantizable_op_type
,
algo
=
"KL"
,
algo
=
"KL"
,
weight_round_algo
=
"round"
,
round_type
=
"round"
,
is_full_quantize
=
False
,
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
is_use_cache_file
=
False
,
is_optimize_model
=
False
,
is_optimize_model
=
False
,
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir
=
model_path
,
model_dir
=
model_path
,
algo
=
algo
,
algo
=
algo
,
quantizable_op_type
=
quantizable_op_type
,
quantizable_op_type
=
quantizable_op_type
,
weight_round_algo
=
weight_round_algo
,
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
onnx_format
=
onnx_format
,
...
@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -275,7 +275,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
def
run_test
(
self
,
def
run_test
(
self
,
model
,
model
,
algo
,
algo
,
weight_round_algo
,
round_type
,
data_urls
,
data_urls
,
data_md5s
,
data_md5s
,
quantizable_op_type
,
quantizable_op_type
,
...
@@ -299,10 +299,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
...
@@ -299,10 +299,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model
,
sample_iterations
*
batch_size
))
format
(
model
,
sample_iterations
*
batch_size
))
self
.
generate_quantized_model
(
model_cache_folder
+
"/model"
,
self
.
generate_quantized_model
(
model_cache_folder
+
"/model"
,
quantizable_op_type
,
algo
,
quantizable_op_type
,
algo
,
round_type
,
weight_round_algo
,
is_full_quantize
,
is_full_quantize
,
is_use_cache_file
,
is_use_cache_file
,
is_optimize_model
,
is_optimize_model
,
onnx_format
)
onnx_format
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model
,
infer_iterations
*
batch_size
))
model
,
infer_iterations
*
batch_size
))
...
@@ -330,7 +329,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
...
@@ -330,7 +329,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_kl_mobilenetv1
(
self
):
def
test_post_training_kl_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"KL"
algo
=
"KL"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
]
...
@@ -345,7 +344,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
...
@@ -345,7 +344,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_use_cache_file
=
False
is_optimize_model
=
True
is_optimize_model
=
True
diff_threshold
=
0.025
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
is_optimize_model
,
diff_threshold
)
...
@@ -355,7 +354,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
...
@@ -355,7 +354,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_avg_mobilenetv1
(
self
):
def
test_post_training_avg_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
]
...
@@ -369,7 +368,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
...
@@ -369,7 +368,7 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_use_cache_file
=
False
is_optimize_model
=
True
is_optimize_model
=
True
diff_threshold
=
0.025
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
is_optimize_model
,
diff_threshold
)
...
@@ -379,7 +378,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
...
@@ -379,7 +378,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_hist_mobilenetv1
(
self
):
def
test_post_training_hist_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"hist"
algo
=
"hist"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
]
...
@@ -393,7 +392,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
...
@@ -393,7 +392,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_use_cache_file
=
False
is_optimize_model
=
True
is_optimize_model
=
True
diff_threshold
=
0.03
diff_threshold
=
0.03
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
is_optimize_model
,
diff_threshold
)
...
@@ -403,7 +402,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
...
@@ -403,7 +402,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_abs_max_mobilenetv1
(
self
):
def
test_post_training_abs_max_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"abs_max"
algo
=
"abs_max"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
]
...
@@ -417,7 +416,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
...
@@ -417,7 +416,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model
=
False
is_optimize_model
=
False
# The accuracy diff of post-training quantization (abs_max) maybe bigger
# The accuracy diff of post-training quantization (abs_max) maybe bigger
diff_threshold
=
0.05
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
is_optimize_model
,
diff_threshold
)
...
@@ -427,7 +426,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
...
@@ -427,7 +426,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def
test_post_training_onnx_format_mobilenetv1
(
self
):
def
test_post_training_onnx_format_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
model
=
"MobileNet-V1"
algo
=
"avg"
algo
=
"avg"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
]
...
@@ -444,7 +443,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
...
@@ -444,7 +443,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
diff_threshold
=
0.05
diff_threshold
=
0.05
self
.
run_test
(
model
,
self
.
run_test
(
model
,
algo
,
algo
,
weight_round_algo
,
round_type
,
data_urls
,
data_urls
,
data_md5s
,
data_md5s
,
quantizable_op_type
,
quantizable_op_type
,
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
浏览文件 @
491b87b4
...
@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
...
@@ -25,7 +25,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def
test_post_training_resnet50
(
self
):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
model
=
"ResNet-50"
algo
=
"min_max"
algo
=
"min_max"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
]
...
@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
...
@@ -35,7 +35,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_use_cache_file
=
False
is_use_cache_file
=
False
is_optimize_model
=
False
is_optimize_model
=
False
diff_threshold
=
0.025
diff_threshold
=
0.025
self
.
run_test
(
model
,
algo
,
weight_round_algo
,
data_urls
,
data_md5s
,
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
)
is_optimize_model
,
diff_threshold
)
...
@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
...
@@ -45,7 +45,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
def
test_post_training_resnet50
(
self
):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
model
=
"ResNet-50"
algo
=
"min_max"
algo
=
"min_max"
weight_round_algo
=
"round"
round_type
=
"round"
data_urls
=
[
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
]
...
@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
...
@@ -58,7 +58,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
onnx_format
=
True
onnx_format
=
True
self
.
run_test
(
model
,
self
.
run_test
(
model
,
algo
,
algo
,
weight_round_algo
,
round_type
,
data_urls
,
data_urls
,
data_md5s
,
data_md5s
,
quantizable_op_type
,
quantizable_op_type
,
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
491b87b4
...
@@ -49,7 +49,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
...
@@ -49,7 +49,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
dtype
,
dtype
,
input_shape
,
input_shape
,
distribution
,
distribution
,
round_type
=
'Ties
ToEven
'
):
round_type
=
'Ties
AwayFromZero
'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
scale
=
np
.
max
(
np
.
abs
(
input_data
))
scale
=
np
.
max
(
np
.
abs
(
input_data
))
...
@@ -58,12 +58,12 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
...
@@ -58,12 +58,12 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
if
round_type
==
'TiesToEven'
:
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
else
:
else
:
round_out
=
round_c
(
output_data
=
round_c
(
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
input_data
.
astype
(
compute_type
)
*
inv_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
inputs
=
{
'X'
:
input_data
}
self
.
inputs
=
{
'X'
:
input_data
}
self
.
outputs
=
{
'Out'
:
output_data
,
'OutScale'
:
scale
}
self
.
outputs
=
{
'Out'
:
output_data
,
'OutScale'
:
scale
}
self
.
dtype
=
dtype
self
.
dtype
=
dtype
...
@@ -75,7 +75,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
...
@@ -75,7 +75,7 @@ class TestFakeQuantizeAbsMaxOp(OpTest):
def
test_fake_quantize_abs_max_round1
(
self
):
def
test_fake_quantize_abs_max_round1
(
self
):
self
.
_fake_quantize_abs_max
(
np
.
float32
,
(
124
,
240
),
self
.
_fake_quantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
,
np
.
random
.
random
,
round_type
=
'Ties
AwayFromZero
'
)
round_type
=
'Ties
ToEven
'
)
def
test_fake_quantize_abs_max_float16
(
self
):
def
test_fake_quantize_abs_max_float16
(
self
):
self
.
_fake_quantize_abs_max
(
np
.
float16
,
(
124
,
240
),
np
.
random
.
random
)
self
.
_fake_quantize_abs_max
(
np
.
float16
,
(
124
,
240
),
np
.
random
.
random
)
...
@@ -110,12 +110,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
...
@@ -110,12 +110,12 @@ class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest):
if
round_type
==
'TiesToEven'
:
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
scale_broadcast
*
bnt
)
input_data
.
astype
(
compute_type
)
/
scale_broadcast
*
bnt
)
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
else
:
else
:
round_out
=
round_c
(
output_data
=
round_c
(
bnt
*
input_data
.
astype
(
compute_type
)
/
input_data
.
astype
(
compute_type
)
/
scale_broadcast
*
bn
t
)
scale_broadcas
t
)
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
if
quant_axis
==
1
:
if
quant_axis
==
1
:
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
(
1
,
)
+
compute_axis
)
(
1
,
)
+
compute_axis
)
...
@@ -169,11 +169,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
...
@@ -169,11 +169,15 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
round_out
=
np
.
round
(
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
input_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
else
:
else
:
round_out
=
round_c
(
if
is_test
:
input_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
clip_data
=
np
.
clip
(
input_data
,
-
in_scale
,
in_scale
)
else
:
clip_data
=
input_data
output_data
=
round_c
(
clip_data
.
astype
(
compute_type
)
/
out_scale
[
0
]
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
inputs
=
{
self
.
inputs
=
{
'X'
:
input_data
,
'X'
:
input_data
,
'Iter'
:
np
.
zeros
(
1
).
astype
(
np
.
int64
),
'Iter'
:
np
.
zeros
(
1
).
astype
(
np
.
int64
),
...
@@ -250,7 +254,7 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
...
@@ -250,7 +254,7 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
distribution
,
distribution
,
dequantize
=
False
,
dequantize
=
False
,
with_gradient
=
False
,
with_gradient
=
False
,
round_type
=
'Ties
ToEven
'
):
round_type
=
'Ties
AwayFromZero
'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
compute_type
=
get_compute_type
(
dtype
)
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
...
@@ -267,12 +271,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
...
@@ -267,12 +271,12 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
if
round_type
==
'TiesToEven'
:
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
round_out
=
np
.
round
(
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
quant_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
else
:
else
:
round_out
=
round_c
(
quant_data
=
round_c
(
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
input_data
.
astype
(
compute_type
)
/
out_scale
*
bnt
)
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
quant_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
if
dequantize
:
if
dequantize
:
output_data
=
(
quant_data
*
out_scale
/
bnt
).
astype
(
dtype
)
output_data
=
(
quant_data
*
out_scale
/
bnt
).
astype
(
dtype
)
self
.
op_type
=
'fake_quantize_dequantize_moving_average_abs_max'
self
.
op_type
=
'fake_quantize_dequantize_moving_average_abs_max'
...
@@ -307,10 +311,9 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
...
@@ -307,10 +311,9 @@ class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest):
np
.
random
.
random
)
np
.
random
.
random
)
def
test_fake_quantize_moving_average_abs_max_round1
(
self
):
def
test_fake_quantize_moving_average_abs_max_round1
(
self
):
self
.
_fake_quantize_moving_average_abs_max
(
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float32
,
(
8
,
16
,
7
,
7
),
np
.
float32
,
(
8
,
16
,
7
,
7
),
np
.
random
.
random
,
np
.
random
.
random
,
round_type
=
'TiesToEven'
)
round_type
=
'TiesAwayFromZero'
)
def
test_fake_quantize_dequantize_moving_average_abs_max
(
self
):
def
test_fake_quantize_dequantize_moving_average_abs_max
(
self
):
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float32
,
(
8
,
16
,
7
,
7
),
self
.
_fake_quantize_moving_average_abs_max
(
np
.
float32
,
(
8
,
16
,
7
,
7
),
...
@@ -329,17 +332,17 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
...
@@ -329,17 +332,17 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
dtype
,
dtype
,
input_shape
,
input_shape
,
distribution
,
distribution
,
round_type
=
'Ties
ToEven
'
):
round_type
=
'Ties
AwayFromZero
'
):
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
input_data
=
distribution
(
input_shape
).
astype
(
dtype
)
scale
=
np
.
max
(
np
.
abs
(
input_data
)).
astype
(
dtype
)
scale
=
np
.
max
(
np
.
abs
(
input_data
)).
astype
(
dtype
)
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
if
round_type
==
'TiesToEven'
:
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
input_data
/
scale
*
bnt
)
round_out
=
np
.
round
(
input_data
/
scale
*
bnt
)
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale
/
bnt
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
else
:
else
:
round_out
=
round_c
(
input_data
/
scale
*
bnt
)
output_data
=
round_c
(
input_data
/
scale
*
bnt
)
*
scale
/
bnt
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale
/
bnt
self
.
inputs
=
{
'X'
:
input_data
}
self
.
inputs
=
{
'X'
:
input_data
}
self
.
outputs
=
{
self
.
outputs
=
{
'Out'
:
output_data
,
'Out'
:
output_data
,
...
@@ -357,7 +360,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
...
@@ -357,7 +360,7 @@ class TestFakeQuantizeDequantizeAbsMaxOp(OpTest):
def
test_fake_quantize_dequantize_abs_max_round1
(
self
):
def
test_fake_quantize_dequantize_abs_max_round1
(
self
):
self
.
_fake_quantize_dequantize_abs_max
(
np
.
float32
,
(
124
,
240
),
self
.
_fake_quantize_dequantize_abs_max
(
np
.
float32
,
(
124
,
240
),
np
.
random
.
random
,
np
.
random
.
random
,
round_type
=
'Ties
AwayFromZero
'
)
round_type
=
'Ties
ToEven
'
)
class
TestChannelWiseFakeQuantizeDequantizeAbsMaxOp
(
OpTest
):
class
TestChannelWiseFakeQuantizeDequantizeAbsMaxOp
(
OpTest
):
...
@@ -382,11 +385,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
...
@@ -382,11 +385,13 @@ class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest):
scale_broadcast
=
np
.
amax
(
input_data
,
axis
=
compute_axis
,
keepdims
=
True
)
scale_broadcast
=
np
.
amax
(
input_data
,
axis
=
compute_axis
,
keepdims
=
True
)
if
round_type
==
'TiesToEven'
:
if
round_type
==
'TiesToEven'
:
round_out
=
np
.
round
(
bnt
*
output_data
/
scale_broadcast
)
round_out
=
np
.
round
(
bnt
*
output_data
/
scale_broadcast
)
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale_broadcast
/
bnt
self
.
attrs
[
'round_type'
]
=
0
self
.
attrs
[
'round_type'
]
=
0
else
:
else
:
round_out
=
round_c
(
bnt
*
output_data
/
scale_broadcast
)
output_data
=
round_c
(
bnt
*
output_data
/
scale_broadcast
)
*
scale_broadcast
/
bnt
self
.
attrs
[
'round_type'
]
=
1
self
.
attrs
[
'round_type'
]
=
1
output_data
=
np
.
clip
(
round_out
,
-
bnt
-
1
,
bnt
)
*
scale_broadcast
/
bnt
if
quant_axis
==
1
:
if
quant_axis
==
1
:
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
scale_broadcast
=
np
.
transpose
(
scale_broadcast
,
(
1
,
)
+
compute_axis
)
(
1
,
)
+
compute_axis
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录