Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e32ea53d
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e32ea53d
编写于
7月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2864 add mobilenetV2 export
Merge pull request !2864 from chenzhongming/master
上级
b9679975
d383ade6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
147 addition
and
100 deletion
+147
-100
mindspore/ccsrc/pipeline/pipeline.cc
mindspore/ccsrc/pipeline/pipeline.cc
+2
-1
mindspore/nn/layer/activation.py
mindspore/nn/layer/activation.py
+1
-0
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+39
-79
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+2
-0
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+15
-7
mindspore/train/quant/quant_utils.py
mindspore/train/quant/quant_utils.py
+34
-13
model_zoo/mobilenetv2_quant/export.py
model_zoo/mobilenetv2_quant/export.py
+54
-0
未找到文件。
mindspore/ccsrc/pipeline/pipeline.cc
浏览文件 @
e32ea53d
...
...
@@ -289,7 +289,8 @@ std::map<std::string, std::pair<PrimitivePyPtr, std::string>> ExecutorPy::FetchI
MS_LOG
(
DEBUG
)
<<
"FetchInfoForQuantExport func graph("
<<
func_graph
->
ToString
()
<<
") phase("
<<
phase_s
<<
")!"
;
std
::
map
<
std
::
string
,
std
::
pair
<
PrimitivePyPtr
,
std
::
string
>>
fake_quant_table
;
auto
filter
=
[](
AnfNodePtr
node
)
{
return
!
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimConv2D
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimMatMul
));
return
!
(
IsPrimitiveCNode
(
node
,
prim
::
kPrimConv2D
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimMatMul
)
||
IsPrimitiveCNode
(
node
,
prim
::
kPrimDepthwiseConv2dNative
));
};
std
::
vector
<
AnfNodePtr
>
nodes
=
DeepScopedGraphSearchWithFilter
(
func_graph
->
get_return
(),
AlwaysInclude
,
filter
);
auto
is_quant_cnode
=
[](
AnfNodePtr
node
)
{
...
...
mindspore/nn/layer/activation.py
浏览文件 @
e32ea53d
...
...
@@ -530,6 +530,7 @@ _activation = {
'relu6'
:
ReLU6
,
'tanh'
:
Tanh
,
'gelu'
:
GELU
,
'elu'
:
ELU
,
'sigmoid'
:
Sigmoid
,
'prelu'
:
PReLU
,
'leakyrelu'
:
LeakyReLU
,
...
...
mindspore/nn/layer/quant.py
浏览文件 @
e32ea53d
...
...
@@ -17,6 +17,7 @@
from
functools
import
partial
import
numpy
as
np
from
mindspore
import
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
...
...
@@ -41,8 +42,7 @@ __all__ = [
'Conv2dBatchNormQuant'
,
'Conv2dQuant'
,
'DenseQuant'
,
'ReLUQuant'
,
'ReLU6Quant'
,
'ActQuant'
,
'HSwishQuant'
,
'HSigmoidQuant'
,
'TensorAddQuant'
,
...
...
@@ -375,9 +375,10 @@ class FakeQuantWithMinMax(Cell):
def
extend_repr
(
self
):
s
=
'num_bits={}, symmetric={}, narrow_range={}, ema={}({}), per_channel={}({}, {}), '
\
'quant_delay={}, min_init={}, max_init={}'
.
format
(
self
.
num_bits
,
self
.
symmetric
,
self
.
narrow_range
,
self
.
ema
,
self
.
ema_decay
,
self
.
per_channel
,
self
.
channel_axis
,
self
.
num_channels
,
self
.
quant_delay
,
self
.
min_init
,
self
.
max_init
)
'quant_delay={}, min_init={}, max_init={}'
.
format
(
self
.
num_bits
,
self
.
symmetric
,
self
.
narrow_range
,
self
.
ema
,
self
.
ema_decay
,
self
.
per_channel
,
self
.
channel_axis
,
self
.
num_channels
,
self
.
quant_delay
,
self
.
min_init
,
self
.
max_init
)
return
s
def
construct
(
self
,
x
):
...
...
@@ -540,10 +541,12 @@ class Conv2dBatchNormQuant(Cell):
def
extend_repr
(
self
):
s
=
'in_channels={}, out_channels={}, kernel_size={}, stride={}, '
\
'pad_mode={}, padding={}, dilation={}, group={}, '
\
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
fake
,
self
.
freeze_bn
,
self
.
momentum
,
self
.
quant_delay
)
'fake={}, freeze_bn={}, momentum={}, quant_delay={}'
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
fake
,
self
.
freeze_bn
,
self
.
momentum
,
self
.
quant_delay
)
return
s
def
construct
(
self
,
x
):
...
...
@@ -685,10 +688,9 @@ class Conv2dQuant(Cell):
def
extend_repr
(
self
):
s
=
'in_channels={}, out_channels={}, kernel_size={}, stride={}, '
\
'pad_mode={}, padding={}, dilation={}, group={}, '
\
'has_bias={}, quant_delay={}'
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
has_bias
,
self
.
quant_delay
)
'has_bias={}, quant_delay={}'
.
format
(
self
.
in_channels
,
self
.
out_channels
,
self
.
kernel_size
,
self
.
stride
,
self
.
pad_mode
,
self
.
padding
,
self
.
dilation
,
self
.
group
,
self
.
has_bias
,
self
.
quant_delay
)
return
s
...
...
@@ -799,76 +801,23 @@ class DenseQuant(Cell):
class
_QuantActivation
(
Cell
):
r
"""
Base class for
Quant
activation function. Add Fake Quant OP after activation OP.
Base class for
quantization aware training
activation function. Add Fake Quant OP after activation OP.
"""
def
get_origin
(
self
):
raise
NotImplementedError
class
ReLU
Quant
(
_QuantActivation
):
class
Act
Quant
(
_QuantActivation
):
r
"""
ReLUQuant activation function. Add Fake Quant OP after Relu OP
.
Quantization aware training activation function
.
For a more Detailed overview of ReLU op.
Args:
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
quant_delay (int): Quantization delay parameters according by global step. Default: 0.
Inputs:
- **x** (Tensor) - The input of ReLUQuant.
Outputs:
Tensor, with the same type and shape as the `x`.
Examples:
>>> relu_quant = nn.ReLUQuant()
>>> input_x = Tensor(np.array([[1, 2, 0], [-1, -2, 1]]), mindspore.float32)
>>> result = relu_quant(input_x)
"""
def
__init__
(
self
,
ema_decay
=
0.999
,
per_channel
=
False
,
num_bits
=
8
,
symmetric
=
False
,
narrow_range
=
False
,
quant_delay
=
0
):
super
(
ReLUQuant
,
self
).
__init__
()
self
.
fake_quant_act
=
FakeQuantWithMinMax
(
min_init
=
0
,
max_init
=
6
,
ema
=
True
,
ema_decay
=
ema_decay
,
per_channel
=
per_channel
,
num_bits
=
num_bits
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
relu
=
P
.
ReLU
()
def
construct
(
self
,
x
):
x
=
self
.
relu
(
x
)
x
=
self
.
fake_quant_act
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
relu
class
ReLU6Quant
(
_QuantActivation
):
r
"""
ReLU6Quant activation function.
Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op
Add Fake Quant OP after activation. Not Recommand to used these cell for Fake Quant Op
Will climp the max range of the activation and the relu6 do the same operation.
For a more Detailed overview of ReLU6 op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...
...
@@ -883,19 +832,20 @@ class ReLU6Quant(_QuantActivation):
Tensor, with the same type and shape as the `x`.
Examples:
>>>
relu6_quant = nn.ReLU6
Quant(4, 1)
>>>
act_quant = nn.Act
Quant(4, 1)
>>> input_x = Tensor(np.array([[1, 2, -1], [-2, 0, -1]]), mindspore.float32)
>>> result =
relu6
_quant(input_x)
>>> result =
act
_quant(input_x)
"""
def
__init__
(
self
,
activation
,
ema_decay
=
0.999
,
per_channel
=
False
,
num_bits
=
8
,
symmetric
=
False
,
narrow_range
=
False
,
quant_delay
=
0
):
super
(
ReLU6
Quant
,
self
).
__init__
()
super
(
Act
Quant
,
self
).
__init__
()
self
.
fake_quant_act
=
FakeQuantWithMinMax
(
min_init
=
0
,
max_init
=
6
,
ema
=
True
,
...
...
@@ -905,15 +855,15 @@ class ReLU6Quant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
relu6
=
P
.
ReLU6
()
self
.
act
=
activation
def
construct
(
self
,
x
):
x
=
self
.
relu6
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
fake_quant_act
(
x
)
return
x
def
get_origin
(
self
):
return
self
.
relu6
return
self
.
act
class
HSwishQuant
(
_QuantActivation
):
...
...
@@ -923,6 +873,7 @@ class HSwishQuant(_QuantActivation):
For a more Detailed overview of HSwish op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...
...
@@ -943,6 +894,7 @@ class HSwishQuant(_QuantActivation):
"""
def
__init__
(
self
,
activation
,
ema_decay
=
0.999
,
per_channel
=
False
,
num_bits
=
8
,
...
...
@@ -968,7 +920,10 @@ class HSwishQuant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
act
=
P
.
HSwish
()
if
isinstance
(
activation
,
nn
.
HSwish
):
self
.
act
=
activation
else
:
raise
ValueError
(
"Activation should be `nn.HSwish`"
)
def
construct
(
self
,
x
):
x
=
self
.
fake_quant_act_before
(
x
)
...
...
@@ -987,6 +942,7 @@ class HSigmoidQuant(_QuantActivation):
For a more Detailed overview of HSigmoid op.
Args:
activation (Cell): Activation cell class.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
...
...
@@ -1007,6 +963,7 @@ class HSigmoidQuant(_QuantActivation):
"""
def
__init__
(
self
,
activation
,
ema_decay
=
0.999
,
per_channel
=
False
,
num_bits
=
8
,
...
...
@@ -1032,7 +989,10 @@ class HSigmoidQuant(_QuantActivation):
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
quant_delay
=
quant_delay
)
self
.
act
=
P
.
HSigmoid
()
if
isinstance
(
activation
,
nn
.
HSwish
):
self
.
act
=
activation
else
:
raise
ValueError
(
"Activation should be `nn.HSigmoid`"
)
def
construct
(
self
,
x
):
x
=
self
.
fake_quant_act_before
(
x
)
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
e32ea53d
...
...
@@ -1004,6 +1004,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer):
def
infer_dtype
(
self
,
x_dtype
,
w_dtype
):
args
=
{
'x'
:
x_dtype
,
'w'
:
w_dtype
}
validator
.
check_tensor_type_same
(
args
,
mstype
.
number_type
,
self
.
name
)
if
x_dtype
.
element_type
()
==
mstype
.
int8
:
return
mstype
.
tensor_type
(
mstype
.
int32
)
return
x_dtype
...
...
mindspore/train/quant/quant.py
浏览文件 @
e32ea53d
...
...
@@ -33,8 +33,10 @@ from ...ops.operations import _inner_ops as inner
from
...train
import
serialization
from
.
import
quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ActQuant
,
nn
.
ReLU6
:
quant
.
ActQuant
,
nn
.
LeakyReLU
:
quant
.
ActQuant
,
nn
.
Sigmoid
:
quant
.
ActQuant
,
nn
.
HSigmoid
:
quant
.
HSigmoidQuant
,
nn
.
HSwish
:
quant
.
HSwishQuant
}
...
...
@@ -257,9 +259,9 @@ class ConvertToQuantNetwork:
def
_convert_activation
(
self
,
activation
):
act_class
=
activation
.
__class__
if
act_class
not
in
_ACTIVATION_MAP
:
raise
ValueError
(
"Unsupported activation in auto quant: "
,
act_class
)
return
_ACTIVATION_MAP
[
act_class
](
num_bits
=
self
.
act_bits
,
raise
ValueError
(
"Unsupported activation in auto quant: "
,
act_class
)
return
_ACTIVATION_MAP
[
act_class
](
activation
=
act_class
,
num_bits
=
self
.
act_bits
,
quant_delay
=
self
.
act_qdelay
,
per_channel
=
self
.
act_channel
,
symmetric
=
self
.
act_symmetric
,
...
...
@@ -317,7 +319,7 @@ class ExportToQuantInferNetwork:
minq
=
self
.
all_parameters
[
minq_name
]
scale_a_in
,
zp_a_in
=
quant_utils
.
scale_zp_from_data
(
fack_quant_a_in_op
,
maxq
,
minq
,
np_type
)
else
:
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fa
ck
_quant.minq`
{
w_minq_name
}
"
)
logger
.
warning
(
f
"Do not find `fake_quant` from input with `fa
ke
_quant.minq`
{
w_minq_name
}
"
)
return
None
# Build the `Quant` `Dequant` op.
...
...
@@ -325,7 +327,7 @@ class ExportToQuantInferNetwork:
quant_op
=
inner
.
AscendQuant
(
float
(
scale_a_in
),
float
(
zp_a_in
))
sqrt_mode
=
False
scale_deq
=
scale_a_out
*
scale_w
if
scale_deq
<
2
**
-
14
:
if
(
scale_deq
<
2
**
-
14
).
all
()
:
scale_deq
=
np
.
sqrt
(
scale_deq
)
sqrt_mode
=
True
dequant_op
=
inner
.
AscendDequant
(
sqrt_mode
)
...
...
@@ -404,11 +406,17 @@ def export(network, *inputs, file_name, file_format='GEIR'):
file_format (str): MindSpore currently supports 'GEIR' format for exported quantization aware model.
- GEIR: Graph Engine Intermediate Representation. An Intermediate representation format of Ascend model.
"""
supported_device
=
[
"Ascend"
]
supported_formats
=
[
'GEIR'
]
if
context
.
get_context
(
'device_target'
)
not
in
supported_device
:
raise
KeyError
(
"Unsupported {} device target."
.
format
(
context
.
get_context
(
'device_target'
)))
if
file_format
not
in
supported_formats
:
raise
ValueError
(
'Illegal file format {}.'
.
format
(
file_format
))
network
.
set_train
(
False
)
if
file_format
==
'GEIR'
:
exporter
=
ExportToQuantInferNetwork
(
network
,
*
inputs
)
deploy_net
=
exporter
.
run
()
...
...
mindspore/train/quant/quant_utils.py
浏览文件 @
e32ea53d
...
...
@@ -45,7 +45,7 @@ def cal_quantization_params(input_min,
raise
ValueError
(
"input min shape should equal to input max."
)
if
len
(
input_min
.
shape
)
>
1
:
raise
ValueError
(
"input min and max shape should be one dim."
)
if
input_min
>
input_max
:
if
(
input_min
>
input_max
).
all
()
:
raise
ValueError
(
"input_min min should less than input max."
)
if
(
input_max
==
input_min
).
all
():
# scale = 1.0, zp = 0.0
...
...
@@ -85,9 +85,7 @@ def cal_quantization_params(input_min,
return
scale
,
zp
def
weight2int
(
data
,
scale
,
zero_point
):
def
weight2int
(
data
,
scale
,
zero_point
):
r
"""
Calculate int8/uint8 weight from fp32. the formula is defined as:
...
...
@@ -103,12 +101,24 @@ def weight2int(data,
weight (numpy.ndarray): The dimension of channel or 1.
"""
if
scale
.
shape
!=
zero_point
.
shape
:
raise
ValueError
(
"scale and zero_point should have the same shape."
)
if
scale
.
shape
[
0
]
>
0
:
scale
=
scale
.
reshape
(
1
,
-
1
)
zero_point
=
zero_point
.
reshape
(
1
,
-
1
)
raise
ValueError
(
"`scale` and `zero_point` should have the same shape."
)
if
scale
.
shape
[
0
]
<
0
:
raise
ValueError
(
"`scale` and `zero_point` shape should greater than zero."
)
if
scale
.
shape
[
0
]
==
data
.
shape
[
0
]:
# `Conv2d` or `Dense` op weight
shape_list
=
[
-
1
]
+
[
1
]
*
len
(
data
.
shape
[
1
:])
scale
=
scale
.
reshape
(
shape_list
)
zero_point
=
zero_point
.
reshape
(
shape_list
)
elif
scale
.
shape
[
0
]
==
data
.
shape
[
1
]:
# `DepthwiseConv2d` op weight
shape_list
=
[
1
,
-
1
]
+
[
1
]
*
len
(
data
.
shape
[
2
:])
scale
=
scale
.
reshape
(
shape_list
)
zero_point
=
zero_point
.
reshape
(
shape_list
)
else
:
raise
ValueError
(
"Unsupported weight shape({})"
.
format
(
data
.
shape
))
return
np
.
round
((
data
/
scale
)
+
zero_point
)
return
np
.
round
((
data
/
scale
)
+
zero_point
)
def
scale_zp_from_fack_quant_cell
(
cell
,
data_type
):
...
...
@@ -183,9 +193,20 @@ def fold_batchnorm(weight, cell_quant):
beta
=
cell_quant
.
beta
.
data
.
asnumpy
()
epsilon
=
cell_quant
.
eps
sigma
=
np
.
sqrt
(
variance
+
epsilon
)
gamma
=
gamma
.
reshape
(
-
1
,
1
,
1
,
1
)
sigma
=
sigma
.
reshape
(
-
1
,
1
,
1
,
1
)
mean
=
mean
.
reshape
(
-
1
,
1
,
1
,
1
)
weight
=
weight
*
gamma
/
sigma
if
gamma
.
shape
[
0
]
==
weight
.
shape
[
0
]:
# `Conv2d` or `Dense` op weight
shape_list
=
[
-
1
]
+
[
1
]
*
len
(
weight
.
shape
[
1
:])
_gamma
=
gamma
.
reshape
(
shape_list
)
_sigma
=
sigma
.
reshape
(
shape_list
)
elif
gamma
.
shape
[
0
]
==
weight
.
shape
[
1
]:
# `DepthwiseConv2d` op weight
shape_list
=
[
1
,
-
1
]
+
[
1
]
*
len
(
weight
.
shape
[
2
:])
_gamma
=
gamma
.
reshape
(
shape_list
)
_sigma
=
sigma
.
reshape
(
shape_list
)
else
:
raise
ValueError
(
"Unsupported weight shape({})"
.
format
(
weight
.
shape
))
weight
=
weight
*
_gamma
/
_sigma
bias
=
beta
-
gamma
*
mean
/
sigma
return
weight
,
bias
model_zoo/mobilenetv2_quant/export.py
0 → 100644
浏览文件 @
e32ea53d
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Export MobilenetV2 on ImageNet"""
import
argparse
import
numpy
as
np
import
mindspore
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore.train.quant
import
quant
from
src.mobilenetV2
import
mobilenetV2
from
src.config
import
config_ascend
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification'
)
parser
.
add_argument
(
'--checkpoint_path'
,
type
=
str
,
default
=
None
,
help
=
'Checkpoint file path'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
None
,
help
=
'Run device target'
)
args_opt
=
parser
.
parse_args
()
if
__name__
==
'__main__'
:
cfg
=
None
if
args_opt
.
device_target
==
"Ascend"
:
cfg
=
config_ascend
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"Ascend"
,
save_graphs
=
False
)
else
:
raise
ValueError
(
"Unsupported device target: {}."
.
format
(
args_opt
.
device_target
))
# define fusion network
network
=
mobilenetV2
(
num_classes
=
cfg
.
num_classes
)
# convert fusion network to quantization aware network
network
=
quant
.
convert_quant_network
(
network
,
bn_fold
=
True
,
per_channel
=
[
True
,
False
],
symmetric
=
[
True
,
False
])
# load checkpoint
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
load_param_into_net
(
network
,
param_dict
)
# export network
print
(
"============== Starting export =============="
)
inputs
=
Tensor
(
np
.
ones
([
1
,
3
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
quant
.
export
(
network
,
inputs
,
file_name
=
"mobilenet_quant"
,
file_format
=
'GEIR'
)
print
(
"============== End export =============="
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录