Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
d92c220c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d92c220c
编写于
9月 01, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
9月 01, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5590 Fixbugfix for server shard range computation
Merge pull request !5590 from ZPaC/master-fix-server-shard-method
上级
e17eea34
be63f8b8
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
33 addition
and
10 deletion
+33
-10
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
+6
-4
mindspore/ccsrc/frontend/parallel/ps/util.cc
mindspore/ccsrc/frontend/parallel/ps/util.cc
+26
-6
mindspore/ccsrc/frontend/parallel/ps/util.h
mindspore/ccsrc/frontend/parallel/ps/util.h
+1
-0
未找到文件。
mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc
浏览文件 @
d92c220c
...
...
@@ -149,10 +149,12 @@ void SparseOptimInfo::ComputeMean(const std::shared_ptr<std::vector<std::shared_
size_t
original_row_count
=
input_shapes
->
front
();
if
(
original_row_count
>
0
)
{
size_t
offset
=
0
;
if
((
original_row_count
%
server_num
)
==
0
)
{
offset
=
original_row_count
/
server_num
*
rank_id
;
}
else
{
offset
=
std
::
round
((
static_cast
<
float
>
(
original_row_count
))
/
server_num
)
*
rank_id
;
std
::
map
<
int
,
int
>
rank_dims
=
Util
::
AllRankLocalShard
(
original_row_count
,
rank_id
,
server_num
);
for
(
size_t
i
=
0
;
i
<
rank_id
;
i
++
)
{
if
(
rank_dims
.
count
(
i
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"No local shard number for rank "
<<
i
;
}
offset
+=
rank_dims
[
i
];
}
for
(
size_t
i
=
0
;
i
<
indices_size
;
i
++
)
{
indices_data
[
i
]
-=
offset
;
...
...
mindspore/ccsrc/frontend/parallel/ps/util.cc
浏览文件 @
d92c220c
...
...
@@ -134,13 +134,33 @@ std::string Util::optimizer_node_name(int id) {
bool
Util
::
is_optimizer
(
std
::
string
name
)
{
return
optimizer_to_ids
.
count
(
name
)
>
0
;
}
int
Util
::
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
)
{
int
shard_size
=
std
::
round
((
static_cast
<
float
>
(
first_dim
))
/
server_num
);
int
remain_size
=
first_dim
%
server_num
;
if
(
remain_size
==
0
||
rank_id
<
server_num
-
1
)
{
return
shard_size
;
}
else
{
return
first_dim
-
(
shard_size
*
(
server_num
-
1
));
std
::
map
<
int
,
int
>
shard_dims
=
AllRankLocalShard
(
first_dim
,
rank_id
,
server_num
);
if
(
shard_dims
.
count
(
rank_id
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid rank id "
<<
rank_id
;
}
return
shard_dims
[
rank_id
];
}
std
::
map
<
int
,
int
>
Util
::
AllRankLocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
)
{
if
(
rank_id
>=
server_num
)
{
MS_LOG
(
EXCEPTION
)
<<
"The rank ID "
<<
rank_id
<<
" should be less than the number of servers "
<<
server_num
;
}
std
::
map
<
int
,
int
>
shard_dims
;
for
(
int
i
=
0
;
i
<
server_num
;
i
++
)
{
shard_dims
[
i
]
=
0
;
}
if
(
server_num
!=
static_cast
<
int
>
(
shard_dims
.
size
()))
{
MS_LOG
(
EXCEPTION
)
<<
"Inconsistent server num "
<<
server_num
<<
" shard dims counter size "
<<
shard_dims
.
size
();
}
int
server_index
=
-
1
;
for
(
int
i
=
0
;
i
<
first_dim
;
i
++
)
{
server_index
=
(
server_index
+
1
)
%
server_num
;
shard_dims
[
server_index
]
=
shard_dims
[
server_index
]
+
1
;
}
if
(
shard_dims
.
count
(
rank_id
)
==
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid rank id "
<<
rank_id
<<
", total server num "
<<
server_num
;
}
return
shard_dims
;
}
void
Util
::
SetRankId
(
int
rank_id
)
{
rank_id_
=
rank_id
;
}
...
...
mindspore/ccsrc/frontend/parallel/ps/util.h
浏览文件 @
d92c220c
...
...
@@ -39,6 +39,7 @@ class Util {
static
std
::
string
optimizer_node_name
(
int
id
);
static
bool
is_optimizer
(
std
::
string
name
);
static
int
LocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
static
std
::
map
<
int
,
int
>
AllRankLocalShard
(
int
first_dim
,
int
rank_id
,
int
server_num
);
static
void
SetRankId
(
int
rank_id
);
static
int
GetRankId
();
static
void
ReduceSparseGradient
(
float
*
gradients
,
int
*
indices
,
const
size_t
indices_size
,
size_t
segment_size
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录