Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d03bbefa
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d03bbefa
编写于
5月 12, 2023
作者:
R
ronnywang
提交者:
GitHub
5月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomDevice] add inference MP support, PART0 (#53719)
* [CustomDevice] add inference MP support, PART0 * update
上级
eb97f4f0
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
791 addition
and
6 deletion
+791
-6
paddle/fluid/imperative/CMakeLists.txt
paddle/fluid/imperative/CMakeLists.txt
+18
-2
paddle/fluid/imperative/xccl_context.cc
paddle/fluid/imperative/xccl_context.cc
+282
-0
paddle/fluid/imperative/xccl_context.h
paddle/fluid/imperative/xccl_context.h
+71
-0
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+230
-0
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+108
-0
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+59
-1
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+1
-1
paddle/phi/backends/c_comm_lib.h
paddle/phi/backends/c_comm_lib.h
+2
-2
paddle/phi/backends/custom/custom_context.cc
paddle/phi/backends/custom/custom_context.cc
+13
-0
paddle/phi/backends/custom/custom_context.h
paddle/phi/backends/custom/custom_context.h
+7
-0
未找到文件。
paddle/fluid/imperative/CMakeLists.txt
浏览文件 @
d03bbefa
...
...
@@ -123,10 +123,26 @@ if(NOT WIN32)
SRCS reducer.cc
DEPS layer
)
endif
()
if
(
WITH_CUSTOM_DEVICE
)
cc_library
(
xccl_context
SRCS xccl_context.cc
DEPS collective_helper device_context tensor var_type_traits
)
if
(
NOT
(
WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL
OR WITH_GLOO
))
cc_library
(
reducer
SRCS reducer.cc
DEPS layer
)
endif
()
endif
()
if
(
WITH_NCCL
OR WITH_RCCL
OR WITH_XPU_BKCL
)
OR WITH_XPU_BKCL
OR WITH_CUSTOM_DEVICE
)
cc_library
(
heter_ccl_context
SRCS heter_ccl_context.cc
...
...
paddle/fluid/imperative/xccl_context.cc
0 → 100644
浏览文件 @
d03bbefa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/xccl_context.h"
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#endif
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
framework
{
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
imperative
{
static
void
XcclAllReduce
(
const
phi
::
DenseTensor
&
src
,
phi
::
DenseTensor
*
dst
,
const
phi
::
stream
::
Stream
&
stream
,
const
phi
::
ccl
::
CCLComm
&
comm
)
{
const
auto
&
place
=
src
.
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_custom_place
(
place
),
true
,
platform
::
errors
::
Unimplemented
(
"Dynamic graph mode does not support multi-CPU training yet."
));
void
*
src_ptr
=
const_cast
<
void
*>
(
src
.
data
());
dst
->
Resize
(
src
.
dims
());
auto
*
dst_ptr
=
phi
::
DeviceContextPool
::
Instance
()
.
Get
(
src
.
place
())
->
Alloc
(
dst
,
src
.
dtype
());
auto
xccl_dtype
=
phi
::
ccl
::
ToCCLDataType
(
src
.
dtype
());
phi
::
DeviceManager
::
CCLAllReduce
(
place
.
GetDeviceType
(),
src_ptr
,
dst_ptr
,
src
.
numel
(),
xccl_dtype
,
phi
::
ccl
::
CCLReduceOp
::
SUM
,
comm
,
stream
);
}
void
XCCLParallelContext
::
BcastXCCLId
(
std
::
vector
<
phi
::
ccl
::
CCLRootId
>
&
xccl_ids
,
// NOLINT
int
root
,
int
server_fd
)
{
if
(
strategy_
.
local_rank_
==
root
)
{
std
::
vector
<
std
::
string
>
other_trainers
;
for
(
auto
&
ep
:
strategy_
.
trainer_endpoints_
)
{
if
(
ep
!=
strategy_
.
current_endpoint_
)
{
other_trainers
.
push_back
(
ep
);
}
}
platform
::
SendBroadCastCommID
(
other_trainers
,
&
xccl_ids
);
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
strategy_
.
current_endpoint_
,
&
xccl_ids
);
}
}
void
XCCLParallelContext
::
Init
()
{
int
server_fd
=
-
1
;
std
::
vector
<
phi
::
ccl
::
CCLRootId
>
xccl_ids
;
xccl_ids
.
resize
(
strategy_
.
nrings_
);
if
(
strategy_
.
local_rank_
==
0
)
{
// generate the unique ncclid on the root worker
for
(
size_t
i
=
0
;
i
<
xccl_ids
.
size
();
++
i
)
{
phi
::
DeviceManager
::
CCLGetUniqueId
(
place_
.
GetDeviceType
(),
&
xccl_ids
[
i
]);
}
}
else
{
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd
=
platform
::
SocketServer
::
GetInstance
(
strategy_
.
current_endpoint_
)
.
socket
();
}
BcastXCCLId
(
xccl_ids
,
0
,
server_fd
);
int
dev_id
=
place_
.
device
;
for
(
int
ring_id
=
0
;
ring_id
<
strategy_
.
nrings_
;
ring_id
++
)
{
VLOG
(
0
)
<<
"init nccl context nranks: "
<<
strategy_
.
nranks_
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" dev id: "
<<
dev_id
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in phi::CustomContext within ring_id
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
CreateComm
(
&
xccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
dev_id
,
ring_id
);
auto
compute_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
auto
comm_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
compute_event
->
Init
(
place_
);
comm_event
->
Init
(
place_
);
compute_events_
.
emplace_back
(
compute_event
);
comm_events_
.
emplace_back
(
comm_event
);
}
}
void
XCCLParallelContext
::
InitWithRingID
(
int
ring_id
)
{
int
server_fd
=
-
1
;
std
::
vector
<
phi
::
ccl
::
CCLRootId
>
xccl_ids
;
xccl_ids
.
resize
(
1
);
if
(
strategy_
.
local_rank_
==
0
)
{
// generate the unique ncclid on the root worker
phi
::
DeviceManager
::
CCLGetUniqueId
(
place_
.
GetDeviceType
(),
&
xccl_ids
[
0
]);
}
else
{
// FIXME(wangxi): gloo will use rank0 endpoint, so not create socket server
// on rank0.
server_fd
=
platform
::
SocketServer
::
GetInstance
(
strategy_
.
current_endpoint_
)
.
socket
();
}
BcastXCCLId
(
xccl_ids
,
0
,
server_fd
);
int
dev_id
=
place_
.
device
;
VLOG
(
0
)
<<
"init xccl context nranks: "
<<
strategy_
.
nranks_
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" dev id: "
<<
dev_id
<<
" ring id: "
<<
ring_id
;
// it will assign xccl_comm in phi::CustomContext within ring_id
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
CreateComm
(
&
xccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
dev_id
,
ring_id
);
auto
compute_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
auto
comm_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
compute_event
->
Init
(
place_
);
comm_event
->
Init
(
place_
);
compute_events_
.
emplace_back
(
compute_event
);
comm_events_
.
emplace_back
(
comm_event
);
}
void
XCCLParallelContext
::
AllReduceByStream
(
const
framework
::
Variable
&
src
,
framework
::
Variable
*
dst
,
int
ring_id
,
bool
use_calc_stream
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_custom_place
(
place_
),
true
,
platform
::
errors
::
Unimplemented
(
"Dynamic graph mode does not support multi-CPU training yet."
));
auto
place
=
place_
;
auto
*
dev_ctx
=
static_cast
<
platform
::
CustomDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
platform
::
XCCLComm
*
comm
=
platform
::
XCCLCommContext
::
Instance
(
place
.
GetDeviceType
())
.
Get
(
ring_id
,
place
);
auto
stream
=
use_calc_stream
?
dev_ctx
->
GetStream
()
:
comm
->
stream
();
if
(
src
.
IsType
<
phi
::
DenseTensor
>
())
{
if
(
!
dst
->
IsType
<
phi
::
DenseTensor
>
())
{
dst
->
Clear
();
}
XcclAllReduce
(
src
.
Get
<
phi
::
DenseTensor
>
(),
dst
->
GetMutable
<
phi
::
DenseTensor
>
(),
*
stream
,
comm
->
comm
());
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"custom device unsupported variable type %s for imperative allreduce, "
"only "
"LoDTensor are supported."
,
platform
::
demangle
(
framework
::
ToTypeName
(
src
.
Type
()))));
}
}
void
XCCLParallelContext
::
Broadcast
(
framework
::
Variable
*
src
,
int
ring_id
)
{
VLOG
(
3
)
<<
"/// DEBUG /// start inter broadcast with ring_id: "
<<
ring_id
;
phi
::
DenseTensor
*
src_tensor
=
src
->
GetMutable
<
phi
::
DenseTensor
>
();
const
auto
&
place
=
src_tensor
->
place
();
platform
::
XCCLComm
*
comm
=
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
Get
(
ring_id
,
place
);
auto
stream
=
comm
->
stream
();
void
*
src_ptr
=
src_tensor
->
data
();
auto
xccl_dtype
=
phi
::
ccl
::
ToCCLDataType
(
src_tensor
->
dtype
());
phi
::
DeviceManager
::
CCLBroadcast
(
place_
.
GetDeviceType
(),
src_ptr
,
src_tensor
->
numel
(),
xccl_dtype
,
0
,
comm
->
comm
(),
*
stream
);
}
paddle
::
platform
::
DeviceContext
*
XCCLParallelContext
::
GetDeviceContext
(
int
ring_id
)
{
return
static_cast
<
platform
::
DeviceContext
*>
(
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
Get
(
ring_id
,
place_
)
->
dev_context
());
}
void
XCCLParallelContext
::
WaitCompute
(
int
ring_id
)
{
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
OutOfRange
(
"ring id must >= 0, but got %d"
,
ring_id
));
PADDLE_ENFORCE_LT
(
ring_id
,
compute_events_
.
size
(),
platform
::
errors
::
OutOfRange
(
"ring id must < compute events size,"
"but got ring id = %d, compute events size = %d"
,
ring_id
,
compute_events_
.
size
()));
auto
compute_stream
=
static_cast
<
phi
::
CustomContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
))
->
GetStream
();
auto
comm_stream
=
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
Get
(
ring_id
,
place_
)
->
stream
();
auto
event
=
compute_events_
[
ring_id
].
get
();
// compute_stream-->event-->comm_stream
event
->
Record
(
compute_stream
.
get
());
comm_stream
->
WaitEvent
(
event
);
}
void
XCCLParallelContext
::
WaitComm
(
int
ring_id
)
{
PADDLE_ENFORCE_GE
(
ring_id
,
0
,
platform
::
errors
::
OutOfRange
(
"ring id must >= 0, but got %d"
,
ring_id
));
PADDLE_ENFORCE_LT
(
ring_id
,
comm_events_
.
size
(),
platform
::
errors
::
OutOfRange
(
"ring id must < comm events size,"
"but got ring id = %d, comm events size = %d"
,
ring_id
,
comm_events_
.
size
()));
auto
compute_stream
=
static_cast
<
phi
::
CustomContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
))
->
GetStream
();
auto
comm_stream
=
platform
::
XCCLCommContext
::
Instance
(
place_
.
GetDeviceType
())
.
Get
(
ring_id
,
place_
)
->
stream
();
auto
event
=
comm_events_
[
ring_id
].
get
();
// comm_stream-->event-->compute_stream
event
->
Record
(
comm_stream
.
get
());
compute_stream
->
WaitEvent
(
event
);
}
void
XCCLParallelContext
::
SynchronizeCompute
()
{
auto
*
compute_dev_ctx
=
static_cast
<
phi
::
CustomContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
));
compute_dev_ctx
->
Wait
();
}
}
// namespace imperative
}
// namespace paddle
paddle/fluid/imperative/xccl_context.h
0 → 100644
浏览文件 @
d03bbefa
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/imperative/parallel_context.h"
namespace
paddle
{
namespace
framework
{
class
Variable
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
imperative
{
#ifdef PADDLE_WITH_CUSTOM_DEVICE
class
XCCLParallelContext
:
public
ParallelContext
{
public:
explicit
XCCLParallelContext
(
const
ParallelStrategy
&
strategy
,
const
platform
::
Place
&
place
)
:
ParallelContext
(
strategy
,
place
)
{}
~
XCCLParallelContext
()
override
=
default
;
void
BcastXCCLId
(
std
::
vector
<
phi
::
ccl
::
CCLRootId
>&
xccl_ids
,
// NOLINT
int
root
,
int
server_fd
);
void
Init
()
override
;
void
InitWithRingID
(
int
ring_id
)
override
;
void
AllReduceByStream
(
const
framework
::
Variable
&
src
,
framework
::
Variable
*
dst
,
int
ring_id
,
bool
use_calc_stream
)
override
;
void
Broadcast
(
framework
::
Variable
*
src
,
int
ring_id
)
override
;
paddle
::
platform
::
DeviceContext
*
GetDeviceContext
(
int
ring_id
)
override
;
void
WaitCompute
(
int
ring_id
)
override
;
void
WaitComm
(
int
ring_id
)
override
;
void
SynchronizeCompute
()
override
;
private:
// used for comm wait compute, compute_stream-->event-->comm_stream[ring_id]
std
::
vector
<
std
::
shared_ptr
<
phi
::
event
::
Event
>>
compute_events_
;
// used for compute wait comm, comm_stream[ring_id]-->event-->compute_stream
std
::
vector
<
std
::
shared_ptr
<
phi
::
event
::
Event
>>
comm_events_
;
};
#endif
}
// namespace imperative
}
// namespace paddle
paddle/fluid/platform/collective_helper.cc
浏览文件 @
d03bbefa
...
...
@@ -404,5 +404,235 @@ void BKCLCommContext::ReleaseBKCLComms() {
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
class
XCCLCommImpl
:
public
XCCLComm
{
public:
void
set_ring_id
(
int
ring_id
)
{
ring_id_
=
ring_id
;
}
int
ring_id
()
const
override
{
return
ring_id_
;
}
void
set_nranks
(
int
nranks
)
{
nranks_
=
nranks
;
}
int
nranks
()
const
override
{
return
nranks_
;
}
void
set_rank
(
int
rank
)
{
rank_
=
rank
;
}
int
rank
()
const
override
{
return
rank_
;
}
int
device_id
()
const
override
{
return
dev_ctx_
->
GetPlace
().
device
;
}
void
set_comm
(
phi
::
ccl
::
CCLComm
comm
)
{
comm_
=
comm
;
}
phi
::
ccl
::
CCLComm
comm
()
const
override
{
return
comm_
;
}
std
::
shared_ptr
<
phi
::
stream
::
Stream
>
stream
()
const
override
{
return
dev_ctx_
->
GetStream
();
}
void
set_dev_ctx
(
std
::
unique_ptr
<
phi
::
CustomContext
>&&
dev_ctx
)
{
dev_ctx_
=
std
::
move
(
dev_ctx
);
}
phi
::
CustomContext
*
dev_context
()
const
override
{
return
dev_ctx_
.
get
();
}
std
::
shared_ptr
<
phi
::
event
::
Event
>
compute_event
()
const
override
{
return
compute_event_
;
}
std
::
shared_ptr
<
phi
::
event
::
Event
>
comm_event
()
const
override
{
return
comm_event_
;
}
void
set_compute_event
(
std
::
shared_ptr
<
phi
::
event
::
Event
>&&
compute_event
)
{
compute_event_
=
std
::
move
(
compute_event
);
}
void
set_comm_event
(
std
::
shared_ptr
<
phi
::
event
::
Event
>&&
comm_event
)
{
comm_event_
=
std
::
move
(
comm_event
);
}
private:
int
ring_id_
;
int
nranks_
;
int
rank_
;
phi
::
ccl
::
CCLComm
comm_
;
std
::
unique_ptr
<
phi
::
CustomContext
>
dev_ctx_
;
// used for comm wait compute, compute_stream-->event-->comm_stream
std
::
shared_ptr
<
phi
::
event
::
Event
>
compute_event_
;
// used for compute wait comm, comm_stream-->event-->compute_stream
std
::
shared_ptr
<
phi
::
event
::
Event
>
comm_event_
;
};
static
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
XCCLCommContext
>>
g_xccl_comm_ctx_map
;
void
XCCLCommContext
::
Release
()
{
for
(
auto
&
it
:
g_xccl_comm_ctx_map
)
{
it
.
second
->
ReleaseXCCLComms
();
}
g_xccl_comm_ctx_map
.
clear
();
}
XCCLCommContext
&
XCCLCommContext
::
Instance
(
const
std
::
string
&
device_type
)
{
if
(
g_xccl_comm_ctx_map
.
find
(
device_type
)
==
g_xccl_comm_ctx_map
.
end
())
{
g_xccl_comm_ctx_map
.
insert
(
{
device_type
,
std
::
unique_ptr
<
XCCLCommContext
>
(
new
XCCLCommContext
(
device_type
))});
}
return
*
g_xccl_comm_ctx_map
[
device_type
];
}
XCCLComm
*
XCCLCommContext
::
CreateComm
(
phi
::
ccl
::
CCLRootId
*
xccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
xccl_id
,
platform
::
errors
::
InvalidArgument
(
"The xccl unique id should not be null."
));
PADDLE_ENFORCE_GT
(
nranks
,
1
,
platform
::
errors
::
InvalidArgument
(
"Expected nranks > 1. But received nranks is %d."
,
nranks
));
PADDLE_ENFORCE_GE
(
rank
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected rank >= 0. But received rank is %d."
,
rank
));
PADDLE_ENFORCE_LT
(
rank
,
nranks
,
platform
::
errors
::
InvalidArgument
(
"Expected rank < nranks. But received rank is %d, nranks is %d."
,
rank
,
nranks
));
PADDLE_ENFORCE_GE
(
dev_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Expected dev_id >= 0. But received dev_id is %d."
,
dev_id
));
phi
::
ccl
::
CCLComm
comm
=
nullptr
;
phi
::
DeviceManager
::
SetDevice
(
device_type_
,
dev_id
);
phi
::
DeviceManager
::
CCLCommInitRank
(
device_type_
,
nranks
,
xccl_id
,
rank
,
&
comm
);
auto
*
comm_wrapper
=
AssignXCCLComm
(
comm
,
nranks
,
rank
,
dev_id
,
ring_id
);
VLOG
(
1
)
<<
"xccl communicator of rank "
<<
rank
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_id
;
return
comm_wrapper
;
}
void
XCCLCommContext
::
CreateXCCLCommMultiTrainer
(
const
std
::
vector
<
int
>&
dev_ids
,
phi
::
ccl
::
CCLRootId
*
xccl_id
,
int
ntrainers
,
int
train_id
,
int
ring_id
)
{
PADDLE_ENFORCE_GT
(
dev_ids
.
size
(),
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"dev ids = [%d], it should greater than 0."
,
dev_ids
.
size
()));
const
int
kDevices
=
dev_ids
.
size
();
VLOG
(
1
)
<<
"Begin CreateXCCLCommMultiTrainer. device number: "
<<
kDevices
<<
", ntrainers: "
<<
ntrainers
<<
", train_id: "
<<
train_id
<<
", rind_id: "
<<
ring_id
;
phi
::
ccl
::
CCLComm
comms
[
kDevices
];
{
for
(
int
i
=
0
;
i
<
kDevices
;
i
++
)
{
phi
::
DeviceManager
::
SetDevice
(
device_type_
,
i
);
phi
::
DeviceManager
::
CCLCommInitRank
(
device_type_
,
kDevices
*
ntrainers
,
xccl_id
,
train_id
*
kDevices
+
i
,
comms
+
i
);
VLOG
(
1
)
<<
"CCLCommInitRank: "
<<
i
;
}
}
PADDLE_ENFORCE_EQ
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"comm_map_ of ring_id: %s should be 0. %s is provided"
,
ring_id
,
comm_map_
.
count
(
ring_id
)));
for
(
int
i
=
0
;
i
<
kDevices
;
++
i
)
{
AssignXCCLComm
(
comms
[
i
],
kDevices
*
ntrainers
,
train_id
*
kDevices
+
i
,
dev_ids
[
i
],
ring_id
);
VLOG
(
1
)
<<
"xccl communicator of train_id "
<<
train_id
*
kDevices
+
i
<<
" in ring "
<<
ring_id
<<
" has been created on device "
<<
dev_ids
[
i
];
}
}
XCCLComm
*
XCCLCommContext
::
AssignXCCLComm
(
phi
::
ccl
::
CCLComm
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
auto
place
=
CustomPlace
(
device_type_
,
dev_id
);
std
::
unique_ptr
<
phi
::
CustomContext
>
dev_ctx
(
new
phi
::
CustomContext
(
place
));
dev_ctx
->
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
place
)
.
get
());
dev_ctx
->
SetHostAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
dev_ctx
->
SetZeroAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetZeroAllocator
(
place
)
.
get
());
dev_ctx
->
SetHostZeroAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetZeroAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
dev_ctx
->
SetPinnedAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
// dev_ctx->PartialInitWithAllocator();
auto
compute_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
auto
comm_event
=
std
::
make_shared
<
phi
::
event
::
Event
>
();
compute_event
->
Init
(
place
);
comm_event
->
Init
(
place
);
auto
*
c
=
new
XCCLCommImpl
;
c
->
set_ring_id
(
ring_id
);
c
->
set_nranks
(
nranks
);
c
->
set_rank
(
rank
);
c
->
set_comm
(
comm
);
c
->
set_dev_ctx
(
std
::
move
(
dev_ctx
));
c
->
set_compute_event
(
std
::
move
(
compute_event
));
c
->
set_comm_event
(
std
::
move
(
comm_event
));
comm_map_mutex_
.
lock
();
if
(
comm_map_
.
count
(
ring_id
)
==
0
)
{
comm_map_
.
emplace
(
ring_id
,
std
::
map
<
int
,
std
::
unique_ptr
<
XCCLComm
>>
());
}
auto
&
dev2comm
=
comm_map_
[
ring_id
];
dev2comm
.
emplace
(
dev_id
,
std
::
unique_ptr
<
XCCLComm
>
(
c
));
comm_map_mutex_
.
unlock
();
if
(
ring_id
==
0
)
{
auto
*
dev_ctx
=
static_cast
<
phi
::
CustomContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
dev_ctx
->
set_xccl_comm
(
comm
);
}
VLOG
(
4
)
<<
"add xccl comm: "
<<
comm_map_
[
ring_id
][
dev_id
].
get
()
<<
", ring_id:"
<<
ring_id
<<
", dev_id:"
<<
dev_id
;
return
comm_map_
[
ring_id
][
dev_id
].
get
();
}
void
XCCLCommContext
::
ReleaseXCCLComms
()
{
for
(
auto
&
p
:
comm_map_
)
{
for
(
auto
&
q
:
p
.
second
)
{
q
.
second
.
reset
();
}
}
}
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/collective_helper.h
浏览文件 @
d03bbefa
...
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/utils/variant.h"
namespace
paddle
{
...
...
@@ -243,5 +244,112 @@ class BKCLCommContext {
DISABLE_COPY_AND_ASSIGN
(
BKCLCommContext
);
};
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
class
XCCLComm
{
public:
virtual
int
ring_id
()
const
=
0
;
virtual
int
nranks
()
const
=
0
;
virtual
int
rank
()
const
=
0
;
virtual
int
device_id
()
const
=
0
;
virtual
phi
::
ccl
::
CCLComm
comm
()
const
=
0
;
virtual
std
::
shared_ptr
<
phi
::
stream
::
Stream
>
stream
()
const
=
0
;
virtual
std
::
shared_ptr
<
phi
::
event
::
Event
>
compute_event
()
const
=
0
;
virtual
std
::
shared_ptr
<
phi
::
event
::
Event
>
comm_event
()
const
=
0
;
virtual
phi
::
CustomContext
*
dev_context
()
const
=
0
;
virtual
~
XCCLComm
()
=
default
;
};
// A singleton XCCL communicator context reserves communication ring ids
class
XCCLCommContext
{
public:
static
XCCLCommContext
&
Instance
(
const
std
::
string
&
device_type
);
static
void
Release
();
XCCLComm
*
CreateComm
(
phi
::
ccl
::
CCLRootId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllXCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
void
CreateXCCLCommMultiTrainer
(
const
std
::
vector
<
int
>&
dev_ids
,
phi
::
ccl
::
CCLRootId
*
xccl_id
,
int
nranks
,
int
rank
,
int
ring_id
);
// a latter comm with the same dev_id and the same ring_id
// will override the former
XCCLComm
*
AssignXCCLComm
(
phi
::
ccl
::
CCLComm
comm
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
// retrieve a communicator by the ring id in multiprocessing mode
XCCLComm
*
Get
(
int
ring_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator in ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_EQ
(
comm_map_
.
at
(
ring_id
).
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"One device id should be specified to retrieve from "
"multiple communicators."
));
return
comm_map_
.
at
(
ring_id
).
begin
()
->
second
.
get
();
}
int
GetRingId
(
phi
::
ccl
::
CCLComm
comm
)
const
{
for
(
const
auto
&
pair
:
comm_map_
)
{
for
(
const
auto
&
p
:
pair
.
second
)
{
if
(
p
.
second
.
get
()
->
comm
()
==
comm
)
{
return
pair
.
first
;
}
}
}
return
-
1
;
}
// retrieve a communicator by the ring id and the device id
XCCLComm
*
Get
(
int
ring_id
,
int
dev_id
)
const
{
PADDLE_ENFORCE_GT
(
comm_map_
.
count
(
ring_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator of ring id %d has not been initialized."
,
ring_id
));
PADDLE_ENFORCE_GT
(
comm_map_
.
at
(
ring_id
).
count
(
dev_id
),
0
,
platform
::
errors
::
InvalidArgument
(
"Communicator at device id %d has not been initialized in ring %d."
,
dev_id
,
ring_id
));
return
comm_map_
.
at
(
ring_id
).
at
(
dev_id
).
get
();
}
// retrieve a communicator by the ring id and place
XCCLComm
*
Get
(
int
ring_id
,
Place
place
)
const
{
return
Get
(
ring_id
,
place
.
device
);
}
private:
std
::
string
device_type_
;
std
::
once_flag
once_flag_
;
std
::
mutex
comm_map_mutex_
;
// ring id to dev-XCCLComm
std
::
map
<
int
,
std
::
map
<
int
,
std
::
unique_ptr
<
XCCLComm
>>>
comm_map_
;
void
ReleaseXCCLComms
();
XCCLCommContext
()
=
default
;
explicit
XCCLCommContext
(
const
std
::
string
&
device_type
)
:
device_type_
(
device_type
)
{}
DISABLE_COPY_AND_ASSIGN
(
XCCLCommContext
);
};
#endif
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
d03bbefa
...
...
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL)
|| defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include <arpa/inet.h>
...
...
@@ -33,6 +33,9 @@ limitations under the License. */
#if defined(PADDLE_WITH_XPU_BKCL)
#include "xpu/bkcl.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/phi/backends/c_comm_lib.h"
#endif
PHI_DECLARE_int32
(
get_host_by_name_time
);
...
...
@@ -348,6 +351,58 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
"send comm unique id"
);
}
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template
<
>
void
RecvCommID
<
phi
::
ccl
::
CCLRootId
>
(
int
conn
,
phi
::
ccl
::
CCLRootId
*
nccl_id
)
{
char
buffer
[
MAX_COMMUNIQUEID_LEN
]
=
{
0
};
CHECK_SYS_CALL
(
SocketRecv
(
conn
,
buffer
,
sizeof
(
size_t
)),
"recv comm unique id size"
);
size_t
unique_id_size
=
*
reinterpret_cast
<
size_t
*>
(
buffer
);
VLOG
(
6
)
<<
"RecvCommID size: "
<<
unique_id_size
;
nccl_id
->
resize
(
unique_id_size
);
size_t
n_repeat
=
unique_id_size
/
MAX_COMMUNIQUEID_LEN
;
size_t
n_remain
=
unique_id_size
%
MAX_COMMUNIQUEID_LEN
;
for
(
size_t
i
=
0
;
i
<
n_repeat
;
++
i
)
{
CHECK_SYS_CALL
(
SocketRecv
(
conn
,
buffer
,
MAX_COMMUNIQUEID_LEN
),
"recv comm unique id"
);
memcpy
(
nccl_id
->
data
()
+
i
*
MAX_COMMUNIQUEID_LEN
,
buffer
,
MAX_COMMUNIQUEID_LEN
);
}
if
(
n_remain
)
{
CHECK_SYS_CALL
(
SocketRecv
(
conn
,
buffer
,
n_remain
),
"recv comm unique id"
);
memcpy
(
nccl_id
->
data
()
+
n_repeat
*
MAX_COMMUNIQUEID_LEN
,
buffer
,
n_remain
);
}
VLOG
(
6
)
<<
"RecvCommID done"
;
}
template
<
>
void
SendCommID
<
phi
::
ccl
::
CCLRootId
>
(
int
conn
,
phi
::
ccl
::
CCLRootId
*
nccl_id
)
{
char
buffer
[
MAX_COMMUNIQUEID_LEN
]
=
{
0
};
size_t
unique_id_size
=
nccl_id
->
size
();
VLOG
(
6
)
<<
"SendCommID size: "
<<
unique_id_size
;
memcpy
(
buffer
,
&
unique_id_size
,
sizeof
(
size_t
));
CHECK_SYS_CALL
(
SocketSend
(
conn
,
buffer
,
sizeof
(
size_t
)),
"send comm unique id size"
);
size_t
n_repeat
=
unique_id_size
/
MAX_COMMUNIQUEID_LEN
;
size_t
n_remain
=
unique_id_size
%
MAX_COMMUNIQUEID_LEN
;
for
(
size_t
i
=
0
;
i
<
n_repeat
;
++
i
)
{
memcpy
(
buffer
,
nccl_id
->
data
()
+
i
*
MAX_COMMUNIQUEID_LEN
,
MAX_COMMUNIQUEID_LEN
);
CHECK_SYS_CALL
(
SocketSend
(
conn
,
buffer
,
MAX_COMMUNIQUEID_LEN
),
"send comm unique id"
);
}
if
(
n_remain
)
{
memcpy
(
buffer
,
nccl_id
->
data
()
+
n_repeat
*
MAX_COMMUNIQUEID_LEN
,
n_remain
);
CHECK_SYS_CALL
(
SocketSend
(
conn
,
buffer
,
n_remain
),
"send comm unique id"
);
}
VLOG
(
6
)
<<
"SendCommID done"
;
}
#endif
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
...
...
@@ -444,6 +499,9 @@ INSTANT_TEMPLATE(ncclUniqueId)
#ifdef PADDLE_WITH_XPU_BKCL
INSTANT_TEMPLATE
(
BKCLUniqueId
)
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
INSTANT_TEMPLATE
(
phi
::
ccl
::
CCLRootId
)
#endif
}
// namespace platform
}
// namespace paddle
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
d03bbefa
...
...
@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
defined(PADDLE_WITH_XPU_BKCL)
|| defined(PADDLE_WITH_CUSTOM_DEVICE)
#include <functional>
#include <memory>
#include <mutex>
...
...
paddle/phi/backends/c_comm_lib.h
浏览文件 @
d03bbefa
...
...
@@ -23,8 +23,8 @@
namespace
phi
{
namespace
ccl
{
using
CCLComm
=
void
*
;
using
CCLRootId
=
std
::
vector
<
uint8_t
>
;
typedef
void
*
CCLComm
;
typedef
std
::
vector
<
uint8_t
>
CCLRootId
;
enum
CCLReduceOp
{
SUM
=
0
,
AVG
,
MAX
,
MIN
,
PRODUCT
};
enum
CCLDataType
{
...
...
paddle/phi/backends/custom/custom_context.cc
浏览文件 @
d03bbefa
...
...
@@ -44,9 +44,15 @@ struct CustomContext::Impl {
void
Wait
()
const
{
stream_
->
Wait
();
}
phi
::
ccl
::
CCLComm
xccl_comm
()
const
{
return
comm_
;
}
void
set_xccl_comm
(
phi
::
ccl
::
CCLComm
comm
)
{
comm_
=
comm
;
}
Place
place_
;
std
::
shared_ptr
<
phi
::
stream
::
Stream
>
stream_
;
phi
::
ccl
::
CCLComm
comm_
;
};
void
CustomContext
::
Init
()
{
impl_
->
Init
();
}
...
...
@@ -72,4 +78,11 @@ CustomContext::CustomContext(const CustomPlace& place)
CustomContext
::~
CustomContext
()
{
impl_
->
Init
();
}
phi
::
ccl
::
CCLComm
CustomContext
::
xccl_comm
()
const
{
return
impl_
->
xccl_comm
();
}
void
CustomContext
::
set_xccl_comm
(
phi
::
ccl
::
CCLComm
comm
)
{
impl_
->
set_xccl_comm
(
comm
);
}
}
// namespace phi
paddle/phi/backends/custom/custom_context.h
浏览文件 @
d03bbefa
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include "paddle/phi/backends/c_comm_lib.h"
#include "paddle/phi/backends/stream.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
...
...
@@ -63,6 +64,12 @@ class CustomContext : public DeviceContext,
// all resources and delete them when destructing.
void
Init
();
/*! \brief Return xccl communicators. */
phi
::
ccl
::
CCLComm
xccl_comm
()
const
;
/*! \brief Set nccl communicators. */
void
set_xccl_comm
(
phi
::
ccl
::
CCLComm
comm
);
private:
CustomContext
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录