Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
61eff67b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
61eff67b
编写于
8月 21, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 21, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4851 modify note
Merge pull request !4851 from caozhou/modified_indent
上级
104e70d3
c438b9d0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
42 addition
and
41 deletion
+42
-41
mindspore/train/serialization.py
mindspore/train/serialization.py
+42
-41
未找到文件。
mindspore/train/serialization.py
浏览文件 @
61eff67b
...
@@ -31,7 +31,8 @@ from mindspore.common.api import _executor
...
@@ -31,7 +31,8 @@ from mindspore.common.api import _executor
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
from
mindspore._checkparam
import
check_input_data
from
mindspore._checkparam
import
check_input_data
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
,
"parse_print"
]
__all__
=
[
"save_checkpoint"
,
"load_checkpoint"
,
"load_param_into_net"
,
"export"
,
"parse_print"
,
"build_searched_strategy"
,
"merge_sliced_parameter"
]
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
tensor_to_ms_type
=
{
"Int8"
:
mstype
.
int8
,
"Uint8"
:
mstype
.
uint8
,
"Int16"
:
mstype
.
int16
,
"Uint16"
:
mstype
.
uint16
,
"Int32"
:
mstype
.
int32
,
"Uint32"
:
mstype
.
uint32
,
"Int64"
:
mstype
.
int64
,
"Uint64"
:
mstype
.
uint64
,
"Int32"
:
mstype
.
int32
,
"Uint32"
:
mstype
.
uint32
,
"Int64"
:
mstype
.
int64
,
"Uint64"
:
mstype
.
uint64
,
...
@@ -575,19 +576,19 @@ def parse_print(print_file_name):
...
@@ -575,19 +576,19 @@ def parse_print(print_file_name):
def
_merge_param_with_strategy
(
sliced_data
,
parameter_name
,
strategy
,
is_even
):
def
_merge_param_with_strategy
(
sliced_data
,
parameter_name
,
strategy
,
is_even
):
"""
"""
Merge data slices to one tensor with whole data when strategy is not None.
Merge data slices to one tensor with whole data when strategy is not None.
Args:
Args:
sliced_data (list[numpy.ndarray]): d
ata slices in order of rank_id.
sliced_data (list[numpy.ndarray]): D
ata slices in order of rank_id.
parameter_name (str): n
ame of parameter.
parameter_name (str): N
ame of parameter.
strategy (dict): p
arameter slice strategy.
strategy (dict): P
arameter slice strategy.
is_even (bool): s
lice manner that True represents slicing evenly and False represents slicing unevenly.
is_even (bool): S
lice manner that True represents slicing evenly and False represents slicing unevenly.
Returns:
Returns:
Tensor, the merged Tensor which has the whole data.
Tensor, the merged Tensor which has the whole data.
Raises:
Raises:
ValueError: f
ailed to merge.
ValueError: F
ailed to merge.
"""
"""
layout
=
strategy
.
get
(
parameter_name
)
layout
=
strategy
.
get
(
parameter_name
)
try
:
try
:
...
@@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
...
@@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
def
build_searched_strategy
(
strategy_filename
):
def
build_searched_strategy
(
strategy_filename
):
"""
"""
b
uild strategy of every parameter in network.
B
uild strategy of every parameter in network.
Args:
Args:
strategy_filename (str):
n
ame of strategy file.
strategy_filename (str):
N
ame of strategy file.
Returns:
Returns:
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Raises:
Raises:
ValueError:
s
trategy file is incorrect.
ValueError:
S
trategy file is incorrect.
TypeError:
s
trategy_filename is not str.
TypeError:
S
trategy_filename is not str.
Examples:
Examples:
>>> strategy_filename = "./strategy_train.ckpt"
>>> strategy_filename = "./strategy_train.ckpt"
...
@@ -707,32 +708,32 @@ def build_searched_strategy(strategy_filename):
...
@@ -707,32 +708,32 @@ def build_searched_strategy(strategy_filename):
def
merge_sliced_parameter
(
sliced_parameters
,
strategy
=
None
):
def
merge_sliced_parameter
(
sliced_parameters
,
strategy
=
None
):
"""
"""
Merge parameter slices to one whole parameter.
Merge parameter slices to one whole parameter.
Args:
Args:
sliced_parameters (list[Parameter]): p
arameter slices in order of rank_id.
sliced_parameters (list[Parameter]): P
arameter slices in order of rank_id.
strategy (dict): p
arameter slice strategy. Default: None.
strategy (dict): P
arameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
If strategy is None, just merge parameter slices in 0 axis order.
- key (str): p
arameter name.
- key (str): P
arameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): s
lice strategy of this parameter.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): S
lice strategy of this parameter.
Returns:
Returns:
Parameter, the merged parameter which has the whole data.
Parameter, the merged parameter which has the whole data.
Raises:
Raises:
ValueError: f
ailed to merge.
ValueError: F
ailed to merge.
TypeError: t
he sliced_parameters is incorrect or strategy is not dict.
TypeError: T
he sliced_parameters is incorrect or strategy is not dict.
KeyError: t
he parameter name is not in keys of strategy.
KeyError: T
he parameter name is not in keys of strategy.
Examples:
Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [
\
>>> sliced_parameters = [
\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"),
\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"),
\
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"),
\
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"),
\
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"),
\
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"),
\
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
"""
"""
if
not
isinstance
(
sliced_parameters
,
list
):
if
not
isinstance
(
sliced_parameters
,
list
):
raise
TypeError
(
f
"The sliced_parameters should be list, but got
{
type
(
sliced_parameters
)
}
."
)
raise
TypeError
(
f
"The sliced_parameters should be list, but got
{
type
(
sliced_parameters
)
}
."
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录