Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
5b3f209e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5b3f209e
编写于
8月 12, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 12, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4269 change export from geir to air
Merge pull request !4269 from fary86/change_export_interface
上级
6e3c87be
73325e0f
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
44 addition
and
35 deletion
+44
-35
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+6
-5
mindspore/common/api.py
mindspore/common/api.py
+1
-1
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+6
-6
mindspore/train/serialization.py
mindspore/train/serialization.py
+12
-9
model_zoo/official/cv/googlenet/README.md
model_zoo/official/cv/googlenet/README.md
+3
-3
model_zoo/official/cv/googlenet/export.py
model_zoo/official/cv/googlenet/export.py
+2
-2
model_zoo/official/cv/googlenet/src/config.py
model_zoo/official/cv/googlenet/src/config.py
+1
-1
model_zoo/official/cv/inceptionv3/export.py
model_zoo/official/cv/inceptionv3/export.py
+2
-2
model_zoo/official/cv/lenet_quant/export.py
model_zoo/official/cv/lenet_quant/export.py
+2
-2
model_zoo/official/cv/mobilenetv2_quant/export.py
model_zoo/official/cv/mobilenetv2_quant/export.py
+1
-1
tests/st/tbe_networks/export_geir.py
tests/st/tbe_networks/export_geir.py
+1
-1
tests/ut/python/utils/test_serialize.py
tests/ut/python/utils/test_serialize.py
+7
-2
未找到文件。
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
5b3f209e
...
@@ -396,13 +396,13 @@ void ExecutorPy::GetGeBackendPolicy() const {
...
@@ -396,13 +396,13 @@ void ExecutorPy::GetGeBackendPolicy() const {
}
}
}
}
bool
IsPhaseExport
Ge
ir
(
const
std
::
string
&
phase_s
)
{
bool
IsPhaseExport
A
ir
(
const
std
::
string
&
phase_s
)
{
auto
phase_to_export
=
"export.
ge
ir"
;
auto
phase_to_export
=
"export.
a
ir"
;
return
phase_s
.
rfind
(
phase_to_export
)
!=
std
::
string
::
npos
;
return
phase_s
.
rfind
(
phase_to_export
)
!=
std
::
string
::
npos
;
}
}
std
::
vector
<
ActionItem
>
GetPipline
(
const
ResourcePtr
&
resource
,
const
std
::
string
&
phase_s
,
bool
use_vm
)
{
std
::
vector
<
ActionItem
>
GetPipline
(
const
ResourcePtr
&
resource
,
const
std
::
string
&
phase_s
,
bool
use_vm
)
{
bool
is_
geir
=
IsPhaseExportGe
ir
(
phase_s
);
bool
is_
air
=
IsPhaseExportA
ir
(
phase_s
);
std
::
string
backend
=
MsContext
::
GetInstance
()
->
backend_policy
();
std
::
string
backend
=
MsContext
::
GetInstance
()
->
backend_policy
();
...
@@ -419,7 +419,7 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin
...
@@ -419,7 +419,7 @@ std::vector<ActionItem> GetPipline(const ResourcePtr &resource, const std::strin
}
}
#endif
#endif
if
(
use_vm
&&
backend
!=
"ge"
&&
!
is_
ge
ir
)
{
if
(
use_vm
&&
backend
!=
"ge"
&&
!
is_
a
ir
)
{
// Create backend and session
// Create backend and session
auto
backend_ptr
=
compile
::
CreateBackend
();
auto
backend_ptr
=
compile
::
CreateBackend
();
// Connect session to debugger
// Connect session to debugger
...
@@ -938,8 +938,9 @@ void FinalizeHccl() {
...
@@ -938,8 +938,9 @@ void FinalizeHccl() {
void
ExportGraph
(
const
std
::
string
&
file_name
,
const
std
::
string
&
,
const
std
::
string
&
phase
)
{
void
ExportGraph
(
const
std
::
string
&
file_name
,
const
std
::
string
&
,
const
std
::
string
&
phase
)
{
#if (ENABLE_GE || ENABLE_D)
#if (ENABLE_GE || ENABLE_D)
ExportDFGraph
(
file_name
,
phase
);
ExportDFGraph
(
file_name
,
phase
);
#else
MS_EXCEPTION
(
ValueError
)
<<
"Only MindSpore with Ascend backend support exporting file in 'AIR' format."
;
#endif
#endif
MS_LOG
(
WARNING
)
<<
"In ut test no export_graph"
;
}
}
void
ReleaseGeTsd
()
{
void
ReleaseGeTsd
()
{
...
...
mindspore/common/api.py
浏览文件 @
5b3f209e
...
@@ -515,7 +515,7 @@ class _Executor:
...
@@ -515,7 +515,7 @@ class _Executor:
graph_id (str): id of graph to be exported
graph_id (str): id of graph to be exported
"""
"""
from
.._c_expression
import
export_graph
from
.._c_expression
import
export_graph
export_graph
(
file_name
,
'
GE
IR'
,
graph_id
)
export_graph
(
file_name
,
'
A
IR'
,
graph_id
)
def
fetch_info_for_quant_export
(
self
,
exec_id
):
def
fetch_info_for_quant_export
(
self
,
exec_id
):
"""Get graph proto from pipeline."""
"""Get graph proto from pipeline."""
...
...
mindspore/train/quant/quant.py
浏览文件 @
5b3f209e
...
@@ -435,9 +435,9 @@ class ExportToQuantInferNetwork:
...
@@ -435,9 +435,9 @@ class ExportToQuantInferNetwork:
return
network
return
network
def
export
(
network
,
*
inputs
,
file_name
,
mean
=
127.5
,
std_dev
=
127.5
,
file_format
=
'
GE
IR'
):
def
export
(
network
,
*
inputs
,
file_name
,
mean
=
127.5
,
std_dev
=
127.5
,
file_format
=
'
A
IR'
):
"""
"""
Exports MindSpore quantization predict model to deploy with
GE
IR.
Exports MindSpore quantization predict model to deploy with
A
IR.
Args:
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
network (Cell): MindSpore network produced by `convert_quant_network`.
...
@@ -445,17 +445,17 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
...
@@ -445,17 +445,17 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
file_name (str): File name of model to export.
file_name (str): File name of model to export.
mean (int): Input data mean. Default: 127.5.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
file_format (str): MindSpore currently supports '
GE
IR', 'ONNX' and 'MINDIR' format for exported
file_format (str): MindSpore currently supports '
A
IR', 'ONNX' and 'MINDIR' format for exported
quantization aware model. Default: '
GE
IR'.
quantization aware model. Default: '
A
IR'.
-
GE
IR: Graph Engine Intermidiate Representation. An intermidiate representation format of
-
A
IR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
Ascend model.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
for MindSpore models.
Recommended suffix for output file is '.mindir'.
Recommended suffix for output file is '.mindir'.
"""
"""
supported_device
=
[
"Ascend"
,
"GPU"
]
supported_device
=
[
"Ascend"
,
"GPU"
]
supported_formats
=
[
'
GE
IR'
,
'MINDIR'
]
supported_formats
=
[
'
A
IR'
,
'MINDIR'
]
mean
=
validator
.
check_type
(
"mean"
,
mean
,
(
int
,
float
))
mean
=
validator
.
check_type
(
"mean"
,
mean
,
(
int
,
float
))
std_dev
=
validator
.
check_type
(
"std_dev"
,
std_dev
,
(
int
,
float
))
std_dev
=
validator
.
check_type
(
"std_dev"
,
std_dev
,
(
int
,
float
))
...
...
mindspore/train/serialization.py
浏览文件 @
5b3f209e
...
@@ -445,7 +445,7 @@ def _fill_param_into_net(net, parameter_list):
...
@@ -445,7 +445,7 @@ def _fill_param_into_net(net, parameter_list):
load_param_into_net
(
net
,
parameter_dict
)
load_param_into_net
(
net
,
parameter_dict
)
def
export
(
net
,
*
inputs
,
file_name
,
file_format
=
'
GE
IR'
):
def
export
(
net
,
*
inputs
,
file_name
,
file_format
=
'
A
IR'
):
"""
"""
Exports MindSpore predict model to file in specified format.
Exports MindSpore predict model to file in specified format.
...
@@ -453,11 +453,12 @@ def export(net, *inputs, file_name, file_format='GEIR'):
...
@@ -453,11 +453,12 @@ def export(net, *inputs, file_name, file_format='GEIR'):
net (Cell): MindSpore network.
net (Cell): MindSpore network.
inputs (Tensor): Inputs of the `net`.
inputs (Tensor): Inputs of the `net`.
file_name (str): File name of model to export.
file_name (str): File name of model to export.
file_format (str): MindSpore currently supports '
GE
IR', 'ONNX' and 'MINDIR' format for exported model.
file_format (str): MindSpore currently supports '
A
IR', 'ONNX' and 'MINDIR' format for exported model.
-
GEIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
-
AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
Ascend model
.
Recommended suffix for output file is '.air'
.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
Recommended suffix for output file is '.onnx'.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
for MindSpore models.
Recommended suffix for output file is '.mindir'.
Recommended suffix for output file is '.mindir'.
...
@@ -465,7 +466,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
...
@@ -465,7 +466,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
logger
.
info
(
"exporting model file:%s format:%s."
,
file_name
,
file_format
)
logger
.
info
(
"exporting model file:%s format:%s."
,
file_name
,
file_format
)
check_input_data
(
*
inputs
,
data_class
=
Tensor
)
check_input_data
(
*
inputs
,
data_class
=
Tensor
)
supported_formats
=
[
'GEIR'
,
'ONNX'
,
'MINDIR'
]
if
file_format
==
'GEIR'
:
logger
.
warning
(
f
"Format 'GEIR' is deprecated, it would be removed in future release, use 'AIR' instead."
)
file_format
=
'AIR'
supported_formats
=
[
'AIR'
,
'ONNX'
,
'MINDIR'
]
if
file_format
not
in
supported_formats
:
if
file_format
not
in
supported_formats
:
raise
ValueError
(
f
'Illegal file format
{
file_format
}
, it must be one of
{
supported_formats
}
'
)
raise
ValueError
(
f
'Illegal file format
{
file_format
}
, it must be one of
{
supported_formats
}
'
)
# switch network mode to infer when it is training
# switch network mode to infer when it is training
...
@@ -474,13 +479,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
...
@@ -474,13 +479,11 @@ def export(net, *inputs, file_name, file_format='GEIR'):
net
.
set_train
(
mode
=
False
)
net
.
set_train
(
mode
=
False
)
# export model
# export model
net
.
init_parameters_data
()
net
.
init_parameters_data
()
if
file_format
==
'
GE
IR'
:
if
file_format
==
'
A
IR'
:
phase_name
=
'export.
ge
ir'
phase_name
=
'export.
a
ir'
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
)
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
)
_executor
.
export
(
file_name
,
graph_id
)
_executor
.
export
(
file_name
,
graph_id
)
elif
file_format
==
'ONNX'
:
# file_format is 'ONNX'
elif
file_format
==
'ONNX'
:
# file_format is 'ONNX'
# NOTICE: the pahse name `export_onnx` is used for judging whether is exporting onnx in the compile pipeline,
# do not change it to other values.
phase_name
=
'export.onnx'
phase_name
=
'export.onnx'
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
graph_id
,
_
=
_executor
.
compile
(
net
,
*
inputs
,
phase
=
phase_name
,
do_convert
=
False
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
)
onnx_stream
=
_executor
.
_get_func_graph_proto
(
graph_id
)
...
...
model_zoo/official/cv/googlenet/README.md
浏览文件 @
5b3f209e
...
@@ -108,7 +108,7 @@ python eval.py > eval.log 2>&1 & OR sh run_eval.sh
...
@@ -108,7 +108,7 @@ python eval.py > eval.log 2>&1 & OR sh run_eval.sh
│ ├──config.py // parameter configuration
│ ├──config.py // parameter configuration
├── train.py // training script
├── train.py // training script
├── eval.py // evaluation script
├── eval.py // evaluation script
├── export.py // export checkpoint files into
ge
ir/onnx
├── export.py // export checkpoint files into
a
ir/onnx
```
```
## [Script Parameters](#contents)
## [Script Parameters](#contents)
...
@@ -133,7 +133,7 @@ Major parameters in train.py and config.py are:
...
@@ -133,7 +133,7 @@ Major parameters in train.py and config.py are:
--
checkpoint_path
:
The
absolute
full
path
to
the
checkpoint
file
saved
--
checkpoint_path
:
The
absolute
full
path
to
the
checkpoint
file
saved
after
training
.
after
training
.
--
onnx_filename
:
File
name
of
the
onnx
model
used
in
export
.
py
.
--
onnx_filename
:
File
name
of
the
onnx
model
used
in
export
.
py
.
--
geir_filename
:
File
name
of
the
ge
ir
model
used
in
export
.
py
.
--
air_filename
:
File
name
of
the
a
ir
model
used
in
export
.
py
.
```
```
...
@@ -226,7 +226,7 @@ accuracy: {'acc': 0.9217}
...
@@ -226,7 +226,7 @@ accuracy: {'acc': 0.9217}
| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins |
| Total time | 1pc: 63.85 mins; 8pcs: 11.28 mins |
| Parameters (M) | 13.0 |
| Parameters (M) | 13.0 |
| Checkpoint for Fine tuning | 43.07M (.ckpt file) |
| Checkpoint for Fine tuning | 43.07M (.ckpt file) |
| Model for inference | 21.50M (.onnx file), 21.60M(.
ge
ir file) |
| Model for inference | 21.50M (.onnx file), 21.60M(.
a
ir file) |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet |
...
...
model_zoo/official/cv/googlenet/export.py
浏览文件 @
5b3f209e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
"""
##############export checkpoint file into
ge
ir and onnx models#################
##############export checkpoint file into
a
ir and onnx models#################
python export.py
python export.py
"""
"""
import
numpy
as
np
import
numpy
as
np
...
@@ -33,4 +33,4 @@ if __name__ == '__main__':
...
@@ -33,4 +33,4 @@ if __name__ == '__main__':
input_arr
=
Tensor
(
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
1
,
3
,
224
,
224
]),
ms
.
float32
)
input_arr
=
Tensor
(
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
1
,
3
,
224
,
224
]),
ms
.
float32
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
onnx_filename
,
file_format
=
"ONNX"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
onnx_filename
,
file_format
=
"ONNX"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
geir_filename
,
file_format
=
"GE
IR"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
air_filename
,
file_format
=
"A
IR"
)
model_zoo/official/cv/googlenet/src/config.py
浏览文件 @
5b3f209e
...
@@ -34,5 +34,5 @@ cifar_cfg = edict({
...
@@ -34,5 +34,5 @@ cifar_cfg = edict({
'keep_checkpoint_max'
:
10
,
'keep_checkpoint_max'
:
10
,
'checkpoint_path'
:
'./train_googlenet_cifar10-125_390.ckpt'
,
'checkpoint_path'
:
'./train_googlenet_cifar10-125_390.ckpt'
,
'onnx_filename'
:
'googlenet.onnx'
,
'onnx_filename'
:
'googlenet.onnx'
,
'
geir_filename'
:
'googlenet.ge
ir'
'
air_filename'
:
'googlenet.a
ir'
})
})
model_zoo/official/cv/inceptionv3/export.py
浏览文件 @
5b3f209e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
"""
##############export checkpoint file into
ge
ir and onnx models#################
##############export checkpoint file into
a
ir and onnx models#################
"""
"""
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
...
@@ -37,4 +37,4 @@ if __name__ == '__main__':
...
@@ -37,4 +37,4 @@ if __name__ == '__main__':
input_arr
=
Tensor
(
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
1
,
3
,
299
,
299
]),
ms
.
float32
)
input_arr
=
Tensor
(
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
1
,
3
,
299
,
299
]),
ms
.
float32
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
onnx_filename
,
file_format
=
"ONNX"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
onnx_filename
,
file_format
=
"ONNX"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
geir_filename
,
file_format
=
"GE
IR"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
air_filename
,
file_format
=
"A
IR"
)
model_zoo/official/cv/lenet_quant/export.py
浏览文件 @
5b3f209e
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
"""
export quantization aware training network to infer `
GE
IR` backend.
export quantization aware training network to infer `
A
IR` backend.
"""
"""
import
argparse
import
argparse
...
@@ -53,4 +53,4 @@ if __name__ == "__main__":
...
@@ -53,4 +53,4 @@ if __name__ == "__main__":
# export network
# export network
inputs
=
Tensor
(
np
.
ones
([
1
,
1
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
inputs
=
Tensor
(
np
.
ones
([
1
,
1
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
quant
.
export
(
network
,
inputs
,
file_name
=
"lenet_quant"
,
file_format
=
'
GE
IR'
)
quant
.
export
(
network
,
inputs
,
file_name
=
"lenet_quant"
,
file_format
=
'
A
IR'
)
model_zoo/official/cv/mobilenetv2_quant/export.py
浏览文件 @
5b3f209e
...
@@ -50,5 +50,5 @@ if __name__ == '__main__':
...
@@ -50,5 +50,5 @@ if __name__ == '__main__':
# export network
# export network
print
(
"============== Starting export =============="
)
print
(
"============== Starting export =============="
)
inputs
=
Tensor
(
np
.
ones
([
1
,
3
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
inputs
=
Tensor
(
np
.
ones
([
1
,
3
,
cfg
.
image_height
,
cfg
.
image_width
]),
mindspore
.
float32
)
quant
.
export
(
network
,
inputs
,
file_name
=
"mobilenet_quant"
,
file_format
=
'
GE
IR'
)
quant
.
export
(
network
,
inputs
,
file_name
=
"mobilenet_quant"
,
file_format
=
'
A
IR'
)
print
(
"============== End export =============="
)
print
(
"============== End export =============="
)
tests/st/tbe_networks/export_geir.py
浏览文件 @
5b3f209e
...
@@ -24,4 +24,4 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
...
@@ -24,4 +24,4 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def
test_resnet50_export
(
batch_size
=
1
,
num_classes
=
5
):
def
test_resnet50_export
(
batch_size
=
1
,
num_classes
=
5
):
input_np
=
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
batch_size
,
3
,
224
,
224
]).
astype
(
np
.
float32
)
input_np
=
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
batch_size
,
3
,
224
,
224
]).
astype
(
np
.
float32
)
net
=
resnet50
(
batch_size
,
num_classes
)
net
=
resnet50
(
batch_size
,
num_classes
)
export
(
net
,
Tensor
(
input_np
),
file_name
=
"./me_resnet50.pb"
,
file_format
=
"
GE
IR"
)
export
(
net
,
Tensor
(
input_np
),
file_name
=
"./me_resnet50.pb"
,
file_format
=
"
A
IR"
)
tests/ut/python/utils/test_serialize.py
浏览文件 @
5b3f209e
...
@@ -87,8 +87,12 @@ def test_save_graph():
...
@@ -87,8 +87,12 @@ def test_save_graph():
x
=
Tensor
(
np
.
random
.
rand
(
2
,
1
,
2
,
3
).
astype
(
np
.
float32
))
x
=
Tensor
(
np
.
random
.
rand
(
2
,
1
,
2
,
3
).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
array
([
1.2
]).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
array
([
1.2
]).
astype
(
np
.
float32
))
out_put
=
net
(
x
,
y
)
out_put
=
net
(
x
,
y
)
_save_graph
(
network
=
net
,
file_name
=
"net-graph.meta"
)
output_file
=
"net-graph.meta"
_save_graph
(
network
=
net
,
file_name
=
output_file
)
out_me_list
.
append
(
out_put
)
out_me_list
.
append
(
out_put
)
assert
os
.
path
.
exists
(
output_file
)
os
.
chmod
(
output_file
,
stat
.
S_IWRITE
)
os
.
remove
(
output_file
)
def
test_save_checkpoint
():
def
test_save_checkpoint
():
...
@@ -318,7 +322,8 @@ class MYNET(nn.Cell):
...
@@ -318,7 +322,8 @@ class MYNET(nn.Cell):
def
test_export
():
def
test_export
():
net
=
MYNET
()
net
=
MYNET
()
input_data
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]).
astype
(
np
.
float32
))
input_data
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
224
,
224
]).
astype
(
np
.
float32
))
export
(
net
,
input_data
,
file_name
=
"./me_export.pb"
,
file_format
=
"GEIR"
)
with
pytest
.
raises
(
ValueError
):
export
(
net
,
input_data
,
file_name
=
"./me_export.pb"
,
file_format
=
"AIR"
)
@
non_graph_engine
@
non_graph_engine
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录