Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ed2e84d4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed2e84d4
编写于
7月 25, 2020
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add LeakReLUQuant OP for bug fix.
上级
b8686fc5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
118 addition
and
43 deletion
+118
-43
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+85
-15
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+2
-0
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+31
-28
未找到文件。
mindspore/nn/layer/quant.py
浏览文件 @
ed2e84d4
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Quantization aware."""
"""Quantization aware
training
."""
from
functools
import
partial
import
numpy
as
np
...
...
@@ -43,6 +43,7 @@ __all__ = [
'Conv2dQuant'
,
'DenseQuant'
,
'ActQuant'
,
'LeakyReLUQuant'
,
'HSwishQuant'
,
'HSigmoidQuant'
,
'TensorAddQuant'
,
...
...
@@ -349,7 +350,7 @@ class FakeQuantWithMinMax(Cell):
self
.
maxq
=
Parameter
(
Tensor
(
max_array
),
name
=
'quant_max'
,
requires_grad
=
False
)
# init fake quant relative op
if
per_channel
:
if
self
.
per_channel
:
quant_fun
=
partial
(
Q
.
FakeQuantPerChannel
,
channel_axis
=
self
.
channel_axis
)
ema_fun
=
partial
(
Q
.
MinMaxUpdatePerChannel
,
channel_axis
=
self
.
channel_axis
)
else
:
...
...
@@ -369,7 +370,7 @@ class FakeQuantWithMinMax(Cell):
num_bits
=
self
.
num_bits
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
quant_delay
=
quant_delay
)
quant_delay
=
self
.
quant_delay
)
self
.
fake_quant_train
=
quant_fun
(
training
=
True
)
self
.
fake_quant_infer
=
quant_fun
(
training
=
False
)
...
...
@@ -832,7 +833,7 @@ class ActQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>> act_quant = nn.ActQuant(
4, 1
)
>>> act_quant = nn.ActQuant(
nn.ReLU
)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result = act_quant(input_x)
"""
...
...
@@ -855,7 +856,7 @@ class ActQuant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
act
=
activation
()
self
.
act
=
activation
def
construct
(
self
,
x
):
x
=
self
.
act
(
x
)
...
...
@@ -865,6 +866,75 @@ class ActQuant(_QuantActivation):
def
get_origin
(
self
):
return
self
.
act
class
LeakyReLUQuant
(
_QuantActivation
):
r
"""
LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP.
For a more Detailed overview of HSwish op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of HSwishQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> activation = nn.LeakyReLUQuant(nn.LeakyReLU())
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = activation(input)
"""
def
__init__
(
self
,
activation
,
ema_decay
=
0.999
,
per_channel
=
False
,
num_bits
=
8
,
symmetric
=
False
,
narrow_range
=
False
,
quant_delay
=
0
):
super
(
LeakyReLUQuant
,
self
).
__init__
()
self
.
fake_quant_act_before
=
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
ema
=
True
,
ema_decay
=
ema_decay
,
per_channel
=
per_channel
,
num_bits
=
num_bits
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
fake_quant_act_after
=
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
ema
=
True
,
ema_decay
=
ema_decay
,
per_channel
=
per_channel
,
num_bits
=
num_bits
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
if
issubclass
(
activation
.
__class__
,
nn
.
LeakyReLU
):
self
.
act
=
activation
else
:
raise
ValueError
(
"Activation should be `nn.LeakyReLU`"
)
def
construct
(
self
,
x
):
x
=
self
.
fake_quant_act_before
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
fake_quant_act_after
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
act
class
HSwishQuant
(
_QuantActivation
):
r
"""
...
...
@@ -888,9 +958,9 @@ class HSwishQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>>
hswish_quant = nn.HSwishQuant(4, 1
)
>>> input
_x
= Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result =
hswish_quant(input_x
)
>>>
activation = nn.HSwishQuant(nn.HSwish()
)
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result =
activation(input
)
"""
def
__init__
(
self
,
...
...
@@ -920,8 +990,8 @@ class HSwishQuant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
if
issubclass
(
activation
,
nn
.
HSwish
):
self
.
act
=
activation
()
if
issubclass
(
activation
.
__class__
,
nn
.
HSwish
):
self
.
act
=
activation
else
:
raise
ValueError
(
"Activation should be `nn.HSwish`"
)
...
...
@@ -957,9 +1027,9 @@ class HSigmoidQuant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>>
hsigmoid_quant = nn.HSigmoidQuant(4, 1
)
>>> input
_x
= Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result =
hsigmoid_quant(input_x
)
>>>
activation = nn.HSigmoidQuant(nn.HSigmoid()
)
>>> input = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result =
activation(input
)
"""
def
__init__
(
self
,
...
...
@@ -989,8 +1059,8 @@ class HSigmoidQuant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
if
issubclass
(
activation
,
nn
.
HSigmoid
):
self
.
act
=
activation
()
if
issubclass
(
activation
.
__class__
,
nn
.
HSigmoid
):
self
.
act
=
activation
else
:
raise
ValueError
(
"Activation should be `nn.HSigmoid`"
)
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
ed2e84d4
...
...
@@ -386,6 +386,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
ascend_support_x_rank
}
'"
)
if
not
self
.
is_ascend
:
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
if
len
(
x_shape
)
==
1
:
self
.
channel_axis
=
0
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
...
...
mindspore/train/quant/quant.py
浏览文件 @
ed2e84d4
...
...
@@ -35,8 +35,8 @@ from . import quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ActQuant
,
nn
.
ReLU6
:
quant
.
ActQuant
,
nn
.
LeakyReLU
:
quant
.
ActQuant
,
nn
.
Sigmoid
:
quant
.
ActQuant
,
nn
.
LeakyReLU
:
quant
.
LeakyReLUQuant
,
nn
.
HSigmoid
:
quant
.
HSigmoidQuant
,
nn
.
HSwish
:
quant
.
HSwishQuant
}
...
...
@@ -167,32 +167,35 @@ class ConvertToQuantNetwork:
convert Conv2d cell to quant cell
"""
conv_inner
=
subcell
.
conv
if
subcell
.
has_bn
and
self
.
bn_fold
:
bn_inner
=
subcell
.
batchnorm
conv_inner
=
quant
.
Conv2dBatchNormQuant
(
conv_inner
.
in_channels
,
conv_inner
.
out_channels
,
kernel_size
=
conv_inner
.
kernel_size
,
stride
=
conv_inner
.
stride
,
pad_mode
=
conv_inner
.
pad_mode
,
padding
=
conv_inner
.
padding
,
dilation
=
conv_inner
.
dilation
,
group
=
conv_inner
.
group
,
eps
=
bn_inner
.
eps
,
quant_delay
=
self
.
weight_qdelay
,
freeze_bn
=
self
.
freeze_bn
,
per_channel
=
self
.
weight_channel
,
num_bits
=
self
.
weight_bits
,
fake
=
True
,
symmetric
=
self
.
weight_symmetric
,
narrow_range
=
self
.
weight_range
)
# change original network BatchNormal OP parameters to quant network
conv_inner
.
gamma
=
subcell
.
batchnorm
.
gamma
conv_inner
.
beta
=
subcell
.
batchnorm
.
beta
conv_inner
.
moving_mean
=
subcell
.
batchnorm
.
moving_mean
conv_inner
.
moving_variance
=
subcell
.
batchnorm
.
moving_variance
del
subcell
.
batchnorm
subcell
.
batchnorm
=
None
subcell
.
has_bn
=
False
if
subcell
.
has_bn
:
if
self
.
bn_fold
:
bn_inner
=
subcell
.
batchnorm
conv_inner
=
quant
.
Conv2dBatchNormQuant
(
conv_inner
.
in_channels
,
conv_inner
.
out_channels
,
kernel_size
=
conv_inner
.
kernel_size
,
stride
=
conv_inner
.
stride
,
pad_mode
=
conv_inner
.
pad_mode
,
padding
=
conv_inner
.
padding
,
dilation
=
conv_inner
.
dilation
,
group
=
conv_inner
.
group
,
eps
=
bn_inner
.
eps
,
quant_delay
=
self
.
weight_qdelay
,
freeze_bn
=
self
.
freeze_bn
,
per_channel
=
self
.
weight_channel
,
num_bits
=
self
.
weight_bits
,
fake
=
True
,
symmetric
=
self
.
weight_symmetric
,
narrow_range
=
self
.
weight_range
)
# change original network BatchNormal OP parameters to quant network
conv_inner
.
gamma
=
subcell
.
batchnorm
.
gamma
conv_inner
.
beta
=
subcell
.
batchnorm
.
beta
conv_inner
.
moving_mean
=
subcell
.
batchnorm
.
moving_mean
conv_inner
.
moving_variance
=
subcell
.
batchnorm
.
moving_variance
del
subcell
.
batchnorm
subcell
.
batchnorm
=
None
subcell
.
has_bn
=
False
else
:
raise
ValueError
(
"Only support Batchnorm fold mode."
)
else
:
conv_inner
=
quant
.
Conv2dQuant
(
conv_inner
.
in_channels
,
conv_inner
.
out_channels
,
...
...
@@ -259,7 +262,7 @@ class ConvertToQuantNetwork:
act_class
=
activation
.
__class__
if
act_class
not
in
_ACTIVATION_MAP
:
raise
ValueError
(
"Unsupported activation in auto quant: "
,
act_class
)
return
_ACTIVATION_MAP
[
act_class
](
activation
=
act
_class
,
return
_ACTIVATION_MAP
[
act_class
](
activation
=
act
ivation
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
act_qdelay
,
per_channel
=
self
.
act_channel
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录