Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b72a7ebb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b72a7ebb
编写于
4月 05, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
4月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new format of quantization (#41041)
上级
b9ee846e
变更
24
展开全部
隐藏空白更改
内联
并排
Showing
24 changed file
with
3034 addition
and
1215 deletion
+3034
-1215
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-1
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+1
-134
paddle/fluid/operators/fake_dequantize_op.cu.h
paddle/fluid/operators/fake_dequantize_op.cu.h
+151
-0
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+1
-524
paddle/fluid/operators/fake_quantize_op.cu.h
paddle/fluid/operators/fake_quantize_op.cu.h
+543
-0
paddle/fluid/operators/quantize_linear_op.cc
paddle/fluid/operators/quantize_linear_op.cc
+173
-0
paddle/fluid/operators/quantize_linear_op.cu
paddle/fluid/operators/quantize_linear_op.cu
+70
-0
paddle/fluid/operators/quantize_linear_op.h
paddle/fluid/operators/quantize_linear_op.h
+119
-0
paddle/phi/kernels/cpu/cast_kernel.cc
paddle/phi/kernels/cpu/cast_kernel.cc
+1
-0
paddle/phi/kernels/gpu/cast_kernel.cu
paddle/phi/kernels/gpu/cast_kernel.cu
+1
-0
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+22
-2
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
...addle/fluid/contrib/slim/quantization/imperative/utils.py
+1
-4
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
...d/contrib/slim/quantization/post_training_quantization.py
+122
-91
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+910
-430
python/paddle/fluid/contrib/slim/quantization/utils.py
python/paddle/fluid/contrib/slim/quantization/utils.py
+321
-0
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
+27
-10
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py
...uid/contrib/slim/tests/test_imperative_qat_channelwise.py
+11
-0
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
.../slim/tests/test_post_training_quantization_lstm_model.py
+61
-9
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
...ntrib/slim/tests/test_post_training_quantization_mnist.py
+73
-3
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
...slim/tests/test_post_training_quantization_mobilenetv1.py
+52
-7
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
...ib/slim/tests/test_post_training_quantization_resnet50.py
+29
-0
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+125
-0
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
...n/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
+78
-0
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+140
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
b72a7ebb
...
...
@@ -102,10 +102,11 @@ endif()
set
(
OP_HEADER_DEPS
${
OP_HEADER_DEPS
}
phi phi_api_utils gather_scatter_kernel
)
register_operators
(
EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
register_operators
(
EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op
quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op
${
OP_MKL_DEPS
}
DEPS
${
OP_HEADER_DEPS
}
)
op_library
(
run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache
${
OP_HEADER_DEPS
}
)
op_library
(
quantize_linear_op DEPS cast_kernel
)
op_library
(
save_combine_op DEPS string_array
)
op_library
(
load_combine_op DEPS string_array
)
...
...
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
b72a7ebb
...
...
@@ -12,142 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fake_dequantize_op.cu.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
KeDequantize
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
T
*
out
)
{
const
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
idx
<
num
)
{
out
[
idx
]
=
in
[
idx
]
*
scale
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
framework
::
Tensor
*
out
)
{
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int
num
=
in
->
numel
();
int
block
=
512
;
int
grid
=
(
num
+
block
-
1
)
/
block
;
KeDequantize
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
out_data
);
}
};
template
<
typename
T
>
__global__
void
DequantizeOneScaleQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale
[
blockIdx
.
x
]
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeOneScaleQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
T
max_range
,
const
int64_t
num
,
const
int
n_scales
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
n_scales
];
out
[
i
]
=
in
[
i
]
*
s
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
int
iter_size
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
(
iter_size
*
channel
);
int
scale_index
=
blockIdx
.
x
%
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale_one
[
scale_index
]
*
scale_two
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
const
int
x_num_col_dims
,
framework
::
Tensor
*
out
)
{
auto
in_dims
=
in
->
dims
();
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
scale_num
==
1
)
{
int64_t
num
=
in
->
numel
();
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
DequantizeOneScaleQuantAxis0
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
0
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(
((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
DequantizeOneScaleQuantAxisN
<
T
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
else
if
(
scale_num
==
2
)
{
// Not need to consider quant_axis
int
num
=
in
->
numel
();
int
iter_size
=
1
;
for
(
int
i
=
0
;
i
<
x_num_col_dims
;
i
++
)
{
iter_size
*=
in
->
dims
()[
i
];
}
int
channel
=
in
->
dims
()[
x_num_col_dims
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
iter_size
*
channel
;
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
iter_size
,
channel
,
out_data
);
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
fake_dequantize_max_abs
,
...
...
paddle/fluid/operators/fake_dequantize_op.cu.h
0 → 100644
浏览文件 @
b72a7ebb
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_
#define PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_
#endif // PADDLE_FLUID_OPERATORS_FAKE_DEQUANTIZE_OP_CU_H_
#include "paddle/fluid/operators/fake_dequantize_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
KeDequantize
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int64_t
num
,
T
*
out
)
{
int64_t
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int64_t
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
out
[
i
]
=
in
[
i
]
*
scale
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
framework
::
Tensor
*
out
)
{
const
T
*
in_data
=
in
->
data
<
T
>
();
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int64_t
num
=
in
->
numel
();
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
KeDequantize
<
T
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
out_data
);
}
};
template
<
typename
T
>
__global__
void
DequantizeOneScaleQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
int
num
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale
[
blockIdx
.
x
]
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeOneScaleQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
T
max_range
,
const
int64_t
num
,
const
int
n_scales
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
n_scales
];
out
[
i
]
=
in
[
i
]
*
s
/
max_range
;
}
}
template
<
typename
T
>
__global__
void
DequantizeTwoScale
(
const
T
*
in
,
const
T
*
scale_one
,
const
T
*
scale_two
,
T
max_range
,
int
num
,
int
iter_size
,
int
channel
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
num
/
(
iter_size
*
channel
);
int
scale_index
=
blockIdx
.
x
%
channel
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
out_c
[
i
]
=
in_c
[
i
]
*
scale_one
[
scale_index
]
*
scale_two
[
0
]
/
max_range
;
}
}
template
<
typename
T
>
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
const
int
x_num_col_dims
,
framework
::
Tensor
*
out
)
{
auto
in_dims
=
in
->
dims
();
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
scale_num
==
1
)
{
int64_t
num
=
in
->
numel
();
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
DequantizeOneScaleQuantAxisN
<
T
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
else
if
(
scale_num
==
2
)
{
// Not need to consider quant_axis
int
num
=
in
->
numel
();
int
iter_size
=
1
;
for
(
int
i
=
0
;
i
<
x_num_col_dims
;
i
++
)
{
iter_size
*=
in
->
dims
()[
i
];
}
int
channel
=
in
->
dims
()[
x_num_col_dims
];
const
T
*
scale_one
=
scales
[
0
]
->
data
<
T
>
();
const
T
*
scale_two
=
scales
[
1
]
->
data
<
T
>
();
int
block
=
1024
;
int
grid
=
iter_size
*
channel
;
DequantizeTwoScale
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_one
,
scale_two
,
max_range
,
num
,
iter_size
,
channel
,
out_data
);
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
b72a7ebb
...
...
@@ -12,531 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.cu.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
if
(
gridDim
.
x
>
1
)
{
T
local_max_data
=
T
(
0
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
tmp
=
abs
(
in
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
}
else
{
if
(
bid
<
n
)
{
shared_max_data
[
tid
]
=
abs
(
in
[
bid
]);
}
else
{
shared_max_data
[
tid
]
=
T
(
0
);
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
)
{
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
framework
::
Tensor
max
;
T
*
max_data
=
max
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
grid
}),
ctx
.
GetPlace
());
FindAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
FindAbsMaxKernel
<
T
><<<
1
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
max_data
,
grid
,
out
);
}
};
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>;
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis0
(
const
T
*
in
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
T
shared_max_data
[];
T
local_max_data
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
fabs
(
in_c
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis1
(
const
T
*
in
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
extern
__shared__
T
shared_max_data
[];
int
cout_wh_size
=
n
/
cin
;
int
wh_size
=
n
/
(
cin
*
cout
);
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
const
T
*
in_current
=
in
+
tid
*
cout_wh_size
+
bid
*
wh_size
;
T
local_max_data
=
T
(
0
);
for
(
int
i
=
0
;
i
<
wh_size
;
i
++
)
{
T
tmp
=
fabs
(
in_current
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
__syncthreads
();
int
len
=
blockDim
.
x
;
for
(
int
i
=
(
len
+
1
)
/
2
;
i
>
0
;
len
=
i
,
i
=
(
i
+
1
)
/
2
)
{
if
(
tid
<
i
&&
tid
+
i
<
len
&&
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
])
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
if
(
i
==
1
)
{
i
=
0
;
// break the loop
}
__syncthreads
();
}
if
(
tid
==
0
&&
shared_max_data
[
0
]
>
out
[
bid
])
{
out
[
bid
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_tensor
,
const
int
quant_axis
,
T
*
out_abs_max
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
const
int
num
=
in_tensor
.
numel
();
auto
in_dims
=
in_tensor
.
dims
();
const
T
*
in_data
=
in_tensor
.
data
<
T
>
();
if
(
quant_axis
==
0
)
{
int
cout
=
in_dims
[
0
];
int
grid
=
cout
;
int
block
=
1024
;
FindChannelAbsMaxKernelQuantAxis0
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
cout
,
out_abs_max
);
}
else
if
(
quant_axis
==
1
)
{
int
cin
=
in_dims
[
0
];
int
cout
=
in_dims
[
1
];
int
grid
=
cout
;
int
max_threads
=
1024
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
out_abs_max
,
0
,
sizeof
(
T
)
*
cout
);
#else
cudaMemset
(
out_abs_max
,
0
,
sizeof
(
T
)
*
cout
);
#endif
for
(
int
i
=
0
;
i
<
cin
/
max_threads
;
i
++
)
{
int
block
=
max_threads
;
FindChannelAbsMaxKernelQuantAxis1
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
cin
,
cout
,
out_abs_max
);
in_data
+=
num
/
cin
;
}
int
block
=
cin
%
max_threads
;
if
(
block
>
0
)
{
FindChannelAbsMaxKernelQuantAxis1
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_abs_max
);
}
}
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
__global__
void
ClipAndQuantDequantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
static_cast
<
T
>
(
round
(
static_cast
<
float
>
(
x
)));
out
[
i
]
=
(
x
*
s
)
/
bin_cnt_t
;
}
}
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantDequantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
// ChannelClipAndQuantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int64_t
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
);
}
}
// ChannelClipAndQuantKernel for quant_axis is N
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
nScale
];
T
inv_s
=
1.0
/
s
;
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
int64_t
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
ChannelClipAndQuantKernelQuantAxisN
<
T
><<<
grid_size
,
block_size
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
const
int64_t
*
iter
,
const
int
window_size
,
T
*
scale_arr
,
T
*
out_scale
,
int
*
need_find_max
,
int
*
out_size
)
{
int
it
=
iter
[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
[
0
];
out_scale
[
0
]
=
max
<
cur
?
cur
:
max
;
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
need_find_max
[
0
]
=
1
;
out_size
[
0
]
=
it
>
window_size
?
window_size
:
it
;
}
else
{
need_find_max
[
0
]
=
0
;
}
}
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
framework
::
Tensor
need_find_max
,
out_size
;
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
({
1
},
gpu_place
);
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
({
1
},
gpu_place
);
FindRangeAbsMaxAndFillArray
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
cur_scale
.
data
<
T
>
(),
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
int
g_find_max
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
g_find_max
,
gpu_place
,
find_max
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
if
(
g_find_max
)
{
int
len
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
len
,
gpu_place
,
out_size_data
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
,
scale_arr
,
len
,
out_scale_data
);
}
}
};
template
<
typename
T
>
__global__
void
FindMovingAverageAbsMaxKernel
(
const
T
*
in_state
,
const
T
*
in_accum
,
const
T
*
cur_scale
,
const
T
rate
,
T
*
out_state
,
T
*
out_accum
,
T
*
out_scale
)
{
T
state
=
rate
*
(
*
in_state
)
+
T
(
1.0
f
);
T
accum
=
rate
*
(
*
in_accum
)
+
(
*
cur_scale
);
*
out_state
=
state
;
*
out_accum
=
accum
;
*
out_scale
=
accum
/
state
;
}
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_state
,
const
T
*
cur_scale
,
const
float
rate
,
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
rate_t
=
static_cast
<
T
>
(
rate
);
T
*
out_state_data
=
out_state
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_accum_data
=
out_accum
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
FindMovingAverageAbsMaxKernel
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
in_state
.
data
<
T
>
(),
in_accum
.
data
<
T
>
(),
cur_scale
,
rate_t
,
out_state_data
,
out_accum_data
,
out_scale_data
);
}
};
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis1
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
inv_s
=
inverse
(
s
);
int
wh_size
=
n
/
(
cin
*
cout
);
const
T
*
in_c
=
in
+
blockIdx
.
x
*
wh_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
wh_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
template
<
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
int
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
}
else
if
(
quant_axis
==
1
)
{
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis1
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
}
};
template
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
...
...
paddle/fluid/operators/fake_quantize_op.cu.h
0 → 100644
浏览文件 @
b72a7ebb
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#define PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
if
(
gridDim
.
x
>
1
)
{
T
local_max_data
=
T
(
0
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
tmp
=
abs
(
in
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
}
else
{
if
(
bid
<
n
)
{
shared_max_data
[
tid
]
=
abs
(
in
[
bid
]);
}
else
{
shared_max_data
[
tid
]
=
T
(
0
);
}
}
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
T
*
in
,
const
int
num
,
T
*
out
)
{
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
framework
::
Tensor
max
;
T
*
max_data
=
max
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
grid
}),
ctx
.
GetPlace
());
FindAbsMaxKernel
<
T
><<<
grid
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in
,
num
,
max_data
);
FindAbsMaxKernel
<
T
><<<
1
,
block
,
1024
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
max_data
,
grid
,
out
);
}
};
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
paddle
::
platform
::
float16
>;
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis0
(
const
T
*
in
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
T
shared_max_data
[];
T
local_max_data
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
fabs
(
in_c
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
__syncthreads
();
for
(
int
i
=
blockDim
.
x
/
2
;
i
>
0
;
i
>>=
1
)
{
if
(
tid
<
i
&&
(
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
]))
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
__syncthreads
();
}
if
(
tid
==
0
)
{
out
[
blockIdx
.
x
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
__global__
void
FindChannelAbsMaxKernelQuantAxis1
(
const
T
*
in
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
extern
__shared__
T
shared_max_data
[];
int
cout_wh_size
=
n
/
cin
;
int
wh_size
=
n
/
(
cin
*
cout
);
int
tid
=
threadIdx
.
x
;
int
bid
=
blockIdx
.
x
;
const
T
*
in_current
=
in
+
tid
*
cout_wh_size
+
bid
*
wh_size
;
T
local_max_data
=
T
(
0
);
for
(
int
i
=
0
;
i
<
wh_size
;
i
++
)
{
T
tmp
=
fabs
(
in_current
[
i
]);
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
}
shared_max_data
[
tid
]
=
local_max_data
;
__syncthreads
();
int
len
=
blockDim
.
x
;
for
(
int
i
=
(
len
+
1
)
/
2
;
i
>
0
;
len
=
i
,
i
=
(
i
+
1
)
/
2
)
{
if
(
tid
<
i
&&
tid
+
i
<
len
&&
shared_max_data
[
tid
]
<
shared_max_data
[
tid
+
i
])
{
shared_max_data
[
tid
]
=
shared_max_data
[
tid
+
i
];
}
if
(
i
==
1
)
{
i
=
0
;
// break the loop
}
__syncthreads
();
}
if
(
tid
==
0
&&
shared_max_data
[
0
]
>
out
[
bid
])
{
out
[
bid
]
=
shared_max_data
[
0
];
}
}
template
<
typename
T
>
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_tensor
,
const
int
quant_axis
,
T
*
out_abs_max
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
const
int
num
=
in_tensor
.
numel
();
auto
in_dims
=
in_tensor
.
dims
();
const
T
*
in_data
=
in_tensor
.
data
<
T
>
();
if
(
quant_axis
==
0
)
{
int
cout
=
in_dims
[
0
];
int
grid
=
cout
;
int
block
=
1024
;
FindChannelAbsMaxKernelQuantAxis0
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
cout
,
out_abs_max
);
}
else
if
(
quant_axis
==
1
)
{
int
cin
=
in_dims
[
0
];
int
cout
=
in_dims
[
1
];
int
grid
=
cout
;
int
max_threads
=
1024
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
out_abs_max
,
0
,
sizeof
(
T
)
*
cout
);
#else
cudaMemset
(
out_abs_max
,
0
,
sizeof
(
T
)
*
cout
);
#endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
for
(
int
i
=
0
;
i
<
cin
/
max_threads
;
i
++
)
{
int
block
=
max_threads
;
FindChannelAbsMaxKernelQuantAxis1
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
cin
,
cout
,
out_abs_max
);
in_data
+=
num
/
cin
;
}
int
block
=
cin
%
max_threads
;
if
(
block
>
0
)
{
FindChannelAbsMaxKernelQuantAxis1
<
T
><<<
grid
,
block
,
block
*
sizeof
(
T
),
ctx
.
stream
()
>>>
(
in_data
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_abs_max
);
}
}
}
};
template
struct
FindChannelAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
ClipAndQuantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
__global__
void
ClipAndQuantDequantKernel
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
static_cast
<
T
>
(
round
(
static_cast
<
float
>
(
x
)));
out
[
i
]
=
(
x
*
s
)
/
bin_cnt_t
;
}
}
template
<
typename
T
>
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
template
struct
ClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
ClipAndFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
framework
::
Tensor
*
out
)
{
int
num
=
in
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ClipAndQuantDequantKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
out_data
);
}
};
// ChannelClipAndQuantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int64_t
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
);
}
}
// ChannelClipAndQuantKernel for quant_axis is N
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
nScale
];
T
inv_s
=
1.0
/
s
;
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
}
}
template
<
typename
T
>
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
int64_t
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
ChannelClipAndQuantKernelQuantAxisN
<
T
><<<
grid_size
,
block_size
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
};
template
struct
ChannelClipAndFakeQuantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
__global__
void
FindRangeAbsMaxAndFillArray
(
const
T
*
cur_scale
,
const
T
*
last_scale
,
const
int64_t
*
iter
,
const
int
window_size
,
T
*
scale_arr
,
T
*
out_scale
,
int
*
need_find_max
,
int
*
out_size
)
{
int
it
=
iter
[
0
];
int
idx
=
it
%
window_size
;
T
removed
=
scale_arr
[
idx
];
T
cur
=
cur_scale
[
0
];
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
[
0
];
out_scale
[
0
]
=
max
<
cur
?
cur
:
max
;
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
need_find_max
[
0
]
=
1
;
out_size
[
0
]
=
it
>
window_size
?
window_size
:
it
;
}
else
{
need_find_max
[
0
]
=
0
;
}
}
template
<
typename
T
>
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
cur_scale
,
const
framework
::
Tensor
&
last_scale
,
const
framework
::
Tensor
&
iter
,
const
int
window_size
,
framework
::
Tensor
*
scales_arr
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
*
scale_arr
=
scales_arr
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
framework
::
Tensor
need_find_max
,
out_size
;
int
*
find_max
=
need_find_max
.
mutable_data
<
int
>
({
1
},
gpu_place
);
int
*
out_size_data
=
out_size
.
mutable_data
<
int
>
({
1
},
gpu_place
);
FindRangeAbsMaxAndFillArray
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
cur_scale
.
data
<
T
>
(),
last_scale
.
data
<
T
>
(),
iter
.
data
<
int64_t
>
(),
window_size
,
scale_arr
,
out_scale_data
,
find_max
,
out_size_data
);
int
g_find_max
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
g_find_max
,
gpu_place
,
find_max
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
if
(
g_find_max
)
{
int
len
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
len
,
gpu_place
,
out_size_data
,
sizeof
(
int
),
ctx
.
stream
());
ctx
.
Wait
();
FindAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
()(
ctx
,
scale_arr
,
len
,
out_scale_data
);
}
}
};
template
<
typename
T
>
__global__
void
FindMovingAverageAbsMaxKernel
(
const
T
*
in_state
,
const
T
*
in_accum
,
const
T
*
cur_scale
,
const
T
rate
,
T
*
out_state
,
T
*
out_accum
,
T
*
out_scale
)
{
T
state
=
rate
*
(
*
in_state
)
+
T
(
1.0
f
);
T
accum
=
rate
*
(
*
in_accum
)
+
(
*
cur_scale
);
*
out_state
=
state
;
*
out_accum
=
accum
;
*
out_scale
=
accum
/
state
;
}
template
struct
FindRangeAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
<
typename
T
>
struct
FindMovingAverageAbsMaxFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in_accum
,
const
framework
::
Tensor
&
in_state
,
const
T
*
cur_scale
,
const
float
rate
,
framework
::
Tensor
*
out_state
,
framework
::
Tensor
*
out_accum
,
framework
::
Tensor
*
out_scale
)
{
const
auto
gpu_place
=
ctx
.
GetPlace
();
T
rate_t
=
static_cast
<
T
>
(
rate
);
T
*
out_state_data
=
out_state
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_accum_data
=
out_accum
->
mutable_data
<
T
>
(
gpu_place
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
gpu_place
);
FindMovingAverageAbsMaxKernel
<
T
><<<
1
,
1
,
0
,
ctx
.
stream
()
>>>
(
in_state
.
data
<
T
>
(),
in_accum
.
data
<
T
>
(),
cur_scale
,
rate_t
,
out_state_data
,
out_accum_data
,
out_scale_data
);
}
};
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
// ChannelClipAndQuantDequantKernel for quant_axis is 1
template
<
typename
T
>
__global__
void
ChannelClipAndQuantDequantKernelQuantAxis1
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
inv_s
=
inverse
(
s
);
int
wh_size
=
n
/
(
cin
*
cout
);
const
T
*
in_c
=
in
+
blockIdx
.
x
*
wh_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
wh_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
)
*
s
/
bin_cnt
;
}
}
template
<
typename
T
>
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
scale
,
const
int
bin_cnt
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
int
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
}
else
if
(
quant_axis
==
1
)
{
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
ChannelClipAndQuantDequantKernelQuantAxis1
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
}
};
template
struct
ChannelClipFakeQuantDequantFunctor
<
platform
::
CUDADeviceContext
,
float
>;
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/quantize_linear_op.cc
0 → 100644
浏览文件 @
b72a7ebb
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/quantize_linear_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
// Dequant op is before quantized op
// Dequantize the weight of quantized op
auto
in_dims
=
in
->
dims
();
const
int64_t
channel
=
in_dims
[
quant_axis
];
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
if
(
quant_axis
==
0
)
{
for
(
int64_t
i
=
0
;
i
<
channel
;
i
++
)
{
T
s
=
scale_factor
[
i
];
framework
::
Tensor
one_channel_in
=
in
->
Slice
(
i
,
i
+
1
);
framework
::
Tensor
one_channel_out
=
out
->
Slice
(
i
,
i
+
1
);
auto
in_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_in
);
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
one_channel_out
);
auto
&
dev
=
*
dev_ctx
.
eigen_device
();
out_e
.
device
(
dev
)
=
in_e
*
s
/
max_range
;
}
}
else
if
(
quant_axis
==
1
)
{
int64_t
out_iter
=
1
;
for
(
int
i
=
0
;
i
<
quant_axis
;
i
++
)
{
out_iter
*=
in_dims
[
i
];
}
int64_t
step_i
=
in
->
numel
()
/
out_iter
;
int64_t
step_j
=
in
->
numel
()
/
(
out_iter
*
channel
);
auto
*
in_data
=
in
->
data
<
T
>
();
auto
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
out_iter
;
i
++
)
{
for
(
int64_t
j
=
0
;
j
<
channel
;
j
++
)
{
auto
*
cur_in
=
in_data
+
i
*
step_i
+
j
*
step_j
;
auto
*
cur_out
=
out_data
+
i
*
step_i
+
j
*
step_j
;
T
s
=
scale_factor
[
j
];
for
(
int64_t
k
=
0
;
k
<
step_j
;
k
++
)
{
*
cur_out
=
(
*
cur_in
)
*
s
/
max_range
;
++
cur_in
;
++
cur_out
;
}
}
}
}
}
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CPUDeviceContext
,
double
>;
class
QuantizeLinearOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ZeroPoint"
),
"Input"
,
"ZeroPoint"
,
"QuantizeLinear"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"QuantizeLinear"
);
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
int
quant_axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"quant_axis"
);
if
(
ctx
->
HasOutput
(
"OutScale"
))
{
if
(
quant_axis
<
0
)
{
ctx
->
SetOutputDim
(
"OutScale"
,
{
1
});
}
else
{
ctx
->
SetOutputDim
(
"OutScale"
,
{
ctx
->
GetInputDim
(
"X"
)[
quant_axis
]});
}
}
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Y"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
};
class
QuantizeLinearOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) Input is float data type."
);
AddInput
(
"Scale"
,
"(Tensor) Input is float data type."
);
AddInput
(
"ZeroPoint"
,
"(Tensor) Input is float data type."
);
AddOutput
(
"Y"
,
"(Tensor) Output of quantized low level tensor, "
"but also saved as float data type."
);
AddOutput
(
"OutScale"
,
"(Tensor) Current scale"
).
AsDispensable
().
AsExtra
();
AddAttr
<
int
>
(
"quant_axis"
,
"(int, default 0) The axis for quantization. "
"For conv2d, depthwise_conv2d, conv2d_transpose "
"and mul, the quant_axis is equal to the cout axis."
)
.
SetDefault
(
0
)
.
AddCustomChecker
([](
const
int
&
quant_axis
)
{
PADDLE_ENFORCE_EQ
(
quant_axis
==
0
||
quant_axis
==
1
||
quant_axis
==
-
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"'quant_axis' should be 0 or 1, but "
"the received is %d"
,
quant_axis
));
});
AddAttr
<
int
>
(
"bit_length"
,
"(int, default 8)"
)
.
SetDefault
(
8
)
.
AddCustomChecker
([](
const
int
&
bit_length
)
{
PADDLE_ENFORCE_EQ
(
bit_length
>=
1
&&
bit_length
<=
16
,
true
,
platform
::
errors
::
InvalidArgument
(
"'bit_length' should be between 1 and 16, but "
"the received is %d"
,
bit_length
));
});
AddAttr
<
bool
>
(
"is_test"
,
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
The scale of QuantizeLinear operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CPU
=
paddle
::
platform
::
CPUDeviceContext
;
REGISTER_OPERATOR
(
quantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
quantize_linear
,
ops
::
QuantizeLinearKernel
<
CPU
,
float
>
);
REGISTER_OPERATOR
(
dequantize_linear
,
ops
::
QuantizeLinearOp
,
ops
::
QuantizeLinearOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
dequantize_linear
,
ops
::
DeQuantizeLinearKernel
<
CPU
,
float
,
float
>
,
ops
::
DeQuantizeLinearKernel
<
CPU
,
int8_t
,
float
>
,
ops
::
DeQuantizeLinearKernel
<
CPU
,
double
,
double
>
);
paddle/fluid/operators/quantize_linear_op.cu
0 → 100644
浏览文件 @
b72a7ebb
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_dequantize_op.cu.h"
#include "paddle/fluid/operators/fake_quantize_op.cu.h"
#include "paddle/fluid/operators/quantize_linear_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
ChannelDequantizeFunctorV2
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
*
scale
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
)
{
auto
in_dims
=
in
->
dims
();
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
int64_t
num
=
in
->
numel
();
const
T
*
scale_factor
=
scale
->
data
<
T
>
();
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
DequantizeOneScaleQuantAxisN
<
T
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
};
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ChannelDequantizeFunctorV2
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
dequantize_linear
,
ops
::
DeQuantizeLinearKernel
<
CUDA
,
float
,
float
>
,
ops
::
DeQuantizeLinearKernel
<
CUDA
,
int8_t
,
float
>
,
ops
::
DeQuantizeLinearKernel
<
CUDA
,
double
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
quantize_linear
,
ops
::
QuantizeLinearKernel
<
CUDA
,
float
>
);
paddle/fluid/operators/quantize_linear_op.h
0 → 100644
浏览文件 @
b72a7ebb
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
struct
ChannelDequantizeFunctorV2
{
void
operator
()(
const
DeviceContext
&
dev_ctx
,
const
framework
::
Tensor
*
in
,
const
framework
::
Tensor
**
scales
,
const
int
scale_num
,
T
max_range
,
const
int
quant_axis
,
framework
::
Tensor
*
out
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
QuantizeLinearKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
in_scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Scale"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Y"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
int
bin_cnt
=
std
::
pow
(
2
,
bit_length
-
1
)
-
1
;
int
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
if
(
quant_axis
<
0
)
{
if
(
!
is_test
)
{
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
T
*
out_s
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
FindAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
in
->
data
<
T
>
(),
in
->
numel
(),
out_s
);
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
out
);
}
else
{
ClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
out
);
}
}
else
{
if
(
!
is_test
)
{
auto
*
out_scale
=
context
.
Output
<
framework
::
Tensor
>
(
"OutScale"
);
T
*
out_scale_data
=
out_scale
->
mutable_data
<
T
>
(
context
.
GetPlace
());
FindChannelAbsMaxFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
quant_axis
,
out_scale_data
);
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
out_scale
,
bin_cnt
,
quant_axis
,
out
);
}
else
{
ChannelClipAndFakeQuantFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
in
,
*
in_scale
,
bin_cnt
,
quant_axis
,
out
);
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
,
typename
D
>
class
DeQuantizeLinearKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in_tmp
=
phi
::
Cast
<
T
>
(
static_cast
<
const
typename
paddle
::
framework
::
ConvertToPhiContext
<
DeviceContext
>::
TYPE
&>
(
dev_ctx
),
*
in
,
experimental
::
CppTypeToDataType
<
D
>::
Type
());
auto
*
scale
=
context
.
Input
<
framework
::
Tensor
>
(
"Scale"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Y"
);
int
bit_length
=
context
.
Attr
<
int
>
(
"bit_length"
);
auto
quant_axis
=
context
.
Attr
<
int
>
(
"quant_axis"
);
out
->
mutable_data
<
D
>
(
dev_ctx
.
GetPlace
());
if
(
quant_axis
<
0
)
{
float
max_range
=
(
std
::
pow
(
2
,
bit_length
-
1
)
-
1
);
DequantizeFunctor
<
DeviceContext
,
D
>
()(
dev_ctx
,
&
in_tmp
,
scale
,
static_cast
<
D
>
(
max_range
),
out
);
}
else
{
PADDLE_ENFORCE_EQ
(
scale
->
numel
(),
in_tmp
.
dims
()[
quant_axis
],
platform
::
errors
::
PreconditionNotMet
(
"The number of first scale values must be the same with "
"quant_axis dimension value of Input(X) when the `scale` has "
"only one element, but %ld != %ld here."
,
scale
->
numel
(),
in_tmp
.
dims
()[
quant_axis
]));
int
max_range
=
(
std
::
pow
(
2
,
bit_length
-
1
)
-
1
);
ChannelDequantizeFunctorV2
<
DeviceContext
,
D
>
()(
dev_ctx
,
&
in_tmp
,
scale
,
static_cast
<
D
>
(
max_range
),
quant_axis
,
out
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/cpu/cast_kernel.cc
浏览文件 @
b72a7ebb
...
...
@@ -41,6 +41,7 @@ PD_REGISTER_KERNEL(cast,
int64_t
,
int16_t
,
bool
,
int8_t
,
uint8_t
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
,
...
...
paddle/phi/kernels/gpu/cast_kernel.cu
浏览文件 @
b72a7ebb
...
...
@@ -41,6 +41,7 @@ void CastKernel(const Context& dev_ctx,
int64_t, \
int16_t, \
bool, \
int8_t, \
uint8_t, \
phi::dtype::float16, \
phi::dtype::complex<float>, \
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
b72a7ebb
...
...
@@ -28,6 +28,7 @@ from paddle.fluid.param_attr import ParamAttr
from
paddle.fluid.initializer
import
Constant
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
paddle.fluid.io
import
load_inference_model
,
save_inference_model
from
..quantization_pass
import
ReplaceFakeQuantDequantPass
,
QuantWeightPass
from
paddle.fluid.log_helper
import
get_logger
from
..
import
quantization_pass
from
.
import
utils
...
...
@@ -431,7 +432,12 @@ class ImperativeQuantizeOutputs(object):
setattr
(
parent_layer
,
sub_name
,
cur_quant_layer
)
def
save_quantized_model
(
self
,
model
,
path
,
input_spec
=
None
,
**
config
):
def
save_quantized_model
(
self
,
model
,
path
,
input_spec
=
None
,
onnx_format
=
False
,
**
config
):
"""
Save the quantized model for the inference.
...
...
@@ -444,6 +450,8 @@ class ImperativeQuantizeOutputs(object):
InputSpec or example Tensor. If None, all input variables of
the original Layer's forward method would be the inputs of
the saved model. Default None.
onnx_format (bool, optional): Whether to export the quantized model
with format of ONNX. Default is False.
**configs (dict, optional): Other save configuration options for
compatibility. We do not recommend using these configurations,
they may be removed in the future. If not necessary, DO NOT use
...
...
@@ -498,6 +506,18 @@ class ImperativeQuantizeOutputs(object):
self
.
_set_skip_quant_attr
(
infer_program
)
clip_extra
=
False
if
onnx_format
:
graph
=
IrGraph
(
core
.
Graph
(
infer_program
.
desc
),
for_test
=
False
)
transform_pass
=
ReplaceFakeQuantDequantPass
(
scope
,
place
)
transform_pass
.
apply
(
graph
)
quant_weight_pass
=
QuantWeightPass
(
scope
,
place
)
quant_weight_pass
.
apply
(
graph
)
infer_program
=
graph
.
to_program
()
clip_extra
=
True
save_inference_model
(
dirname
=
dirname
,
feeded_var_names
=
feed_target_names
,
...
...
@@ -506,7 +526,7 @@ class ImperativeQuantizeOutputs(object):
main_program
=
infer_program
.
clone
(),
model_filename
=
model_filename
,
params_filename
=
params_filename
,
clip_extra
=
False
)
clip_extra
=
clip_extra
)
if
is_dynamic_mode
:
paddle
.
disable_static
()
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
浏览文件 @
b72a7ebb
...
...
@@ -18,10 +18,7 @@ import numpy as np
import
paddle
import
paddle.nn.quant.quant_layers
as
quant_layers
from
..quantization_pass
import
_get_op_input_var_names
from
..quantization_pass
import
_get_op_output_var_names
from
..quantization_pass
import
_get_output_name_index
from
..quantization_pass
import
_get_input_name_index
from
..utils
import
_get_op_input_var_names
,
_get_op_output_var_names
,
_get_output_name_index
,
_get_input_name_index
layer_name_map
=
{
'Conv2DTranspose'
:
paddle
.
nn
.
Conv2DTranspose
,
...
...
python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py
浏览文件 @
b72a7ebb
此差异已折叠。
点击以展开。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
b72a7ebb
此差异已折叠。
点击以展开。
python/paddle/fluid/contrib/slim/quantization/utils.py
浏览文件 @
b72a7ebb
...
...
@@ -13,11 +13,292 @@
# limitations under the License.
import
numpy
as
np
from
....framework
import
IrNode
from
....framework
import
Operator
_weight_supported_quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'conv2d_transpose'
,
'mul'
,
'matmul'
,
'matmul_v2'
]
_act_supported_quantizable_op_type
=
[
"pool2d"
,
"elementwise_add"
,
"concat"
,
"softmax"
,
"argmax"
,
"transpose"
,
"equal"
,
"gather"
,
"greater_equal"
,
"greater_than"
,
"less_equal"
,
"less_than"
,
"mean"
,
"not_equal"
,
"reshape"
,
"reshape2"
,
"dropout"
,
"bilinear_interp"
,
"nearest_interp"
,
"trilinear_interp"
,
"slice"
,
"squeeze"
,
"elementwise_sub"
,
"mul"
,
"matmul"
,
"relu"
,
"relu6"
,
"leaky_relu"
,
"tanh"
,
"swish"
,
"scale"
,
"transpose"
,
"transpose2"
,
"sigmoid"
,
"pad2d"
,
"flatten"
,
"flatten2"
,
"batch_norm"
,
"layer_norm"
,
"matmul_v2"
,
"split"
,
"flatten_contiguous_range"
,
"squeeze2"
,
"nearest_interp_v2"
,
"bilinear_interp"
,
"bilinear_interp_v2"
,
"fill_constant_batch_size_like"
,
"arg_max"
,
"abs"
,
"assign"
,
"cast"
,
"clip"
,
"box_coder"
,
"crop"
,
"cumsum"
,
"elementwise_mul"
,
"elementwise_pow"
,
"expand_v2"
,
"fill_any_like"
,
"fill_constant"
,
"gelu"
,
"hard_sigmoid"
,
"hard_swish"
,
"instance_norm"
,
"lookup_table"
,
"lookup_table_v2"
,
"norm"
,
"p_norm"
,
"pad3d"
,
"pow"
,
"prelu"
,
"reduce_mean"
,
"unsqueeze"
,
"unsqueeze2"
,
"logical_and"
,
"logical_not"
,
"meshgrid"
,
"roi_align"
,
"strided_slice"
,
"where"
,
"grid_sampler"
,
"tile"
,
"group_norm"
,
"reduce_sum"
,
"square"
,
"softplus"
,
"shuffle_channel"
,
]
_out_scale_op_list
=
list
(
set
(
_weight_supported_quantizable_op_type
+
_act_supported_quantizable_op_type
))
_channelwise_quant_axis1_ops
=
[
'conv2d_transpose'
,
'mul'
,
'matmul'
,
'matmul_v2'
]
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name
=
{
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"conv2d_transpose"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"matmul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"matmul_v2"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
"elementwise_add"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"concat"
:
[[
"X"
],
[
"Out"
]],
"softmax"
:
[[
"X"
],
[
"Out"
]],
"argmax"
:
[[
"X"
],
[
"Out"
]],
"transpose"
:
[[
"X"
],
[
"Out"
]],
"equal"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"gather"
:
[[
"X"
],
[
"Out"
]],
"greater_equal"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"greater_than"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"less_equal"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"less_than"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"mean"
:
[[
"X"
],
[
"Out"
]],
"not_equal"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"reshape"
:
[[
"X"
],
[
"Out"
]],
"reshape2"
:
[[
"X"
],
[
"Out"
]],
"transpose2"
:
[[
"X"
],
[
"Out"
]],
"bilinear_interp"
:
[[
"X"
],
[
"Out"
]],
"nearest_interp"
:
[[
"X"
],
[
"Out"
]],
"trilinear_interp"
:
[[
"X"
],
[
"Out"
]],
"slice"
:
[[
"Input"
],
[
"Out"
]],
"squeeze"
:
[[
"X"
],
[
"Out"
]],
"elementwise_sub"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"relu"
:
[[
"X"
],
[
"Out"
]],
"relu6"
:
[[
"X"
],
[
"Out"
]],
"leaky_relu"
:
[[
"X"
],
[
"Out"
]],
"prelu"
:
[[
"X"
,
"Alpha"
],
[
"Out"
]],
"tanh"
:
[[
"X"
],
[
"Out"
]],
"swish"
:
[[
"X"
],
[
"Out"
]],
"dropout"
:
[[
"X"
],
[
"Out"
]],
"batch_norm"
:
[[
"X"
],
[
"Y"
]],
"layer_norm"
:
[[
"X"
],
[
"Y"
]],
"sigmoid"
:
[[
"X"
],
[
"Out"
]],
"elementwise_mul"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"elementwise_pow"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"scale"
:
[[
"X"
],
[
"Out"
]],
"hard_swish"
:
[[
"X"
],
[
"Out"
]],
"hard_sigmoid"
:
[[
"X"
],
[
"Out"
]],
"gru"
:
[[
"Input"
,
"Weight"
],
[
"Hidden"
]],
"lstm"
:
[[
"Input"
,
"Weight"
],
[
"Hidden"
]],
"pad2d"
:
[[
"X"
],
[
"Out"
]],
"pad3d"
:
[[
"X"
],
[
"Out"
]],
"flatten"
:
[[
"X"
],
[
"Out"
]],
"flatten2"
:
[[
"X"
],
[
"Out"
]],
"unsqueeze2"
:
[[
"X"
],
[
"Out"
]],
"unsqueeze2"
:
[[
"X"
],
[
"Out"
]],
"flatten_contiguous_range"
:
[[
"X"
],
[
"Out"
]],
"split"
:
[[
"X"
],
[
"Out"
]],
"squeeze2"
:
[[
"X"
],
[
"Out"
]],
"nearest_interp_v2"
:
[[
"X"
],
[
"Out"
]],
"bilinear_interp"
:
[[
"X"
],
[
"Out"
]],
"bilinear_interp_v2"
:
[[
"X"
],
[
"Out"
]],
"fill_constant_batch_size_like"
:
[[
"Input"
],
[
"Out"
]],
"arg_max"
:
[[
"X"
],
[
"Out"
]],
"abs"
:
[[
"X"
],
[
"Out"
]],
"assign"
:
[[
"X"
],
[
"Out"
]],
"cast"
:
[[
"X"
],
[
"Out"
]],
"clip"
:
[[
"X"
],
[
"Out"
]],
"box_coder"
:
[[
"PriorBox"
],
[
"OutputBox"
]],
"crop"
:
[[
"X"
],
[
"Out"
]],
"cumsum"
:
[[
"X"
],
[
"Out"
]],
"expand_v2"
:
[[
"X"
],
[
"Out"
]],
"fill_any_like"
:
[[
"X"
],
[
"Out"
]],
"fill_constant"
:
[[],
[
"Out"
]],
"gelu"
:
[[
"X"
],
[
"Out"
]],
"instance_norm"
:
[[
"X"
],
[
"Out"
]],
"lookup_table"
:
[[
"W"
,
"Ids"
],
[
"Out"
]],
"lookup_table_v2"
:
[[
"W"
,
"Ids"
],
[
"Out"
]],
"norm"
:
[[
"X"
],
[
"Norm"
]],
"p_norm"
:
[[
"X"
],
[
"Out"
]],
"pow"
:
[[
"X"
],
[
"Out"
]],
"reduce_mean"
:
[[
"X"
],
[
"Out"
]],
"stack"
:
[[
"X"
],
[
"Y"
]],
"top_k_v2"
:
[[
"X"
],
[
"Out"
,
"Indices"
]],
"logical_and"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"logical_not"
:
[[
"X"
],
[
"Out"
]],
"meshgrid"
:
[[
"X"
],
[
"Out"
]],
"roi_align"
:
[[
"X"
,
"ROIs"
],
[
"Out"
]],
"strided_slice"
:
[[
"Input"
],
[
"Out"
]],
"where"
:
[[
"Condition"
,
"X"
,
"Y"
],
[
"Out"
]],
"grid_sampler"
:
[[
"X"
,
"Grid"
],
[
"Output"
]],
"tile"
:
[[
"X"
],
[
"Out"
]],
"group_norm"
:
[[
"X"
],
[
"Y"
,
"Mean"
,
"Variance"
]],
"reduce_sum"
:
[[
"X"
],
[
"Out"
]],
"square"
:
[[
"X"
],
[
"Out"
]],
"softplus"
:
[[
"X"
],
[
"Out"
]],
"shuffle_channel"
:
[[
"X"
],
[
"Out"
]],
}
def
_get_op_input_var_names
(
op
):
"""
Get the input var names of the op.
Args:
op(IrNode, Operator): the input op.
Returns:
input_var_names or None.
"""
assert
isinstance
(
op
,
(
IrNode
,
Operator
)),
\
"The input op should be IrNode or Operator."
var_names
=
[]
op_name
=
op
.
name
()
if
isinstance
(
op
,
IrNode
)
\
else
op
.
type
if
op_name
not
in
_op_real_in_out_name
:
return
[]
name_list
=
_op_real_in_out_name
[
op_name
][
0
]
for
name
in
name_list
:
var_name
=
op
.
input
(
name
)
if
isinstance
(
var_name
,
list
):
var_names
.
extend
(
var_name
)
else
:
var_names
.
append
(
var_name
)
return
var_names
def
_get_op_output_var_names
(
op
):
""" """
assert
isinstance
(
op
,
(
IrNode
,
Operator
)),
\
"The input op should be IrNode or Operator."
var_names
=
[]
op_name
=
op
.
name
()
if
isinstance
(
op
,
IrNode
)
\
else
op
.
type
if
op_name
not
in
_op_real_in_out_name
:
return
[]
name_list
=
_op_real_in_out_name
[
op_name
][
1
]
for
name
in
name_list
:
var_name
=
op
.
output
(
name
)
if
isinstance
(
var_name
,
list
):
var_names
.
extend
(
var_name
)
else
:
var_names
.
append
(
var_name
)
return
var_names
def
_get_input_name_index
(
op
,
input_var_name
):
"""Get the input name and index of the var_name in the op"""
assert
isinstance
(
op
,
(
IrNode
,
Operator
)),
\
"The input op should be IrNode or Operator."
op_name
=
op
.
name
()
if
isinstance
(
op
,
IrNode
)
\
else
op
.
type
if
op_name
not
in
_op_real_in_out_name
:
return
None
res
=
None
for
argname
in
_op_real_in_out_name
[
op_name
][
0
]:
var_names
=
op
.
input
(
argname
)
for
index
,
name
in
enumerate
(
var_names
):
if
name
==
input_var_name
:
res
=
(
argname
,
index
)
return
res
def
_get_output_name_index
(
op
,
output_var_name
):
"""Get the output name and index of the var_name in the op"""
assert
isinstance
(
op
,
(
IrNode
,
Operator
)),
\
"The input op should be IrNode or Operator."
op_name
=
op
.
name
()
if
isinstance
(
op
,
IrNode
)
\
else
op
.
type
if
op_name
not
in
_op_real_in_out_name
:
return
None
name_list
=
_op_real_in_out_name
[
op_name
][
1
]
res
=
None
for
name
in
name_list
:
var_name
=
op
.
output
(
name
)
for
index
,
val
in
enumerate
(
var_name
):
if
val
==
output_var_name
:
res
=
(
name
,
index
)
return
res
def
load_variable_data
(
scope
,
var_name
):
'''
...
...
@@ -84,6 +365,46 @@ def dequant_tensor(x, scale, quant_axis=0, weight_bits=8):
return
x
def
bias_correction_w
(
x
,
x_quant
,
scale_v
,
quant_axis
,
weight_bits
=
8
):
'''
Bias correction for weight
'''
eps
=
1e-8
bnt
=
(
1
<<
(
weight_bits
-
1
))
-
1
x_dequant
=
x_quant
.
copy
()
if
isinstance
(
scale_v
,
list
):
if
quant_axis
==
0
:
for
i
,
s
in
enumerate
(
scale_v
):
x_dequant
[
i
]
=
x_dequant
[
i
]
*
s
/
bnt
quant_bias
=
x
-
x_dequant
mean_bias
=
quant_bias
.
reshape
(
quant_bias
.
shape
[
0
],
-
1
).
mean
(
-
1
)
std_orig
=
x
.
reshape
(
x
.
shape
[
0
],
-
1
).
std
(
-
1
)
std_quant
=
x_dequant
.
reshape
(
x_dequant
.
shape
[
0
],
-
1
).
std
(
-
1
)
std_bias
=
std_orig
/
(
std_quant
+
eps
)
else
:
for
i
,
s
in
enumerate
(
scale_v
):
x_dequant
[:,
i
]
=
x_quant
[:,
i
]
*
s
/
bnt
quant_bias
=
x
-
x_dequant
mean_bias
=
np
.
array
(
[
quant_bias
[:,
i
].
mean
()
for
i
in
range
(
quant_bias
.
shape
[
1
])])
std_orig
=
np
.
array
([
x
[:,
i
].
std
()
for
i
in
range
(
x
.
shape
[
1
])])
std_quant
=
np
.
array
(
[
x_dequant
[:,
i
].
std
()
for
i
in
range
(
x_dequant
.
shape
[
1
])])
std_bias
=
std_orig
/
(
std_quant
+
eps
)
else
:
x_dequant
=
x_quant
*
scale_v
/
bnt
mean_bias
=
(
x
-
x_dequant
).
mean
()
std_bias
=
x
.
std
()
/
(
x_dequant
.
std
()
+
eps
)
if
mean_bias
.
ndim
==
1
:
std_bias
=
np
.
resize
(
std_bias
,
x
.
shape
)
mean_bias
=
np
.
resize
(
mean_bias
,
x
.
shape
)
x_dequant
=
(
mean_bias
+
x_dequant
)
*
std_bias
quantized_param_v
=
quant_tensor
(
x_dequant
,
scale_v
,
quant_axis
,
weight_bits
)
return
quantized_param_v
def
stable_sigmoid
(
x
):
sig
=
np
.
where
(
x
<
0
,
np
.
exp
(
x
)
/
(
1
+
np
.
exp
(
x
)),
1
/
(
1
+
np
.
exp
(
-
x
)))
return
sig
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
浏览文件 @
b72a7ebb
...
...
@@ -53,7 +53,9 @@ class TestImperativeQat(unittest.TestCase):
def
set_vars
(
self
):
self
.
weight_quantize_type
=
'abs_max'
self
.
activation_quantize_type
=
'moving_average_abs_max'
print
(
'weight_quantize_type'
,
self
.
weight_quantize_type
)
self
.
onnx_format
=
False
self
.
check_export_model_accuracy
=
True
self
.
diff_threshold
=
0.01
def
func_qat
(
self
):
self
.
set_vars
()
...
...
@@ -159,9 +161,13 @@ class TestImperativeQat(unittest.TestCase):
data
=
next
(
test_reader
())
test_data
=
np
.
array
([
x
[
0
].
reshape
(
1
,
28
,
28
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
(
[
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
-
1
,
1
)
test_img
=
fluid
.
dygraph
.
to_variable
(
test_data
)
label
=
fluid
.
dygraph
.
to_variable
(
y_data
)
lenet
.
eval
()
before_save
=
lenet
(
test_img
)
fp32_out
=
lenet
(
test_img
)
fp32_acc
=
fluid
.
layers
.
accuracy
(
fp32_out
,
label
).
numpy
()
with
tempfile
.
TemporaryDirectory
(
prefix
=
"qat_save_path_"
)
as
tmpdir
:
# save inference quantized model
...
...
@@ -171,7 +177,8 @@ class TestImperativeQat(unittest.TestCase):
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
1
,
28
,
28
],
dtype
=
'float32'
)
])
],
onnx_format
=
self
.
onnx_format
)
print
(
'Quantized model saved in %s'
%
tmpdir
)
if
core
.
is_compiled_with_cuda
():
...
...
@@ -185,13 +192,15 @@ class TestImperativeQat(unittest.TestCase):
executor
=
exe
,
model_filename
=
"lenet"
+
INFER_MODEL_SUFFIX
,
params_filename
=
"lenet"
+
INFER_PARAMS_SUFFIX
)
after_save
,
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
test_data
},
fetch_list
=
fetch_targets
)
# check
self
.
assertTrue
(
np
.
allclose
(
after_save
,
before_save
.
numpy
()),
msg
=
'Failed to save the inference quantized model.'
)
quant_out
,
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
test_data
},
fetch_list
=
fetch_targets
)
paddle
.
disable_static
()
quant_out
=
fluid
.
dygraph
.
to_variable
(
quant_out
)
quant_acc
=
fluid
.
layers
.
accuracy
(
quant_out
,
label
).
numpy
()
paddle
.
enable_static
()
delta_value
=
fp32_acc
-
quant_acc
self
.
assertLess
(
delta_value
,
self
.
diff_threshold
)
def
test_qat
(
self
):
with
_test_eager_guard
():
...
...
@@ -199,5 +208,13 @@ class TestImperativeQat(unittest.TestCase):
self
.
func_qat
()
class
TestImperativeQatONNXFormat
(
unittest
.
TestCase
):
def
set_vars
(
self
):
self
.
weight_quantize_type
=
'abs_max'
self
.
activation_quantize_type
=
'moving_average_abs_max'
self
.
onnx_format
=
True
self
.
diff_threshold
=
0.025
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py
浏览文件 @
b72a7ebb
...
...
@@ -41,6 +41,17 @@ class TestImperativeQatChannelWise(TestImperativeQat):
def
set_vars
(
self
):
self
.
weight_quantize_type
=
'channel_wise_abs_max'
self
.
activation_quantize_type
=
'moving_average_abs_max'
self
.
diff_threshold
=
0.01
self
.
onnx_format
=
False
print
(
'weight_quantize_type'
,
self
.
weight_quantize_type
)
class
TestImperativeQatChannelWiseONNXFormat
(
TestImperativeQat
):
def
set_vars
(
self
):
self
.
weight_quantize_type
=
'channel_wise_abs_max'
self
.
activation_quantize_type
=
'moving_average_abs_max'
self
.
onnx_format
=
True
self
.
diff_threshold
=
0.025
print
(
'weight_quantize_type'
,
self
.
weight_quantize_type
)
...
...
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_lstm_model.py
浏览文件 @
b72a7ebb
...
...
@@ -173,7 +173,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file
=
False
,
is_optimize_model
=
False
,
batch_size
=
10
,
batch_nums
=
10
):
batch_nums
=
10
,
onnx_format
=
False
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -190,14 +191,28 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
is_use_cache_file
=
is_use_cache_file
)
ptq
.
quantize
()
ptq
.
save_quantized_model
(
self
.
int8_model_path
)
def
run_test
(
self
,
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
infer_iterations
,
quant_iterations
):
def
run_test
(
self
,
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
infer_iterations
,
quant_iterations
,
onnx_format
=
False
):
fp32_model_path
=
self
.
download_model
(
model_url
,
model_md5
,
model_name
)
fp32_model_path
=
os
.
path
.
join
(
fp32_model_path
,
model_name
)
...
...
@@ -211,10 +226,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start post training quantization for {0} on {1} samples ..."
.
format
(
model_name
,
quant_iterations
))
self
.
generate_quantized_model
(
fp32_model_path
,
data_path
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
quant_iterations
)
self
.
generate_quantized_model
(
fp32_model_path
,
data_path
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
quant_iterations
,
onnx_format
)
print
(
"Start INT8 inference for {0} on {1} samples ..."
.
format
(
model_name
,
infer_iterations
))
...
...
@@ -278,5 +293,42 @@ class TestPostTrainingKLForMnistAdaround(TestPostTrainingQuantization):
diff_threshold
,
infer_iterations
,
quant_iterations
)
class
TestPostTrainingKLForMnistONNXFormat
(
TestPostTrainingQuantization
):
def
test_post_training_kl_onnx_format
(
self
):
model_name
=
"nlp_lstm_fp32_model"
model_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/nlp_lstm_fp32_model.tar.gz"
model_md5
=
"519b8eeac756e7b4b7bcb2868e880452"
data_name
=
"quant_lstm_input_data"
data_url
=
"https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/quant_lstm_input_data.tar.gz"
data_md5
=
"add84c754e9b792fea1fbd728d134ab7"
algo
=
"KL"
round_type
=
"round"
quantizable_op_type
=
[
"mul"
,
"lstm"
]
is_full_quantize
=
False
is_use_cache_file
=
False
is_optimize_model
=
False
diff_threshold
=
0.01
infer_iterations
=
100
quant_iterations
=
10
onnx_format
=
True
self
.
run_test
(
model_name
,
model_url
,
model_md5
,
data_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
infer_iterations
,
quant_iterations
,
onnx_format
=
onnx_format
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mnist.py
浏览文件 @
b72a7ebb
...
...
@@ -116,7 +116,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_use_cache_file
=
False
,
is_optimize_model
=
False
,
batch_size
=
10
,
batch_nums
=
10
):
batch_nums
=
10
,
onnx_format
=
False
):
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
...
...
@@ -134,6 +135,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
is_use_cache_file
=
is_use_cache_file
)
ptq
.
quantize
()
ptq
.
save_quantized_model
(
self
.
int8_model_path
)
...
...
@@ -151,7 +153,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
diff_threshold
,
batch_size
=
10
,
infer_iterations
=
10
,
quant_iterations
=
5
):
quant_iterations
=
5
,
onnx_format
=
False
):
origin_model_path
=
self
.
download_model
(
data_url
,
data_md5
,
model_name
)
origin_model_path
=
os
.
path
.
join
(
origin_model_path
,
model_name
)
...
...
@@ -166,7 +169,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
self
.
generate_quantized_model
(
origin_model_path
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
batch_size
,
quant_iterations
)
batch_size
,
quant_iterations
,
onnx_format
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model_name
,
infer_iterations
*
batch_size
))
...
...
@@ -335,5 +338,72 @@ class TestPostTrainingmseAdaroundForMnist(TestPostTrainingQuantization):
infer_iterations
,
quant_iterations
)
class
TestPostTrainingmseForMnistONNXFormat
(
TestPostTrainingQuantization
):
def
test_post_training_mse_onnx_format
(
self
):
model_name
=
"mnist_model"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
is_optimize_model
=
True
onnx_format
=
True
diff_threshold
=
0.01
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
,
onnx_format
=
onnx_format
)
class
TestPostTrainingmseForMnistONNXFormatFullQuant
(
TestPostTrainingQuantization
):
def
test_post_training_mse_onnx_format_full_quant
(
self
):
model_name
=
"mnist_model"
data_url
=
"http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5
=
"be71d3997ec35ac2a65ae8a145e2887c"
algo
=
"mse"
round_type
=
"round"
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
]
is_full_quantize
=
True
is_use_cache_file
=
False
is_optimize_model
=
False
onnx_format
=
True
diff_threshold
=
0.01
batch_size
=
10
infer_iterations
=
50
quant_iterations
=
5
self
.
run_test
(
model_name
,
data_url
,
data_md5
,
algo
,
round_type
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
batch_size
,
infer_iterations
,
quant_iterations
,
onnx_format
=
onnx_format
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py
浏览文件 @
b72a7ebb
...
...
@@ -243,7 +243,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type
=
"round"
,
is_full_quantize
=
False
,
is_use_cache_file
=
False
,
is_optimize_model
=
False
):
is_optimize_model
=
False
,
onnx_format
=
False
):
try
:
os
.
system
(
"mkdir "
+
self
.
int8_model
)
except
Exception
as
e
:
...
...
@@ -265,13 +266,23 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type
=
round_type
,
is_full_quantize
=
is_full_quantize
,
optimize_model
=
is_optimize_model
,
onnx_format
=
onnx_format
,
is_use_cache_file
=
is_use_cache_file
)
ptq
.
quantize
()
ptq
.
save_quantized_model
(
self
.
int8_model
)
def
run_test
(
self
,
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
):
def
run_test
(
self
,
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
onnx_format
=
False
):
infer_iterations
=
self
.
infer_iterations
batch_size
=
self
.
batch_size
sample_iterations
=
self
.
sample_iterations
...
...
@@ -285,9 +296,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
print
(
"Start INT8 post training quantization for {0} on {1} images ..."
.
format
(
model
,
sample_iterations
*
batch_size
))
self
.
generate_quantized_model
(
model_cache_folder
+
"/model"
,
quantizable_op_type
,
algo
,
round_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
)
self
.
generate_quantized_model
(
model_cache_folder
+
"/model"
,
quantizable_op_type
,
algo
,
round_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
onnx_format
)
print
(
"Start INT8 inference for {0} on {1} images ..."
.
format
(
model
,
infer_iterations
*
batch_size
))
...
...
@@ -517,5 +529,38 @@ class TestPostTrainingEMDForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model
,
diff_threshold
)
class
TestPostTrainingAvgONNXFormatForMobilenetv1
(
TestPostTrainingQuantization
):
def
test_post_training_onnx_format_mobilenetv1
(
self
):
model
=
"MobileNet-V1"
algo
=
"avg"
round_type
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s
=
[
'13892b0716d26443a8cdea15b3c6438b'
]
quantizable_op_type
=
[
"conv2d"
,
"depthwise_conv2d"
,
"mul"
,
]
is_full_quantize
=
False
is_use_cache_file
=
False
is_optimize_model
=
True
onnx_format
=
True
diff_threshold
=
0.05
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
onnx_format
=
onnx_format
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_resnet50.py
浏览文件 @
b72a7ebb
...
...
@@ -39,5 +39,34 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
is_optimize_model
,
diff_threshold
)
class
TestPostTrainingForResnet50ONNXFormat
(
TestPostTrainingQuantization
):
def
test_post_training_resnet50
(
self
):
model
=
"ResNet-50"
algo
=
"min_max"
round_type
=
"round"
data_urls
=
[
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s
=
[
'4a5194524823d9b76da6e738e1367881'
]
quantizable_op_type
=
[
"conv2d"
,
"mul"
]
is_full_quantize
=
False
is_use_cache_file
=
False
is_optimize_model
=
False
diff_threshold
=
0.025
onnx_format
=
True
self
.
run_test
(
model
,
algo
,
round_type
,
data_urls
,
data_md5s
,
quantizable_op_type
,
is_full_quantize
,
is_use_cache_file
,
is_optimize_model
,
diff_threshold
,
onnx_format
=
onnx_format
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
b72a7ebb
...
...
@@ -21,6 +21,7 @@ import six
import
paddle
from
paddle.fluid.framework
import
IrGraph
from
paddle.fluid.contrib.slim.quantization
import
QuantizationTransformPass
from
paddle.fluid.contrib.slim.quantization
import
QuantizationTransformPassV2
from
paddle.fluid.contrib.slim.quantization
import
QuantizationFreezePass
from
paddle.fluid.contrib.slim.quantization
import
ConvertToInt8Pass
from
paddle.fluid.contrib.slim.quantization
import
TransformForMobilePass
...
...
@@ -686,5 +687,129 @@ class TestAddQuantDequantPass(unittest.TestCase):
for_ci
=
True
)
class
TestQuantizationTransformPassV2
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
quantizable_op_and_inputs
=
{
'conv2d'
:
[
'Input'
,
'Filter'
],
'depthwise_conv2d'
:
[
'Input'
,
'Filter'
],
'mul'
:
[
'X'
,
'Y'
]
}
self
.
quantizable_grad_op_inputs
=
{
'conv2d_grad'
:
[
'Input'
,
'Filter'
],
'depthwise_conv2d_grad'
:
[
'Input'
,
'Filter'
],
'mul_grad'
:
[
'X'
,
'Y'
]
}
def
check_program
(
self
,
program
):
quantized_ops
=
set
()
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
# check forward
if
op
.
type
in
self
.
quantizable_op_and_inputs
:
for
arg_name
in
op
.
input_arg_names
:
self
.
assertTrue
(
arg_name
.
endswith
(
'.quantized.dequantized'
))
quantized_ops
.
add
(
arg_name
)
for
op
in
block
.
ops
:
# check backward
if
op
.
type
in
self
.
quantizable_grad_op_inputs
:
for
pname
in
self
.
quantizable_grad_op_inputs
[
op
.
type
]:
arg_name
=
op
.
input
(
pname
)[
0
]
self
.
assertTrue
(
arg_name
.
endswith
(
'.quantized.dequantized'
))
self
.
assertTrue
(
arg_name
in
quantized_ops
)
def
linear_fc_quant
(
self
,
activation_quant_type
,
weight_quantize_type
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
linear_fc
(
3
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
transform_pass
=
QuantizationTransformPassV2
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quantize_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_fc_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
if
not
for_ci
:
val_marked_nodes
=
set
()
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_fc_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_linear_fc_quant_abs_max
(
self
):
self
.
linear_fc_quant
(
'abs_max'
,
'abs_max'
,
for_ci
=
True
)
def
test_linear_fc_quant_channel_wise_abs_max
(
self
):
self
.
linear_fc_quant
(
'abs_max'
,
'channel_wise_abs_max'
,
for_ci
=
True
)
def
residual_block_quant
(
self
,
activation_quant_type
,
weight_quantize_type
,
quantizable_op_type
,
for_ci
=
True
):
main
=
fluid
.
Program
()
startup
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
,
startup
):
loss
=
residual_block
(
2
)
opt
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
0.001
)
opt
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
graph
=
IrGraph
(
core
.
Graph
(
main
.
desc
),
for_test
=
False
)
transform_pass
=
QuantizationTransformPass
(
scope
=
fluid
.
global_scope
(),
place
=
place
,
activation_quantize_type
=
activation_quant_type
,
weight_quantize_type
=
weight_quantize_type
,
quantizable_op_type
=
quantizable_op_type
)
transform_pass
.
apply
(
graph
)
if
not
for_ci
:
marked_nodes
=
set
()
for
op
in
graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
marked_nodes
.
add
(
op
)
graph
.
draw
(
'.'
,
'quantize_residual_'
+
activation_quant_type
,
marked_nodes
)
program
=
graph
.
to_program
()
self
.
check_program
(
program
)
val_graph
=
IrGraph
(
core
.
Graph
(
program
.
desc
),
for_test
=
False
)
if
not
for_ci
:
val_marked_nodes
=
set
()
for
op
in
val_graph
.
all_op_nodes
():
if
op
.
name
().
find
(
'quantize'
)
>
-
1
:
val_marked_nodes
.
add
(
op
)
val_graph
.
draw
(
'.'
,
'val_residual_'
+
activation_quant_type
,
val_marked_nodes
)
def
test_residual_block_abs_max
(
self
):
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
,
'matmul'
]
self
.
residual_block_quant
(
'abs_max'
,
'abs_max'
,
quantizable_op_type
,
for_ci
=
True
)
def
test_residual_block_channel_wise_abs_max
(
self
):
quantizable_op_type
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
,
'matmul'
]
self
.
residual_block_quant
(
'abs_max'
,
'channel_wise_abs_max'
,
quantizable_op_type
,
for_ci
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
浏览文件 @
b72a7ebb
...
...
@@ -172,5 +172,83 @@ class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp):
self
.
data_type
=
"float32"
class
TestChannelWiseDequantizeOp
(
OpTest
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
data_type
=
"float32"
self
.
quant_axis
=
0
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"dequantize_linear"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scale
=
channel_wise_quantize_max_abs
(
x
,
self
.
bit_length
,
self
.
quant_axis
)
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scale
,
self
.
bit_length
,
self
.
quant_axis
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
print
(
'TestChannelWiseDequantizeOp:'
)
self
.
inputs
=
{
'X'
:
yq
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
}
self
.
outputs
=
{
'Y'
:
ydq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestChannelWiseDequantizeOp1
(
TestChannelWiseDequantizeOp
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
data_type
=
"float32"
self
.
quant_axis
=
1
class
TestDequantizeOp
(
OpTest
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
quant_axis
=
-
1
self
.
max_range
=
math
.
pow
(
2
,
self
.
bit_length
-
1
)
-
1
self
.
data_type
=
"float32"
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"dequantize_linear"
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
data_type
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
max_range
)
ydq
=
dequantize_max_abs
(
yq
,
scale
,
self
.
max_range
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
self
.
inputs
=
{
'X'
:
yq
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
}
self
.
outputs
=
{
'Y'
:
ydq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestDequantizeOpDouble
(
TestDequantizeOp
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
max_range
=
math
.
pow
(
2
,
self
.
bit_length
-
1
)
-
1
self
.
data_type
=
"float64"
self
.
quant_axis
=
-
1
class
TestDequantizeOp5Bits
(
TestDequantizeOp
):
def
set_args
(
self
):
self
.
bit_length
=
5
self
.
max_range
=
math
.
pow
(
2
,
self
.
bit_length
-
1
)
-
1
self
.
data_type
=
"float32"
self
.
quant_axis
=
-
1
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
b72a7ebb
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
math
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
...
...
@@ -374,5 +375,144 @@ class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
30
,
15
)).
astype
(
"float32"
),
}
def
quantize_max_abs
(
x
,
max_range
):
scale
=
np
.
max
(
np
.
abs
(
x
).
flatten
())
y
=
np
.
round
(
x
/
scale
*
max_range
)
return
y
,
scale
def
channel_wise_quantize_max_abs
(
x
,
quant_bit
=
8
,
quant_axis
=
0
):
assert
quant_axis
in
[
0
,
1
],
"The quant_axis should be 0 or 1."
scales
=
[]
y
=
x
.
copy
()
max_range
=
math
.
pow
(
2
,
quant_bit
-
1
)
-
1
if
quant_axis
==
0
:
for
i
in
range
(
x
.
shape
[
0
]):
scale
=
np
.
max
(
np
.
abs
(
x
[
i
])).
astype
(
"float32"
)
scales
.
append
(
scale
)
y
[
i
]
=
np
.
round
(
x
[
i
]
*
max_range
/
scale
)
elif
quant_axis
==
1
:
for
i
in
range
(
x
.
shape
[
1
]):
scale
=
np
.
max
(
np
.
abs
(
x
[:,
i
])).
astype
(
"float32"
)
scales
.
append
(
scale
)
y
[:,
i
]
=
np
.
round
(
x
[:,
i
]
*
max_range
/
scale
)
return
y
,
scales
class
TestChannelWiseQuantizeOp
(
OpTest
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
data_type
=
"float32"
self
.
quant_axis
=
0
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"quantize_linear"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scale
=
channel_wise_quantize_max_abs
(
x
,
self
.
bit_length
,
self
.
quant_axis
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
self
.
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
}
self
.
outputs
=
{
'Y'
:
yq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestChannelWiseQuantizeOp1
(
TestChannelWiseQuantizeOp
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
data_type
=
"float32"
self
.
quant_axis
=
1
class
TestChannelWiseQuantizeOpTrain
(
OpTest
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
data_type
=
"float32"
self
.
quant_axis
=
0
self
.
is_test
=
False
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"quantize_linear"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
data_type
)
yq
,
scale
=
channel_wise_quantize_max_abs
(
x
,
self
.
bit_length
,
self
.
quant_axis
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
self
.
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
,
'is_test'
:
self
.
is_test
}
self
.
outputs
=
{
'Y'
:
yq
,
'OutScale'
:
scale
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestquantizeOp
(
OpTest
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
quant_axis
=
-
1
self
.
max_range
=
math
.
pow
(
2
,
self
.
bit_length
-
1
)
-
1
self
.
data_type
=
"float32"
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"quantize_linear"
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
data_type
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
max_range
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
self
.
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
,
}
self
.
outputs
=
{
'Y'
:
yq
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestquantizeOpTrain
(
TestquantizeOp
):
def
set_args
(
self
):
self
.
bit_length
=
8
self
.
quant_axis
=
-
1
self
.
max_range
=
math
.
pow
(
2
,
self
.
bit_length
-
1
)
-
1
self
.
data_type
=
"float32"
self
.
is_test
=
False
def
setUp
(
self
):
self
.
set_args
()
self
.
op_type
=
"quantize_linear"
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
data_type
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
max_range
)
scale
=
np
.
array
(
scale
).
astype
(
self
.
data_type
)
zero_point
=
np
.
zeros
(
scale
.
shape
,
dtype
=
"int32"
)
self
.
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
,
'ZeroPoint'
:
zero_point
}
self
.
attrs
=
{
'bit_length'
:
self
.
bit_length
,
'quant_axis'
:
self
.
quant_axis
,
'is_test'
:
self
.
is_test
}
self
.
outputs
=
{
'Y'
:
yq
,
'OutScale'
:
scale
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录