Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3a40ac65
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a40ac65
编写于
6月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2435 fix perchannel num_channels not set bug and adjust quant.py params order
Merge pull request !2435 from 王东旭/r0.3
上级
3e3cbbba
f110c761
变更
5
展开全部
显示空白变更内容
内联
并排
Showing
5 changed file
with
607 addition
and
407 deletion
+607
-407
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+411
-126
mindspore/ops/_grad/grad_quant_ops.py
mindspore/ops/_grad/grad_quant_ops.py
+5
-5
mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
...spore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
+37
-46
mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py
mindspore/ops/_op_impl/_custom_op/minmax_update_perlayer.py
+28
-45
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+126
-185
未找到文件。
mindspore/nn/layer/quant.py
浏览文件 @
3a40ac65
此差异已折叠。
点击以展开。
mindspore/ops/_grad/grad_quant_ops.py
浏览文件 @
3a40ac65
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Generate bprop for
aware quantization
ops"""
"""Generate bprop for
quantization aware
ops"""
from
..
import
operations
as
P
from
..operations
import
_quant_ops
as
Q
...
...
@@ -133,9 +133,9 @@ def get_bprop_batchnorm_fold2_(self):
return
bprop
@
bprop_getters
.
register
(
Q
.
FakeQuantMinMaxPerLayerUpdate
)
@
bprop_getters
.
register
(
Q
.
MinMaxUpdatePerLayer
)
def
get_bprop_fakequant_with_minmax_per_layer_update
(
self
):
"""Generate bprop for
FakeQuantMinMaxPerLayerUpdate
for Ascend"""
"""Generate bprop for
MinMaxUpdatePerLayer
for Ascend"""
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
...
...
@@ -143,9 +143,9 @@ def get_bprop_fakequant_with_minmax_per_layer_update(self):
return
bprop
@
bprop_getters
.
register
(
Q
.
FakeQuantMinMaxPerChannelUpdate
)
@
bprop_getters
.
register
(
Q
.
MinMaxUpdatePerChannel
)
def
get_bprop_fakequant_with_minmax_per_channel_update
(
self
):
"""Generate bprop for
FakeQuantMinMaxPerChannelUpdate
for Ascend"""
"""Generate bprop for
MinMaxUpdatePerChannel
for Ascend"""
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
...
...
mindspore/ops/_op_impl/_custom_op/
fake_quant_minmax_perchannel_update
.py
→
mindspore/ops/_op_impl/_custom_op/
minmax_update_perchannel
.py
浏览文件 @
3a40ac65
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
FakeQuantMinMaxPerChannelUpdate
op"""
"""
MinMaxUpdatePerChannel
op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
...
...
@@ -22,20 +21,15 @@ from topi import generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_min_max_per_channel_update_op_info
=
TBERegOp
(
"FakeQuantMinMaxPerChannelUpdate"
)
\
minmax_update_perchannel_op_info
=
TBERegOp
(
"MinMaxUpdatePerChannel"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"
fake_quant_min_max_per_channel_update
.so"
)
\
.
binfile_name
(
"
minmax_update_perchannel
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"
fake_quant_min_max_per_channel_update
"
)
\
.
kernel_name
(
"
minmax_update_perchannel
"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
...
...
@@ -47,24 +41,27 @@ fake_quant_min_max_per_channel_update_op_info = TBERegOp("FakeQuantMinMaxPerChan
.
get_op_info
()
@
op_info_register
(
fake_quant_min_max_per_channel_update
_op_info
)
def
_
fake_quant_min_max_per_channel_update
_tbe
():
"""
FakeQuantPerChannelUpdate
TBE register"""
@
op_info_register
(
minmax_update_perchannel
_op_info
)
def
_
minmax_update_perchannel
_tbe
():
"""
MinMaxUpdatePerChannel
TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_min_max_per_channel_update"
)
def
fake_quant_min_max_per_channel_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update"
):
"""FakeQuantPerChannelUpdate compute"""
@
fusion_manager
.
register
(
"minmax_update_perchannel"
)
def
minmax_update_perchannel_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
channel_axis
):
"""MinMaxUpdatePerChannel compute"""
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
if
channel_axis
==
0
:
axis
=
[
1
,
2
,
3
,
4
]
else
:
axis
=
[
0
,
2
,
3
]
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
...
...
@@ -79,11 +76,11 @@ def fake_quant_min_max_per_channel_update_compute(x, min_val, max_val,
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_min_max_per_channel_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update
"
):
"""
FakeQuantPerLayer
op"""
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
int
,
str
)
def
minmax_update_perchannel
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
channel_axis
,
kernel_name
=
"minmax_update_perchannel
"
):
"""
MinMaxUpdatePerChannel
op"""
x_shape
=
x
.
get
(
"ori_shape"
)
x_format
=
x
.
get
(
"format"
)
x_dtype
=
x
.
get
(
"dtype"
)
...
...
@@ -112,21 +109,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
if
channel_axis_
==
0
:
shape_c
=
min_val
.
get
(
"ori_shape"
)
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
input_data
=
tvm
.
placeholder
(
x
.
get
(
"shape"
),
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res_list
=
fake_quant_min_max_per_channel_update
_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis_
,
kernel_name
)
res_list
=
minmax_update_perchannel
_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
channel_axis_
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/_op_impl/_custom_op/
fake_quant_minmax_perlayer_update
.py
→
mindspore/ops/_op_impl/_custom_op/
minmax_update_perlayer
.py
浏览文件 @
3a40ac65
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
FakeQuantMinMaxPerLayerUpdate
op"""
"""
MinMaxUpdatePerLayer
op"""
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
from
te
import
tvm
...
...
@@ -22,20 +22,15 @@ from topi import generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_minmax_update_op_info
=
TBERegOp
(
"FakeQuantMinMaxPerLayerUpdate"
)
\
minmax_update_perlayer_op_info
=
TBERegOp
(
"MinMaxUpdatePerLayer"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"
fake_quant_minmax_update
.so"
)
\
.
binfile_name
(
"
minmax_update_perlayer
.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"
fake_quant_minmax_update
"
)
\
.
kernel_name
(
"
minmax_update_perlayer
"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
...
...
@@ -46,23 +41,22 @@ fake_quant_minmax_update_op_info = TBERegOp("FakeQuantMinMaxPerLayerUpdate") \
.
get_op_info
()
@
op_info_register
(
fake_quant_minmax_update
_op_info
)
def
_
fake_quant_minmax_update
_tbe
():
"""
FakeQuantMinMaxPerLayerUpdate
TBE register"""
@
op_info_register
(
minmax_update_perlayer
_op_info
)
def
_
minmax_update_perlayer
_tbe
():
"""
MinMaxUpdatePerLayer
TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_minmax_update"
)
def
fake_quant_minmax_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
=
"fake_quant_minmax_update"
):
"""FakeQuantMinMaxPerLayerUpdate compute"""
@
fusion_manager
.
register
(
"minmax_update_perlayer"
)
def
minmax_update_perlayer_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
):
"""MinMaxUpdatePerLayer compute"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
min_val
=
te
.
lang
.
cce
.
broadcast
(
min_val
,
shape_min
,
x
.
dtype
)
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
axis
=
tuple
(
range
(
len
(
shape
)))
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
...
...
@@ -79,11 +73,10 @@ def fake_quant_minmax_update_compute(x, min_val, max_val, ema, ema_decay, quant_
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
str
)
def
fake_quant_minmax_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
kernel_name
=
"fake_quant_minmax_update"
):
"""FakeQuantPerLayer op"""
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
str
)
def
minmax_update_perlayer
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
kernel_name
=
"minmax_update_perlayer"
):
"""MinMaxUpdatePerLayer op"""
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
...
...
@@ -112,20 +105,10 @@ def fake_quant_minmax_update(x, min_val, max_val, min_up, max_up,
input_shape
=
(
functools_reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[:]),)
shape_min
,
_
,
_
=
util
.
produce_shapes
(
min_shape
,
input_shape
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res_list
=
fake_quant_minmax_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
)
res_list
=
minmax_update_perlayer_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
3a40ac65
...
...
@@ -21,12 +21,12 @@ from ..._checkparam import Rel
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
from
...common
import
dtype
as
mstype
__all__
=
[
"FakeQuantPerLayer"
,
__all__
=
[
"MinMaxUpdatePerLayer"
,
"MinMaxUpdatePerChannel"
,
"FakeQuantPerLayer"
,
"FakeQuantPerLayerGrad"
,
"FakeQuantPerChannel"
,
"FakeQuantPerChannelGrad"
,
"FakeQuantMinMaxPerLayerUpdate"
,
"FakeQuantMinMaxPerChannelUpdate"
,
"BatchNormFold"
,
"BatchNormFoldGrad"
,
"CorrectionMul"
,
...
...
@@ -36,23 +36,140 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFold2Grad"
,
"BatchNormFoldD"
,
"BatchNormFoldGradD"
,
"BNTrainingReduce"
,
"BatchNormFold2_D"
,
"BatchNormFold2GradD"
,
"BatchNormFold2GradReduce"
,
"BatchNormFold2GradReduce"
]
class
MinMaxUpdatePerLayer
(
PrimitiveWithInfer
):
r
"""
Update min and max per layer.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = MinMaxUpdatePerLayer(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perlayer
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
MinMaxUpdatePerChannel
(
PrimitiveWithInfer
):
r
"""
Update min and max per channel.
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perchannel
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
(
{
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
FakeQuantPerLayer
(
PrimitiveWithInfer
):
r
"""
Simulate the quantize and dequantize operations in training time.
Args:
num_bits (int) : Number bits for
aware quantilization
. Default: 8.
num_bits (int) : Number bits for
quantization aware
. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate
aware quantiz
e funcion. After delay step in training time begin simulate the aware
simulate
quantization awar
e funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
...
...
@@ -334,7 +451,7 @@ class BatchNormFold(PrimitiveWithInfer):
Batch normalization folded.
Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.
1
.
momentum (float): Momentum value should be [0, 1]. Default: 0.
9
.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True.
...
...
@@ -366,7 +483,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel_axis
=
1
@
prim_attr_register
def
__init__
(
self
,
momentum
=
0.
1
,
epsilon
=
1e-5
,
is_training
=
True
,
freeze_bn
=
0
):
def
__init__
(
self
,
momentum
=
0.
9
,
epsilon
=
1e-5
,
is_training
=
True
,
freeze_bn
=
0
):
"""init batch norm fold layer"""
self
.
momentum
=
validator
.
check_number_range
(
'momentum'
,
momentum
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
epsilon
=
validator
.
check_float_positive
(
'epsilon'
,
epsilon
,
self
.
name
)
...
...
@@ -731,32 +848,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
return
x_type
class
BNTrainingReduce
(
PrimitiveWithInfer
):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@
prim_attr_register
def
__init__
(
self
):
"""init _BNTrainingReduce layer"""
self
.
init_prim_io_names
(
inputs
=
[
'x'
],
outputs
=
[
'x_sum'
,
'x_square_sum'
])
def
infer_shape
(
self
,
x_shape
):
return
[
x_shape
[
1
]],
[
x_shape
[
1
]]
def
infer_dtype
(
self
,
x_type
):
return
x_type
,
x_type
class
BatchNormFold2_D
(
PrimitiveWithInfer
):
"""
Scale the bias with a correction factor to the long term statistics
...
...
@@ -859,153 +950,3 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
def
infer_dtype
(
self
,
dout_type
,
x_type
):
validator
.
check
(
"dout type"
,
dout_type
,
"x type"
,
x_type
)
return
dout_type
,
dout_type
class
FakeQuantMinMaxPerLayerUpdate
(
PrimitiveWithInfer
):
r
"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
):
"""init FakeQuantMinMaxPerLayerUpdate OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_minmax_perlayer_update
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
FakeQuantMinMaxPerChannelUpdate
(
PrimitiveWithInfer
):
r
"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_minmax_perchannel_update
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
(
{
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录