Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f88af205
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f88af205
编写于
6月 21, 2021
作者:
C
cc
提交者:
GitHub
6月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Combine amp and qat (#33484)
* Combine amp and qat * add unit test
上级
0905deec
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
267 addition
and
19 deletion
+267
-19
paddle/fluid/imperative/amp_auto_cast.cc
paddle/fluid/imperative/amp_auto_cast.cc
+15
-2
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+27
-17
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
+1
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py
...addle/fluid/contrib/slim/tests/test_imperative_qat_amp.py
+222
-0
python/paddle/fluid/dygraph/amp/auto_cast.py
python/paddle/fluid/dygraph/amp/auto_cast.py
+2
-0
未找到文件。
paddle/fluid/imperative/amp_auto_cast.cc
浏览文件 @
f88af205
...
...
@@ -141,7 +141,7 @@ static inline std::shared_ptr<imperative::VarBase> CastToFP32(
}
static
inline
framework
::
proto
::
VarType
::
Type
GetPromoteType
(
const
NameVarBaseMap
&
ins
)
{
const
std
::
string
&
op_type
,
const
NameVarBaseMap
&
ins
)
{
auto
dst_type
=
framework
::
proto
::
VarType
::
FP16
;
for
(
const
auto
&
pair
:
ins
)
{
for
(
const
auto
&
var
:
pair
.
second
)
{
...
...
@@ -151,6 +151,18 @@ static inline framework::proto::VarType::Type GetPromoteType(
}
}
}
// NOTE(juncai): moving_average_abs_max_scale only consider the
// dtype of input(X)
if
(
op_type
==
"moving_average_abs_max_scale"
)
{
for
(
const
auto
&
pair
:
ins
)
{
if
(
pair
.
first
==
"X"
&&
pair
.
second
.
front
()
->
DataType
()
==
framework
::
proto
::
VarType
::
FP16
)
{
dst_type
=
framework
::
proto
::
VarType
::
FP16
;
}
}
}
return
dst_type
;
}
...
...
@@ -183,7 +195,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type,
}
return
new_ins
;
}
else
{
auto
dst_type
=
GetPromoteType
(
ins
);
auto
dst_type
=
GetPromoteType
(
op_type
,
ins
);
// NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32.
if
(
dst_type
==
framework
::
proto
::
VarType
::
FP16
&&
AmpOperators
::
Instance
().
GetMutableUnsupportedFp16Ops
()
->
count
(
...
...
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
f88af205
...
...
@@ -25,18 +25,19 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
T
shared_max_data
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
if
(
gridDim
.
x
>
1
)
{
shared_max_data
[
tid
]
=
T
(
0
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
tmp
=
f
abs
(
in
[
i
]);
T
tmp
=
abs
(
in
[
i
]);
if
(
tmp
>
shared_max_data
[
tid
])
{
shared_max_data
[
tid
]
=
tmp
;
}
}
}
else
{
if
(
bid
<
n
)
{
shared_max_data
[
tid
]
=
f
abs
(
in
[
bid
]);
shared_max_data
[
tid
]
=
abs
(
in
[
bid
]);
}
else
{
shared_max_data
[
tid
]
=
T
(
0
);
}
...
...
@@ -73,6 +74,8 @@ struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
};
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>;
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis0
(
const
T
*
in
,
const
int
n
,
...
...
@@ -213,13 +216,16 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
(
bin_cnt_t
/
s
)
*
x
;
x
=
static_cast
<
T
>
(
round
(
static_cast
<
float
>
(
x
)));
out
[
i
]
=
(
x
*
s
)
/
bin_cnt_t
;
}
}
...
...
@@ -261,9 +267,6 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
}
};
template
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
// ChannelClipAndQuantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
...
...
@@ -423,8 +426,10 @@ struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
memory
::
Copy
(
platform
::
CPUPlace
(),
&
scale
,
gpu_place
,
cur_scale
,
sizeof
(
T
),
ctx
.
stream
());
ctx
.
Wait
();
state
=
rate
*
state
+
1
;
accum
=
rate
*
accum
+
scale
;
T
rate_t
=
static_cast
<
T
>
(
rate
);
state
=
rate_t
*
state
+
static_cast
<
T
>
(
1.0
);
accum
=
rate_t
*
accum
+
scale
;
scale
=
accum
/
state
;
memory
::
Copy
(
gpu_place
,
out_accum
->
mutable_data
<
T
>
(
gpu_place
),
...
...
@@ -527,10 +532,12 @@ template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
float16
=
paddle
::
platform
::
float16
;
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
...
...
@@ -539,12 +546,15 @@ REGISTER_OP_CUDA_KERNEL(
fake_quantize_moving_average_abs_max
,
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float
>
);
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float
>
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_moving_average_abs_max
,
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeDequantizeMovingAverageAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
stright_throuth_estimator_grad
,
ops
::
StrightThroughEstimatorGradKernel
<
CUDA
,
float
>
);
ops
::
StrightThroughEstimatorGradKernel
<
CUDA
,
float
>
,
ops
::
StrightThroughEstimatorGradKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_dequantize_abs_max
,
ops
::
FakeChannelWiseQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float
>
);
python/paddle/fluid/contrib/slim/tests/CMakeLists.txt
浏览文件 @
f88af205
...
...
@@ -127,6 +127,7 @@ if(WIN32)
list
(
REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model
)
list
(
REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1
)
list
(
REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2
)
list
(
REMOVE_ITEM TEST_OPS test_imperative_qat_amp
)
endif
()
if
(
LINUX AND WITH_MKLDNN
)
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py
0 → 100644
浏览文件 @
f88af205
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from
__future__
import
print_function
import
os
import
numpy
as
np
import
random
import
shutil
import
time
import
unittest
import
logging
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.contrib.slim.quantization
import
ImperativeQuantAware
from
paddle.fluid.log_helper
import
get_logger
from
paddle.dataset.common
import
download
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
os
.
environ
[
"CPU_NUM"
]
=
"1"
if
paddle
.
is_compiled_with_cuda
():
fluid
.
set_flags
({
"FLAGS_cudnn_deterministic"
:
True
})
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
class
TestImperativeQatAmp
(
unittest
.
TestCase
):
"""
Test the combination of qat and amp.
"""
@
classmethod
def
setUpClass
(
cls
):
timestamp
=
time
.
strftime
(
'%Y-%m-%d-%H-%M-%S'
,
time
.
localtime
())
cls
.
root_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"imperative_qat_amp_"
+
timestamp
)
cls
.
save_path
=
os
.
path
.
join
(
cls
.
root_path
,
"model"
)
cls
.
download_path
=
'dygraph_int8/download'
cls
.
cache_folder
=
os
.
path
.
expanduser
(
'~/.cache/paddle/dataset/'
+
cls
.
download_path
)
cls
.
lenet_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz"
cls
.
lenet_md5
=
"953b802fb73b52fae42896e3c24f0afb"
seed
=
1
np
.
random
.
seed
(
seed
)
paddle
.
static
.
default_main_program
().
random_seed
=
seed
paddle
.
static
.
default_startup_program
().
random_seed
=
seed
@
classmethod
def
tearDownClass
(
cls
):
try
:
shutil
.
rmtree
(
cls
.
root_path
)
except
Exception
as
e
:
print
(
"Failed to delete {} due to {}"
.
format
(
cls
.
root_path
,
str
(
e
)))
def
cache_unzipping
(
self
,
target_folder
,
zip_path
):
if
not
os
.
path
.
exists
(
target_folder
):
cmd
=
'mkdir {0} && tar xf {1} -C {0}'
.
format
(
target_folder
,
zip_path
)
os
.
system
(
cmd
)
def
download_model
(
self
,
data_url
,
data_md5
,
folder_name
):
download
(
data_url
,
self
.
download_path
,
data_md5
)
file_name
=
data_url
.
split
(
'/'
)[
-
1
]
zip_path
=
os
.
path
.
join
(
self
.
cache_folder
,
file_name
)
print
(
'Data is downloaded at {0}'
.
format
(
zip_path
))
data_cache_folder
=
os
.
path
.
join
(
self
.
cache_folder
,
folder_name
)
self
.
cache_unzipping
(
data_cache_folder
,
zip_path
)
return
data_cache_folder
def
set_vars
(
self
):
self
.
qat
=
ImperativeQuantAware
()
self
.
train_batch_num
=
30
self
.
train_batch_size
=
32
self
.
test_batch_num
=
100
self
.
test_batch_size
=
32
self
.
eval_acc_top1
=
0.99
def
model_train
(
self
,
model
,
batch_num
=-
1
,
batch_size
=
32
,
use_amp
=
False
):
model
.
train
()
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
adam
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
500
)
for
batch_id
,
data
in
enumerate
(
train_reader
()):
x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
paddle
.
to_tensor
(
x_data
)
label
=
paddle
.
to_tensor
(
y_data
)
if
use_amp
:
with
paddle
.
amp
.
auto_cast
():
out
=
model
(
img
)
acc
=
fluid
.
layers
.
accuracy
(
out
,
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
scaled_loss
=
scaler
.
scale
(
avg_loss
)
scaled_loss
.
backward
()
scaler
.
minimize
(
adam
,
scaled_loss
)
adam
.
clear_gradients
()
else
:
out
=
model
(
img
)
acc
=
fluid
.
layers
.
accuracy
(
out
,
label
)
loss
=
fluid
.
layers
.
cross_entropy
(
out
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
model
.
clear_gradients
()
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Train | step {}: loss = {:}, acc= {:}"
.
format
(
batch_id
,
avg_loss
.
numpy
(),
acc
.
numpy
()))
if
batch_num
>
0
and
batch_id
+
1
>=
batch_num
:
break
def
model_test
(
self
,
model
,
batch_num
=-
1
,
batch_size
=
32
,
use_amp
=
False
):
model
.
eval
()
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
)
acc_top1_list
=
[]
for
batch_id
,
data
in
enumerate
(
test_reader
()):
x_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
img
=
paddle
.
to_tensor
(
x_data
)
label
=
paddle
.
to_tensor
(
y_data
)
with
paddle
.
amp
.
auto_cast
(
use_amp
):
out
=
model
(
img
)
acc_top1
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
1
)
acc_top5
=
fluid
.
layers
.
accuracy
(
input
=
out
,
label
=
label
,
k
=
5
)
acc_top1_list
.
append
(
float
(
acc_top1
.
numpy
()))
if
batch_id
%
100
==
0
:
_logger
.
info
(
"Test | At step {}: acc1 = {:}, acc5 = {:}"
.
format
(
batch_id
,
acc_top1
.
numpy
(),
acc_top5
.
numpy
()))
if
batch_num
>
0
and
batch_id
+
1
>=
batch_num
:
break
acc_top1
=
sum
(
acc_top1_list
)
/
len
(
acc_top1_list
)
return
acc_top1
def
test_ptq
(
self
):
start_time
=
time
.
time
()
self
.
set_vars
()
params_path
=
self
.
download_model
(
self
.
lenet_url
,
self
.
lenet_md5
,
"lenet"
)
params_path
+=
"/lenet_pretrained/lenet.pdparams"
with
fluid
.
dygraph
.
guard
():
model
=
ImperativeLenet
()
model_state_dict
=
paddle
.
load
(
params_path
)
model
.
set_state_dict
(
model_state_dict
)
_logger
.
info
(
"Test fp32 model"
)
fp32_acc_top1
=
self
.
model_test
(
model
,
self
.
test_batch_num
,
self
.
test_batch_size
)
self
.
qat
.
quantize
(
model
)
use_amp
=
True
self
.
model_train
(
model
,
self
.
train_batch_num
,
self
.
train_batch_size
,
use_amp
)
_logger
.
info
(
"Test int8 model"
)
int8_acc_top1
=
self
.
model_test
(
model
,
self
.
test_batch_num
,
self
.
test_batch_size
,
use_amp
)
_logger
.
info
(
'fp32_acc_top1: %f, int8_acc_top1: %f'
%
(
fp32_acc_top1
,
int8_acc_top1
))
self
.
assertTrue
(
int8_acc_top1
>
fp32_acc_top1
-
0.01
,
msg
=
'fp32_acc_top1: %f, int8_acc_top1: %f'
%
(
fp32_acc_top1
,
int8_acc_top1
))
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
]
paddle
.
jit
.
save
(
layer
=
model
,
path
=
self
.
save_path
,
input_spec
=
input_spec
)
print
(
'Quantized model saved in {%s}'
%
self
.
save_path
)
end_time
=
time
.
time
()
print
(
"total time: %ss"
%
(
end_time
-
start_time
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/dygraph/amp/auto_cast.py
浏览文件 @
f88af205
...
...
@@ -29,6 +29,8 @@ WHITE_LIST = {
'matmul'
,
'matmul_v2'
,
'mul'
,
'fake_quantize_dequantize_abs_max'
,
'fake_quantize_dequantize_moving_average_abs_max'
,
}
# The set of ops that support fp16 calculation and are considered numerically-
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录