Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7a724ddb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a724ddb
编写于
10月 11, 2021
作者:
Y
yaoxuefeng
提交者:
GitHub
10月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix multi-node (#36329)
上级
414c252a
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
15 addition
and
5 deletion
+15
-5
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
+9
-1
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+4
-4
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+2
-0
未找到文件。
paddle/fluid/framework/fleet/ps_gpu_wrapper.h
浏览文件 @
7a724ddb
...
@@ -117,6 +117,15 @@ class PSGPUWrapper {
...
@@ -117,6 +117,15 @@ class PSGPUWrapper {
resource_
=
std
::
make_shared
<
HeterPsResource
>
(
dev_ids
);
resource_
=
std
::
make_shared
<
HeterPsResource
>
(
dev_ids
);
resource_
->
enable_p2p
();
resource_
->
enable_p2p
();
keys_tensor
.
resize
(
resource_
->
total_gpu
());
keys_tensor
.
resize
(
resource_
->
total_gpu
());
#ifdef PADDLE_WITH_GLOO
auto
gloo
=
paddle
::
framework
::
GlooWrapper
::
GetInstance
();
if
(
gloo
->
Size
()
>
1
)
{
multi_node_
=
1
;
}
#else
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"heter ps need compile with GLOO"
));
#endif
if
(
multi_node_
)
{
if
(
multi_node_
)
{
int
dev_size
=
dev_ids
.
size
();
int
dev_size
=
dev_ids
.
size
();
// init inner comm
// init inner comm
...
@@ -127,7 +136,6 @@ class PSGPUWrapper {
...
@@ -127,7 +136,6 @@ class PSGPUWrapper {
// init inter comm
// init inter comm
#ifdef PADDLE_WITH_GLOO
#ifdef PADDLE_WITH_GLOO
inter_comms_
.
resize
(
dev_size
);
inter_comms_
.
resize
(
dev_size
);
auto
gloo
=
paddle
::
framework
::
GlooWrapper
::
GetInstance
();
if
(
gloo
->
Rank
()
==
0
)
{
if
(
gloo
->
Rank
()
==
0
)
{
for
(
int
i
=
0
;
i
<
dev_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
dev_size
;
++
i
)
{
platform
::
dynload
::
ncclGetUniqueId
(
&
inter_ncclids_
[
i
]);
platform
::
dynload
::
ncclGetUniqueId
(
&
inter_ncclids_
[
i
]);
...
...
paddle/fluid/platform/collective_helper.cc
浏览文件 @
7a724ddb
...
@@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
...
@@ -148,7 +148,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
paddle
::
platform
::
errors
::
InvalidArgument
(
paddle
::
platform
::
errors
::
InvalidArgument
(
"dev ids = [%d], it should greater than 0."
,
dev_ids
.
size
()));
"dev ids = [%d], it should greater than 0."
,
dev_ids
.
size
()));
const
int
kDevices
=
dev_ids
.
size
();
const
int
kDevices
=
dev_ids
.
size
();
VLOG
(
3
)
<<
"Begin CreateNCCLCommMultiTrainer. device number: "
<<
kDevices
VLOG
(
1
)
<<
"Begin CreateNCCLCommMultiTrainer. device number: "
<<
kDevices
<<
", ntrainers: "
<<
ntrainers
<<
", train_id: "
<<
train_id
<<
", ntrainers: "
<<
ntrainers
<<
", train_id: "
<<
train_id
<<
", rind_id: "
<<
ring_id
;
<<
", rind_id: "
<<
ring_id
;
ncclComm_t
comms
[
kDevices
];
ncclComm_t
comms
[
kDevices
];
...
@@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
...
@@ -162,10 +162,10 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
#endif
#endif
platform
::
dynload
::
ncclCommInitRank
(
comms
+
i
,
kDevices
*
ntrainers
,
platform
::
dynload
::
ncclCommInitRank
(
comms
+
i
,
kDevices
*
ntrainers
,
*
nccl_id
,
train_id
*
kDevices
+
i
);
*
nccl_id
,
train_id
*
kDevices
+
i
);
VLOG
(
3
)
<<
"ncclCommInitRank: "
<<
i
;
VLOG
(
1
)
<<
"ncclCommInitRank: "
<<
i
;
}
}
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
ncclGroupEnd
());
PADDLE_ENFORCE_CUDA_SUCCESS
(
dynload
::
ncclGroupEnd
());
VLOG
(
3
)
<<
"nccl group end seccessss"
;
VLOG
(
1
)
<<
"nccl group end seccessss"
;
}
}
PADDLE_ENFORCE_EQ
(
comm_map_
.
count
(
ring_id
),
0
,
PADDLE_ENFORCE_EQ
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
...
@@ -174,7 +174,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer(
for
(
int
i
=
0
;
i
<
kDevices
;
++
i
)
{
for
(
int
i
=
0
;
i
<
kDevices
;
++
i
)
{
AssignNCCLComm
(
comms
[
i
],
kDevices
*
ntrainers
,
train_id
*
kDevices
+
i
,
AssignNCCLComm
(
comms
[
i
],
kDevices
*
ntrainers
,
train_id
*
kDevices
+
i
,
dev_ids
[
i
],
ring_id
);
dev_ids
[
i
],
ring_id
);
VLOG
(
3
)
<<
"nccl communicator of train_id "
<<
train_id
*
kDevices
+
i
VLOG
(
1
)
<<
"nccl communicator of train_id "
<<
train_id
*
kDevices
+
i
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_ids
[
i
];
<<
dev_ids
[
i
];
}
}
...
...
python/paddle/fluid/dataset.py
浏览文件 @
7a724ddb
...
@@ -396,6 +396,8 @@ class InMemoryDataset(DatasetBase):
...
@@ -396,6 +396,8 @@ class InMemoryDataset(DatasetBase):
Set data_feed_desc
Set data_feed_desc
"""
"""
self
.
proto_desc
.
name
=
data_feed_type
self
.
proto_desc
.
name
=
data_feed_type
if
(
self
.
proto_desc
.
name
==
"SlotRecordInMemoryDataFeed"
):
self
.
dataset
=
core
.
Dataset
(
"SlotRecordDataset"
)
@
deprecated
(
@
deprecated
(
since
=
"2.0.0"
,
since
=
"2.0.0"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录