Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
73583f86
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
73583f86
编写于
3月 08, 2022
作者:
L
lilong12
提交者:
GitHub
3月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add the implementation of process group for hccl (#40228)
* add pg_hccl
上级
7024ade7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
995 addition
and
0 deletion
+995
-0
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+3
-0
paddle/fluid/distributed/collective/HCCLTools.h
paddle/fluid/distributed/collective/HCCLTools.h
+174
-0
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
+356
-0
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
+152
-0
paddle/fluid/platform/device/npu/hccl_helper.h
paddle/fluid/platform/device/npu/hccl_helper.h
+17
-0
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+3
-0
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+12
-0
python/paddle/fluid/tests/unittests/npu/process_group_hccl.py
...on/paddle/fluid/tests/unittests/npu/process_group_hccl.py
+249
-0
python/paddle/fluid/tests/unittests/npu/test_collective_process_group_hccl.py
...tests/unittests/npu/test_collective_process_group_hccl.py
+29
-0
未找到文件。
paddle/fluid/distributed/collective/CMakeLists.txt
浏览文件 @
73583f86
...
...
@@ -7,3 +7,6 @@ cc_library(eager_reducer SRCS reducer.cc DEPS eager_api processgroup)
if
(
WITH_NCCL
)
cc_library
(
processgroup_nccl SRCS ProcessGroupNCCL.cc DEPS place cuda_stream enforce collective_helper device_context phi phi_api eager_api
)
endif
()
if
(
WITH_ASCEND_CL
)
cc_library
(
processgroup_hccl SRCS ProcessGroupHCCL.cc DEPS place npu_stream enforce collective_helper device_context phi phi_api eager_api
)
endif
()
paddle/fluid/distributed/collective/HCCLTools.h
0 → 100644
浏览文件 @
73583f86
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <error.h>
#include <string>
#include "boost/variant.hpp"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/enforce_npu.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
class
NPUEventManager
{
public:
NPUEventManager
()
=
default
;
~
NPUEventManager
()
{
if
(
is_created_
)
{
platform
::
NPUDeviceGuard
guard
(
device_index_
);
platform
::
NPUEventDestroy
(
event_
);
}
}
NPUEventManager
(
const
NPUEventManager
&
)
=
delete
;
NPUEventManager
&
operator
=
(
const
NPUEventManager
&
)
=
delete
;
NPUEventManager
(
NPUEventManager
&&
other
)
{
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
}
NPUEventManager
&
operator
=
(
NPUEventManager
&&
other
)
{
std
::
swap
(
is_created_
,
other
.
is_created_
);
std
::
swap
(
device_index_
,
other
.
device_index_
);
std
::
swap
(
event_
,
other
.
event_
);
return
*
this
;
}
bool
IsCreated
()
const
{
return
is_created_
;
}
bool
DeviceId
()
const
{
return
device_index_
;
}
aclrtEvent
GetRawNPUEvent
()
const
{
return
event_
;
}
void
Record
(
const
paddle
::
platform
::
NPUDeviceContext
&
ctx
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
if
(
!
is_created_
)
{
CreateEvent
(
device_index
);
}
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"NPUDeviceContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
NPUDeviceGuard
guard
(
device_index_
);
platform
::
NPUEventRecord
(
event_
,
ctx
.
stream
());
}
bool
Query
()
const
{
aclrtEventStatus
status
=
ACL_EVENT_STATUS_COMPLETE
;
platform
::
NPUEventQuery
(
event_
,
&
status
);
if
(
status
==
ACL_EVENT_STATUS_COMPLETE
)
{
return
true
;
}
return
false
;
}
void
Block
(
const
paddle
::
platform
::
NPUDeviceContext
&
ctx
)
const
{
if
(
is_created_
)
{
auto
device_index
=
ctx
.
GetPlace
().
device
;
PADDLE_ENFORCE_EQ
(
device_index
,
device_index_
,
platform
::
errors
::
PreconditionNotMet
(
"CUDADeviceContext's device %d does not match"
"Event's device %d"
,
device_index
,
device_index_
));
platform
::
NPUDeviceGuard
guard
(
device_index_
);
platform
::
NPUStreamWaitEvent
(
ctx
.
stream
(),
event_
);
}
}
private:
bool
is_created_
{
false
};
aclrtEvent
event_
{};
int8_t
device_index_
{
0
};
private:
void
CreateEvent
(
int
device_index
)
{
device_index_
=
device_index
;
platform
::
NPUDeviceGuard
guard
(
device_index
);
platform
::
NPUEventCreate
(
&
event_
);
is_created_
=
true
;
}
};
class
HCCLCommManager
{
public:
explicit
HCCLCommManager
(
HcclComm
hcclComm
)
:
hccl_comm_
(
hcclComm
)
{}
HCCLCommManager
()
:
HCCLCommManager
(
nullptr
)
{}
~
HCCLCommManager
()
noexcept
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
if
(
hccl_comm_
)
{
platform
::
dynload
::
HcclCommDestroy
(
hccl_comm_
);
}
}
static
std
::
shared_ptr
<
HCCLCommManager
>
Create
(
int
num_ranks
,
int
rank
,
HcclRootInfo
*
comm_id
,
HcclComm
hccl_comm
)
{
auto
hccl_manager
=
std
::
make_shared
<
HCCLCommManager
>
();
auto
ret
=
platform
::
dynload
::
HcclCommInitRootInfo
(
num_ranks
,
comm_id
,
rank
,
&
hccl_comm
);
using
__NPU_STATUS_TYPE__
=
decltype
(
ret
);
constexpr
auto
__success_type__
=
platform
::
details
::
NPUStatusType
<
__NPU_STATUS_TYPE__
>::
kSuccess
;
if
(
UNLIKELY
(
ret
!=
__success_type__
))
{
VLOG
(
0
)
<<
"Error: create hccl_id error."
;
exit
(
-
1
);
}
hccl_manager
->
hccl_id_
=
comm_id
;
hccl_manager
->
rank_
=
rank
;
hccl_manager
->
hccl_comm_
=
hccl_comm
;
return
hccl_manager
;
}
HcclRootInfo
*
GetHcclId
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
hccl_id_
;
}
HcclComm
GetHcclComm
()
const
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
return
hccl_comm_
;
}
HCCLCommManager
(
const
HCCLCommManager
&
)
=
delete
;
HCCLCommManager
&
operator
=
(
const
HCCLCommManager
&
)
=
delete
;
HCCLCommManager
&
operator
=
(
HCCLCommManager
&&
other
)
=
delete
;
HCCLCommManager
(
HCCLCommManager
&&
other
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
other
.
mutex_
);
std
::
swap
(
hccl_comm_
,
other
.
hccl_comm_
);
}
protected:
HcclComm
hccl_comm_
;
HcclRootInfo
*
hccl_id_
;
int
rank_
;
mutable
std
::
mutex
mutex_
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupHCCL.cc
0 → 100644
浏览文件 @
73583f86
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h"
DECLARE_bool
(
hccl_blocking_wait
);
// DECLARE_bool(use_stream_safe_npu_allocator);
constexpr
int64_t
kWaitBlockTImeout
=
10
;
namespace
paddle
{
namespace
distributed
{
static
HcclReduceOp
ToHCCLRedType
(
ReduceOp
reduction
)
{
static
const
std
::
map
<
ReduceOp
,
HcclReduceOp
>
red_type
=
{
{
ReduceOp
::
MIN
,
HCCL_REDUCE_MIN
},
{
ReduceOp
::
MAX
,
HCCL_REDUCE_MAX
},
{
ReduceOp
::
SUM
,
HCCL_REDUCE_SUM
},
{
ReduceOp
::
PRODUCT
,
HCCL_REDUCE_PROD
},
};
auto
it
=
red_type
.
find
(
reduction
);
PADDLE_ENFORCE_EQ
(
it
!=
red_type
.
end
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Invalid hccl reduction. "
"Must be Min | Max | Prod | Sum"
));
return
it
->
second
;
}
std
::
string
SerializeHCCLUniqueId
(
const
HcclRootInfo
&
hcclID
)
{
const
uint8_t
*
bytes
=
reinterpret_cast
<
const
uint8_t
*>
(
&
hcclID
);
std
::
ostringstream
oss
;
for
(
size_t
i
=
0
;
i
<
sizeof
(
hcclID
);
++
i
)
{
oss
<<
std
::
hex
<<
static_cast
<
int
>
(
bytes
[
i
]);
}
return
oss
.
str
();
}
// Get the list of devices from list of tensors
std
::
vector
<
Place
>
GetPlaceList
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
std
::
vector
<
Place
>
places
;
places
.
reserve
(
tensors
.
size
());
for
(
auto
&
tensor
:
tensors
)
{
places
.
push_back
(
tensor
.
inner_place
());
}
return
places
;
}
// Get the deviceList String from the list of devices
std
::
string
GetKeyFromPlaces
(
const
std
::
vector
<
Place
>&
places
)
{
std
::
string
placeList
;
for
(
auto
&
place
:
places
)
{
std
::
stringstream
tmp
;
tmp
<<
place
;
if
(
placeList
.
empty
())
{
placeList
+=
tmp
.
str
();
}
else
{
placeList
+=
","
+
tmp
.
str
();
}
}
return
placeList
;
}
// bool CheckTensorsInNPUPlace(const std::vector<Tensor>& tensors) {
// return std::all_of(tensors.cbegin(), tensors.cend(), [&](const Tensor& t) {
// return t.place() == platform::DeviceType::NPU;
// });
// }
void
SyncDefaultStream
(
const
std
::
vector
<
Place
>&
places
,
std
::
vector
<
NPUEventManager
>&
hcclEvents
,
// NOLINT
std
::
vector
<
std
::
unique_ptr
<
NPUDeviceContext
>>&
dev_ctx
)
{
// NOLINT
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
auto
*
default_ctx
=
static_cast
<
platform
::
NPUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
places
[
i
]));
hcclEvents
[
i
].
Record
(
*
dev_ctx
[
i
]);
hcclEvents
[
i
].
Block
(
*
default_ctx
);
}
}
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
ProcessGroupHCCL
::
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
comm_type
,
const
std
::
vector
<
Tensor
>&
inputs
)
{
return
std
::
make_shared
<
ProcessGroupHCCL
::
HCCLTask
>
(
places
,
rank
,
comm_type
,
inputs
);
}
ProcessGroupHCCL
::
HCCLTask
::
HCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
)
:
Task
(
rank
,
inputs
,
CommType
),
places_
(
places
)
{
control_events_
.
resize
(
places
.
size
());
hcclComms_
.
resize
(
places
.
size
());
}
ProcessGroupHCCL
::
HCCLTask
::~
HCCLTask
()
{}
void
ProcessGroupHCCL
::
HCCLTask
::
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
)
{
// NOLINT
outputs_
=
std
::
make_shared
<
std
::
vector
<
Tensor
>>
(
outputs
);
}
void
ProcessGroupHCCL
::
HCCLTask
::
SynchronizeStreams
()
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
*
default_ctx
=
static_cast
<
platform
::
NPUDeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
places_
[
i
]));
platform
::
NPUStreamWaitEvent
(
default_ctx
->
stream
(),
control_events_
[
i
].
GetRawNPUEvent
());
}
}
bool
ProcessGroupHCCL
::
HCCLTask
::
IsCompleted
()
{
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
if
(
!
control_events_
[
i
].
Query
())
{
return
false
;
}
}
return
true
;
}
// TODO(sandyhouse): Add timeout for wait, now timeout unused
bool
ProcessGroupHCCL
::
HCCLTask
::
Wait
(
std
::
chrono
::
milliseconds
timeout
)
{
SynchronizeStreams
();
if
(
FLAGS_hccl_blocking_wait
)
{
// NOTE(sandyhouse): It will block host for sync
while
(
!
IsCompleted
())
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
kWaitBlockTImeout
));
}
}
return
true
;
}
// Same as Wait
void
ProcessGroupHCCL
::
HCCLTask
::
Synchronize
()
{
Wait
(
kWaitTimeout
);
}
ProcessGroupHCCL
::
ProcessGroupHCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
int
rank
,
int
size
)
:
ProcessGroup
(
rank
,
size
),
store_
(
store
)
{}
void
ProcessGroupHCCL
::
BroadcastUniqueHCCLID
(
std
::
vector
<
HcclRootInfo
>&
hccl_ids
)
{
// NOLINT
if
(
rank_
==
0
)
{
for
(
size_t
i
=
0
;
i
<
hccl_ids
.
size
();
i
++
)
{
auto
key
=
"ProcessGroupHCCL/hccl_ids/"
+
std
::
to_string
(
i
);
auto
hccl_id
=
std
::
vector
<
uint8_t
>
(
reinterpret_cast
<
uint8_t
*>
(
&
hccl_ids
[
i
]),
reinterpret_cast
<
uint8_t
*>
(
&
hccl_ids
[
i
])
+
sizeof
(
HcclRootInfo
));
store_
->
set
(
key
,
hccl_id
);
}
}
else
{
for
(
size_t
i
=
0
;
i
<
hccl_ids
.
size
();
i
++
)
{
auto
key
=
"ProcessGroupHCCL/hccl_ids/"
+
std
::
to_string
(
i
);
auto
ret
=
store_
->
get
(
key
);
std
::
memcpy
(
&
hccl_ids
[
i
],
ret
.
data
(),
ret
.
size
());
}
}
}
// create HCCLManager cache for places_key
void
ProcessGroupHCCL
::
CreateHCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
)
{
PADDLE_ENFORCE_EQ
(
places_key
.
empty
(),
false
,
platform
::
errors
::
PreconditionNotMet
(
"Not able to create/get the HCCL Communicator since "
"the NPU place are not known"
));
std
::
vector
<
std
::
shared_ptr
<
HCCLCommManager
>>
hccl_comms
;
hccl_comms
.
resize
(
places
.
size
());
// using vector just for broadcast
std
::
vector
<
HcclRootInfo
>
hccl_ids
;
hccl_ids
.
resize
(
1
);
auto
&
hccl_id
=
hccl_ids
.
front
();
if
(
rank_
==
0
)
{
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
HcclGetRootInfo
(
&
hccl_id
));
}
BroadcastUniqueHCCLID
(
hccl_ids
);
VLOG
(
3
)
<<
"init hccl rank: "
<<
rank_
<<
", nranks: "
<<
size_
<<
", place: "
<<
places_key
<<
", hccl uniqueid: "
<<
SerializeHCCLUniqueId
(
hccl_id
);
std
::
vector
<
std
::
unique_ptr
<
NPUDeviceContext
>>
dev_ctx
;
dev_ctx
.
resize
(
places
.
size
());
std
::
unique_ptr
<
HcclComm
[]
>
comms
(
new
HcclComm
[
places
.
size
()]);
for
(
size_t
i
=
0
;
i
<
places
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
hccl_comms
[
i
]
=
HCCLCommManager
::
Create
(
GetSize
(),
GetRank
(),
&
hccl_id
,
comms
.
get
()
+
i
);
dev_ctx
[
i
].
reset
(
new
NPUDeviceContext
(
places
[
i
]));
}
std
::
vector
<
NPUEventManager
>
events
;
events
.
resize
(
places
.
size
());
// These caches will be useful to process sync/wait/communicate
places_to_events_
.
emplace
(
places_key
,
std
::
move
(
events
));
places_to_hcclcomm_
.
emplace
(
places_key
,
std
::
move
(
hccl_comms
));
places_to_ctx_
.
emplace
(
places_key
,
std
::
move
(
dev_ctx
));
}
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
std
::
vector
<
Tensor
>&
outputs
,
Fn
fn
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
inputs
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_hcclcomm_
.
find
(
key
)
==
places_to_hcclcomm_
.
end
())
{
CreateHCCLManagerCache
(
key
,
places
);
}
}
auto
&
hccl_comms
=
places_to_hcclcomm_
[
key
];
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
inputs
);
task
->
SetOutputs
(
outputs
);
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < inputs.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(inputs[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
const
auto
&
hccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
inputs
[
i
],
outputs
[
i
],
hccl_comms
[
i
]
->
GetHcclComm
(),
hccl_stream
);
}
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
task
->
control_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
}
return
task
;
}
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
Fn
fn
,
int
dst_rank
,
CommType
op_type
)
{
const
auto
places
=
GetPlaceList
(
tensors
);
const
auto
key
=
GetKeyFromPlaces
(
places
);
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
places_to_hcclcomm_
.
find
(
key
)
==
places_to_hcclcomm_
.
end
())
{
CreateHCCLManagerCache
(
key
,
places
);
}
}
auto
&
hccl_comms
=
places_to_hcclcomm_
[
key
];
SyncDefaultStream
(
places
,
places_to_events_
[
key
],
places_to_ctx_
[
key
]);
auto
task
=
CreateTask
(
places
,
rank_
,
op_type
,
tensors
);
// construct uninitialize guard for device
// if (FLAGS_use_stream_safe_npu_allocator) {
// for (size_t i = 0; i < tensors.size(); ++i) {
// platform::NPUDeviceGuard guard(places[i].GetDeviceId());
// auto dense_tensor =
// std::dynamic_pointer_cast<phi::DenseTensor>(tensors[i].impl());
// memory::RecordStream(dense_tensor->Holder(),
// places_to_ctx_[key][i]->stream());
// }
// }
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
const
auto
&
hccl_stream
=
places_to_ctx_
[
key
][
i
]
->
stream
();
fn
(
tensors
[
i
],
hccl_comms
[
i
]
->
GetHcclComm
(),
hccl_stream
,
dst_rank
);
}
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
platform
::
NPUDeviceGuard
guard
(
places
[
i
].
GetDeviceId
());
task
->
control_events_
[
i
].
Record
(
*
places_to_ctx_
[
key
][
i
]);
}
return
task
;
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
opts
)
{
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// NPUPlace."));
return
Collective
(
tensors
,
tensors
,
[
&
](
const
Tensor
&
input
,
Tensor
&
output
,
HcclComm
comm
,
const
aclrtStream
&
stream
)
{
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
return
platform
::
dynload
::
HcclAllReduce
(
input_tensor
->
data
(),
output_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToHCCLDataType
(
input
.
type
()),
ToHCCLRedType
(
opts
.
reduce_op
),
comm
,
stream
);
},
CommType
::
ALLREDUCE
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupHCCL
::
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
opts
)
{
// PADDLE_ENFORCE_EQ(
// CheckTensorsInNPUPlace(tensors), true,
// platform::errors::InvalidArgument("All inputs should be in
// CudaPlace."));
return
Collective
(
tensors
,
tensors
,
[
&
](
Tensor
&
input
,
Tensor
&
output
,
HcclComm
comm
,
const
aclrtStream
&
stream
)
{
const
auto
root
=
opts
.
source_rank
*
tensors
.
size
()
+
opts
.
source_root
;
auto
input_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
input
.
impl
());
auto
output_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
output
.
impl
());
return
platform
::
dynload
::
HcclBroadcast
(
input_tensor
->
data
(),
input_tensor
->
numel
(),
platform
::
ToHCCLDataType
(
input
.
type
()),
root
,
comm
,
stream
);
},
CommType
::
BROADCAST
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupHCCL.h
0 → 100644
浏览文件 @
73583f86
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/platform/device/npu/npu_stream.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/distributed/collective/HCCLTools.h"
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gen_comm_id_helper.h"
#include "paddle/fluid/platform/place.h"
constexpr
const
char
*
HCCL_BACKEND_NAME
=
"HCCL"
;
namespace
paddle
{
namespace
distributed
{
using
Place
=
paddle
::
platform
::
Place
;
using
NPUStream
=
platform
::
stream
::
NPUStream
;
using
NPUDeviceContext
=
paddle
::
platform
::
NPUDeviceContext
;
class
ProcessGroupHCCL
:
public
ProcessGroup
{
public:
class
HCCLTask
:
public
ProcessGroup
::
Task
,
public
std
::
enable_shared_from_this
<
HCCLTask
>
{
public:
HCCLTask
(
const
std
::
vector
<
Place
>&
places
,
int
rank
,
CommType
CommType
,
const
std
::
vector
<
Tensor
>&
inputs
);
bool
IsCompleted
();
void
SynchronizeStreams
();
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
=
kWaitTimeout
);
void
Synchronize
();
void
SetOutputs
(
std
::
vector
<
Tensor
>&
outputs
);
// NOLINT
virtual
~
HCCLTask
();
std
::
vector
<
NPUEventManager
>
control_events_
;
protected:
std
::
vector
<
Place
>
places_
;
std
::
vector
<
std
::
shared_ptr
<
HCCLCommManager
>>
hcclComms_
;
std
::
shared_ptr
<
std
::
vector
<
Tensor
>>
outputs_
;
private:
};
ProcessGroupHCCL
(
const
std
::
shared_ptr
<
Store
>&
store
,
int
rank
,
int
size
);
const
std
::
string
GetBackendName
()
const
override
{
return
std
::
string
(
HCCL_BACKEND_NAME
);
}
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
AllreduceOptions
&
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
tensors
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Barrier
(
const
BarrierOptions
&
=
BarrierOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Send
(
std
::
vector
<
Tensor
>&
tensors
,
int
dst_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Recv
(
std
::
vector
<
Tensor
>&
tensors
,
int
src_rank
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllGather
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllToAll
(
std
::
vector
<
Tensor
>&
in
,
std
::
vector
<
Tensor
>&
out
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Reduce
(
std
::
vector
<
Tensor
>&
tensors
,
const
ReduceOptions
&
opts
)
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Scatter
(
std
::
vector
<
Tensor
>&
in_tensors
,
std
::
vector
<
Tensor
>&
out_tensors
,
const
ScatterOptions
&
)
override
;
protected:
virtual
std
::
shared_ptr
<
ProcessGroupHCCL
::
HCCLTask
>
CreateTask
(
std
::
vector
<
Place
>
places
,
int
rank
,
CommType
opType
,
const
std
::
vector
<
Tensor
>&
inputs
);
std
::
shared_ptr
<
Store
>
store_
;
std
::
shared_ptr
<
HCCLCommManager
>
hccl_comm_
;
std
::
mutex
mutex_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
HCCLCommManager
>>>
places_to_hcclcomm_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
NPUEventManager
>>
places_to_events_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
unique_ptr
<
NPUDeviceContext
>>>
places_to_ctx_
;
std
::
set
<
int
>
used_place_ids_
;
private:
void
BcastHCCLId
(
std
::
vector
<
HcclRootInfo
>&
hccl_ids
,
int
root
,
// NOLINT
int
server_fd
);
void
BroadcastUniqueHCCLID
(
std
::
vector
<
HcclRootInfo
>&
hccl_ids
);
// NOLINT
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Collective
(
std
::
vector
<
Tensor
>&
inputs
,
// NOLINT
std
::
vector
<
Tensor
>&
outputs
,
// NOLINT
Fn
fn
,
CommType
op_type
);
template
<
typename
Fn
>
std
::
shared_ptr
<
ProcessGroup
::
Task
>
PointToPoint
(
std
::
vector
<
Tensor
>&
tensors
,
// NOLINT
Fn
fn
,
int
dst_rank
,
CommType
op_type
);
void
CreateHCCLManagerCache
(
const
std
::
string
&
places_key
,
const
std
::
vector
<
Place
>&
places
);
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/platform/device/npu/hccl_helper.h
浏览文件 @
73583f86
...
...
@@ -53,6 +53,23 @@ inline HcclDataType ToHCCLDataType(framework::proto::VarType::Type type) {
}
}
inline
HcclDataType
ToHCCLDataType
(
experimental
::
DataType
type
)
{
if
(
type
==
experimental
::
DataType
::
FLOAT32
)
{
return
HCCL_DATA_TYPE_FP32
;
}
else
if
(
type
==
experimental
::
DataType
::
FLOAT16
)
{
return
HCCL_DATA_TYPE_FP16
;
}
else
if
(
type
==
experimental
::
DataType
::
INT64
)
{
return
HCCL_DATA_TYPE_INT64
;
}
else
if
(
type
==
experimental
::
DataType
::
INT32
)
{
return
HCCL_DATA_TYPE_INT32
;
}
else
if
(
type
==
experimental
::
DataType
::
INT8
)
{
return
HCCL_DATA_TYPE_INT8
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"This datatype in hccl is not supported."
));
}
}
// NOTE(minqiyang): according to the ncclGroupEnd documentations:
// https://docs.nvidia.com/deeplearning/sdk/nccl-api/ncclapidoc.html,
// ncclGroupEnd will wait for all communicators to be initialized, which will
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
73583f86
...
...
@@ -88,6 +88,9 @@ if(NOT ON_INFER)
if
(
WITH_GLOO
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_gloo
)
endif
()
if
(
WITH_ASCEND
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_hccl
)
endif
()
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
distributed_py.cc
)
endif
()
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
73583f86
...
...
@@ -35,6 +35,10 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/distributed/collective/ProcessGroupHCCL.h"
#endif
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
...
...
@@ -201,6 +205,14 @@ void BindDistributed(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
py
::
class_
<
distributed
::
ProcessGroupHCCL
,
std
::
shared_ptr
<
distributed
::
ProcessGroupHCCL
>>
(
*
m
,
"ProcessGroupHCCL"
,
ProcessGroup
)
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
distributed
::
Store
>
&
,
int
,
int
>
(),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
py
::
class_
<
distributed
::
ProcessGroup
::
Task
,
std
::
shared_ptr
<
distributed
::
ProcessGroup
::
Task
>>
(
*
m
,
"task"
)
.
def
(
"is_completed"
,
&
distributed
::
ProcessGroup
::
Task
::
IsCompleted
)
...
...
python/paddle/fluid/tests/unittests/npu/process_group_hccl.py
0 → 100644
浏览文件 @
73583f86
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
random
import
numpy
as
np
import
os
import
shutil
import
paddle
from
paddle.fluid
import
core
from
datetime
import
timedelta
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
_test_eager_guard
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
def
init_process_group
(
strategy
=
None
):
nranks
=
ParallelEnv
().
nranks
rank
=
ParallelEnv
().
local_rank
is_master
=
True
if
rank
==
0
else
False
store
=
paddle
.
fluid
.
core
.
TCPStore
(
"127.0.0.1"
,
6173
,
is_master
,
nranks
)
pg_group
=
core
.
ProcessGroupHCCL
(
store
,
rank
,
nranks
)
return
pg_group
class
TestProcessGroupFp32
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
random
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
self
.
config
()
def
config
(
self
):
self
.
dtype
=
"float32"
self
.
shape
=
(
2
,
10
,
5
)
def
test_create_process_group_nccl
(
self
):
with
_test_eager_guard
():
paddle
.
set_device
(
'npu:%d'
%
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
pg
=
init_process_group
()
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
sum_result
=
tensor_x
+
tensor_y
if
pg
.
rank
()
==
0
:
task
=
pg
.
allreduce
(
tensor_x
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
sum_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
sum_result
)
print
(
"test allreduce sum api ok"
)
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
max_result
=
paddle
.
maximum
(
tensor_x
,
tensor_y
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
allreduce
(
tensor_x
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
max_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
max_result
)
print
(
"test allreduce max api ok"
)
# test broadcast
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
broadcast_result
=
paddle
.
assign
(
tensor_x
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
broadcast
(
tensor_x
,
0
)
task
.
synchronize
()
paddle
.
device
.
cuda
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_x
)
else
:
task
=
pg
.
broadcast
(
tensor_y
,
0
)
task
.
synchronize
()
paddle
.
device
.
cuda
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_y
)
print
(
"test broadcast api ok"
)
# test barrier
# rank 0
if
pg
.
rank
()
==
0
:
task
=
pg
.
barrier
()
task
.
wait
()
# rank 1
else
:
task
=
pg
.
barrier
()
task
.
wait
()
print
(
"test barrier api ok
\n
"
)
exit
(
0
)
# test allgather
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
tensor_y
=
paddle
.
to_tensor
(
y
)
out_shape
=
list
(
self
.
shape
)
out_shape
[
0
]
*=
2
out
=
np
.
random
.
random
(
out_shape
).
astype
(
self
.
dtype
)
tensor_out
=
paddle
.
to_tensor
(
out
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
all_gather
(
tensor_x
,
tensor_out
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
# rank 1
else
:
task
=
pg
.
all_gather
(
tensor_y
,
tensor_out
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
out_1
=
paddle
.
slice
(
tensor_out
,
[
0
],
[
0
],
[
out_shape
[
0
]
//
2
])
out_2
=
paddle
.
slice
(
tensor_out
,
[
0
],
[
out_shape
[
0
]
//
2
],
[
out_shape
[
0
]])
assert
np
.
array_equal
(
tensor_x
,
out_1
)
assert
np
.
array_equal
(
tensor_y
,
out_2
)
print
(
"test allgather api ok
\n
"
)
# test alltoall
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
out1
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
out2
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
tensor_y
=
paddle
.
to_tensor
(
y
)
tensor_out1
=
paddle
.
to_tensor
(
out1
)
tensor_out2
=
paddle
.
to_tensor
(
out2
)
raw_tensor_x_2
=
paddle
.
slice
(
tensor_x
,
[
0
],
[
self
.
shape
[
0
]
//
2
],
[
self
.
shape
[
0
]])
raw_tensor_y_1
=
paddle
.
slice
(
tensor_y
,
[
0
],
[
0
],
[
self
.
shape
[
0
]
//
2
])
if
pg
.
rank
()
==
0
:
task
=
pg
.
alltoall
(
tensor_x
,
tensor_out1
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
# rank 1
else
:
task
=
pg
.
alltoall
(
tensor_y
,
tensor_out2
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
out1_2
=
paddle
.
slice
(
tensor_out1
,
[
0
],
[
self
.
shape
[
0
]
//
2
],
[
self
.
shape
[
0
]])
out2_1
=
paddle
.
slice
(
tensor_out2
,
[
0
],
[
0
],
[
self
.
shape
[
0
]
//
2
])
if
pg
.
rank
()
==
0
:
assert
np
.
array_equal
(
out1_2
.
numpy
(),
raw_tensor_y_1
.
numpy
())
else
:
assert
np
.
array_equal
(
out2_1
,
raw_tensor_x_2
)
print
(
"test alltoall api ok
\n
"
)
# test Reduce
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
tensor_y
=
paddle
.
to_tensor
(
y
)
sum_result
=
tensor_x
+
tensor_y
if
pg
.
rank
()
==
0
:
task
=
pg
.
reduce
(
tensor_x
,
0
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
# rank 1
else
:
task
=
pg
.
reduce
(
tensor_y
,
0
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
if
pg
.
rank
()
==
0
:
assert
np
.
array_equal
(
tensor_x
,
sum_result
)
print
(
"test reduce sum api ok
\n
"
)
# test Scatter
# rank 0
in_shape
=
list
(
self
.
shape
)
in_shape
[
0
]
*=
2
x
=
np
.
random
.
random
(
in_shape
).
astype
(
self
.
dtype
)
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
tensor_y
=
paddle
.
to_tensor
(
y
)
if
pg
.
rank
()
==
0
:
task
=
pg
.
scatter
(
tensor_x
,
tensor_y
,
0
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
# rank 1
else
:
task
=
pg
.
scatter
(
tensor_x
,
tensor_y
,
0
)
task
.
wait
()
paddle
.
device
.
cuda
.
synchronize
()
out1
=
paddle
.
slice
(
tensor_x
,
[
0
],
[
0
],
[
self
.
shape
[
0
]])
out2
=
paddle
.
slice
(
tensor_x
,
[
0
],
[
self
.
shape
[
0
]],
[
self
.
shape
[
0
]
*
2
])
if
pg
.
rank
()
==
0
:
assert
np
.
array_equal
(
tensor_y
,
out1
)
else
:
assert
np
.
array_equal
(
tensor_y
,
out2
)
print
(
"test scatter api ok
\n
"
)
class
TestProcessGroupFp16
(
TestProcessGroupFp32
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
random
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
self
.
config
()
def
config
(
self
):
self
.
dtype
=
"float16"
self
.
shape
=
(
4
,
20
,
20
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/npu/test_collective_process_group_hccl.py
0 → 100644
浏览文件 @
73583f86
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
sys
sys
.
path
.
append
(
".."
)
from
test_parallel_dygraph_dataparallel
import
TestMultipleGpus
class
TestProcessGroup
(
TestMultipleGpus
):
def
test_process_group_nccl
(
self
):
self
.
run_mnist_2gpu
(
'process_group_hccl.py'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录