Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0b205817
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0b205817
编写于
2月 23, 2022
作者:
S
ShenLiang
提交者:
GitHub
2月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add ProcessGroupNCCL for distributed training (#39737)
* add processgroup_nccl
上级
058e1d85
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
1253 addition
and
6 deletion
+1253
-6
paddle/fluid/distributed/CMakeLists.txt
paddle/fluid/distributed/CMakeLists.txt
+1
-1
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+5
-0
paddle/fluid/distributed/collective/NCCLTools.h
paddle/fluid/distributed/collective/NCCLTools.h
+198
-0
paddle/fluid/distributed/collective/ProcessGroup.cc
paddle/fluid/distributed/collective/ProcessGroup.cc
+40
-0
paddle/fluid/distributed/collective/ProcessGroup.h
paddle/fluid/distributed/collective/ProcessGroup.h
+108
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
+321
-0
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
+126
-0
paddle/fluid/distributed/collective/Types.h
paddle/fluid/distributed/collective/Types.h
+36
-0
paddle/fluid/platform/cuda_device_guard.h
paddle/fluid/platform/cuda_device_guard.h
+19
-5
paddle/fluid/platform/device/gpu/nccl_helper.h
paddle/fluid/platform/device/gpu/nccl_helper.h
+17
-0
paddle/fluid/platform/flags.cc
paddle/fluid/platform/flags.cc
+12
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+8
-0
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+149
-0
paddle/fluid/pybind/distributed_py.h
paddle/fluid/pybind/distributed_py.h
+29
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+4
-0
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+4
-0
python/paddle/fluid/tests/unittests/process_group_nccl.py
python/paddle/fluid/tests/unittests/process_group_nccl.py
+149
-0
python/paddle/fluid/tests/unittests/test_collective_process_group.py
...le/fluid/tests/unittests/test_collective_process_group.py
+27
-0
未找到文件。
paddle/fluid/distributed/CMakeLists.txt
浏览文件 @
0b205817
add_subdirectory
(
collective
)
add_subdirectory
(
store
)
if
(
NOT WITH_PSCORE
)
add_subdirectory
(
fleet_executor
)
return
()
...
...
paddle/fluid/distributed/collective/CMakeLists.txt
0 → 100644
浏览文件 @
0b205817
cc_library
(
processgroup SRCS ProcessGroup.cc DEPS pten pten_api eager_api
)
if
(
WITH_NCCL
)
cc_library
(
processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context pten pten_api eager_api
)
endif
()
paddle/fluid/distributed/collective/NCCLTools.h
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_runtime.h>
#include <error.h>
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
#define NCCLCHECK(cmd) \
do { \
ncclResult_t r = cmd; \
if (r != ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
platform::dynload::ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
// NOTE(shenliang03): EventManager are movable not copyable CudaEvent wrapper.
// EventManage is different from paddle::platform::CudaEvent.
// It uses lazy initialization and is only created when the
// Record() method is called for the first time; it also monitors
// device information to ensure that recorded stream and event
// are on the same device.
class
EventManager
{
public:
EventManager
()
{}
explicit
EventManager
(
unsigned
int
flags
)
:
flags_
{
flags
}
{}
~
EventManager
()
{
if
(
is_created_
)
{
platform
::
CUDADeviceGuard
guard
(
device_index_
);
cudaEventDestroy
(
event_
);
}
}
EventManager
(
const
EventManager
&
)
=
delete
;
EventManager
&
operator
=
(
const
EventManager
&
)
=
delete
;
EventManager
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
}
EventManager
&
operator
=
(
EventManager
&&
other
)
{
std
::
swap
(
flags_
,
other
.
flags_
);
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
return
*
this
;
}
bool
IsCreated
()
const
{
return
is_created_
;
}
bool
DeviceId
()
const
{
return
device_index_
;
}
gpuEvent_t
GetRawCudaEvent
()
const
{
return
event_
;
}
void
Record
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
if
(
!
is_created_
)
{
CreateEvent
(
device_index
);
}
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"CUDADeviceContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventRecord
(
event_
,
ctx
.
stream
()));
}
bool
Query
()
const
{
gpuError_t
err
=
cudaEventQuery
(
event_
);
if
(
err
==
cudaSuccess
)
{
return
true
;
}
else
if
(
err
==
cudaErrorNotReady
)
{
return
false
;
}
else
{
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
return
false
;
}
}
void
Synchronize
()
const
{
if
(
is_created_
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventSynchronize
(
event_
));
}
}
void
Block
(
const
paddle
::
platform
::
CUDADeviceContext
&
ctx
)
const
{
if
(
is_created_
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"CUDADeviceContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
CUDADeviceGuard
guard
(
device_index_
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamWaitEvent
(
ctx
.
stream
(),
event_
,
0
));
}
}
private:
unsigned
int
flags_
=
cudaEventDefault
;
bool
is_created_
{
false
};
gpuEvent_t
event_
{};
int8_t
device_index_
{
0
};
private:
void
CreateEvent
(
int
device_index
)
{
device_index_
=
device_index
;
platform
::
CUDADeviceGuard
guard
(
device_index
);
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaEventCreateWithFlags
(
&
event_
,
flags_
));
is_created_
=
true
;
}
};
// NOTE(shenliang03): NCCLCommManager is more lightweight than
// platform::NCCLComm
class
NCCLCommManager
{
public:
explicit
NCCLCommManager
(
ncclComm_t
ncclComm
)
:
nccl_comm_
(
ncclComm
)
{}
NCCLCommManager
()
:
NCCLCommManager
(
nullptr
)
{}
~
NCCLCommManager
()
noexcept
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
nccl_comm_
)
{
platform
::
dynload
::
ncclCommDestroy
(
nccl_comm_
);
}
}
static
std
::
shared_ptr
<
NCCLCommManager
>
Create
(
int
num_ranks
,
int
rank
,
ncclUniqueId
comm_id
)
{
auto
nccl_manager
=
std
::
make_shared
<
NCCLCommManager
>
();
NCCLCHECK
(
platform
::
dynload
::
ncclCommInitRank
(
&
(
nccl_manager
->
nccl_comm_
),
num_ranks
,
comm_id
,
rank
));
nccl_manager
->
nccl_id_
=
comm_id
;
nccl_manager
->
rank_
=
rank
;
return
nccl_manager
;
}
ncclUniqueId
GetNcclId
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_id_
;
}
ncclComm_t
GetNcclComm
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
nccl_comm_
;
}
NCCLCommManager
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
const
NCCLCommManager
&
)
=
delete
;
NCCLCommManager
&
operator
=
(
NCCLCommManager
&&
other
)
=
delete
;
NCCLCommManager
(
NCCLCommManager
&&
other
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
other
.
mutex_
);
std
::
swap
(
nccl_comm_
,
other
.
nccl_comm_
);
}
protected:
ncclComm_t
nccl_comm_
;
ncclUniqueId
nccl_id_
;
int
rank_
;
mutable
std
::
mutex
mutex_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroup.cc
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
namespace
paddle
{
namespace
distributed
{
ProcessGroup
::
Task
::
Task
(
int
rank
,
const
std
::
vector
<
Tensor
>&
inputTensors
,
CommType
comm_type
)
:
rank_
(
rank
),
comm_type_
(
comm_type
)
{}
ProcessGroup
::
Task
::~
Task
()
=
default
;
bool
ProcessGroup
::
Task
::
IsCompleted
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
return
is_completed_
;
}
bool
ProcessGroup
::
Task
::
Wait
(
std
::
chrono
::
milliseconds
timeout
)
{
return
false
;
}
void
ProcessGroup
::
Task
::
Synchronize
()
{}
ProcessGroup
::
ProcessGroup
(
int
rank
,
int
size
)
:
rank_
(
rank
),
size_
(
size
)
{}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroup.h
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/enforce.h"
constexpr
auto
kWaitTimeout
=
std
::
chrono
::
milliseconds
(
0
);
namespace
paddle
{
namespace
distributed
{
using
Tensor
=
paddle
::
experimental
::
Tensor
;
enum
class
CommType
:
std
::
uint8_t
{
BROADCAST
=
0
,
ALLREDUCE
=
1
,
ALLREDUCE_SPARSE
=
2
,
// TODO(shenliang03): to support sparse in allreduce
REDUCE
=
3
,
ALLGATHER
=
4
,
GATHER
=
5
,
SCATTER
=
6
,
REDUCE_SCATTER
=
7
,
ALLTOALL
=
8
,
SEND
=
9
,
RECV
=
10
,
BARRIER
=
11
,
UNKNOWN
=
100
,
};
struct
ProcessGroupStrategy
{
int
nranks_
{
1
};
int
local_rank_
{
0
};
std
::
vector
<
std
::
string
>
trainer_endpoints_
{};
std
::
string
current_endpoint_
{
""
};
int
nrings_
{
1
};
};
class
ProcessGroup
{
public:
class
Task
{
public:
Task
(
int
rank
,
const
std
::
vector
<
Tensor
>&
inputTensors
,
CommType
opType
=
CommType
::
UNKNOWN
);
virtual
~
Task
();
virtual
bool
IsCompleted
();
virtual
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
);
virtual
void
Synchronize
();
protected:
const
int
rank_
;
CommType
comm_type_
;
std
::
mutex
mutex_
;
bool
is_completed_
=
false
;
};
explicit
ProcessGroup
(
int
rank
,
int
size
);
virtual
~
ProcessGroup
()
{}
int
GetRank
()
const
{
return
rank_
;
}
int
GetSize
()
const
{
return
size_
;
}
virtual
const
std
::
string
GetBackendName
()
const
=
0
;
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
/* tensors */
,
const
AllreduceOptions
&
=
AllreduceOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support allreduce"
,
GetBackendName
()));
}
virtual
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
/* tensors */
,
const
BroadcastOptions
&
=
BroadcastOptions
())
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ProcessGroup%s does not support allreduce"
,
GetBackendName
()));
}
protected:
const
int
rank_
;
const
int
size_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupNCCL.cc
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
DECLARE_bool
(
nccl_blocking_wait
);
DECLARE_bool
(
use_stream_safe_cuda_allocator
);
constexpr
int64_t
kWaitBlockTImeout
=
10
;
namespace
paddle
{
namespace
distributed
{
static
ncclRedOp_t
ToNCCLRedType
(
ReduceOp
reduction
)
{
static
const
std
::
map
<
ReduceOp
,
ncclRedOp_t
>
red_type
=
{
{
ReduceOp
::
MIN
,
ncclMin
},
{
ReduceOp
::
MAX
,
ncclMax
},
{
ReduceOp
::
SUM
,
ncclSum
},
{
ReduceOp
::
PRODUCT
,
ncclProd
},
};
auto
it
=
red_type
.
find
(
reduction
);
PADDLE_ENFORCE_EQ
(
it
!=
red_type
.
end
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Invalid nccl reduction. Must be ncclMin | ncclMax | "
"ncclProd | ncclSum"
));
return
it
->
second
;
}
std
::
string
SerializeNCCLUniqueId
(
const
ncclUniqueId
&
ncclID
)
{
const
uint8_t
*
bytes
=
reinterpret_cast
<
const
uint8_t
*>
(
&
ncclID
);
std
::
ostringstream
oss
;
for
(
auto
i
=
0
;
i
<
NCCL_UNIQUE_ID_BYTES
;
++
i
)
{
oss
<<
std
::
hex
<<
static_cast
<
int
>
(
bytes
[
i
]);
}
return
oss
.
str
();
}
// Get the list of devices from list of tensors
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
std
::
vector
<
Place
>
places
;
places
.
reserve
(
tensors
.
size
());
for
(
auto
&
tensor
:
tensors
)
{
places
.
push_back
(
tensor
.
inner_place
());
}
return
places
;
}
// Get the deviceList String from the list of devices
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
)
{
std
::
string
placeList
;
for
(
auto
&
place
:
places
)
{
std
::
stringstream
tmp
;
tmp
<<
place
;
if
(
placeList
.
empty
())
{
placeList
+=
tmp
.
str
();
}
else
{
placeList
+=
","
+
tmp
.
str
();
}
}
return
placeList
;
}
bool
CheckTensorsInCudaPlace
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
Tensor
&
t
)
{
return
t
.
place
()
==
PlaceType
::
kGPU
;
});
}
void
SyncDefaultStream
(
const
std
::
vector
<
Place
>&
places
,
std
::
vector
<
EventManager
>&
ncclEvents
,
// NOLINT
std
::
vector
<
std
::
unique_ptr
<
CUDADeviceContext
>>&
dev_ctx
)
{
// NOLINT
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
auto
*
default_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
ncclEvents
[
i
].
Record
(
*
dev_ctx
[
i
]);
ncclEvents
[
i
].
Block
(
*
default_ctx
);
}
}
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
ProcessGroupNCCL
::
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
Tensor
>&
inputs
)
{
return
std
::
make_shared
<
ProcessGroupNCCL
::
NCCLTask
>
(
places
,
rank
,
comm_type
,
inputs
);
}
ProcessGroupNCCL
::
NCCLTask
::
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
)
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
control_events_
.
resize
(
places
.
size
());
ncclComms_
.
resize
(
places
.
size
());
}
ProcessGroupNCCL
::
NCCLTask
::~
NCCLTask
()
{}
void
ProcessGroupNCCL
::
NCCLTask
::
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
)
{
// NOLINT
outputs_
=
std
::
make_shared
<
std
::
vector
<
Tensor
>>
(
outputs
);
}
void
ProcessGroupNCCL
::
NCCLTask
::
SynchronizeStreams
()
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
*
default_ctx
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]));
default_ctx
->
WaitEvent
(
control_events_
[
i
].
GetRawCudaEvent
());
}
}
bool
ProcessGroupNCCL
::
NCCLTask
::
IsCompleted
()
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
if
(
!
control_events_
[
i
].
Query
())
{
return
false
;
}
}
return
true
;
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool
ProcessGroupNCCL
::
NCCLTask
::
Wait
(
std
::
chrono
::
milliseconds
timeout
)
{
SynchronizeStreams
();
if
(
FLAGS_nccl_blocking_wait
)
{
// NOTE(shenliang03): It will block host for sync
while
(
!
IsCompleted
())
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
kWaitBlockTImeout
));
}
}
return
true
;
}
// Same as Wait
void
ProcessGroupNCCL
::
NCCLTask
::
Synchronize
()
{
Wait
(
kWaitTimeout
);
}
ProcessGroupNCCL
::
ProcessGroupNCCL
(
const
ProcessGroupStrategy
&
strategy
,
int
rank
,
int
size
)
:
ProcessGroup
(
rank
,
size
),
strategy_
(
strategy
)
{}
void
ProcessGroupNCCL
::
BcastNCCLId
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
,
// NOLINT
int
root
,
int
server_fd
)
{
if
(
strategy_
.
local_rank_
==
root
)
{
std
::
vector
<
std
::
string
>
other_trainers
;
for
(
auto
&
ep
:
strategy_
.
trainer_endpoints_
)
{
if
(
ep
!=
strategy_
.
current_endpoint_
)
{
other_trainers
.
push_back
(
ep
);
}
}
platform
::
SendBroadCastCommID
(
other_trainers
,
&
nccl_ids
);
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
strategy_
.
current_endpoint_
,
&
nccl_ids
);
}
}
void
ProcessGroupNCCL
::
BroadcastUniqueNCCLID
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
)
{
// NOLINT
int
server_fd
=
-
1
;
if
(
rank_
!=
0
)
{
server_fd
=
platform
::
SocketServer
::
GetInstance
(
strategy_
.
current_endpoint_
)
.
socket
();
}
BcastNCCLId
(
nccl_ids
,
0
,
server_fd
);
}
// create NCCLManager cache for places_key
void
ProcessGroupNCCL
::
CreateNCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
)
{
PADDLE_ENFORCE_EQ
(
places_key
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"Not able to create/get the NCCL Communicator since "
"the GPU place are not known"
));
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
nccl_comms
;
nccl_comms
.
resize
(
places
.
size
());
// using vector just for broadcast
std
::
vector
<
ncclUniqueId
>
nccl_ids
;
nccl_ids
.
resize
(
1
);
auto
&
nccl_id
=
nccl_ids
.
front
();
if
(
rank_
==
0
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGetUniqueId
(
&
nccl_id
));
}
BroadcastUniqueNCCLID
(
nccl_ids
);
VLOG
(
3
)
<<
"init nccl rank: "
<<
strategy_
.
local_rank_
<<
", nranks: "
<<
strategy_
.
nranks_
<<
", place: "
<<
places_key
<<
", nccl uniqueid: "
<<
SerializeNCCLUniqueId
(
nccl_id
);
std
::
vector
<
std
::
unique_ptr
<
CUDADeviceContext
>>
dev_ctx
;
dev_ctx
.
resize
(
places
.
size
());
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupStart
());
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
platform
::
CUDADeviceGuard
guard
(
places
[
i
]);
nccl_comms
[
i
]
=
NCCLCommManager
::
Create
(
GetSize
(),
GetRank
(),
nccl_id
);
dev_ctx
[
i
].
reset
(
new
CUDADeviceContext
(
places
[
i
]));
}
PADDLE_ENFORCE_GPU_SUCCESS
(
platform
::
dynload
::
ncclGroupEnd
());
std
::
vector
<
EventManager
>
events
;
events
.
resize
(
places
.
size
());
// These caches will be useful to process sync/wait/communicate
places_to_events_
.
emplace
(
places_key
,
std
::
move
(
events
));
places_to_ncclcomm_
.
emplace
(
places_key
,
std
::
move
(
nccl_comms
));
places_to_ctx_
.
emplace
(
places_key
,
std
::
move
(
dev_ctx
));
}
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
Tensor
>&
outputs
,
Fn
fn
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_ncclcomm_
.
find
(
key
)
==
places_to_ncclcomm_
.
end
())
{
CreateNCCLManagerCache
(
key
,
places
);
}
}
auto
&
nccl_comms
=
places_to_ncclcomm_
[
key
];
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
inputs
);
task
->
SetOutputs
(
outputs
);
// construct uninitialize guard for device
platform
::
CUDADeviceGuard
cuda_guard
;
if
(
FLAGS_use_stream_safe_cuda_allocator
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
auto
dense_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
inputs
[
i
].
impl
());
memory
::
RecordStream
(
dense_tensor
->
Holder
(),
places_to_ctx_
[
key
][
i
]
->
stream
());
}
}
{
platform
::
NCCLGroupGuard
nccl_guard
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
const
auto
&
nccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
inputs
[
i
],
outputs
[
i
],
nccl_comms
[
i
]
->
GetNcclComm
(),
nccl_stream
);
}
}
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
cuda_guard
.
SetDevice
(
places
[
i
]);
task
->
control_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
}
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
tensors
,
tensors
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
return
platform
::
dynload
::
ncclAllReduce
(
input_tensor
->
data
(),
output_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
ToNCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupNCCL
::
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
opts
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCudaPlace
(
tensors
),
true
,
platform
::
errors
::
InvalidArgument
(
"All inputs should be in CudaPlace."
));
return
Collective
(
tensors
,
tensors
,
[
&
](
Tensor
&
input
,
Tensor
&
output
,
ncclComm_t
comm
,
const
gpuStream_t
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
tensors
.
size
()
+
opts
.
source_root
;
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
return
platform
::
dynload
::
ncclBcast
(
input_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToNCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupNCCL.h
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
constexpr
const
char
*
NCCL_BACKEND_NAME
=
"NCCL"
;
namespace
paddle
{
namespace
distributed
{
using
Place
=
paddle
::
platform
::
Place
;
using
CUDAStream
=
platform
::
stream
::
CUDAStream
;
using
CUDADeviceContext
=
paddle
::
platform
::
CUDADeviceContext
;
class
ProcessGroupNCCL
:
public
ProcessGroup
{
public:
class
NCCLTask
:
public
ProcessGroup
::
Task
,
public
std
::
enable_shared_from_this
<
NCCLTask
>
{
public:
NCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
);
bool
IsCompleted
();
void
SynchronizeStreams
();
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
);
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
);
// NOLINT
virtual
~
NCCLTask
();
std
::
vector
<
EventManager
>
control_events_
;
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>
ncclComms_
;
std
::
shared_ptr
<
std
::
vector
<
Tensor
>>
outputs_
;
private:
};
ProcessGroupNCCL
(
const
ProcessGroupStrategy
&
strategy
,
int
rank
,
int
size
);
const
std
::
string
GetBackendName
()
const
override
{
return
std
::
string
(
NCCL_BACKEND_NAME
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
protected:
virtual
std
::
shared_ptr
<
ProcessGroupNCCL
::
NCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
const
std
::
vector
<
Tensor
>&
inputs
);
protected:
ProcessGroupStrategy
strategy_
;
std
::
shared_ptr
<
NCCLCommManager
>
nccl_comm_
;
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
NCCLCommManager
>>>
places_to_ncclcomm_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
EventManager
>>
places_to_events_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
CUDADeviceContext
>>>
places_to_ctx_
;
private:
void
BcastNCCLId
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
,
int
root
,
// NOLINT
int
server_fd
);
void
BroadcastUniqueNCCLID
(
std
::
vector
<
ncclUniqueId
>&
nccl_ids
);
// NOLINT
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
Fn
fn
,
CommType
op_type
);
void
CreateNCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/Types.h
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <cstdint>
#include <vector>
namespace
paddle
{
namespace
distributed
{
// TODO(shenliang03): To support AVG for reduce
enum
class
ReduceOp
:
std
::
uint8_t
{
SUM
=
0
,
AVG
,
MAX
,
MIN
,
PRODUCT
};
struct
AllreduceOptions
{
ReduceOp
reduce_op
=
ReduceOp
::
SUM
;
};
struct
BroadcastOptions
{
int
source_rank
=
0
;
int
source_root
=
0
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/platform/cuda_device_guard.h
浏览文件 @
0b205817
...
...
@@ -14,13 +14,28 @@
#pragma once
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
platform
{
class
CUDADeviceGuard
{
public:
explicit
inline
CUDADeviceGuard
(
int
dev_id
)
{
explicit
CUDADeviceGuard
(
int
dev_id
)
{
SetDeviceIndex
(
dev_id
);
}
explicit
CUDADeviceGuard
(
const
CUDAPlace
&
place
)
:
CUDADeviceGuard
(
place
.
device
)
{}
// create uninitialized CUDADeviceGuard
CUDADeviceGuard
()
{}
~
CUDADeviceGuard
()
{
if
(
prev_id_
!=
-
1
)
{
platform
::
SetDeviceId
(
prev_id_
);
}
}
inline
void
SetDeviceIndex
(
const
int
dev_id
)
{
int
prev_id
=
platform
::
GetCurrentDeviceId
();
if
(
prev_id
!=
dev_id
)
{
prev_id_
=
prev_id
;
...
...
@@ -28,10 +43,9 @@ class CUDADeviceGuard {
}
}
inline
~
CUDADeviceGuard
()
{
if
(
prev_id_
!=
-
1
)
{
platform
::
SetDeviceId
(
prev_id_
);
}
void
SetDevice
(
const
CUDAPlace
&
place
)
{
int
dev_id
=
place
.
device
;
SetDeviceIndex
(
dev_id
);
}
CUDADeviceGuard
(
const
CUDADeviceGuard
&
o
)
=
delete
;
...
...
paddle/fluid/platform/device/gpu/nccl_helper.h
浏览文件 @
0b205817
...
...
@@ -56,6 +56,23 @@ inline ncclDataType_t ToNCCLDataType(framework::proto::VarType::Type type) {
}
}
inline
ncclDataType_t
ToNCCLDataType
(
experimental
::
DataType
type
)
{
if
(
type
==
experimental
::
DataType
::
FLOAT32
)
{
return
ncclFloat
;
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT64
)
{
return
ncclDouble
;
}
else
if
(
type
==
experimental
::
DataType
::
INT32
)
{
return
ncclInt
;
}
else
if
(
type
==
experimental
::
DataType
::
INT64
)
{
return
ncclInt64
;
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
ncclFloat16
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in nccl is not supported."
));
}
}
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
...
...
paddle/fluid/platform/flags.cc
浏览文件 @
0b205817
...
...
@@ -761,3 +761,15 @@ DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false"
);
DEFINE_bool
(
enable_ins_parser_file
,
false
,
"enable parser ins file , default false"
);
/**
* ProcessGroupNCCL related FLAG
* Name: nccl_blocking_wait
* Since Version:
* Value Range: bool, default=false
* Example:
* Note: nccl blocking wait.
*/
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PADDLE_DEFINE_EXPORTED_bool
(
nccl_blocking_wait
,
false
,
"nccl blocking wait"
);
#endif
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
0b205817
...
...
@@ -80,6 +80,14 @@ set(PYBIND_SRCS
communication.cc
cuda_streams_py.cc
)
if
(
NOT ON_INFER
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup
)
if
(
WITH_NCCL
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_nccl
)
endif
()
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
distributed_py.cc
)
endif
()
if
(
WITH_ASCEND
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
ascend_wrapper
)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
ascend_wrapper_py.cc
)
...
...
paddle/fluid/pybind/distributed_py.cc
0 → 100644
浏览文件 @
0b205817
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/api/all.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
using
Tensor
=
paddle
::
experimental
::
Tensor
;
void
BindDistributed
(
py
::
module
*
m
)
{
py
::
enum_
<
distributed
::
ReduceOp
>
(
*
m
,
"ReduceOp"
)
.
value
(
"SUM"
,
distributed
::
ReduceOp
::
SUM
)
.
value
(
"AVG"
,
distributed
::
ReduceOp
::
AVG
)
.
value
(
"MAX"
,
distributed
::
ReduceOp
::
MAX
)
.
value
(
"MIN"
,
distributed
::
ReduceOp
::
MIN
)
.
value
(
"PRODUCT"
,
distributed
::
ReduceOp
::
PRODUCT
);
py
::
class_
<
distributed
::
AllreduceOptions
>
(
*
m
,
"AllreduceOptions"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"reduce_op"
,
&
distributed
::
AllreduceOptions
::
reduce_op
);
py
::
class_
<
distributed
::
BroadcastOptions
>
(
*
m
,
"BroadcastOptions"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"source_rank"
,
&
distributed
::
BroadcastOptions
::
source_rank
)
.
def_readwrite
(
"source_root"
,
&
distributed
::
BroadcastOptions
::
source_root
);
auto
ProcessGroup
=
py
::
class_
<
distributed
::
ProcessGroup
,
std
::
shared_ptr
<
distributed
::
ProcessGroup
>>
(
*
m
,
"ProcessGroup"
)
.
def
(
"rank"
,
&
distributed
::
ProcessGroup
::
GetRank
)
.
def
(
"size"
,
&
distributed
::
ProcessGroup
::
GetSize
)
.
def
(
"name"
,
&
distributed
::
ProcessGroup
::
GetBackendName
)
.
def
(
"allreduce"
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
distributed
::
ReduceOp
op
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
AllreduceOptions
opts
;
opts
.
reduce_op
=
op
;
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
return
self
.
AllReduce
(
tensors
,
opts
);
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"op"
)
=
distributed
::
ReduceOp
::
SUM
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"broadcast"
,
[](
distributed
::
ProcessGroup
&
self
,
py
::
handle
py_tensor
,
int
source_rank
)
{
auto
tensor
=
CastPyArg2Tensor
(
py_tensor
.
ptr
(),
0
);
distributed
::
BroadcastOptions
opts
;
opts
.
source_rank
=
source_rank
;
std
::
vector
<
Tensor
>
tensors
=
{
tensor
};
return
self
.
Broadcast
(
tensors
,
opts
);
},
py
::
arg
(
"tensor"
),
py
::
arg
(
"source_rank"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#if defined(PADDLE_WITH_NCCL)
py
::
class_
<
distributed
::
ProcessGroupNCCL
,
std
::
shared_ptr
<
distributed
::
ProcessGroupNCCL
>>
(
*
m
,
"ProcessGroupNCCL"
,
ProcessGroup
)
.
def
(
py
::
init
<
const
distributed
::
ProcessGroupStrategy
&
,
int
,
int
>
(),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
class_
<
distributed
::
ProcessGroup
::
Task
,
std
::
shared_ptr
<
distributed
::
ProcessGroup
::
Task
>>
(
*
m
,
"task"
)
.
def
(
"is_completed"
,
&
distributed
::
ProcessGroup
::
Task
::
IsCompleted
)
.
def
(
"wait"
,
&
distributed
::
ProcessGroup
::
Task
::
Wait
,
py
::
arg
(
"timeout"
)
=
kWaitTimeout
,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"synchronize"
,
&
distributed
::
ProcessGroup
::
Task
::
Synchronize
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
// define parallel strategy, it will be removed
py
::
class_
<
distributed
::
ProcessGroupStrategy
>
pg_strategy
(
*
m
,
"ProcessGroupStrategy"
,
""
);
pg_strategy
.
def
(
py
::
init
())
.
def_property
(
"nranks"
,
[](
const
distributed
::
ProcessGroupStrategy
&
self
)
{
return
self
.
nranks_
;
},
[](
distributed
::
ProcessGroupStrategy
&
self
,
int
nranks
)
{
self
.
nranks_
=
nranks
;
})
.
def_property
(
"local_rank"
,
[](
const
distributed
::
ProcessGroupStrategy
&
self
)
{
return
self
.
local_rank_
;
},
[](
distributed
::
ProcessGroupStrategy
&
self
,
int
local_rank
)
{
self
.
local_rank_
=
local_rank
;
})
.
def_property
(
"trainer_endpoints"
,
[](
const
distributed
::
ProcessGroupStrategy
&
self
)
{
return
self
.
trainer_endpoints_
;
},
[](
distributed
::
ProcessGroupStrategy
&
self
,
std
::
vector
<
std
::
string
>
eps
)
{
self
.
trainer_endpoints_
=
eps
;
})
.
def_property
(
"current_endpoint"
,
[](
const
distributed
::
ProcessGroupStrategy
&
self
)
{
return
self
.
current_endpoint_
;
},
[](
distributed
::
ProcessGroupStrategy
&
self
,
const
std
::
string
&
ep
)
{
self
.
current_endpoint_
=
ep
;
})
.
def_property
(
"nrings"
,
[](
const
distributed
::
ProcessGroupStrategy
&
self
)
{
return
self
.
nrings_
;
},
[](
distributed
::
ProcessGroupStrategy
&
self
,
int
nrings
)
{
self
.
nrings_
=
nrings
;
});
}
}
// end namespace pybind
}
// namespace paddle
paddle/fluid/pybind/distributed_py.h
0 → 100644
浏览文件 @
0b205817
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/chrono.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
void
BindDistributed
(
py
::
module
*
m
);
}
// namespace pybind
}
// namespace paddle
paddle/fluid/pybind/pybind.cc
浏览文件 @
0b205817
...
...
@@ -78,6 +78,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/fluid/pybind/distributed_py.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/lod_utils.h"
#ifndef PADDLE_ON_INFERENCE
...
...
@@ -3895,6 +3896,9 @@ All parameter, weight, gradient are variables in Paddle.
BindCompatible
(
&
m
);
BindDataset
(
&
m
);
BindGenerator
(
&
m
);
#ifndef PADDLE_ON_INFERENCE
BindDistributed
(
&
m
);
#endif
#ifdef PADDLE_WITH_ASCEND
BindAscendWrapper
(
&
m
);
BindAscendGraph
(
&
m
);
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
0b205817
...
...
@@ -54,6 +54,7 @@ list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_data_unshard
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_save_load
)
list
(
APPEND DIST_TEST_OPS test_auto_parallel_autoconvert
)
list
(
APPEND DIST_TEST_OPS test_collective_process_group
)
set
(
MIXED_DIST_TEST_OPS
${
DIST_TEST_OPS
}
)
#remove distribute unittests.
list
(
APPEND MIXED_DIST_TEST_OPS test_dgc_op
)
...
...
@@ -290,6 +291,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
LIST
(
REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard
)
LIST
(
REMOVE_ITEM TEST_OPS test_auto_parallel_save_load
)
LIST
(
REMOVE_ITEM TEST_OPS test_auto_parallel_autoconvert
)
LIST
(
REMOVE_ITEM TEST_OPS test_collective_process_group
)
elseif
(
WITH_GPU
)
if
(
${
CUDNN_VERSION
}
VERSION_LESS 7100
)
LIST
(
REMOVE_ITEM TEST_OPS test_conv2d_fusion_op
)
...
...
@@ -1114,6 +1116,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties
(
test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_save_load PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_collective_process_group PROPERTIES TIMEOUT 120
)
if
(
${
NCCL_VERSION
}
VERSION_GREATER_EQUAL 2212
)
set_tests_properties
(
test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120
)
set_tests_properties
(
test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120
)
...
...
python/paddle/fluid/tests/unittests/process_group_nccl.py
0 → 100644
浏览文件 @
0b205817
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
random
import
numpy
as
np
import
os
import
shutil
import
paddle
from
paddle.fluid
import
core
from
datetime
import
timedelta
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
_test_eager_guard
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
ProcessGroupStrategy
=
core
.
ProcessGroupStrategy
def
init_process_group
(
strategy
=
None
):
# this will remove
if
strategy
is
None
:
strategy
=
ProcessGroupStrategy
()
strategy
.
nranks
=
ParallelEnv
().
nranks
strategy
.
local_rank
=
ParallelEnv
().
local_rank
strategy
.
trainer_endpoints
=
ParallelEnv
().
trainer_endpoints
strategy
.
current_endpoint
=
ParallelEnv
().
current_endpoint
if
strategy
.
nranks
<
2
:
return
pg_group
=
core
.
ProcessGroupNCCL
(
strategy
,
strategy
.
local_rank
,
strategy
.
nranks
)
return
pg_group
class
TestProcessGroupFp32
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
random
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
self
.
config
()
def
config
(
self
):
self
.
dtype
=
"float32"
self
.
shape
=
(
2
,
10
,
5
)
def
test_create_process_group_nccl
(
self
):
with
_test_eager_guard
():
paddle
.
set_device
(
'gpu:%d'
%
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
pg
=
init_process_group
()
print
(
"rank:"
,
pg
.
rank
(),
"size:"
,
pg
.
size
(),
"name:"
,
pg
.
name
())
print
(
"test new group api ok"
)
# test allreduce sum
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
sum_result
=
tensor_x
+
tensor_y
if
pg
.
rank
()
==
0
:
task
=
pg
.
allreduce
(
tensor_x
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
sum_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
sum_result
)
print
(
"test allreduce sum api ok"
)
# test allreduce max
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
max_result
=
paddle
.
maximum
(
tensor_x
,
tensor_y
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
allreduce
(
tensor_x
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
max_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
max_result
)
print
(
"test allreduce max api ok"
)
# test broadcast
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
broadcast_result
=
paddle
.
assign
(
tensor_x
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
broadcast
(
tensor_x
,
0
)
task
.
synchronize
()
paddle
.
device
.
cuda
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_x
)
else
:
task
=
pg
.
broadcast
(
tensor_y
,
0
)
task
.
synchronize
()
paddle
.
device
.
cuda
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_y
)
print
(
"test broadcast api ok"
)
class
TestProcessGroupFp16
(
TestProcessGroupFp32
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
random
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
self
.
config
()
def
config
(
self
):
self
.
dtype
=
"float16"
self
.
shape
=
(
4
,
20
,
20
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_collective_process_group.py
0 → 100644
浏览文件 @
0b205817
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestProcessGroup
(
TestMultipleGpus
):
def
test_process_group_nccl
(
self
):
self
.
run_mnist_2gpu
(
'process_group_nccl.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录