Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
mindspore
提交
3c48de82
M
mindspore
项目概览
wmsofts
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3c48de82
编写于
6月 24, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2573 fix print file bug
Merge pull request !2573 from jinyaohui/print
上级
dd75ebfa
e893c701
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
11 addition
and
2 deletion
+11
-2
mindspore/context.py
mindspore/context.py
+3
-0
mindspore/train/serialization.py
mindspore/train/serialization.py
+8
-2
未找到文件。
mindspore/context.py
浏览文件 @
3c48de82
...
...
@@ -564,6 +564,8 @@ def set_context(**kwargs):
check_bprop (bool): Whether to check bprop. Default: False.
max_device_memory (str): Sets the maximum memory available for device, currently only supported on GPU.
The format is "xxGB". Default: "1024GB".
print_file_path (str): The path of print data to save. If this parameter is set, print data is saved to
a file by default,and turn off printing to the screen.
Raises:
ValueError: If input key is not an attribute in context.
...
...
@@ -584,6 +586,7 @@ def set_context(**kwargs):
>>> save_graphs_path="/mindspore")
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb")
"""
for
key
,
value
in
kwargs
.
items
():
if
not
hasattr
(
_context
(),
key
):
...
...
mindspore/train/serialization.py
浏览文件 @
3c48de82
...
...
@@ -29,8 +29,7 @@ from mindspore.common.api import _executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore._checkparam
import
check_input_data
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
]
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
,
"parse_print"
]
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
"Int32"
:
mstype
.
int32
,
"Uint32"
:
mstype
.
uint32
,
"Int64"
:
mstype
.
int64
,
"Uint64"
:
mstype
.
uint64
,
...
...
@@ -513,6 +512,13 @@ def parse_print(print_file_name):
tensor_list
.
append
(
Tensor
(
param_value
,
ms_type
))
# Scale type
else
:
data_type_
=
data_type
.
lower
()
if
'float'
in
data_type_
:
param_data
=
float
(
param_data
[
0
])
elif
'int'
in
data_type_
:
param_data
=
int
(
param_data
[
0
])
elif
'bool'
in
data_type_
:
param_data
=
bool
(
param_data
[
0
])
tensor_list
.
append
(
Tensor
(
param_data
,
ms_type
))
except
BaseException
as
e
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录