Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c16f85f9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c16f85f9
编写于
3月 03, 2022
作者:
L
lilong12
提交者:
GitHub
3月 03, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the implementation of Gloo for ProcessGroup (#39892)
* add pg_gloo
上级
272b32fd
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
701 addition
and
41 deletion
+701
-41
paddle/fluid/distributed/collective/CMakeLists.txt
paddle/fluid/distributed/collective/CMakeLists.txt
+3
-0
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
+308
-0
paddle/fluid/distributed/collective/ProcessGroupGloo.h
paddle/fluid/distributed/collective/ProcessGroupGloo.h
+138
-0
paddle/fluid/distributed/store/store.h
paddle/fluid/distributed/store/store.h
+2
-0
paddle/fluid/distributed/store/tcp_store.cc
paddle/fluid/distributed/store/tcp_store.cc
+53
-32
paddle/fluid/distributed/store/tcp_store.h
paddle/fluid/distributed/store/tcp_store.h
+8
-4
paddle/fluid/distributed/store/tcp_utils.cc
paddle/fluid/distributed/store/tcp_utils.cc
+2
-1
paddle/fluid/pybind/CMakeLists.txt
paddle/fluid/pybind/CMakeLists.txt
+3
-0
paddle/fluid/pybind/communication.cc
paddle/fluid/pybind/communication.cc
+9
-3
paddle/fluid/pybind/distributed_py.cc
paddle/fluid/pybind/distributed_py.cc
+53
-1
python/paddle/fluid/tests/unittests/process_group_gloo.py
python/paddle/fluid/tests/unittests/process_group_gloo.py
+119
-0
python/paddle/fluid/tests/unittests/test_collective_process_group.py
...le/fluid/tests/unittests/test_collective_process_group.py
+3
-0
未找到文件。
paddle/fluid/distributed/collective/CMakeLists.txt
浏览文件 @
c16f85f9
cc_library
(
processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api
)
cc_library
(
processgroup SRCS ProcessGroup.cc DEPS phi phi_api eager_api
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
processgroup_gloo SRCS ProcessGroupGloo.cc DEPS phi phi_api eager_api gloo_wrapper
)
endif
()
cc_library
(
eager_reducer SRCS reducer.cc DEPS eager_api processgroup
)
cc_library
(
eager_reducer SRCS reducer.cc DEPS eager_api processgroup
)
if
(
WITH_NCCL
)
if
(
WITH_NCCL
)
...
...
paddle/fluid/distributed/collective/ProcessGroupGloo.cc
0 → 100644
浏览文件 @
c16f85f9
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#ifdef _WIN32
#include <gloo/common/win.h>
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <netdb.h>
#include <sys/socket.h>
#include <unistd.h>
#endif
#include <gloo/broadcast.h>
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
distributed
{
#ifdef _WIN32
#define GENERATE_FUNC(type, func, ...) \
switch (type) { \
case experimental::DataType::FLOAT32: \
func<float>(__VA_ARGS__); \
break; \
case experimental::DataType::FLOAT64: \
func<double>(__VA_ARGS__); \
break; \
case experimental::DataType::FLOAT16: \
func<gloo::float16>(__VA_ARGS__); \
break; \
case experimental::DataType::INT32: \
func<int32_t>(__VA_ARGS__); \
break; \
case experimental::DataType::INT64: \
func<int64_t>(__VA_ARGS__); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#define HOST_NAME_MAX 256
#else
#define GENERATE_FUNC(type, func, args...) \
switch (type) { \
case experimental::DataType::FLOAT32: \
func<float>(args); \
break; \
case experimental::DataType::FLOAT64: \
func<double>(args); \
break; \
case experimental::DataType::FLOAT16: \
func<gloo::float16>(args); \
break; \
case experimental::DataType::INT32: \
func<int32_t>(args); \
break; \
case experimental::DataType::INT64: \
func<int64_t>(args); \
break; \
default: \
VLOG(0) << "Error: Unknown DataType."; \
exit(-1); \
}
#endif
typedef
void
(
*
reduce_func
)(
void
*
,
const
void
*
,
const
void
*
,
size_t
);
template
<
typename
T
>
reduce_func
get_function
(
const
ReduceOp
&
r
)
{
switch
(
r
)
{
case
ReduceOp
::
SUM
:
return
reduce_func
(
&::
gloo
::
sum
<
T
>
);
case
ReduceOp
::
PRODUCT
:
return
reduce_func
(
&::
gloo
::
product
<
T
>
);
case
ReduceOp
::
MIN
:
return
reduce_func
(
&::
gloo
::
min
<
T
>
);
case
ReduceOp
::
MAX
:
return
reduce_func
(
&::
gloo
::
max
<
T
>
);
case
ReduceOp
::
AVG
:
VLOG
(
0
)
<<
"Error: Unsupported ReduceOp::AVG."
;
exit
(
-
1
);
}
VLOG
(
0
)
<<
"Error: Unknown ReduceOp."
;
exit
(
-
1
);
}
bool
CheckTensorsInCPUPlace
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
return
std
::
all_of
(
tensors
.
cbegin
(),
tensors
.
cend
(),
[
&
](
const
Tensor
&
t
)
{
return
t
.
place
()
==
PlaceType
::
kCPU
;
});
}
template
<
typename
T
>
T
*
get_data
(
const
Tensor
&
tensor
)
{
auto
raw_tensor
=
std
::
dynamic_pointer_cast
<
phi
::
DenseTensor
>
(
tensor
.
impl
());
return
static_cast
<
T
*>
(
raw_tensor
->
data
());
}
template
<
typename
T
>
std
::
vector
<
T
*>
get_multi_data
(
const
std
::
vector
<
Tensor
>&
tensors
)
{
std
::
vector
<
T
*>
ret
(
tensors
.
size
());
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
i
++
)
{
ret
[
i
]
=
get_data
<
T
>
(
tensors
[
i
]);
}
return
ret
;
}
template
<
typename
T
,
typename
P
>
void
set_output
(
P
&
opts
,
const
Tensor
&
tensor
)
{
// NOLINT
opts
.
setOutput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
}
template
<
typename
T
,
typename
P
>
void
set_input
(
P
&
opts
,
const
Tensor
&
tensor
)
{
// NOLINT
opts
.
setInput
(
get_data
<
T
>
(
tensor
),
tensor
.
numel
());
}
template
<
typename
T
,
typename
P
>
void
set_outputs
(
P
&
opts
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
opts
.
setOutputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
}
template
<
typename
T
,
typename
P
>
void
set_inputs
(
P
&
opts
,
const
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
opts
.
setInputs
(
get_multi_data
<
T
>
(
tensors
),
tensors
[
0
].
numel
());
}
ProcessGroupGloo
::
GlooTask
::
GlooTask
(
int
rank
,
const
std
::
vector
<
Tensor
>&
inputs
,
CommType
comm_type
)
:
ProcessGroup
::
Task
(
rank
,
inputs
,
comm_type
)
{
PADDLE_ENFORCE_EQ
(
CheckTensorsInCPUPlace
(
inputs
),
true
,
platform
::
errors
::
Fatal
(
"Only CPU place is supported for ProcessGroupGloo."
));
}
ProcessGroupGloo
::
ProcessGroupGloo
(
const
std
::
shared_ptr
<
GlooStore
>&
store
,
int
rank
,
int
world_size
,
const
std
::
shared_ptr
<
GlooOptions
>
options
)
:
ProcessGroup
(
rank
,
world_size
),
_tag
(
0
),
_store
(
store
)
{
_context
=
std
::
make_shared
<
gloo
::
rendezvous
::
Context
>
(
rank
,
world_size
);
auto
prefix_store
=
::
gloo
::
rendezvous
::
PrefixStore
(
std
::
to_string
(
0
),
*
_store
);
_context
->
connectFullMesh
(
prefix_store
,
options
->
device
);
}
class
BroadcastGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
BroadcastGlooTask
(
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
const
std
::
vector
<
Tensor
>&
inputs
,
int
rank
,
int
root
,
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
BROADCAST
),
_context
(
context
),
_root
(
root
),
_inputs
(
inputs
),
_tag
(
tag
)
{}
void
Run
()
override
{
_do_broadcast
(
_inputs
[
0
]);
}
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
const
int
_root
;
std
::
vector
<
Tensor
>
_inputs
{};
const
uint32_t
_tag
;
void
_do_broadcast
(
const
Tensor
&
tensor
)
{
gloo
::
BroadcastOptions
opts
(
_context
);
const
auto
&
dtype
=
tensor
.
type
();
GENERATE_FUNC
(
dtype
,
set_output
,
opts
,
tensor
);
opts
.
setRoot
(
_root
);
opts
.
setTag
(
_tag
);
gloo
::
broadcast
(
opts
);
}
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
Broadcast
(
std
::
vector
<
Tensor
>&
inputs
,
const
BroadcastOptions
&
opts
)
{
auto
root
=
opts
.
source_rank
;
std
::
unique_ptr
<
BroadcastGlooTask
>
task
;
auto
tag
=
next_tag
();
auto
context
=
get_context
();
task
=
std
::
make_unique
<
BroadcastGlooTask
>
(
context
,
inputs
,
rank_
,
root
,
tag
);
task
->
Run
();
return
task
;
}
class
AllreduceGlooTask
:
public
ProcessGroupGloo
::
GlooTask
{
public:
AllreduceGlooTask
(
int
rank
,
const
std
::
shared_ptr
<
gloo
::
Context
>&
context
,
std
::
vector
<
Tensor
>&
inputs
,
ReduceOp
reduce_op
,
// NOLINT
uint32_t
tag
)
:
ProcessGroupGloo
::
GlooTask
(
rank
,
inputs
,
CommType
::
ALLREDUCE
),
_context
(
context
),
_inputs
(
inputs
),
_reduce_op
(
reduce_op
),
_tag
(
tag
)
{}
void
Run
()
override
{
_do_allreduce
(
_inputs
);
}
private:
std
::
shared_ptr
<
gloo
::
Context
>
_context
;
std
::
vector
<
Tensor
>
_inputs
;
const
ReduceOp
_reduce_op
;
uint32_t
_tag
;
gloo
::
AllreduceOptions
::
Func
_get_function
(
const
experimental
::
DataType
type
,
const
ReduceOp
op
)
{
gloo
::
AllreduceOptions
::
Func
fn
;
GENERATE_FUNC
(
type
,
_get_function_impl
,
fn
,
op
);
return
fn
;
}
template
<
typename
T
>
void
_get_function_impl
(
gloo
::
AllreduceOptions
::
Func
&
fn
,
// NOLINT
const
ReduceOp
op
)
{
fn
=
get_function
<
T
>
(
op
);
}
void
_do_allreduce
(
std
::
vector
<
Tensor
>&
tensors
)
{
// NOLINT
const
auto
&
dtype
=
tensors
[
0
].
type
();
gloo
::
AllreduceOptions
opts
(
_context
);
GENERATE_FUNC
(
dtype
,
set_inputs
,
opts
,
tensors
);
GENERATE_FUNC
(
dtype
,
set_outputs
,
opts
,
tensors
);
opts
.
setReduceFunction
(
_get_function
(
dtype
,
_reduce_op
));
opts
.
setTag
(
_tag
);
gloo
::
allreduce
(
opts
);
}
};
std
::
shared_ptr
<
ProcessGroup
::
Task
>
ProcessGroupGloo
::
AllReduce
(
std
::
vector
<
Tensor
>&
inputs
,
const
AllreduceOptions
&
opts
)
{
auto
tag
=
next_tag
();
std
::
shared_ptr
<
GlooTask
>
task
;
auto
context
=
get_context
();
task
=
std
::
make_shared
<
AllreduceGlooTask
>
(
rank_
,
context
,
inputs
,
opts
.
reduce_op
,
tag
);
task
->
Run
();
return
task
;
}
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
ProcessGroupGloo
::
createDeviceForInterface
(
const
std
::
string
&
ifname
)
{
::
gloo
::
transport
::
tcp
::
attr
attr
;
attr
.
iface
=
ifname
;
return
::
gloo
::
transport
::
tcp
::
CreateDevice
(
attr
);
}
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
ProcessGroupGloo
::
createDeviceForHostname
(
const
std
::
string
&
hostname
)
{
::
gloo
::
transport
::
tcp
::
attr
attr
;
attr
.
hostname
=
hostname
;
return
::
gloo
::
transport
::
tcp
::
CreateDevice
(
attr
);
}
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
ProcessGroupGloo
::
createDefaultDevice
()
{
std
::
array
<
char
,
HOST_NAME_MAX
>
hostname
{};
auto
ret
=
::
gethostname
(
hostname
.
data
(),
HOST_NAME_MAX
);
PADDLE_ENFORCE_EQ
(
ret
,
0
,
platform
::
errors
::
Fatal
(
"Get hostname error for createDefaultDevice."
));
::
addrinfo
*
result
;
result
=
tcputils
::
get_addr_info
(
hostname
.
data
(),
""
,
0
,
AF_UNSPEC
);
::
addrinfo
*
cur
;
for
(
cur
=
result
;
cur
!=
nullptr
;
cur
=
cur
->
ai_next
)
{
SocketType
socket
=
::
socket
(
cur
->
ai_family
,
cur
->
ai_socktype
,
cur
->
ai_protocol
);
if
(
socket
==
-
1
)
{
continue
;
}
ret
=
::
bind
(
socket
,
cur
->
ai_addr
,
cur
->
ai_addrlen
);
#ifdef _WIN32
closesocket
(
socket
);
#else
close
(
socket
);
#endif
if
(
ret
==
-
1
)
{
continue
;
}
break
;
}
freeaddrinfo
(
result
);
if
(
cur
!=
nullptr
)
{
return
createDeviceForHostname
(
hostname
.
data
());
}
return
createDeviceForHostname
(
"127.0.0.1"
);
}
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/collective/ProcessGroupGloo.h
0 → 100644
浏览文件 @
c16f85f9
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <future>
#include <mutex>
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#ifdef PADDLE_WITH_GLOO
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
#include "paddle/fluid/distributed/store/store.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
constexpr
const
char
*
GLOO_BACKEND_NAME
=
"GLOO"
;
namespace
paddle
{
namespace
distributed
{
class
ProcessGroupGloo
:
public
ProcessGroup
{
public:
class
GlooTask
:
public
ProcessGroup
::
Task
,
public
std
::
enable_shared_from_this
<
GlooTask
>
{
public:
explicit
GlooTask
(
int
rank
,
const
std
::
vector
<
Tensor
>&
input_tensors
,
CommType
comm_type
);
~
GlooTask
()
=
default
;
virtual
void
Run
()
=
0
;
bool
Wait
(
std
::
chrono
::
milliseconds
timeout
)
override
{
return
true
;
}
bool
IsCompleted
()
override
{
return
true
;
}
void
Synchronize
()
override
{}
protected:
friend
class
ProcessGroupGloo
;
};
class
GlooStore
:
public
::
gloo
::
rendezvous
::
Store
{
public:
explicit
GlooStore
(
const
std
::
shared_ptr
<
paddle
::
distributed
::
TCPStore
>&
store
)
:
_store
(
store
)
{}
~
GlooStore
()
=
default
;
std
::
vector
<
char
>
get
(
const
std
::
string
&
key
)
override
{
VLOG
(
3
)
<<
"GlooStore::get"
;
auto
value
=
_store
->
get
(
key
);
return
std
::
vector
<
char
>
(
value
.
begin
(),
value
.
end
());
}
void
wait
(
const
std
::
vector
<
std
::
string
>&
keys
)
override
{
VLOG
(
3
)
<<
"GlooStore::wait"
;
for
(
auto
&
key
:
keys
)
{
_store
->
wait
(
key
);
}
}
void
set
(
const
std
::
string
&
key
,
const
std
::
vector
<
char
>&
value
)
override
{
VLOG
(
3
)
<<
"GlooStore::set"
;
std
::
vector
<
uint8_t
>
tmp
(
value
.
begin
(),
value
.
end
());
_store
->
set
(
key
,
tmp
);
}
void
wait
(
const
std
::
vector
<
std
::
string
>&
keys
,
const
std
::
chrono
::
milliseconds
&
timeout
)
override
{
VLOG
(
3
)
<<
"GlooStore::wait"
;
for
(
auto
&
key
:
keys
)
{
_store
->
wait
(
key
);
}
// wait(keys);
}
protected:
std
::
shared_ptr
<
paddle
::
distributed
::
TCPStore
>
_store
;
};
class
GlooOptions
{
public:
GlooOptions
()
=
default
;
~
GlooOptions
()
=
default
;
static
std
::
shared_ptr
<
GlooOptions
>
create
()
{
return
std
::
make_shared
<
GlooOptions
>
();
}
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
device
;
};
explicit
ProcessGroupGloo
(
const
std
::
shared_ptr
<
GlooStore
>&
store
,
int
rank
,
int
world_size
,
std
::
shared_ptr
<
GlooOptions
>
options
);
~
ProcessGroupGloo
()
=
default
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
Broadcast
(
std
::
vector
<
Tensor
>&
inputs
,
const
BroadcastOptions
&
=
BroadcastOptions
())
override
;
std
::
shared_ptr
<
ProcessGroup
::
Task
>
AllReduce
(
std
::
vector
<
Tensor
>&
inputs
,
const
AllreduceOptions
&
opts
=
AllreduceOptions
())
override
;
std
::
shared_ptr
<::
gloo
::
Context
>
get_context
()
{
return
_context
;
}
uint64_t
next_tag
()
{
return
_tag
++
;
}
const
std
::
string
GetBackendName
()
const
override
{
return
GLOO_BACKEND_NAME
;
}
// Helper functions for Gloo.
static
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
createDeviceForHostname
(
const
std
::
string
&
hostname
);
static
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
createDeviceForInterface
(
const
std
::
string
&
ifname
);
static
std
::
shared_ptr
<::
gloo
::
transport
::
Device
>
createDefaultDevice
();
protected:
uint32_t
_tag
;
std
::
shared_ptr
<
gloo
::
rendezvous
::
Context
>
_context
;
std
::
shared_ptr
<
GlooStore
>
_store
;
};
}
// namespace distributed
}
// namespace paddle
paddle/fluid/distributed/store/store.h
浏览文件 @
c16f85f9
...
@@ -32,6 +32,8 @@ class Store {
...
@@ -32,6 +32,8 @@ class Store {
virtual
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
=
0
;
virtual
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
=
0
;
virtual
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
=
0
;
virtual
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
=
0
;
virtual
void
wait
(
const
std
::
string
&
key
)
=
0
;
virtual
void
wait
(
const
std
::
string
&
key
)
=
0
;
virtual
void
set
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>&
value
)
=
0
;
virtual
const
std
::
chrono
::
seconds
&
timeout
()
const
{
return
_timeout
;
}
virtual
const
std
::
chrono
::
seconds
&
timeout
()
const
{
return
_timeout
;
}
...
...
paddle/fluid/distributed/store/tcp_store.cc
浏览文件 @
c16f85f9
...
@@ -27,11 +27,13 @@ namespace detail {
...
@@ -27,11 +27,13 @@ namespace detail {
constexpr
int
INFTIME
=
-
1
;
constexpr
int
INFTIME
=
-
1
;
std
::
unique_ptr
<
MasterDaemon
>
MasterDaemon
::
start
(
SocketType
socket
)
{
std
::
unique_ptr
<
MasterDaemon
>
MasterDaemon
::
start
(
SocketType
socket
,
return
std
::
make_unique
<
MasterDaemon
>
(
socket
);
int
nranks
)
{
return
std
::
make_unique
<
MasterDaemon
>
(
socket
,
nranks
);
}
}
MasterDaemon
::
MasterDaemon
(
SocketType
socket
)
:
_listen_socket
(
socket
)
{
MasterDaemon
::
MasterDaemon
(
SocketType
socket
,
int
nranks
)
:
_listen_socket
(
socket
),
_nranks
(
nranks
)
{
_background_thread
=
std
::
thread
{
&
MasterDaemon
::
run
,
this
};
_background_thread
=
std
::
thread
{
&
MasterDaemon
::
run
,
this
};
}
}
...
@@ -64,6 +66,13 @@ void MasterDaemon::_do_add(SocketType socket) {
...
@@ -64,6 +66,13 @@ void MasterDaemon::_do_add(SocketType socket) {
tcputils
::
send_value
<
int64_t
>
(
socket
,
new_value
);
tcputils
::
send_value
<
int64_t
>
(
socket
,
new_value
);
}
}
void
MasterDaemon
::
_do_set
(
SocketType
socket
)
{
VLOG
(
3
)
<<
"MasterDaemon::_do_set"
;
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
auto
value
=
tcputils
::
receive_vector
<
uint8_t
>
(
socket
);
_store
[
key
]
=
value
;
}
void
MasterDaemon
::
_do_get
(
SocketType
socket
)
{
void
MasterDaemon
::
_do_get
(
SocketType
socket
)
{
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
std
::
string
key
=
tcputils
::
receive_string
(
socket
);
auto
iter
=
_store
.
find
(
key
);
auto
iter
=
_store
.
find
(
key
);
...
@@ -71,16 +80,15 @@ void MasterDaemon::_do_get(SocketType socket) {
...
@@ -71,16 +80,15 @@ void MasterDaemon::_do_get(SocketType socket) {
iter
,
_store
.
end
(),
iter
,
_store
.
end
(),
platform
::
errors
::
InvalidArgument
(
"Key %s not found in TCPStore."
,
key
));
platform
::
errors
::
InvalidArgument
(
"Key %s not found in TCPStore."
,
key
));
std
::
vector
<
uint8_t
>
value
=
iter
->
second
;
std
::
vector
<
uint8_t
>
value
=
iter
->
second
;
VLOG
(
3
)
<<
"TCPStore: value ("
<<
std
::
stoll
(
std
::
string
(
reinterpret_cast
<
char
*>
(
value
.
data
()),
value
.
size
()))
<<
") for key ("
<<
key
<<
")."
;
tcputils
::
send_vector
<
uint8_t
>
(
socket
,
value
);
tcputils
::
send_vector
<
uint8_t
>
(
socket
,
value
);
}
}
void
MasterDaemon
::
_do_stop
(
SocketType
socket
)
{
void
MasterDaemon
::
_do_stop
(
SocketType
socket
)
{
VLOG
(
3
)
<<
"MasterDaemon::_do_stop"
;
ReplyType
value
=
ReplyType
::
STOP_WAIT
;
ReplyType
value
=
ReplyType
::
STOP_WAIT
;
if
(
--
_nranks
==
0
)
{
_stop
=
true
;
_stop
=
true
;
}
tcputils
::
send_value
<
ReplyType
>
(
socket
,
value
);
tcputils
::
send_value
<
ReplyType
>
(
socket
,
value
);
}
}
...
@@ -140,21 +148,27 @@ void MasterDaemon::run() {
...
@@ -140,21 +148,27 @@ void MasterDaemon::run() {
case
Command
::
GET
:
case
Command
::
GET
:
_do_get
(
fds
[
i
].
fd
);
_do_get
(
fds
[
i
].
fd
);
break
;
break
;
case
Command
::
SET
:
_do_set
(
fds
[
i
].
fd
);
break
;
case
Command
::
WAIT
:
case
Command
::
WAIT
:
_do_wait
(
fds
[
i
].
fd
);
_do_wait
(
fds
[
i
].
fd
);
break
;
break
;
case
Command
::
STOP
:
case
Command
::
STOP
:
_do_stop
(
fds
[
i
].
fd
);
_do_stop
(
fds
[
i
].
fd
);
break
;
break
;
default:
VLOG
(
0
)
<<
"Unknow command: "
<<
static_cast
<
int
>
(
command
);
exit
(
-
1
);
}
}
}
}
}
}
}
}
std
::
unique_ptr
<
TCPServer
>
TCPServer
::
create
(
uint16_t
port
)
{
std
::
unique_ptr
<
TCPServer
>
TCPServer
::
create
(
uint16_t
port
,
int
nranks
)
{
int
socket
=
tcputils
::
tcp_listen
(
""
,
std
::
to_string
(
port
),
AF_INET
);
int
socket
=
tcputils
::
tcp_listen
(
""
,
std
::
to_string
(
port
),
AF_INET
);
auto
server
=
std
::
make_unique
<
TCPServer
>
();
auto
server
=
std
::
make_unique
<
TCPServer
>
();
server
->
_master_daemon
=
MasterDaemon
::
start
(
socket
);
server
->
_master_daemon
=
MasterDaemon
::
start
(
socket
,
nranks
);
return
server
;
return
server
;
}
}
...
@@ -200,7 +214,7 @@ TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
...
@@ -200,7 +214,7 @@ TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
size_t
num_workers
,
std
::
chrono
::
seconds
timeout
)
size_t
num_workers
,
std
::
chrono
::
seconds
timeout
)
:
Store
(
timeout
),
_is_master
(
is_master
),
_num_workers
(
num_workers
)
{
:
Store
(
timeout
),
_is_master
(
is_master
),
_num_workers
(
num_workers
)
{
if
(
_is_master
)
{
if
(
_is_master
)
{
_server
=
detail
::
TCPServer
::
create
(
port
);
_server
=
detail
::
TCPServer
::
create
(
port
,
num_workers
);
}
}
_client
=
detail
::
TCPClient
::
connect
(
host
,
port
);
_client
=
detail
::
TCPClient
::
connect
(
host
,
port
);
...
@@ -213,7 +227,6 @@ void TCPStore::waitWorkers() {
...
@@ -213,7 +227,6 @@ void TCPStore::waitWorkers() {
}
}
add
(
_init_key
,
1
);
add
(
_init_key
,
1
);
if
(
_server
)
{
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
auto
begin
=
std
::
chrono
::
steady_clock
::
now
();
do
{
do
{
auto
value
=
get
(
_init_key
);
auto
value
=
get
(
_init_key
);
...
@@ -233,16 +246,22 @@ void TCPStore::waitWorkers() {
...
@@ -233,16 +246,22 @@ void TCPStore::waitWorkers() {
"TCPStore timeouted and not all workers got ready."
));
"TCPStore timeouted and not all workers got ready."
));
}
}
}
while
(
true
);
}
while
(
true
);
}
VLOG
(
3
)
<<
"TCPStore initialized."
;
VLOG
(
3
)
<<
"TCPStore initialized."
;
}
}
int64_t
TCPStore
::
add
(
const
std
::
string
&
key
,
int64_t
value
)
{
int64_t
TCPStore
::
add
(
const
std
::
string
&
key
,
int64_t
value
)
{
VLOG
(
3
)
<<
"TCPStore add."
;
_client
->
send_command_for_key
(
Command
::
ADD
,
_key_prefix
+
key
);
_client
->
send_command_for_key
(
Command
::
ADD
,
_key_prefix
+
key
);
_client
->
send_value
<
std
::
int64_t
>
(
value
);
_client
->
send_value
<
std
::
int64_t
>
(
value
);
return
_client
->
receive_value
<
std
::
int64_t
>
();
return
_client
->
receive_value
<
std
::
int64_t
>
();
}
}
void
TCPStore
::
set
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>&
value
)
{
VLOG
(
3
)
<<
"TCPStore set."
;
_client
->
send_command_for_key
(
Command
::
SET
,
_key_prefix
+
key
);
_client
->
send_vector
<
std
::
uint8_t
>
(
value
);
}
std
::
vector
<
uint8_t
>
TCPStore
::
get
(
const
std
::
string
&
key
)
{
std
::
vector
<
uint8_t
>
TCPStore
::
get
(
const
std
::
string
&
key
)
{
wait
(
key
);
wait
(
key
);
_client
->
send_command_for_key
(
Command
::
GET
,
_key_prefix
+
key
);
_client
->
send_command_for_key
(
Command
::
GET
,
_key_prefix
+
key
);
...
@@ -252,6 +271,7 @@ std::vector<uint8_t> TCPStore::get(const std::string& key) {
...
@@ -252,6 +271,7 @@ std::vector<uint8_t> TCPStore::get(const std::string& key) {
void
TCPStore
::
wait
(
const
std
::
string
&
key
)
{
void
TCPStore
::
wait
(
const
std
::
string
&
key
)
{
ReplyType
reply
;
ReplyType
reply
;
VLOG
(
3
)
<<
"TCPStore wait."
;
do
{
do
{
_client
->
send_command_for_key
(
Command
::
WAIT
,
_key_prefix
+
key
);
_client
->
send_command_for_key
(
Command
::
WAIT
,
_key_prefix
+
key
);
...
@@ -262,6 +282,7 @@ void TCPStore::wait(const std::string& key) {
...
@@ -262,6 +282,7 @@ void TCPStore::wait(const std::string& key) {
TCPStore
::~
TCPStore
()
{
TCPStore
::~
TCPStore
()
{
_client
->
send_command_for_key
(
Command
::
STOP
,
""
);
_client
->
send_command_for_key
(
Command
::
STOP
,
""
);
VLOG
(
3
)
<<
"~TCPStore"
;
ReplyType
ret
=
_client
->
receive_value
<
ReplyType
>
();
ReplyType
ret
=
_client
->
receive_value
<
ReplyType
>
();
PADDLE_ENFORCE_EQ
(
ret
,
ReplyType
::
STOP_WAIT
,
PADDLE_ENFORCE_EQ
(
ret
,
ReplyType
::
STOP_WAIT
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/distributed/store/tcp_store.h
浏览文件 @
c16f85f9
...
@@ -27,15 +27,16 @@ namespace paddle {
...
@@ -27,15 +27,16 @@ namespace paddle {
namespace
distributed
{
namespace
distributed
{
enum
class
ReplyType
{
WAITING
,
STOP_WAIT
};
enum
class
ReplyType
{
WAITING
,
STOP_WAIT
};
enum
class
Command
{
ADD
,
GET
,
WAIT
,
STOP
};
enum
class
Command
{
ADD
,
GET
,
SET
,
WAIT
,
STOP
};
namespace
detail
{
namespace
detail
{
class
MasterDaemon
{
class
MasterDaemon
{
public:
public:
static
std
::
unique_ptr
<
MasterDaemon
>
start
(
SocketType
listen_socket
);
static
std
::
unique_ptr
<
MasterDaemon
>
start
(
SocketType
listen_socket
,
int
nranks
);
MasterDaemon
()
=
delete
;
MasterDaemon
()
=
delete
;
explicit
MasterDaemon
(
SocketType
listen_socket
);
explicit
MasterDaemon
(
SocketType
listen_socket
,
int
nranks
);
~
MasterDaemon
();
~
MasterDaemon
();
private:
private:
...
@@ -43,18 +44,20 @@ class MasterDaemon {
...
@@ -43,18 +44,20 @@ class MasterDaemon {
void
_do_add
(
SocketType
socket
);
void
_do_add
(
SocketType
socket
);
void
_do_wait
(
SocketType
socket
);
void
_do_wait
(
SocketType
socket
);
void
_do_get
(
SocketType
socket
);
void
_do_get
(
SocketType
socket
);
void
_do_set
(
SocketType
socket
);
void
_do_stop
(
SocketType
socket
);
void
_do_stop
(
SocketType
socket
);
SocketType
_listen_socket
;
SocketType
_listen_socket
;
std
::
vector
<
SocketType
>
_sockets
;
std
::
vector
<
SocketType
>
_sockets
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
_store
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
uint8_t
>>
_store
;
std
::
thread
_background_thread
{};
std
::
thread
_background_thread
{};
int
_nranks
;
bool
_stop
=
false
;
bool
_stop
=
false
;
};
};
class
TCPServer
{
class
TCPServer
{
public:
public:
TCPServer
()
=
default
;
TCPServer
()
=
default
;
static
std
::
unique_ptr
<
TCPServer
>
create
(
std
::
uint16_t
port
);
static
std
::
unique_ptr
<
TCPServer
>
create
(
std
::
uint16_t
port
,
int
nranks
);
private:
private:
std
::
unique_ptr
<
MasterDaemon
>
_master_daemon
;
std
::
unique_ptr
<
MasterDaemon
>
_master_daemon
;
...
@@ -97,6 +100,7 @@ class TCPStore : public Store {
...
@@ -97,6 +100,7 @@ class TCPStore : public Store {
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
override
;
int64_t
add
(
const
std
::
string
&
key
,
int64_t
value
)
override
;
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
override
;
std
::
vector
<
uint8_t
>
get
(
const
std
::
string
&
key
)
override
;
void
wait
(
const
std
::
string
&
key
)
override
;
void
wait
(
const
std
::
string
&
key
)
override
;
void
set
(
const
std
::
string
&
key
,
const
std
::
vector
<
uint8_t
>&
value
)
override
;
private:
private:
void
waitWorkers
();
void
waitWorkers
();
...
...
paddle/fluid/distributed/store/tcp_utils.cc
浏览文件 @
c16f85f9
...
@@ -46,9 +46,10 @@ void close_socket(SocketType socket) {
...
@@ -46,9 +46,10 @@ void close_socket(SocketType socket) {
hints
.
ai_socktype
=
SOCK_STREAM
;
hints
.
ai_socktype
=
SOCK_STREAM
;
const
char
*
node
=
host
.
empty
()
?
nullptr
:
host
.
c_str
();
const
char
*
node
=
host
.
empty
()
?
nullptr
:
host
.
c_str
();
const
char
*
port_cstr
=
port
.
empty
()
?
nullptr
:
port
.
c_str
();
int
n
;
int
n
;
n
=
::
getaddrinfo
(
node
,
port
.
c_str
()
,
&
hints
,
&
res
);
n
=
::
getaddrinfo
(
node
,
port
_cstr
,
&
hints
,
&
res
);
const
char
*
gai_err
=
::
gai_strerror
(
n
);
const
char
*
gai_err
=
::
gai_strerror
(
n
);
const
char
*
proto
=
const
char
*
proto
=
(
family
==
AF_INET
?
"IPv4"
:
family
==
AF_INET6
?
"IPv6"
:
""
);
(
family
==
AF_INET
?
"IPv4"
:
family
==
AF_INET6
?
"IPv6"
:
""
);
...
...
paddle/fluid/pybind/CMakeLists.txt
浏览文件 @
c16f85f9
...
@@ -85,6 +85,9 @@ if(NOT ON_INFER)
...
@@ -85,6 +85,9 @@ if(NOT ON_INFER)
if
(
WITH_NCCL
)
if
(
WITH_NCCL
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_nccl
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_nccl
)
endif
()
endif
()
if
(
WITH_GLOO
)
set
(
PYBIND_DEPS
${
PYBIND_DEPS
}
processgroup_gloo
)
endif
()
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
distributed_py.cc
)
set
(
PYBIND_SRCS
${
PYBIND_SRCS
}
distributed_py.cc
)
endif
()
endif
()
...
...
paddle/fluid/pybind/communication.cc
浏览文件 @
c16f85f9
...
@@ -31,9 +31,15 @@ namespace pybind {
...
@@ -31,9 +31,15 @@ namespace pybind {
using
TCPStore
=
paddle
::
distributed
::
TCPStore
;
using
TCPStore
=
paddle
::
distributed
::
TCPStore
;
void
BindTCPStore
(
py
::
module
*
m
)
{
void
BindTCPStore
(
py
::
module
*
m
)
{
py
::
class_
<
TCPStore
>
(
*
m
,
"TCPStore"
)
py
::
class_
<
TCPStore
,
std
::
shared_ptr
<
TCPStore
>>
(
*
m
,
"TCPStore"
)
.
def
(
.
def
(
py
::
init
([](
std
::
string
hostname
,
uint16_t
port
,
bool
is_master
,
py
::
init
<
std
::
string
,
uint16_t
,
bool
,
size_t
,
std
::
chrono
::
seconds
>
())
size_t
world_size
,
std
::
chrono
::
seconds
timeout
)
{
return
std
::
make_shared
<
TCPStore
>
(
hostname
,
port
,
is_master
,
world_size
,
timeout
);
}),
py
::
arg
(
"hostname"
),
py
::
arg
(
"port"
),
py
::
arg
(
"is_master"
),
py
::
arg
(
"world_size"
),
py
::
arg
(
"timeout"
),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"add"
,
&
TCPStore
::
add
)
.
def
(
"add"
,
&
TCPStore
::
add
)
.
def
(
"get"
,
&
TCPStore
::
get
);
.
def
(
"get"
,
&
TCPStore
::
get
);
}
}
...
...
paddle/fluid/pybind/distributed_py.cc
浏览文件 @
c16f85f9
...
@@ -35,6 +35,11 @@ limitations under the License. */
...
@@ -35,6 +35,11 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#endif
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/distributed/collective/ProcessGroupGloo.h"
#include "paddle/fluid/distributed/store/tcp_store.h"
#endif
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
paddle
{
...
@@ -42,6 +47,14 @@ namespace pybind {
...
@@ -42,6 +47,14 @@ namespace pybind {
using
Tensor
=
paddle
::
experimental
::
Tensor
;
using
Tensor
=
paddle
::
experimental
::
Tensor
;
#if defined(PADDLE_WITH_GLOO)
using
ProcessGroupGloo
=
paddle
::
distributed
::
ProcessGroupGloo
;
using
GlooStore
=
paddle
::
distributed
::
ProcessGroupGloo
::
GlooStore
;
using
GlooOptions
=
paddle
::
distributed
::
ProcessGroupGloo
::
GlooOptions
;
#endif
static
std
::
string
GLOO_SOCKET_IFNAME_ENV
=
"GLOO_SOCKET_IFNAME"
;
// NOLINT
void
BindDistributed
(
py
::
module
*
m
)
{
void
BindDistributed
(
py
::
module
*
m
)
{
py
::
enum_
<
distributed
::
ReduceOp
>
(
*
m
,
"ReduceOp"
)
py
::
enum_
<
distributed
::
ReduceOp
>
(
*
m
,
"ReduceOp"
)
.
value
(
"SUM"
,
distributed
::
ReduceOp
::
SUM
)
.
value
(
"SUM"
,
distributed
::
ReduceOp
::
SUM
)
...
@@ -129,6 +142,7 @@ void BindDistributed(py::module *m) {
...
@@ -129,6 +142,7 @@ void BindDistributed(py::module *m) {
*
m
,
"ProcessGroupNCCL"
,
ProcessGroup
)
*
m
,
"ProcessGroupNCCL"
,
ProcessGroup
)
.
def
(
py
::
init
<
const
distributed
::
ProcessGroupStrategy
&
,
int
,
int
>
(),
.
def
(
py
::
init
<
const
distributed
::
ProcessGroupStrategy
&
,
int
,
int
>
(),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
py
::
class_
<
distributed
::
ProcessGroup
::
Task
,
py
::
class_
<
distributed
::
ProcessGroup
::
Task
,
std
::
shared_ptr
<
distributed
::
ProcessGroup
::
Task
>>
(
*
m
,
"task"
)
std
::
shared_ptr
<
distributed
::
ProcessGroup
::
Task
>>
(
*
m
,
"task"
)
...
@@ -138,7 +152,6 @@ void BindDistributed(py::module *m) {
...
@@ -138,7 +152,6 @@ void BindDistributed(py::module *m) {
py
::
call_guard
<
py
::
gil_scoped_release
>
())
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"synchronize"
,
&
distributed
::
ProcessGroup
::
Task
::
Synchronize
,
.
def
(
"synchronize"
,
&
distributed
::
ProcessGroup
::
Task
::
Synchronize
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
// define parallel strategy, it will be removed
// define parallel strategy, it will be removed
py
::
class_
<
distributed
::
ProcessGroupStrategy
>
pg_strategy
(
py
::
class_
<
distributed
::
ProcessGroupStrategy
>
pg_strategy
(
...
@@ -178,6 +191,45 @@ void BindDistributed(py::module *m) {
...
@@ -178,6 +191,45 @@ void BindDistributed(py::module *m) {
self
.
nrings_
=
nrings
;
self
.
nrings_
=
nrings
;
});
});
#if defined(PADDLE_WITH_GLOO)
py
::
class_
<
GlooOptions
>
(
*
m
,
"GlooOptions"
)
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"_device"
,
&
GlooOptions
::
device
)
.
def_static
(
"create"
,
&
GlooOptions
::
create
);
py
::
class_
<
GlooStore
,
std
::
shared_ptr
<
GlooStore
>>
(
*
m
,
"GlooStore"
)
.
def
(
py
::
init
(
[](
const
std
::
shared_ptr
<
paddle
::
distributed
::
TCPStore
>
&
store
)
{
return
std
::
make_shared
<
GlooStore
>
(
store
);
}),
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
class_
<
ProcessGroupGloo
,
std
::
shared_ptr
<
ProcessGroupGloo
>>
(
*
m
,
"ProcessGroupGloo"
,
ProcessGroup
)
.
def
(
py
::
init
<
const
std
::
shared_ptr
<
GlooStore
>
&
,
int
,
int
,
std
::
shared_ptr
<
GlooOptions
>
&>
(),
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
py
::
init
([](
const
std
::
shared_ptr
<
GlooStore
>
&
store
,
int
rank
,
int
world_size
)
{
auto
opts
=
GlooOptions
::
create
();
char
*
ifname
=
getenv
(
GLOO_SOCKET_IFNAME_ENV
.
c_str
());
if
(
ifname
&&
strlen
(
ifname
)
>
1
)
{
opts
->
device
=
ProcessGroupGloo
::
createDeviceForInterface
(
std
::
string
(
ifname
));
}
else
{
opts
->
device
=
ProcessGroupGloo
::
createDefaultDevice
();
}
return
std
::
make_shared
<
ProcessGroupGloo
>
(
store
,
rank
,
world_size
,
opts
);
}),
py
::
arg
(
"store"
),
py
::
arg
(
"rank"
),
py
::
arg
(
"world_size"
),
// py::arg("timeout") =
// kProcessGroupDefaultTimeout,
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def_static
(
"create_default_device"
,
&
ProcessGroupGloo
::
createDefaultDevice
);
#endif
m
->
def
(
"eager_assign_group_by_size"
,
m
->
def
(
"eager_assign_group_by_size"
,
[](
py
::
handle
py_tensors
,
std
::
vector
<
bool
>
is_sparse_gradient
,
[](
py
::
handle
py_tensors
,
std
::
vector
<
bool
>
is_sparse_gradient
,
std
::
vector
<
size_t
>
group_size_limits
,
std
::
vector
<
size_t
>
group_size_limits
,
...
...
python/paddle/fluid/tests/unittests/process_group_gloo.py
0 → 100644
浏览文件 @
c16f85f9
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
random
import
numpy
as
np
import
os
import
shutil
import
paddle
from
paddle.fluid
import
core
import
datetime
from
datetime
import
timedelta
import
paddle.fluid.core
as
core
from
paddle.fluid.framework
import
_test_eager_guard
from
paddle.fluid.dygraph.parallel
import
ParallelEnv
class
TestProcessGroupFp32
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
seed
(
2022
)
random
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
self
.
config
()
def
config
(
self
):
self
.
dtype
=
"float32"
self
.
shape
=
(
2
,
10
,
5
)
def
test_create_process_group_gloo
(
self
):
with
_test_eager_guard
():
nranks
=
ParallelEnv
().
nranks
rank
=
ParallelEnv
().
local_rank
is_master
=
True
if
rank
==
0
else
False
store
=
paddle
.
fluid
.
core
.
TCPStore
(
"127.0.0.1"
,
6172
,
is_master
,
nranks
,
datetime
.
timedelta
(
0
))
gloo_store
=
paddle
.
fluid
.
core
.
GlooStore
(
store
)
opt
=
paddle
.
fluid
.
core
.
GlooOptions
()
pg
=
paddle
.
fluid
.
core
.
ProcessGroupGloo
(
gloo_store
,
rank
,
nranks
)
# test allreduce sum
# rank 0
paddle
.
device
.
set_device
(
'cpu'
)
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
sum_result
=
x
+
y
if
rank
==
0
:
task
=
pg
.
allreduce
(
tensor_x
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
sum_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
sum_result
)
print
(
"test allreduce sum api ok"
)
# test allreduce max
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
max_result
=
paddle
.
maximum
(
tensor_x
,
tensor_y
)
if
rank
==
0
:
task
=
pg
.
allreduce
(
tensor_x
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_x
,
max_result
)
else
:
task
=
pg
.
allreduce
(
tensor_y
,
core
.
ReduceOp
.
MAX
)
task
.
wait
()
assert
np
.
array_equal
(
tensor_y
,
max_result
)
print
(
"test allreduce max api ok"
)
# test broadcast
# rank 0
x
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_x
=
paddle
.
to_tensor
(
x
)
# rank 1
y
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
tensor_y
=
paddle
.
to_tensor
(
y
)
broadcast_result
=
paddle
.
assign
(
tensor_x
)
if
rank
==
0
:
task
=
pg
.
broadcast
(
tensor_x
,
0
)
task
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_x
)
else
:
task
=
pg
.
broadcast
(
tensor_y
,
0
)
task
.
synchronize
()
assert
task
.
is_completed
()
assert
np
.
array_equal
(
broadcast_result
,
tensor_y
)
print
(
"test broadcast api ok"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_collective_process_group.py
浏览文件 @
c16f85f9
...
@@ -22,6 +22,9 @@ class TestProcessGroup(TestMultipleGpus):
...
@@ -22,6 +22,9 @@ class TestProcessGroup(TestMultipleGpus):
def
test_process_group_nccl
(
self
):
def
test_process_group_nccl
(
self
):
self
.
run_mnist_2gpu
(
'process_group_nccl.py'
)
self
.
run_mnist_2gpu
(
'process_group_nccl.py'
)
def
test_process_group_gloo
(
self
):
self
.
run_mnist_2gpu
(
'process_group_gloo.py'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录