Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1089c908
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1089c908
编写于
6月 30, 2020
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick r0.5 to master for quantizaiton aware training
上级
746ecc2e
变更
28
隐藏空白更改
内联
并排
Showing
28 changed file
with
322 addition
and
142 deletion
+322
-142
mindspore/ccsrc/utils/checkpoint.proto
mindspore/ccsrc/utils/checkpoint.proto
+0
-1
mindspore/nn/cell.py
mindspore/nn/cell.py
+9
-0
mindspore/nn/layer/conv.py
mindspore/nn/layer/conv.py
+149
-1
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+44
-30
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
+0
-1
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
+0
-1
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
...re/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
+0
-1
mindspore/ops/_op_impl/_custom_op/correction_mul.py
mindspore/ops/_op_impl/_custom_op/correction_mul.py
+0
-1
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
+0
-2
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
+9
-5
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
...ore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
+9
-5
mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
...spore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
+9
-5
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+28
-13
mindspore/train/callback/_checkpoint.py
mindspore/train/callback/_checkpoint.py
+7
-18
mindspore/train/callback/_loss_monitor.py
mindspore/train/callback/_loss_monitor.py
+1
-1
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+28
-15
mindspore/train/serialization.py
mindspore/train/serialization.py
+4
-17
model_zoo/lenet_quant/README.md
model_zoo/lenet_quant/README.md
+3
-3
model_zoo/lenet_quant/eval.py
model_zoo/lenet_quant/eval.py
+1
-1
model_zoo/lenet_quant/eval_quant.py
model_zoo/lenet_quant/eval_quant.py
+2
-2
model_zoo/lenet_quant/train.py
model_zoo/lenet_quant/train.py
+2
-3
model_zoo/lenet_quant/train_quant.py
model_zoo/lenet_quant/train_quant.py
+6
-5
model_zoo/mobilenetv2/scripts/run_infer.sh
model_zoo/mobilenetv2/scripts/run_infer.sh
+1
-1
model_zoo/mobilenetv2/scripts/run_train.sh
model_zoo/mobilenetv2/scripts/run_train.sh
+1
-1
model_zoo/mobilenetv3/scripts/run_infer.sh
model_zoo/mobilenetv3/scripts/run_infer.sh
+1
-1
model_zoo/mobilenetv3/scripts/run_train.sh
model_zoo/mobilenetv3/scripts/run_train.sh
+1
-1
tests/ut/python/train/quant/mobilenetv2_combined.py
tests/ut/python/train/quant/mobilenetv2_combined.py
+6
-6
tests/ut/python/train/quant/test_quant.py
tests/ut/python/train/quant/test_quant.py
+1
-1
未找到文件。
mindspore/ccsrc/utils/checkpoint.proto
浏览文件 @
1089c908
...
@@ -22,7 +22,6 @@ message Checkpoint {
...
@@ -22,7 +22,6 @@ message Checkpoint {
required
TensorProto
tensor
=
2
;
required
TensorProto
tensor
=
2
;
}
}
repeated
Value
value
=
1
;
repeated
Value
value
=
1
;
required
string
model_type
=
2
;
}
}
...
...
mindspore/nn/cell.py
浏览文件 @
1089c908
...
@@ -81,6 +81,7 @@ class Cell:
...
@@ -81,6 +81,7 @@ class Cell:
self
.
enable_hook
=
False
self
.
enable_hook
=
False
self
.
_bprop_debug
=
False
self
.
_bprop_debug
=
False
self
.
_is_run
=
False
self
.
_is_run
=
False
self
.
cell_type
=
None
@
property
@
property
def
is_run
(
self
):
def
is_run
(
self
):
...
@@ -140,6 +141,14 @@ class Cell:
...
@@ -140,6 +141,14 @@ class Cell:
for
cell_name
,
cell
in
cells_name
:
for
cell_name
,
cell
in
cells_name
:
cell
.
_param_prefix
=
cell_name
cell
.
_param_prefix
=
cell_name
def
update_cell_type
(
self
,
cell_type
):
"""
Update current cell type mainly identify if quantization aware training network.
After invoked, can set the cell type to 'cell_type'.
"""
self
.
cell_type
=
cell_type
@
cell_init_args
.
setter
@
cell_init_args
.
setter
def
cell_init_args
(
self
,
value
):
def
cell_init_args
(
self
,
value
):
if
not
isinstance
(
value
,
str
):
if
not
isinstance
(
value
,
str
):
...
...
mindspore/nn/layer/conv.py
浏览文件 @
1089c908
...
@@ -17,11 +17,12 @@ from mindspore import log as logger
...
@@ -17,11 +17,12 @@ from mindspore import log as logger
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore._checkparam
import
ParamValidator
as
validator
,
Rel
from
mindspore._checkparam
import
check_bool
,
twice
,
check_int_positive
,
check_int_non_negative
from
mindspore._checkparam
import
check_bool
,
twice
,
check_int_positive
,
check_int_non_negative
from
mindspore._extends
import
cell_attr_register
from
mindspore._extends
import
cell_attr_register
from
..cell
import
Cell
from
..cell
import
Cell
__all__
=
[
'Conv2d'
,
'Conv2dTranspose'
]
__all__
=
[
'Conv2d'
,
'Conv2dTranspose'
,
'DepthwiseConv2d'
]
class
_Conv
(
Cell
):
class
_Conv
(
Cell
):
"""
"""
...
@@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
...
@@ -397,3 +398,150 @@ class Conv2dTranspose(_Conv):
self
.
weight
,
self
.
weight
,
self
.
bias
)
self
.
bias
)
return
s
return
s
class
DepthwiseConv2d
(
Cell
):
r
"""
2D depthwise convolution layer.
Applies a 2D depthwise convolution over an input tensor which is typically of shape:
math:`(N, C_{in}, H_{in}, W_{in})`, where :math:`N` is batch size and :math:`C_{in}` is channel number.
For each batch of shape:math:`(C_{in}, H_{in}, W_{in})`, the formula is defined as:
.. math::
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
of kernel and it has shape :math:`(\text{ks_h}, \text{ks_w})`, where :math:`\text{ks_h}` and
:math:`\text{ks_w}` are height and width of the convolution kernel. The full kernel has shape
:math:`(C_{out}, C_{in} // \text{group}, \text{ks_h}, \text{ks_w})`, where group is the group number
to split the input in the channel dimension.
If the 'pad_mode' is set to be "valid", the output height and width will be
:math:`\left \lfloor{1 + \frac{H_{in} + 2 \times \text{padding} - \text{ks_h} -
(\text{ks_h} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` and
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents
the height and width of movement are both strides, or a tuple of two int numbers that
represent height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
- same: Adopts the way of completion. Output height and width will be the same as the input.
Total number of padding will be calculated for horizontal and vertical
direction and evenly distributed to top and bottom, left and right if possible. Otherwise, the
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
must be 0.
- valid: Adopts the way of discarding. The possibly largest height and width of output will be return
without padding. Extra pixels will be discarded. If this mode is set, `padding`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` should be greater than or equal to 0.
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
to use for dilated convolution. If set to be :math:`k > 1`, there will
be :math:`k - 1` pixels skipped for each sampling location. Its value should
be greater or equal to 1 and bounded by the height and width of the
input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.DepthwiseConv2d(120, 240, 4, has_bias=False, weight_init='normal')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
(1, 240, 1024, 640)
"""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
,
dilation
=
1
,
group
=
1
,
has_bias
=
False
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
):
super
(
DepthwiseConv2d
,
self
).
__init__
()
self
.
kernel_size
=
twice
(
kernel_size
)
self
.
stride
=
twice
(
stride
)
self
.
dilation
=
twice
(
dilation
)
self
.
in_channels
=
check_int_positive
(
in_channels
)
self
.
out_channels
=
check_int_positive
(
out_channels
)
validator
.
check_integer
(
'group'
,
group
,
in_channels
,
Rel
.
EQ
)
validator
.
check_integer
(
'group'
,
group
,
out_channels
,
Rel
.
EQ
)
validator
.
check_integer
(
'group'
,
group
,
1
,
Rel
.
GE
)
self
.
pad_mode
=
pad_mode
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
group
=
group
self
.
has_bias
=
has_bias
self
.
conv
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
self
.
kernel_size
,
pad_mode
=
self
.
pad_mode
,
pad
=
self
.
padding
,
stride
=
self
.
stride
,
dilation
=
self
.
dilation
)
self
.
bias_add
=
P
.
BiasAdd
()
weight_shape
=
[
1
,
in_channels
,
*
self
.
kernel_size
]
self
.
weight
=
Parameter
(
initializer
(
weight_init
,
weight_shape
),
name
=
'weight'
)
if
check_bool
(
has_bias
):
self
.
bias
=
Parameter
(
initializer
(
bias_init
,
[
out_channels
]),
name
=
'bias'
)
else
:
if
bias_init
!=
'zeros'
:
logger
.
warning
(
"value of `has_bias` is False, value of `bias_init` will be ignore."
)
self
.
bias
=
None
def
construct
(
self
,
x
):
out
=
self
.
conv
(
x
,
self
.
weight
)
if
self
.
has_bias
:
out
=
self
.
bias_add
(
out
,
self
.
bias
)
return
out
def
extend_repr
(
self
):
s
=
'input_channels={}, output_channels={}, kernel_size={}, stride={}, '
\
'pad_mode={}, padding={}, dilation={}, group={},'
\
'has_bias={}, weight_init={}, bias_init={}'
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
has_bias
,
self
.
weight_init
,
self
.
bias_init
)
if
self
.
has_bias
:
s
+=
', bias={}'
.
format
(
self
.
bias
)
return
s
mindspore/nn/layer/quant.py
浏览文件 @
1089c908
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
from
functools
import
partial
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
functional
as
F
...
@@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
...
@@ -23,10 +24,9 @@ from mindspore.common.parameter import Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore._checkparam
import
check_int_positive
,
check_bool
,
twice
from
mindspore._checkparam
import
check_int_positive
,
check_bool
,
twice
from
mindspore._checkparam
import
Validator
as
validator
,
Rel
from
mindspore._checkparam
import
Rel
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.layer.activation
import
get_activation
import
mindspore.context
as
context
import
mindspore.context
as
context
from
.normalization
import
BatchNorm2d
from
.normalization
import
BatchNorm2d
from
.activation
import
get_activation
from
.activation
import
get_activation
from
..cell
import
Cell
from
..cell
import
Cell
...
@@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
...
@@ -82,7 +82,7 @@ class Conv2dBnAct(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: Non
e.
has_bn (bool): Specifies to used batchnorm or not. Default: Fals
e.
activation (string): Specifies activation type. The optional values are as following:
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
...
@@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
...
@@ -94,7 +94,7 @@ class Conv2dBnAct(Cell):
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
Examples:
>>> net = Conv2dBnAct(120, 240, 4,
batchnorm
=True, activation='ReLU')
>>> net = Conv2dBnAct(120, 240, 4,
has_bn
=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape
>>> net(input).shape
(1, 240, 1024, 640)
(1, 240, 1024, 640)
...
@@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
...
@@ -112,28 +112,39 @@ class Conv2dBnAct(Cell):
has_bias
=
False
,
has_bias
=
False
,
weight_init
=
'normal'
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
,
bias_init
=
'zeros'
,
batchnorm
=
Non
e
,
has_bn
=
Fals
e
,
activation
=
None
):
activation
=
None
):
super
(
Conv2dBnAct
,
self
).
__init__
()
super
(
Conv2dBnAct
,
self
).
__init__
()
self
.
conv
=
conv
.
Conv2d
(
in_channels
,
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
and
group
>
1
:
out_channels
,
self
.
conv
=
conv
.
DepthwiseConv2d
(
in_channels
,
kernel_size
,
out_channels
,
stride
,
kernel_size
=
kernel_size
,
pad_mode
,
stride
=
stride
,
padding
,
pad_mode
=
pad_mode
,
dilation
,
padding
=
padding
,
group
,
dilation
=
dilation
,
has_bias
,
group
=
group
,
weight_init
,
has_bias
=
has_bias
,
bias_init
)
weight_init
=
weight_init
,
self
.
has_bn
=
batchnorm
is
not
None
bias_init
=
bias_init
)
else
:
self
.
conv
=
conv
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
kernel_size
,
stride
=
stride
,
pad_mode
=
pad_mode
,
padding
=
padding
,
dilation
=
dilation
,
group
=
group
,
has_bias
=
has_bias
,
weight_init
=
weight_init
,
bias_init
=
bias_init
)
self
.
has_bn
=
validator
.
check_bool
(
"has_bn"
,
has_bn
)
self
.
has_act
=
activation
is
not
None
self
.
has_act
=
activation
is
not
None
self
.
batchnorm
=
batchnorm
if
has_bn
:
if
batchnorm
is
True
:
self
.
batchnorm
=
BatchNorm2d
(
out_channels
)
self
.
batchnorm
=
BatchNorm2d
(
out_channels
)
elif
batchnorm
is
not
None
:
validator
.
check_isinstance
(
'batchnorm'
,
batchnorm
,
(
BatchNorm2d
,))
self
.
activation
=
get_activation
(
activation
)
self
.
activation
=
get_activation
(
activation
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
...
@@ -160,7 +171,7 @@ class DenseBnAct(Cell):
...
@@ -160,7 +171,7 @@ class DenseBnAct(Cell):
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: Non
e.
has_bn (bool): Specifies to used batchnorm or not. Default: Fals
e.
activation (string): Specifies activation type. The optional values are as following:
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
...
@@ -183,7 +194,7 @@ class DenseBnAct(Cell):
...
@@ -183,7 +194,7 @@ class DenseBnAct(Cell):
weight_init
=
'normal'
,
weight_init
=
'normal'
,
bias_init
=
'zeros'
,
bias_init
=
'zeros'
,
has_bias
=
True
,
has_bias
=
True
,
batchnorm
=
Non
e
,
has_bn
=
Fals
e
,
activation
=
None
):
activation
=
None
):
super
(
DenseBnAct
,
self
).
__init__
()
super
(
DenseBnAct
,
self
).
__init__
()
self
.
dense
=
basic
.
Dense
(
self
.
dense
=
basic
.
Dense
(
...
@@ -192,12 +203,10 @@ class DenseBnAct(Cell):
...
@@ -192,12 +203,10 @@ class DenseBnAct(Cell):
weight_init
,
weight_init
,
bias_init
,
bias_init
,
has_bias
)
has_bias
)
self
.
has_bn
=
batchnorm
is
not
None
self
.
has_bn
=
validator
.
check_bool
(
"has_bn"
,
has_bn
)
self
.
has_act
=
activation
is
not
None
self
.
has_act
=
activation
is
not
None
if
batchnorm
is
True
:
if
has_bn
:
self
.
batchnorm
=
BatchNorm2d
(
out_channels
)
self
.
batchnorm
=
BatchNorm2d
(
out_channels
)
elif
batchnorm
is
not
None
:
validator
.
check_isinstance
(
'batchnorm'
,
batchnorm
,
(
BatchNorm2d
,))
self
.
activation
=
get_activation
(
activation
)
self
.
activation
=
get_activation
(
activation
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
...
@@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell):
...
@@ -312,6 +321,10 @@ class FakeQuantWithMinMax(Cell):
quant_delay
=
0
):
quant_delay
=
0
):
"""init FakeQuantWithMinMax layer"""
"""init FakeQuantWithMinMax layer"""
super
(
FakeQuantWithMinMax
,
self
).
__init__
()
super
(
FakeQuantWithMinMax
,
self
).
__init__
()
validator
.
check_type
(
"min_init"
,
min_init
,
[
int
,
float
])
validator
.
check_type
(
"max_init"
,
max_init
,
[
int
,
float
])
validator
.
check
(
"min_init"
,
min_init
,
"max_init"
,
max_init
,
rel
=
Rel
.
LT
)
validator
.
check_integer
(
'quant_delay'
,
quant_delay
,
0
,
Rel
.
GE
)
self
.
min_init
=
min_init
self
.
min_init
=
min_init
self
.
max_init
=
max_init
self
.
max_init
=
max_init
self
.
num_bits
=
num_bits
self
.
num_bits
=
num_bits
...
@@ -1183,12 +1196,13 @@ class QuantBlock(Cell):
...
@@ -1183,12 +1196,13 @@ class QuantBlock(Cell):
self
.
has_bias
=
bias
is
None
self
.
has_bias
=
bias
is
None
self
.
activation
=
activation
self
.
activation
=
activation
self
.
has_act
=
activation
is
None
self
.
has_act
=
activation
is
None
self
.
bias_add
=
P
.
BiasAdd
()
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
x
=
self
.
quant
(
x
)
x
=
self
.
quant
(
x
)
x
=
self
.
core_op
(
x
,
self
.
weight
)
x
=
self
.
core_op
(
x
,
self
.
weight
)
if
self
.
has_bias
:
if
self
.
has_bias
:
output
=
self
.
bias_add
(
output
,
self
.
bias
)
x
=
self
.
bias_add
(
x
,
self
.
bias
)
if
self
.
has_act
:
if
self
.
has_act
:
x
=
self
.
activation
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dequant
(
x
,
self
.
dequant_scale
)
x
=
self
.
dequant
(
x
,
self
.
dequant_scale
)
...
...
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
浏览文件 @
1089c908
...
@@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
...
@@ -30,7 +30,6 @@ batchnorm_fold2_op_info = TBERegOp("BatchNormFold2_D") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2"
)
\
.
kernel_name
(
"batchnorm_fold2"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"beta"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"beta"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"gamma"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"gamma"
,
None
,
"required"
,
None
)
\
...
...
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
浏览文件 @
1089c908
...
@@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
...
@@ -30,7 +30,6 @@ batchnorm_fold2_grad_op_info = TBERegOp("BatchNormFold2GradD") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2_grad"
)
\
.
kernel_name
(
"batchnorm_fold2_grad"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"dout_reduce"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"dout_reduce"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"dout_x_reduce"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"dout_x_reduce"
,
None
,
"required"
,
None
)
\
...
...
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
浏览文件 @
1089c908
...
@@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
...
@@ -31,7 +31,6 @@ batchnorm_fold2_grad_reduce_op_info = TBERegOp("BatchNormFold2GradReduce") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2_grad_reduce"
)
\
.
kernel_name
(
"batchnorm_fold2_grad_reduce"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dout_reduce"
,
True
,
"required"
,
"all"
)
\
.
output
(
0
,
"dout_reduce"
,
True
,
"required"
,
"all"
)
\
...
...
mindspore/ops/_op_impl/_custom_op/correction_mul.py
浏览文件 @
1089c908
...
@@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
...
@@ -30,7 +30,6 @@ correction_mul_op_info = TBERegOp("CorrectionMul") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul"
)
\
.
kernel_name
(
"correction_mul"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"batch_std"
,
None
,
"required"
,
None
)
\
...
...
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
浏览文件 @
1089c908
...
@@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
...
@@ -30,7 +30,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul_grad"
)
\
.
kernel_name
(
"correction_mul_grad"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
...
@@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
...
@@ -128,7 +127,6 @@ correction_mul_grad_reduce_op_info = TBERegOp("CorrectionMulGradReduce") \
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul_grad_reduce"
)
\
.
kernel_name
(
"correction_mul_grad_reduce"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"d_batch_std"
,
True
,
"required"
,
"all"
)
\
.
output
(
0
,
"d_batch_std"
,
True
,
"required"
,
"all"
)
\
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
浏览文件 @
1089c908
...
@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
...
@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape_
[
0
]
!=
min_shape
[
0
]
and
x_shape_
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
...
@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min
=
quant_min
+
1
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
shape_c
[
channel_axis
_
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
if
x_format
==
"NC1HWC0"
and
channel_axis
_
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
shape_c
=
min_val
.
get
(
"shape"
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
浏览文件 @
1089c908
...
@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
...
@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape_
[
0
]
!=
min_shape
[
0
]
and
x_shape_
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
...
@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min
=
quant_min
+
1
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
shape_c
[
channel_axis
_
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
if
x_format
==
"NC1HWC0"
and
channel_axis
_
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
shape_c
=
min_val
.
get
(
"shape"
)
dout_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
dout_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
...
...
mindspore/ops/_op_impl/_custom_op/minmax_update_perchannel.py
浏览文件 @
1089c908
...
@@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
...
@@ -88,11 +88,15 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape
[
0
]
!=
min_shape
[
0
]
and
x_shape
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
...
@@ -105,7 +109,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
channel_axis
==
0
:
if
channel_axis
_
==
0
:
shape_c
=
min_val
.
get
(
"ori_shape"
)
shape_c
=
min_val
.
get
(
"ori_shape"
)
else
:
else
:
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
...
@@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
...
@@ -113,7 +117,7 @@ def minmax_update_perchannel(x, min_val, max_val, min_up, max_up,
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res_list
=
minmax_update_perchannel_compute
(
input_data
,
min_data
,
max_data
,
res_list
=
minmax_update_perchannel_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
channel_axis
)
ema
,
ema_decay
,
channel_axis
_
)
with
tvm
.
target
.
cce
():
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
1089c908
...
@@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
...
@@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
Args:
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int):
Channel asis for per channel compute
. Default: 1.
channel_axis (int):
Quantization by channel axis. Ascend backend only supports 0 or 1
. Default: 1.
Inputs:
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
...
@@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
...
@@ -123,11 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
ascend_support_x_rank
=
[
2
,
4
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
,
channel_axis
=
1
):
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
is_ascend
=
context
.
get_context
(
'device_target'
)
==
"Ascend"
if
self
.
is_ascend
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perchannel
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perchannel
if
ema
and
not
ema_decay
:
if
ema
and
not
ema_decay
:
raise
ValueError
(
raise
ValueError
(
...
@@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
...
@@ -136,13 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
if
self
.
is_ascend
:
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
else
:
self
.
channel_axis
=
validator
.
check_integer
(
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
if
self
.
is_ascend
and
len
(
x_shape
)
not
in
self
.
ascend_support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
ascend_support_x_rank
}
'"
)
if
not
self
.
is_ascend
:
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
validator
.
check_integer
(
"min shape"
,
len
(
...
@@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
...
@@ -221,8 +228,8 @@ class FakeQuantPerLayer(PrimitiveWithInfer):
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_
value_type
(
self
.
quant_delay
=
validator
.
check_
integer
(
'quant_delay'
,
quant_delay
,
(
int
,)
,
self
.
name
)
'quant_delay'
,
quant_delay
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
outputs
=
[
'out'
])
...
@@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
...
@@ -314,6 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis. Ascend backend only supports 0 or 1. Default: 1.
Inputs:
Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
...
@@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
...
@@ -331,6 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max)
>>> result = fake_quant(input_x, _min, _max)
"""
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
ascend_support_x_rank
=
[
2
,
4
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
...
@@ -343,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training
=
True
,
training
=
True
,
channel_axis
=
1
):
channel_axis
=
1
):
"""init FakeQuantPerChannel OP"""
"""init FakeQuantPerChannel OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
is_ascend
=
context
.
get_context
(
'device_target'
)
==
"Ascend"
if
self
.
is_ascend
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_perchannel
from
mindspore.ops._op_impl._custom_op
import
fake_quant_perchannel
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
raise
ValueError
(
...
@@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
...
@@ -363,14 +373,19 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
self
.
quant_delay
=
validator
.
check_integer
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
'quant_delay'
,
quant_delay
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
if
self
.
is_ascend
:
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
else
:
self
.
channel_axis
=
validator
.
check_integer
(
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
if
self
.
is_ascend
and
len
(
x_shape
)
not
in
self
.
ascend_support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
ascend_support_x_rank
}
'"
)
if
not
self
.
is_ascend
:
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
validator
.
check_integer
(
"min shape"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
"min shape"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
...
...
mindspore/train/callback/_checkpoint.py
浏览文件 @
1089c908
...
@@ -21,7 +21,7 @@ import time
...
@@ -21,7 +21,7 @@ import time
import
mindspore.context
as
context
import
mindspore.context
as
context
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore._checkparam
import
check_bool
,
check_
string
,
check_
int_non_negative
from
mindspore._checkparam
import
check_bool
,
check_int_non_negative
from
mindspore.train._utils
import
_make_directory
from
mindspore.train._utils
import
_make_directory
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_save_graph
from
mindspore.train.serialization
import
_exec_save_checkpoint
,
_save_graph
from
._callback
import
Callback
,
set_cur_net
from
._callback
import
Callback
,
set_cur_net
...
@@ -86,7 +86,6 @@ class CheckpointConfig:
...
@@ -86,7 +86,6 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises:
Raises:
ValueError: If the input_param is None or 0.
ValueError: If the input_param is None or 0.
...
@@ -101,8 +100,7 @@ class CheckpointConfig:
...
@@ -101,8 +100,7 @@ class CheckpointConfig:
save_checkpoint_seconds
=
0
,
save_checkpoint_seconds
=
0
,
keep_checkpoint_max
=
5
,
keep_checkpoint_max
=
5
,
keep_checkpoint_per_n_minutes
=
0
,
keep_checkpoint_per_n_minutes
=
0
,
integrated_save
=
True
,
integrated_save
=
True
):
model_type
=
"normal"
):
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
if
not
save_checkpoint_steps
and
not
save_checkpoint_seconds
and
\
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
not
keep_checkpoint_max
and
not
keep_checkpoint_per_n_minutes
:
...
@@ -116,8 +114,6 @@ class CheckpointConfig:
...
@@ -116,8 +114,6 @@ class CheckpointConfig:
keep_checkpoint_max
=
check_int_non_negative
(
keep_checkpoint_max
)
keep_checkpoint_max
=
check_int_non_negative
(
keep_checkpoint_max
)
if
keep_checkpoint_per_n_minutes
:
if
keep_checkpoint_per_n_minutes
:
keep_checkpoint_per_n_minutes
=
check_int_non_negative
(
keep_checkpoint_per_n_minutes
)
keep_checkpoint_per_n_minutes
=
check_int_non_negative
(
keep_checkpoint_per_n_minutes
)
if
model_type
:
model_type
=
check_string
(
model_type
,
[
"normal"
,
"fusion"
,
"quant"
])
self
.
_save_checkpoint_steps
=
save_checkpoint_steps
self
.
_save_checkpoint_steps
=
save_checkpoint_steps
self
.
_save_checkpoint_seconds
=
save_checkpoint_seconds
self
.
_save_checkpoint_seconds
=
save_checkpoint_seconds
...
@@ -132,7 +128,6 @@ class CheckpointConfig:
...
@@ -132,7 +128,6 @@ class CheckpointConfig:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
if
not
self
.
_keep_checkpoint_per_n_minutes
or
self
.
_keep_checkpoint_per_n_minutes
==
0
:
self
.
_keep_checkpoint_max
=
1
self
.
_keep_checkpoint_max
=
1
self
.
_model_type
=
model_type
self
.
_integrated_save
=
check_bool
(
integrated_save
)
self
.
_integrated_save
=
check_bool
(
integrated_save
)
@
property
@
property
...
@@ -160,18 +155,12 @@ class CheckpointConfig:
...
@@ -160,18 +155,12 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
"""Get the value of _integrated_save."""
return
self
.
_integrated_save
return
self
.
_integrated_save
@
property
def
model_type
(
self
):
"""Get the value of model_type."""
return
self
.
_model_type
def
get_checkpoint_policy
(
self
):
def
get_checkpoint_policy
(
self
):
"""Get the policy of checkpoint."""
"""Get the policy of checkpoint."""
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
checkpoint_policy
=
{
'save_checkpoint_steps'
:
self
.
_save_checkpoint_steps
,
'save_checkpoint_seconds'
:
self
.
_save_checkpoint_seconds
,
'save_checkpoint_seconds'
:
self
.
_save_checkpoint_seconds
,
'keep_checkpoint_max'
:
self
.
_keep_checkpoint_max
,
'keep_checkpoint_max'
:
self
.
_keep_checkpoint_max
,
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
,
'keep_checkpoint_per_n_minutes'
:
self
.
_keep_checkpoint_per_n_minutes
}
'model_type'
:
self
.
_model_type
}
return
checkpoint_policy
return
checkpoint_policy
...
@@ -236,7 +225,7 @@ class ModelCheckpoint(Callback):
...
@@ -236,7 +225,7 @@ class ModelCheckpoint(Callback):
graph_file_name
=
os
.
path
.
join
(
self
.
_directory
,
self
.
_prefix
+
'-graph.meta'
)
graph_file_name
=
os
.
path
.
join
(
self
.
_directory
,
self
.
_prefix
+
'-graph.meta'
)
_save_graph
(
cb_params
.
train_network
,
graph_file_name
)
_save_graph
(
cb_params
.
train_network
,
graph_file_name
)
self
.
_graph_saved
=
True
self
.
_graph_saved
=
True
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
)
self
.
_save_ckpt
(
cb_params
)
def
end
(
self
,
run_context
):
def
end
(
self
,
run_context
):
"""
"""
...
@@ -247,7 +236,7 @@ class ModelCheckpoint(Callback):
...
@@ -247,7 +236,7 @@ class ModelCheckpoint(Callback):
"""
"""
cb_params
=
run_context
.
original_args
()
cb_params
=
run_context
.
original_args
()
_to_save_last_ckpt
=
True
_to_save_last_ckpt
=
True
self
.
_save_ckpt
(
cb_params
,
self
.
_config
.
model_type
,
_to_save_last_ckpt
)
self
.
_save_ckpt
(
cb_params
,
_to_save_last_ckpt
)
from
mindspore.parallel._cell_wrapper
import
destroy_allgather_cell
from
mindspore.parallel._cell_wrapper
import
destroy_allgather_cell
destroy_allgather_cell
()
destroy_allgather_cell
()
...
@@ -266,7 +255,7 @@ class ModelCheckpoint(Callback):
...
@@ -266,7 +255,7 @@ class ModelCheckpoint(Callback):
return
False
return
False
def
_save_ckpt
(
self
,
cb_params
,
model_type
,
force_to_save
=
False
):
def
_save_ckpt
(
self
,
cb_params
,
force_to_save
=
False
):
"""Save checkpoint files."""
"""Save checkpoint files."""
if
cb_params
.
cur_step_num
==
self
.
_last_triggered_step
:
if
cb_params
.
cur_step_num
==
self
.
_last_triggered_step
:
return
return
...
@@ -302,7 +291,7 @@ class ModelCheckpoint(Callback):
...
@@ -302,7 +291,7 @@ class ModelCheckpoint(Callback):
set_cur_net
(
cb_params
.
train_network
)
set_cur_net
(
cb_params
.
train_network
)
cb_params
.
train_network
.
exec_checkpoint_graph
()
cb_params
.
train_network
.
exec_checkpoint_graph
()
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
model_type
,
self
.
_config
.
integrated_save
)
_exec_save_checkpoint
(
cb_params
.
train_network
,
gen_file
,
self
.
_config
.
integrated_save
)
if
os
.
path
.
exists
(
gen_file
):
if
os
.
path
.
exists
(
gen_file
):
shutil
.
move
(
gen_file
,
cur_file
)
shutil
.
move
(
gen_file
,
cur_file
)
...
...
mindspore/train/callback/_loss_monitor.py
浏览文件 @
1089c908
...
@@ -86,7 +86,7 @@ class LossMonitor(Callback):
...
@@ -86,7 +86,7 @@ class LossMonitor(Callback):
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
if
self
.
_per_print_times
!=
0
and
cb_params
.
cur_step_num
%
self
.
_per_print_times
==
0
:
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
print
(
"Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}
/
{:5.4f}], time: [{:5.4f}]"
.
format
(
"loss: [{:5.4f}
], avg los: [
{:5.4f}], time: [{:5.4f}]"
.
format
(
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cb_params
.
cur_epoch_num
,
cb_params
.
epoch_num
,
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
cur_step_in_epoch
,
int
(
cb_params
.
batch_num
),
step_loss
,
np
.
mean
(
self
.
losses
),
step_loss
,
np
.
mean
(
self
.
losses
),
...
...
mindspore/train/quant/quant.py
浏览文件 @
1089c908
...
@@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
...
@@ -42,15 +42,14 @@ _ACTIVATION_MAP = {nn.ReLU: quant.ReLUQuant,
class
_AddFakeQuantInput
(
nn
.
Cell
):
class
_AddFakeQuantInput
(
nn
.
Cell
):
"""
"""
Add FakeQuant
at input and output of the Network. Only support one input and one out
put case.
Add FakeQuant
OP at input of the network. Only support one in
put case.
"""
"""
def
__init__
(
self
,
network
,
quant_delay
=
0
):
def
__init__
(
self
,
network
,
quant_delay
=
0
):
super
(
_AddFakeQuantInput
,
self
).
__init__
(
auto_prefix
=
False
)
super
(
_AddFakeQuantInput
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input.'
)
self
.
network
=
network
self
.
network
=
network
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input'
)
def
construct
(
self
,
data
):
def
construct
(
self
,
data
):
data
=
self
.
fake_quant_input
(
data
)
data
=
self
.
fake_quant_input
(
data
)
...
@@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell):
...
@@ -60,7 +59,7 @@ class _AddFakeQuantInput(nn.Cell):
class
_AddFakeQuantAfterSubCell
(
nn
.
Cell
):
class
_AddFakeQuantAfterSubCell
(
nn
.
Cell
):
"""
"""
Add FakeQuant after of the sub Cell.
Add FakeQuant
OP
after of the sub Cell.
"""
"""
def
__init__
(
self
,
subcell
,
**
kwargs
):
def
__init__
(
self
,
subcell
,
**
kwargs
):
...
@@ -115,11 +114,12 @@ class ConvertToQuantNetwork:
...
@@ -115,11 +114,12 @@ class ConvertToQuantNetwork:
self
.
network
.
update_cell_prefix
()
self
.
network
.
update_cell_prefix
()
network
=
self
.
_convert_subcells2quant
(
self
.
network
)
network
=
self
.
_convert_subcells2quant
(
self
.
network
)
network
=
_AddFakeQuantInput
(
network
)
network
=
_AddFakeQuantInput
(
network
)
self
.
network
.
update_cell_type
(
"quant"
)
return
network
return
network
def
_convert_subcells2quant
(
self
,
network
):
def
_convert_subcells2quant
(
self
,
network
):
"""
"""
conve
t sub cell
to quant cell
conve
rt sub cell like `Conv2dBnAct` and `DenseBnAct`
to quant cell
"""
"""
cells
=
network
.
name_cells
()
cells
=
network
.
name_cells
()
change
=
False
change
=
False
...
@@ -138,13 +138,13 @@ class ConvertToQuantNetwork:
...
@@ -138,13 +138,13 @@ class ConvertToQuantNetwork:
if
isinstance
(
network
,
nn
.
SequentialCell
)
and
change
:
if
isinstance
(
network
,
nn
.
SequentialCell
)
and
change
:
network
.
cell_list
=
list
(
network
.
cells
())
network
.
cell_list
=
list
(
network
.
cells
())
#
tensoradd to tensoradd quan
t
#
add FakeQuant OP after OP in while lis
t
add_list
=
[]
add_list
=
[]
for
name
in
network
.
__dict__
:
for
name
in
network
.
__dict__
:
if
name
[
0
]
==
'_'
:
if
name
[
0
]
==
'_'
:
continue
continue
attr
=
network
.
__dict__
[
name
]
attr
=
network
.
__dict__
[
name
]
if
isinstance
(
attr
,
ops
.
Primitive
)
and
attr
.
name
in
ConvertToQuantNetwork
.
__quant_op_name__
:
if
isinstance
(
attr
,
ops
.
Primitive
)
and
attr
.
name
in
self
.
__quant_op_name__
:
add_list
.
append
((
name
,
attr
))
add_list
.
append
((
name
,
attr
))
for
name
,
prim_op
in
add_list
:
for
name
,
prim_op
in
add_list
:
prefix
=
name
prefix
=
name
...
@@ -164,11 +164,11 @@ class ConvertToQuantNetwork:
...
@@ -164,11 +164,11 @@ class ConvertToQuantNetwork:
def
_convert_conv
(
self
,
subcell
):
def
_convert_conv
(
self
,
subcell
):
"""
"""
conve
t conv
cell to quant cell
conve
rt Conv2d
cell to quant cell
"""
"""
conv_inner
=
subcell
.
conv
conv_inner
=
subcell
.
conv
bn_inner
=
subcell
.
batchnorm
if
subcell
.
has_bn
and
self
.
bn_fold
:
if
subcell
.
batchnorm
is
not
None
and
self
.
bn_fold
:
bn_inner
=
subcell
.
batchnorm
conv_inner
=
quant
.
Conv2dBatchNormQuant
(
conv_inner
.
in_channels
,
conv_inner
=
quant
.
Conv2dBatchNormQuant
(
conv_inner
.
in_channels
,
conv_inner
.
out_channels
,
conv_inner
.
out_channels
,
kernel_size
=
conv_inner
.
kernel_size
,
kernel_size
=
conv_inner
.
kernel_size
,
...
@@ -178,7 +178,7 @@ class ConvertToQuantNetwork:
...
@@ -178,7 +178,7 @@ class ConvertToQuantNetwork:
dilation
=
conv_inner
.
dilation
,
dilation
=
conv_inner
.
dilation
,
group
=
conv_inner
.
group
,
group
=
conv_inner
.
group
,
eps
=
bn_inner
.
eps
,
eps
=
bn_inner
.
eps
,
momentum
=
bn_inner
.
momentum
,
momentum
=
1
-
bn_inner
.
momentum
,
quant_delay
=
self
.
weight_qdelay
,
quant_delay
=
self
.
weight_qdelay
,
freeze_bn
=
self
.
freeze_bn
,
freeze_bn
=
self
.
freeze_bn
,
per_channel
=
self
.
weight_channel
,
per_channel
=
self
.
weight_channel
,
...
@@ -186,6 +186,11 @@ class ConvertToQuantNetwork:
...
@@ -186,6 +186,11 @@ class ConvertToQuantNetwork:
fake
=
True
,
fake
=
True
,
symmetric
=
self
.
weight_symmetric
,
symmetric
=
self
.
weight_symmetric
,
narrow_range
=
self
.
weight_range
)
narrow_range
=
self
.
weight_range
)
# change original network BatchNormal OP parameters to quant network
conv_inner
.
gamma
=
subcell
.
batchnorm
.
gamma
conv_inner
.
beta
=
subcell
.
batchnorm
.
beta
conv_inner
.
moving_mean
=
subcell
.
batchnorm
.
moving_mean
conv_inner
.
moving_variance
=
subcell
.
batchnorm
.
moving_variance
del
subcell
.
batchnorm
del
subcell
.
batchnorm
subcell
.
batchnorm
=
None
subcell
.
batchnorm
=
None
subcell
.
has_bn
=
False
subcell
.
has_bn
=
False
...
@@ -204,6 +209,10 @@ class ConvertToQuantNetwork:
...
@@ -204,6 +209,10 @@ class ConvertToQuantNetwork:
num_bits
=
self
.
weight_bits
,
num_bits
=
self
.
weight_bits
,
symmetric
=
self
.
weight_symmetric
,
symmetric
=
self
.
weight_symmetric
,
narrow_range
=
self
.
weight_range
)
narrow_range
=
self
.
weight_range
)
# change original network Conv2D OP parameters to quant network
conv_inner
.
weight
=
subcell
.
conv
.
weight
if
subcell
.
conv
.
has_bias
:
conv_inner
.
bias
=
subcell
.
conv
.
bias
subcell
.
conv
=
conv_inner
subcell
.
conv
=
conv_inner
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
...
@@ -230,6 +239,10 @@ class ConvertToQuantNetwork:
...
@@ -230,6 +239,10 @@ class ConvertToQuantNetwork:
per_channel
=
self
.
weight_channel
,
per_channel
=
self
.
weight_channel
,
symmetric
=
self
.
weight_symmetric
,
symmetric
=
self
.
weight_symmetric
,
narrow_range
=
self
.
weight_range
)
narrow_range
=
self
.
weight_range
)
# change original network Dense OP parameters to quant network
dense_inner
.
weight
=
subcell
.
dense
.
weight
if
subcell
.
dense
.
has_bias
:
dense_inner
.
bias
=
subcell
.
dense
.
bias
subcell
.
dense
=
dense_inner
subcell
.
dense
=
dense_inner
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
if
subcell
.
has_act
and
subcell
.
activation
is
not
None
:
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
subcell
.
activation
=
self
.
_convert_activation
(
subcell
.
activation
)
...
@@ -247,12 +260,12 @@ class ConvertToQuantNetwork:
...
@@ -247,12 +260,12 @@ class ConvertToQuantNetwork:
act_class
=
activation
.
__class__
act_class
=
activation
.
__class__
if
act_class
not
in
_ACTIVATION_MAP
:
if
act_class
not
in
_ACTIVATION_MAP
:
raise
ValueError
(
raise
ValueError
(
"Unsupported activation in auto
Q
uant: "
,
act_class
)
"Unsupported activation in auto
q
uant: "
,
act_class
)
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
act_qdelay
,
quant_delay
=
self
.
act_qdelay
,
per_channel
=
self
.
act_channel
,
per_channel
=
self
.
act_channel
,
symmetric
=
self
.
weigh
t_symmetric
,
symmetric
=
self
.
ac
t_symmetric
,
narrow_range
=
self
.
weigh
t_range
)
narrow_range
=
self
.
ac
t_range
)
class
ExportQuantNetworkDeploy
:
class
ExportQuantNetworkDeploy
:
...
...
mindspore/train/serialization.py
浏览文件 @
1089c908
...
@@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
...
@@ -40,8 +40,6 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Int32"
:
np
.
int32
,
"Uint32"
:
np
.
uint32
,
"Int64"
:
np
.
int64
,
"Uint64"
:
np
.
uint64
,
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
"Float16"
:
np
.
float16
,
"Float32"
:
np
.
float32
,
"Float64"
:
np
.
float64
,
"Bool"
:
np
.
bool_
}
ModelType
=
[
"normal"
,
"fusion"
,
"quant"
]
def
_special_process_par
(
par
,
new_par
):
def
_special_process_par
(
par
,
new_par
):
"""
"""
...
@@ -103,7 +101,7 @@ def _update_param(param, new_param):
...
@@ -103,7 +101,7 @@ def _update_param(param, new_param):
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
param
.
set_parameter_data
(
type
(
param
.
data
)(
new_param
.
data
))
def
save_checkpoint
(
parameter_list
,
ckpt_file_name
,
model_type
=
"normal"
):
def
save_checkpoint
(
parameter_list
,
ckpt_file_name
):
"""
"""
Saves checkpoint info to a specified file.
Saves checkpoint info to a specified file.
...
@@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
...
@@ -111,14 +109,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
parameter_list (list): Parameters list, each element is a dict
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises:
Raises:
RuntimeError: Failed to save the Checkpoint file.
RuntimeError: Failed to save the Checkpoint file.
"""
"""
logger
.
info
(
"Execute save checkpoint process."
)
logger
.
info
(
"Execute save checkpoint process."
)
checkpoint_list
=
Checkpoint
()
checkpoint_list
=
Checkpoint
()
checkpoint_list
.
model_type
=
model_type
try
:
try
:
for
param
in
parameter_list
:
for
param
in
parameter_list
:
...
@@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
...
@@ -147,13 +143,12 @@ def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
logger
.
info
(
"Save checkpoint process finish."
)
logger
.
info
(
"Save checkpoint process finish."
)
def
load_checkpoint
(
ckpt_file_name
,
model_type
=
"normal"
,
net
=
None
):
def
load_checkpoint
(
ckpt_file_name
,
net
=
None
):
"""
"""
Loads checkpoint info from a specified file.
Loads checkpoint info from a specified file.
Args:
Args:
ckpt_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None
net (Cell): Cell network. Default: None
Returns:
Returns:
...
@@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
...
@@ -165,9 +160,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
if
not
isinstance
(
ckpt_file_name
,
str
):
if
not
isinstance
(
ckpt_file_name
,
str
):
raise
ValueError
(
"The ckpt_file_name must be string."
)
raise
ValueError
(
"The ckpt_file_name must be string."
)
if
model_type
not
in
ModelType
:
raise
ValueError
(
f
"The model_type is not in
{
ModelType
}
."
)
if
not
os
.
path
.
exists
(
ckpt_file_name
)
or
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
if
not
os
.
path
.
exists
(
ckpt_file_name
)
or
ckpt_file_name
[
-
5
:]
!=
".ckpt"
:
raise
ValueError
(
"Please input the correct checkpoint file name."
)
raise
ValueError
(
"Please input the correct checkpoint file name."
)
...
@@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
...
@@ -186,10 +178,6 @@ def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
raise
ValueError
(
e
.
__str__
())
raise
ValueError
(
e
.
__str__
())
parameter_dict
=
{}
parameter_dict
=
{}
if
checkpoint_list
.
model_type
:
if
model_type
!=
checkpoint_list
.
model_type
:
raise
KeyError
(
"Checkpoint file model type({}) is not equal to input model type({})."
.
format
(
checkpoint_list
.
model_type
,
model_type
))
try
:
try
:
for
element
in
checkpoint_list
.
value
:
for
element
in
checkpoint_list
.
value
:
data
=
element
.
tensor
.
tensor_content
data
=
element
.
tensor
.
tensor_content
...
@@ -314,14 +302,13 @@ def _save_graph(network, file_name):
...
@@ -314,14 +302,13 @@ def _save_graph(network, file_name):
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
os
.
chmod
(
file_name
,
stat
.
S_IWUSR
|
stat
.
S_IRUSR
)
def
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
,
model_type
=
"normal"
,
integrated_save
=
True
):
def
_exec_save_checkpoint
(
train_network
,
ckpt_file_name
,
integrated_save
=
True
):
"""
"""
Saves checkpoint for 'ms' backend.
Saves checkpoint for 'ms' backend.
Args:
Args:
train_network (Network): The train network for training.
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
"""
"""
...
@@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in
...
@@ -346,7 +333,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", in
each_param
[
"data"
]
=
param_data
each_param
[
"data"
]
=
param_data
param_list
.
append
(
each_param
)
param_list
.
append
(
each_param
)
save_checkpoint
(
param_list
,
ckpt_file_name
,
model_type
)
save_checkpoint
(
param_list
,
ckpt_file_name
)
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
def
_get_merged_param_data
(
net
,
param_name
,
param_data
):
...
...
model_zoo/lenet_quant/README.md
浏览文件 @
1089c908
...
@@ -33,7 +33,7 @@ Then you will get the following display
...
@@ -33,7 +33,7 @@ Then you will get the following display
```
bash
```
bash
>>>
Found existing installation: mindspore-ascend
>>>
Found existing installation: mindspore-ascend
>>>
Uninstalling mindspore-ascend:
>>>
Uninstalling mindspore-ascend:
>>>
Successfully uninstalled mindspore-ascend.
>>>
Successfully uninstalled mindspore-ascend.
```
```
### Prepare Dataset
### Prepare Dataset
...
@@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
...
@@ -186,7 +186,7 @@ model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
### train quantization aware model
### train quantization aware model
Also, you can just run this command inst
r
ead.
Also, you can just run this command instead.
```
python
```
python
python
train_quant
.
py
--
data_path
MNIST_Data
--
device_target
Ascend
--
ckpt_path
checkpoint_lenet
.
ckpt
python
train_quant
.
py
--
data_path
MNIST_Data
--
device_target
Ascend
--
ckpt_path
checkpoint_lenet
.
ckpt
...
@@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
...
@@ -235,7 +235,7 @@ The top1 accuracy would display on shell.
Here are some optional parameters:
Here are some optional parameters:
```
bash
```
bash
--device_target
{
Ascend,GPU
,CPU
}
--device_target
{
Ascend,GPU
}
device where the code will be implemented
(
default: Ascend
)
device where the code will be implemented
(
default: Ascend
)
--data_path
DATA_PATH
--data_path
DATA_PATH
path where the dataset is saved
path where the dataset is saved
...
...
model_zoo/lenet_quant/eval.py
浏览文件 @
1089c908
...
@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
...
@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
help
=
'path where the dataset is saved'
)
...
...
model_zoo/lenet_quant/eval_quant.py
浏览文件 @
1089c908
...
@@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
...
@@ -32,7 +32,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
help
=
'path where the dataset is saved'
)
...
@@ -61,7 +61,7 @@ if __name__ == "__main__":
...
@@ -61,7 +61,7 @@ if __name__ == "__main__":
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
model
=
Model
(
network
,
net_loss
,
net_opt
,
metrics
=
{
"Accuracy"
:
Accuracy
()})
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
model_type
=
"quant"
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
print
(
"============== Starting Testing =============="
)
print
(
"============== Starting Testing =============="
)
...
...
model_zoo/lenet_quant/train.py
浏览文件 @
1089c908
...
@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
...
@@ -31,7 +31,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
help
=
'path where the dataset is saved'
)
...
@@ -56,8 +56,7 @@ if __name__ == "__main__":
...
@@ -56,8 +56,7 @@ if __name__ == "__main__":
# call back and monitor
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
model_type
=
network
.
type
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
# define model
...
...
model_zoo/lenet_quant/train_quant.py
浏览文件 @
1089c908
...
@@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
...
@@ -33,7 +33,7 @@ from src.lenet_fusion import LeNet5 as LeNet5Fusion
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'MindSpore MNIST Example'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
,
'CPU'
],
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
"./MNIST_Data"
,
help
=
'path where the dataset is saved'
)
help
=
'path where the dataset is saved'
)
...
@@ -50,11 +50,13 @@ if __name__ == "__main__":
...
@@ -50,11 +50,13 @@ if __name__ == "__main__":
# define fusion network
# define fusion network
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
network
=
LeNet5Fusion
(
cfg
.
num_classes
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# load quantization aware network checkpoint
# load quantization aware network checkpoint
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
param_dict
=
load_checkpoint
(
args
.
ckpt_path
,
network
.
type
)
load_param_into_net
(
network
,
param_dict
)
load_param_into_net
(
network
,
param_dict
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
quant_delay
=
0
,
bn_fold
=
False
,
freeze_bn
=
10000
)
# define network loss
# define network loss
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
net_loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
,
reduction
=
"mean"
)
...
@@ -64,8 +66,7 @@ if __name__ == "__main__":
...
@@ -64,8 +66,7 @@ if __name__ == "__main__":
# call back and monitor
# call back and monitor
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
time_cb
=
TimeMonitor
(
data_size
=
ds_train
.
get_dataset_size
())
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
config_ckpt
=
CheckpointConfig
(
save_checkpoint_steps
=
cfg
.
epoch_size
*
step_size
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
model_type
=
"quant"
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
ckpt_callback
=
ModelCheckpoint
(
prefix
=
"checkpoint_lenet"
,
config
=
config_ckpt
)
# define model
# define model
...
...
model_zoo/mobilenetv2/scripts/run_infer.sh
浏览文件 @
1089c908
...
@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
...
@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export
DEVICE_ID
=
0
export
DEVICE_ID
=
0
export
RANK_ID
=
0
export
RANK_ID
=
0
export
RANK_SIZE
=
1
export
RANK_SIZE
=
1
if
[
-d
"eval"
]
;
if
[
-d
"
../
eval"
]
;
then
then
rm
-rf
../eval
rm
-rf
../eval
fi
fi
...
...
model_zoo/mobilenetv2/scripts/run_train.sh
浏览文件 @
1089c908
...
@@ -62,7 +62,7 @@ run_gpu()
...
@@ -62,7 +62,7 @@ run_gpu()
BASEPATH
=
$(
cd
"
`
dirname
$0
`
"
||
exit
;
pwd
)
BASEPATH
=
$(
cd
"
`
dirname
$0
`
"
||
exit
;
pwd
)
export
PYTHONPATH
=
${
BASEPATH
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
BASEPATH
}
:
$PYTHONPATH
if
[
-d
"train"
]
;
if
[
-d
"
../
train"
]
;
then
then
rm
-rf
../train
rm
-rf
../train
fi
fi
...
...
model_zoo/mobilenetv3/scripts/run_infer.sh
浏览文件 @
1089c908
...
@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
...
@@ -40,7 +40,7 @@ export PYTHONPATH=${BASEPATH}:$PYTHONPATH
export
DEVICE_ID
=
0
export
DEVICE_ID
=
0
export
RANK_ID
=
0
export
RANK_ID
=
0
export
RANK_SIZE
=
1
export
RANK_SIZE
=
1
if
[
-d
"eval"
]
;
if
[
-d
"
../
eval"
]
;
then
then
rm
-rf
../eval
rm
-rf
../eval
fi
fi
...
...
model_zoo/mobilenetv3/scripts/run_train.sh
浏览文件 @
1089c908
...
@@ -60,7 +60,7 @@ run_gpu()
...
@@ -60,7 +60,7 @@ run_gpu()
BASEPATH
=
$(
cd
"
`
dirname
$0
`
"
||
exit
;
pwd
)
BASEPATH
=
$(
cd
"
`
dirname
$0
`
"
||
exit
;
pwd
)
export
PYTHONPATH
=
${
BASEPATH
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
BASEPATH
}
:
$PYTHONPATH
if
[
-d
"train"
]
;
if
[
-d
"
../
train"
]
;
then
then
rm
-rf
../train
rm
-rf
../train
fi
fi
...
...
tests/ut/python/train/quant/mobilenetv2_combined.py
浏览文件 @
1089c908
...
@@ -31,7 +31,7 @@ def _conv_bn(in_channel,
...
@@ -31,7 +31,7 @@ def _conv_bn(in_channel,
out_channel
,
out_channel
,
kernel_size
=
ksize
,
kernel_size
=
ksize
,
stride
=
stride
,
stride
=
stride
,
batchnorm
=
True
)])
has_bn
=
True
)])
class
InvertedResidual
(
nn
.
Cell
):
class
InvertedResidual
(
nn
.
Cell
):
...
@@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell):
...
@@ -49,25 +49,25 @@ class InvertedResidual(nn.Cell):
3
,
3
,
stride
,
stride
,
group
=
hidden_dim
,
group
=
hidden_dim
,
batchnorm
=
True
,
has_bn
=
True
,
activation
=
'relu6'
),
activation
=
'relu6'
),
nn
.
Conv2dBnAct
(
hidden_dim
,
oup
,
1
,
1
,
nn
.
Conv2dBnAct
(
hidden_dim
,
oup
,
1
,
1
,
batchnorm
=
True
)
has_bn
=
True
)
])
])
else
:
else
:
self
.
conv
=
nn
.
SequentialCell
([
self
.
conv
=
nn
.
SequentialCell
([
nn
.
Conv2dBnAct
(
inp
,
hidden_dim
,
1
,
1
,
nn
.
Conv2dBnAct
(
inp
,
hidden_dim
,
1
,
1
,
batchnorm
=
True
,
has_bn
=
True
,
activation
=
'relu6'
),
activation
=
'relu6'
),
nn
.
Conv2dBnAct
(
hidden_dim
,
nn
.
Conv2dBnAct
(
hidden_dim
,
hidden_dim
,
hidden_dim
,
3
,
3
,
stride
,
stride
,
group
=
hidden_dim
,
group
=
hidden_dim
,
batchnorm
=
True
,
has_bn
=
True
,
activation
=
'relu6'
),
activation
=
'relu6'
),
nn
.
Conv2dBnAct
(
hidden_dim
,
oup
,
1
,
1
,
nn
.
Conv2dBnAct
(
hidden_dim
,
oup
,
1
,
1
,
batchnorm
=
True
)
has_bn
=
True
)
])
])
self
.
add
=
P
.
TensorAdd
()
self
.
add
=
P
.
TensorAdd
()
...
...
tests/ut/python/train/quant/test_quant.py
浏览文件 @
1089c908
...
@@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
...
@@ -42,7 +42,7 @@ class LeNet5(nn.Cell):
def
__init__
(
self
,
num_class
=
10
):
def
__init__
(
self
,
num_class
=
10
):
super
(
LeNet5
,
self
).
__init__
()
super
(
LeNet5
,
self
).
__init__
()
self
.
num_class
=
num_class
self
.
num_class
=
num_class
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
batchnorm
=
True
,
activation
=
'relu6'
,
pad_mode
=
"valid"
)
self
.
conv1
=
nn
.
Conv2dBnAct
(
1
,
6
,
kernel_size
=
5
,
has_bn
=
True
,
activation
=
'relu6'
,
pad_mode
=
"valid"
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
,
pad_mode
=
"valid"
)
self
.
conv2
=
nn
.
Conv2dBnAct
(
6
,
16
,
kernel_size
=
5
,
activation
=
'relu'
,
pad_mode
=
"valid"
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc1
=
nn
.
DenseBnAct
(
16
*
5
*
5
,
120
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
self
.
fc2
=
nn
.
DenseBnAct
(
120
,
84
,
activation
=
'relu'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录