Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5d422287
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
5d422287
编写于
4月 19, 2022
作者:
L
Leo Chen
提交者:
GitHub
4月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add float16 to fake quantize/dequantize OP (#40664)
上级
7ce0ee69
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
222 addition
and
76 deletion
+222
-76
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+5
-2
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+10
-5
paddle/fluid/operators/fake_quantize_op.cu.h
paddle/fluid/operators/fake_quantize_op.cu.h
+34
-12
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
...n/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
+58
-17
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+115
-40
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
5d422287
...
...
@@ -17,10 +17,13 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
float16
=
paddle
::
platform
::
float16
;
REGISTER_OP_CUDA_KERNEL
(
fake_dequantize_max_abs
,
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
float
>
,
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
double
>
);
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
double
>
,
ops
::
FakeDequantizeMaxAbsKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_dequantize_max_abs
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
float
>
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
double
>
);
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
double
>
,
ops
::
FakeChannelWiseDequantizeMaxAbsKernel
<
CUDA
,
float16
>
);
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
5d422287
...
...
@@ -19,17 +19,22 @@ namespace ops = paddle::operators;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
using
float16
=
paddle
::
platform
::
float16
;
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_abs_max
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_dequantize_abs_max
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeDequantizeAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_channel_wise_quantize_abs_max
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeChannelWiseQuantizeAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_range_abs_max
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeRangeAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fake_quantize_moving_average_abs_max
,
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
);
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CUDA
,
float
>
,
ops
::
FakeQuantizeMovingAverageAbsMaxKernel
<
CUDA
,
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
moving_average_abs_max_scale
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float
>
,
ops
::
MovingAverageAbsMaxScaleKernel
<
CUDA
,
float16
>
);
...
...
paddle/fluid/operators/fake_quantize_op.cu.h
浏览文件 @
5d422287
...
...
@@ -24,6 +24,16 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
QuantizeDataType
{
using
type
=
T
;
};
template
<
>
struct
QuantizeDataType
<
paddle
::
platform
::
float16
>
{
using
type
=
float
;
};
template
<
typename
T
>
__global__
void
FindAbsMaxKernel
(
const
T
*
in
,
const
int
n
,
T
*
out
)
{
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
...
...
@@ -87,10 +97,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
extern
__shared__
T
shared_max_data
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
T
local_max_data
=
T
(
0
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
tmp
=
fabs
(
in_c
[
i
]);
T
tmp
=
static_cast
<
T
>
(
fabs
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
in_c
[
i
])));
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
...
...
@@ -112,7 +124,8 @@ template <typename T>
__global__
void
FindChannelAbsMaxKernelQuantAxis1
(
const
T
*
in
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
extern
__shared__
T
shared_max_data
[];
extern
__shared__
char
*
shared_max_data_tmp
[];
auto
shared_max_data
=
reinterpret_cast
<
T
*>
(
shared_max_data_tmp
);
int
cout_wh_size
=
n
/
cin
;
int
wh_size
=
n
/
(
cin
*
cout
);
...
...
@@ -121,7 +134,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
const
T
*
in_current
=
in
+
tid
*
cout_wh_size
+
bid
*
wh_size
;
T
local_max_data
=
T
(
0
);
for
(
int
i
=
0
;
i
<
wh_size
;
i
++
)
{
T
tmp
=
fabs
(
in_current
[
i
]);
T
tmp
=
static_cast
<
T
>
(
fabs
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
in_current
[
i
])));
if
(
tmp
>
local_max_data
)
{
local_max_data
=
tmp
;
}
...
...
@@ -205,12 +219,14 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
}
}
...
...
@@ -230,7 +246,8 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
static_cast
<
T
>
(
round
(
static_cast
<
float
>
(
x
)));
x
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
x
)));
out
[
i
]
=
(
x
*
s
)
/
bin_cnt_t
;
}
}
...
...
@@ -287,13 +304,15 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out_c
[
i
]
=
round
(
v
);
v
=
bin_cnt_t
*
inv_s
*
v
;
out_c
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
}
}
...
...
@@ -303,14 +322,16 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
for
(
int64_t
i
=
idx
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
nScale
];
T
inv_s
=
inverse
(
s
);
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
[
i
]
=
round
(
v
);
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
}
}
...
...
@@ -376,7 +397,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
scale_arr
[
idx
]
=
cur
;
T
max
=
last_scale
[
0
];
out_scale
[
0
]
=
max
<
cur
?
cur
:
max
;
if
(
fabs
(
removed
-
max
)
<
1e-6
)
{
if
(
fabs
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
removed
-
max
))
<
1e-6
)
{
need_find_max
[
0
]
=
1
;
out_size
[
0
]
=
it
>
window_size
?
window_size
:
it
;
}
else
{
...
...
python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py
浏览文件 @
5d422287
...
...
@@ -18,6 +18,7 @@ import unittest
import
numpy
as
np
import
math
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
def
quantize_max_abs
(
x
,
max_range
):
...
...
@@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x,
class
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales
(
OpTest
):
def
set_args
(
self
):
self
.
quant_bits
=
[
8
,
8
]
self
.
data_type
=
"float32"
self
.
activation_scale
=
0.7861
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
setUp
(
self
):
self
.
set_args
()
self
.
set_dtype
()
self
.
op_type
=
"fake_channel_wise_dequantize_max_abs"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
d
ata_
type
)
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
dtype
)
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
],
1
)
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scales
,
self
.
quant_bits
,
1
,
self
.
activation_scale
)
self
.
inputs
=
{
'X'
:
yq
,
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
d
ata_
type
)),
(
"scales1"
,
np
.
array
(
[
self
.
activation_scale
]).
astype
(
self
.
data_
type
))]
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
dtype
)),
(
"scales1"
,
np
.
array
([
self
.
activation_scale
]).
astype
(
self
.
d
type
))]
}
self
.
attrs
=
{
'quant_bits'
:
self
.
quant_bits
}
self
.
outputs
=
{
'Out'
:
ydq
}
...
...
@@ -100,16 +104,28 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest):
self
.
check_output
()
class
TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16
(
TestFakeChannelWiseDequantizeMaxAbsOpTwoScales
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
class
TestFakeChannelWiseDequantizeMaxAbsOpOneScale
(
OpTest
):
def
set_args
(
self
):
self
.
quant_bits
=
[
8
]
self
.
data_type
=
"float32"
self
.
quant_axis
=
0
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
setUp
(
self
):
self
.
set_args
()
self
.
set_dtype
()
self
.
op_type
=
"fake_channel_wise_dequantize_max_abs"
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
d
ata_
type
)
x
=
np
.
random
.
randn
(
4
,
3
,
64
,
64
).
astype
(
self
.
dtype
)
yq
,
scales
=
channel_wise_quantize_max_abs
(
x
,
self
.
quant_bits
[
0
],
self
.
quant_axis
)
ydq
=
channel_wise_dequantize_max_abs
(
yq
,
scales
,
self
.
quant_bits
,
...
...
@@ -117,7 +133,7 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest):
self
.
inputs
=
{
'X'
:
yq
,
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
d
ata_
type
))]
'Scales'
:
[(
"scales0"
,
np
.
array
(
scales
).
astype
(
self
.
dtype
))]
}
self
.
attrs
=
{
'quant_bits'
:
self
.
quant_bits
,
...
...
@@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale
):
def
set_args
(
self
):
self
.
quant_bits
=
[
8
]
self
.
data_type
=
"float32"
self
.
quant_axis
=
1
class
TestFakeChannelWiseDequantizeMaxAbsOpOneScaleFloat16
(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
class
TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16
(
TestFakeChannelWiseDequantizeMaxAbsOpOneScale1
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
class
TestFakeDequantizeMaxAbsOp
(
OpTest
):
def
set_args
(
self
):
self
.
num_bits
=
8
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float32"
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
setUp
(
self
):
self
.
set_args
()
self
.
set_dtype
()
self
.
op_type
=
"fake_dequantize_max_abs"
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
d
ata_
type
)
x
=
np
.
random
.
randn
(
31
,
65
).
astype
(
self
.
dtype
)
yq
,
scale
=
quantize_max_abs
(
x
,
self
.
max_range
)
ydq
=
dequantize_max_abs
(
yq
,
scale
,
self
.
max_range
)
self
.
inputs
=
{
'X'
:
yq
,
'Scale'
:
np
.
array
(
scale
).
astype
(
self
.
d
ata_
type
)}
self
.
inputs
=
{
'X'
:
yq
,
'Scale'
:
np
.
array
(
scale
).
astype
(
self
.
dtype
)}
self
.
attrs
=
{
'max_range'
:
self
.
max_range
}
self
.
outputs
=
{
'Out'
:
ydq
}
...
...
@@ -159,17 +195,22 @@ class TestFakeDequantizeMaxAbsOp(OpTest):
class
TestFakeDequantizeMaxAbsOpDouble
(
TestFakeDequantizeMaxAbsOp
):
def
set_args
(
self
):
self
.
num_bits
=
8
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float64"
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float64
class
TestFakeDequantizeMaxAbsOp5Bits
(
TestFakeDequantizeMaxAbsOp
):
def
set_args
(
self
):
self
.
num_bits
=
5
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"float32"
class
TestFakeDequantizeMaxAbsOpFloat16
(
TestFakeDequantizeMaxAbsOp
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
class
TestChannelWiseDequantizeOp
(
OpTest
):
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
5d422287
...
...
@@ -15,28 +15,51 @@
from
__future__
import
print_function
import
unittest
import
math
import
numpy
as
np
import
math
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
# numpy.round has different behavior in comparision to c++ round function
# so we use round_c instead of numpy.round to align the output data
def
round_c_single_element
(
x
):
dtype
=
type
(
x
)
if
x
>=
0
:
return
dtype
(
np
.
floor
(
x
+
0.5
))
else
:
return
dtype
(
np
.
ceil
(
x
-
0.5
))
round_c
=
np
.
vectorize
(
round_c_single_element
)
class
TestFakeQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
op_type
=
"fake_quantize_abs_max"
self
.
attrs
=
{
'bit_length'
:
8
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
124
,
240
)).
astype
(
"float32"
),
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
124
,
240
)).
astype
(
self
.
dtype
),
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
self
.
dtype
)
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
scale
*
(
'Out'
:
round_c
(
self
.
inputs
[
'X'
]
/
scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'OutScale'
:
np
.
array
(
scale
).
astype
(
"float32"
),
'OutScale'
:
np
.
array
(
scale
).
astype
(
self
.
dtype
),
}
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeOpFloat16
(
TestFakeQuantizeOp
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestFakeQuantizeOp1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fake_quantize_abs_max"
...
...
@@ -73,6 +96,7 @@ class TestFakeQuantizeOp2(OpTest):
class
TestFakeChannelWiseQuantizeOp
(
OpTest
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
set_arg
()
assert
self
.
quant_axis
in
[
0
,
1
],
"quant_axis should be 0 or 1."
...
...
@@ -84,53 +108,70 @@ class TestFakeChannelWiseQuantizeOp(OpTest):
bnt
=
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
if
self
.
quant_axis
==
0
:
for
i
in
range
(
self
.
inputs
[
'X'
].
shape
[
0
]):
scale_v
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
][
i
])).
astype
(
"float32"
)
scale_v
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
][
i
])).
astype
(
self
.
dtype
)
scales
.
append
(
scale_v
)
outputs
[
i
]
=
np
.
round
(
outputs
[
i
]
/
scale_v
*
bnt
)
outputs
[
i
]
=
round_c
(
self
.
dtype
(
bnt
)
*
(
self
.
dtype
(
1.0
)
/
scale_v
)
*
outputs
[
i
])
elif
self
.
quant_axis
==
1
:
for
i
in
range
(
self
.
inputs
[
'X'
].
shape
[
1
]):
scale_v
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
][:,
i
])).
astype
(
"float32"
)
self
.
dtype
)
scales
.
append
(
scale_v
)
outputs
[:,
i
]
=
np
.
round
(
outputs
[:,
i
]
/
scale_v
*
bnt
)
outputs
[:,
i
]
=
round_c
(
self
.
dtype
(
bnt
)
*
(
self
.
dtype
(
1.0
)
/
scale_v
)
*
outputs
[:,
i
])
self
.
outputs
=
{
'Out'
:
outputs
,
'OutScale'
:
np
.
array
(
scales
).
astype
(
"float32"
),
'OutScale'
:
np
.
array
(
scales
).
astype
(
self
.
dtype
),
}
def
set_arg
(
self
):
self
.
quant_axis
=
0
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
20
,
15
,
6
,
6
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
20
,
15
,
6
,
6
)).
astype
(
self
.
dtype
),
}
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeChannelWiseQuantizeOpFloat16
(
TestFakeChannelWiseQuantizeOp
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestFakeChannelWiseQuantizeOp1
(
TestFakeChannelWiseQuantizeOp
):
def
set_quant_axis
(
self
):
self
.
quant_axis
=
1
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
15
,
20
,
5
,
5
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
15
,
20
,
5
,
5
)).
astype
(
self
.
dtype
),
}
class
TestFakeChannelWiseQuantizeOp1Float16
(
TestFakeChannelWiseQuantizeOp1
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestFakeChannelWiseQuantizeOp2
(
TestFakeChannelWiseQuantizeOp
):
def
set_quant_axis
(
self
):
self
.
quant_axis
=
0
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
30
,
15
)).
astype
(
"float32"
),
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
30
,
15
)).
astype
(
self
.
dtype
),
}
class
TestFakeChannelWiseQuantizeOp3
(
TestFakeChannelWiseQuantizeOp
):
def
set_quant_axis
(
self
):
self
.
quant_axis
=
1
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
30
,
15
)).
astype
(
"float32"
),
}
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
30
,
15
)).
astype
(
self
.
dtype
),
}
class
TestFakeQuantizeRangeAbsMaxOp
(
OpTest
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
op_type
=
"fake_quantize_range_abs_max"
self
.
attrs
=
{
'bit_length'
:
int
(
5
),
...
...
@@ -138,27 +179,36 @@ class TestFakeQuantizeRangeAbsMaxOp(OpTest):
'is_test'
:
False
}
x
=
(
np
.
random
.
random
((
8
,
16
,
7
,
7
))
-
0.5
)
*
10
x
=
x
.
astype
(
"float32"
)
x
=
x
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
x
,
'Iter'
:
np
.
zeros
(
1
).
astype
(
"int64"
),
'InScale'
:
np
.
zeros
(
1
).
astype
(
"float32"
)
'InScale'
:
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
}
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
scale
=
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
self
.
dtype
)
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
)
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
self
.
dtype
)
out_scales
[
0
]
=
scale
self
.
outputs
=
{
'Out'
:
np
.
round
(
self
.
inputs
[
'X'
]
/
scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)),
'Out'
:
round_c
(
self
.
dtype
((
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)
*
(
self
.
dtype
(
1.0
)
/
scale
)
*
self
.
inputs
[
'X'
]),
'OutScale'
:
scale
,
'OutScales'
:
out_scales
,
}
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFakeQuantizeRangeAbsMaxOpFloat16
(
TestFakeQuantizeRangeAbsMaxOp
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestMovingAverageAbsMaxScaleOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"moving_average_abs_max_scale"
...
...
@@ -195,6 +245,7 @@ class TestMovingAverageAbsMaxScaleOp(OpTest):
class
TestFakeQuantizeRangeAbsMaxOp2
(
OpTest
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
op_type
=
"fake_quantize_range_abs_max"
self
.
attrs
=
{
'bit_length'
:
int
(
8
),
...
...
@@ -202,55 +253,68 @@ class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
'is_test'
:
True
}
x
=
(
np
.
random
.
random
((
8
,
16
,
7
,
7
))
-
0.5
)
*
10
x
=
x
.
astype
(
"float32"
)
scale
=
np
.
array
([
np
.
max
(
np
.
abs
(
x
)).
astype
(
"float32"
)
-
1.0
])
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
"float32"
)
out_scales
[
0
]
=
scale
x
=
x
.
astype
(
self
.
dtype
)
scale
=
np
.
array
([
np
.
max
(
np
.
abs
(
x
)).
astype
(
self
.
dtype
)
-
1.0
])
out_scales
=
np
.
zeros
(
self
.
attrs
[
'window_size'
]).
astype
(
self
.
dtype
)
out_scales
[
0
]
=
scale
.
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
x
,
'Iter'
:
np
.
zeros
(
1
).
astype
(
"int64"
),
'InScale'
:
scale
.
astype
(
"float32"
)
'InScale'
:
scale
.
astype
(
self
.
dtype
)
}
xs
=
np
.
clip
(
x
,
-
scale
,
scale
)
qs
=
np
.
round
(
xs
/
scale
*
((
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
))
xs
=
np
.
clip
(
x
,
-
scale
,
scale
).
astype
(
self
.
dtype
)
qs
=
round_c
(
self
.
dtype
(
self
.
dtype
((
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
)
*
(
self
.
dtype
(
1.0
)
/
scale
)
*
xs
))
self
.
outputs
=
{
'Out'
:
qs
,
'OutScale'
:
scale
.
astype
(
"float32"
),
'OutScale'
:
scale
.
astype
(
self
.
dtype
),
'OutScales'
:
out_scales
,
}
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output
(
no_check_set
=
set
([
'OutScale'
,
'OutScales'
]))
class
TestFakeQuantizeRangeAbsMaxOp2Float16
(
TestFakeQuantizeRangeAbsMaxOp2
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
class
TestMovingOpBase
(
OpTest
):
def
setUp
(
self
):
self
.
set_dtype
()
self
.
init_type
()
self
.
attrs
=
{
'bit_length'
:
int
(
5
),
'moving_rate'
:
float
(
0.9
),
'is_test'
:
False
}
accum
=
np
.
zeros
(
1
).
astype
(
"float32"
)
accum
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
accum
[
0
]
=
1
state
=
np
.
zeros
(
1
).
astype
(
"float32"
)
state
[
0
]
=
1
scale
=
np
.
zeros
(
1
).
astype
(
"float32"
)
state
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
state
[
0
]
=
self
.
dtype
(
1.0
)
scale
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
scale
[
0
]
=
0.001
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
8
,
16
,
7
,
7
)).
astype
(
"float32"
),
'X'
:
np
.
random
.
random
((
8
,
16
,
7
,
7
)).
astype
(
self
.
dtype
),
'InScale'
:
scale
,
'InAccum'
:
accum
,
'InState'
:
state
,
}
out_accum
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_state
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_scale
=
np
.
zeros
(
1
).
astype
(
"float32"
)
out_accum
[
0
]
=
self
.
attrs
[
'moving_rate'
]
*
accum
[
0
]
+
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
"float32"
)
out_state
[
0
]
=
self
.
attrs
[
'moving_rate'
]
*
state
[
0
]
+
1
out_scale
=
out_accum
/
out_state
out_accum
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
out_state
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
out_scale
=
np
.
zeros
(
1
).
astype
(
self
.
dtype
)
out_accum
[
0
]
=
self
.
dtype
(
self
.
attrs
[
'moving_rate'
])
*
self
.
dtype
(
accum
[
0
])
+
np
.
max
(
np
.
abs
(
self
.
inputs
[
'X'
])).
astype
(
self
.
dtype
)
out_state
[
0
]
=
self
.
dtype
(
self
.
attrs
[
'moving_rate'
])
*
self
.
dtype
(
state
[
0
])
+
self
.
dtype
(
1.0
)
out_scale
=
self
.
dtype
(
self
.
dtype
(
out_accum
)
/
self
.
dtype
(
out_state
))
out_data
=
self
.
calc_output
(
out_scale
)
self
.
outputs
=
{
'Out'
:
out_data
,
...
...
@@ -259,17 +323,28 @@ class TestMovingOpBase(OpTest):
'OutScale'
:
out_scale
,
}
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
init_type
(
self
):
self
.
op_type
=
"fake_quantize_moving_average_abs_max"
def
calc_output
(
self
,
out_scale
):
return
np
.
round
(
self
.
inputs
[
'X'
]
/
out_scale
*
(
return
round_c
(
self
.
inputs
[
'X'
]
/
out_scale
*
(
(
1
<<
(
self
.
attrs
[
'bit_length'
]
-
1
))
-
1
))
def
test_check_output
(
self
):
self
.
check_output
()
class
TestMovingOpBaseFloat16
(
TestMovingOpBase
):
def
set_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-2
)
class
TestFakeQuantDequantMovingOp
(
TestMovingOpBase
):
def
init_type
(
self
):
self
.
op_type
=
"fake_quantize_dequantize_moving_average_abs_max"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录