Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
56b7ebbc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56b7ebbc
编写于
8月 03, 2021
作者:
W
WangXi
提交者:
GitHub
8月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] remove the using of global ring in hybrid parallel (#34525)
上级
9b6c7eb9
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
329 addition
and
247 deletion
+329
-247
paddle/fluid/imperative/bkcl_context.cc
paddle/fluid/imperative/bkcl_context.cc
+2
-2
paddle/fluid/imperative/nccl_context.cc
paddle/fluid/imperative/nccl_context.cc
+2
-2
paddle/fluid/operators/collective/c_comm_init_op.cc
paddle/fluid/operators/collective/c_comm_init_op.cc
+40
-44
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+5
-4
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+4
-4
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+4
-4
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+55
-21
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+3
-3
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+9
-15
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+46
-25
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+24
-45
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+31
-52
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+94
-18
未找到文件。
paddle/fluid/imperative/bkcl_context.cc
浏览文件 @
56b7ebbc
...
...
@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
}
...
...
@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
}
...
...
paddle/fluid/imperative/nccl_context.cc
浏览文件 @
56b7ebbc
...
...
@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
...
...
@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
compute_events_
.
emplace_back
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
...
...
paddle/fluid/operators/collective/c_comm_init_op.cc
浏览文件 @
56b7ebbc
...
...
@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
operators
{
...
...
@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using
UniqueId
=
ncclUniqueId
;
using
Place
=
platform
::
CUDAPlace
;
using
CommContext
=
platform
::
NCCLCommContext
;
#elif defined(PADDLE_WITH_XPU_BKCL)
using
UniqueId
=
BKCLUniqueId
;
using
Place
=
platform
::
XPUPlace
;
using
CommContext
=
platform
::
BKCLCommContext
;
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU or XPU."
));
#endif
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
)
||
is_xpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto
var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
if
(
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId
*
nccl_id
=
var
->
GetMutable
<
ncclUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
}
platform
::
NCCLCommContext
::
Instance
().
CreateNCCLComm
(
nccl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU."
));
#endif
}
else
if
(
is_xpu_place
(
place
))
{
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId
*
bkcl_id
=
var
->
GetMutable
<
BKCLUniqueId
>
();
UniqueId
*
comm_id
=
var
->
GetMutable
<
UniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ
(
rid
,
0
,
platform
::
errors
::
OutOfRange
(
"Ring id must equal 0 in multi Kunlun cards training, but got %d"
,
rid
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place
).
device
;
#endif
int
device_id
=
BOOST_GET_CONST
(
Place
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
}
platform
::
BKCLCommContext
::
Instance
().
CreateBKCLComm
(
bkcl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with XPU."
));
CommContext
::
Instance
().
CreateComm
(
comm_id
,
nranks
,
rank_id
,
device_id
,
rid
);
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
}
}
};
...
...
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
浏览文件 @
56b7ebbc
...
...
@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
...
...
@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID
(
&
bkcl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
,
ring_id
);
}
else
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
,
ring_id
);
}
CopyBKCLIDToVar
(
bkcl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
};
...
...
@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
};
...
...
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
浏览文件 @
56b7ebbc
...
...
@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
...
...
@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID
(
&
hccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
,
ring_id
);
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
,
ring_id
);
}
CopyHCCLIDToVar
(
hccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
};
...
...
@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
};
...
...
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
56b7ebbc
...
...
@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
...
...
@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID
(
&
nccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
,
ring_id
);
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
,
ring_id
);
}
CopyNCCLIDToVar
(
nccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
};
...
...
@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
};
...
...
paddle/fluid/platform/collective_helper.cc
浏览文件 @
56b7ebbc
...
...
@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
};
NCCLComm
*
NCCLCommContext
::
Create
NCCL
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
NCCLComm
*
NCCLCommContext
::
CreateComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
nccl_id
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
};
BKCLComm
*
BKCLCommContext
::
Create
BKCL
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
BKCLComm
*
BKCLCommContext
::
CreateComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
56b7ebbc
...
...
@@ -72,8 +72,8 @@ class NCCLCommContext {
return
comm_ctx
;
}
NCCLComm
*
Create
NCCLComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
NCCLComm
*
Create
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllNCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
...
@@ -274,8 +274,8 @@ class BKCLCommContext {
return
comm_ctx
;
}
BKCLComm
*
Create
BKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
BKCLComm
*
Create
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
...
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
56b7ebbc
...
...
@@ -42,7 +42,10 @@ namespace platform {
std
::
once_flag
SocketServer
::
init_flag_
;
constexpr
char
COMM_HEAD
[]
=
"_pd_gen_comm_id_"
;
struct
CommHead
{
int
version
=
1
;
// unused for now
int
ring_id
=
0
;
};
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
...
...
@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
void
CloseSocket
(
int
fd
)
{
CHECK_SYS_CALL
(
close
(
fd
),
"close"
);
}
static
int
SocketAccept
(
int
server_fd
,
const
char
*
head
)
{
static
int
SocketAccept
(
int
server_fd
,
const
CommHead
head
)
{
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
struct
sockaddr_in
client_addr
;
socklen_t
addr_length
=
sizeof
(
client_addr
);
char
buffer
[
1024
]
=
{
0
};
int
conn
=
-
1
;
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
while
(
true
)
{
CHECK_SYS_CALL_VAL
(
...
...
@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
&
addr_length
),
"accept"
,
conn
);
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
strlen
(
head
));
if
(
ret_val
>
0
&&
strncmp
(
buffer
,
head
,
strlen
(
head
))
==
0
)
{
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// send a message to the sender, indicating that the link is correct
CHECK_SYS_CALL
(
SocketSend
(
conn
,
phead
,
sizeof
(
head
)),
"send"
);
break
;
// accept client
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
...
...
@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
return
conn
;
}
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
char
*
head
)
{
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
CommHead
head
)
{
auto
addr
=
paddle
::
string
::
Split
(
ep
,
':'
);
PADDLE_ENFORCE_EQ
(
addr
.
size
(),
2UL
,
...
...
@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
std
::
string
host
=
addr
[
0
];
int
port
=
std
::
stoi
(
addr
[
1
]);
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
struct
sockaddr_in
server_addr
;
memset
(
&
server_addr
,
0
,
sizeof
(
server_addr
));
server_addr
.
sin_family
=
AF_INET
;
...
...
@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
platform
::
errors
::
Unavailable
(
"Open address %s failed: %s"
,
ep
,
strerror
(
errno
)));
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
char
buffer
[
1024
]
=
{
0
};
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
// TODO(wangxi) Set from env, default 900s=15min
int
timeout
=
900
*
1000
;
int
try_times
=
0
;
int
total_time
=
0
;
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
while
(
true
)
{
int
ret_val
=
-
1
;
RETRY_SYS_CALL_VAL
(
...
...
@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
continue
;
}
CHECK_SYS_CALL
(
SocketSend
(
sock
,
head
,
strlen
(
head
)),
"send"
);
break
;
CHECK_SYS_CALL
(
SocketSend
(
sock
,
phead
,
sizeof
(
head
)),
"send"
);
ret_val
=
SocketRecv
(
sock
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// recv same message from recver, indicating that the link is correct
break
;
// accept client
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
CloseSocket
(
sock
);
}
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
// unmatched link, retry after 80ms
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
80
));
}
return
sock
;
}
...
...
@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
CommHead
head
;
head
.
ring_id
=
ring_id
;
// connect with server
std
::
vector
<
int
>
connects
;
for
(
auto
server
:
servers
)
{
VLOG
(
3
)
<<
"connecting endpoint: "
<<
server
;
int
conn
=
ConnectAddr
(
server
,
COMM_HEAD
);
int
conn
=
ConnectAddr
(
server
,
head
);
connects
.
push_back
(
conn
);
}
VLOG
(
3
)
<<
"connecting completed..."
;
...
...
@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
int
server
=
CreateListenSocket
(
endpoint
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
,
ring_id
);
CloseSocket
(
server
);
}
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
int
client
=
SocketAccept
(
server_fd
,
COMM_HEAD
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
CommHead
head
;
head
.
ring_id
=
ring_id
;
int
client
=
SocketAccept
(
server_fd
,
head
);
for
(
size_t
i
=
0
;
i
<
nccl_ids
->
size
();
++
i
)
{
VLOG
(
3
)
<<
"trainer: "
<<
endpoint
...
...
@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
std::vector<Type> * nccl_ids);
std::vector<Type> * nccl_ids, \
int ring_id = 0); \
template void RecvBroadCastCommID<Type>( \
std::string endpoint, std::vector<Type> * nccl_ids, int ring_id = 0); \
template void RecvBroadCastCommID<Type>(int server_fd, std::string endpoint, \
std::vector<Type>* nccl_ids, \
int ring_id = 0);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE
(
ncclUniqueId
)
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
56b7ebbc
...
...
@@ -31,16 +31,16 @@ void CloseSocket(int fd);
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
// recv nccl id from socket
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
class
SocketServer
{
public:
...
...
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
56b7ebbc
...
...
@@ -126,11 +126,11 @@ class CollectiveHelper(object):
_add_sync_by_allreduce
(
block
)
return
if
core
.
is_compiled_with_cuda
():
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'nccl
_id'
),
name
=
unique_name
.
generate
(
'comm
_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
core
.
is_compiled_with_cuda
():
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
inputs
=
{},
...
...
@@ -139,6 +139,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
...
...
@@ -152,10 +153,6 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
})
elif
core
.
is_compiled_with_xpu
():
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'bkcl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
type
=
'c_gen_bkcl_id'
,
inputs
=
{},
...
...
@@ -164,6 +161,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
...
...
@@ -177,24 +175,20 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
})
elif
core
.
is_compiled_with_npu
():
hccl_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'hccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
endpoint_to_index_map
=
{
e
:
idx
for
idx
,
e
in
enumerate
(
endpoints
)}
block
.
append_op
(
type
=
'c_gen_hccl_id'
,
inputs
=
{},
outputs
=
{
'Out'
:
hccl
_id_var
},
outputs
=
{
'Out'
:
comm
_id_var
},
attrs
=
{
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
block
.
append_op
(
type
=
'c_comm_init_hccl'
,
inputs
=
{
'X'
:
hccl
_id_var
},
inputs
=
{
'X'
:
comm
_id_var
},
outputs
=
{},
attrs
=
{
'rank'
:
rank
,
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
56b7ebbc
...
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
return
inserted_op_num
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
):
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
s
):
"""
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
...
...
@@ -146,6 +146,7 @@ class FP16Utils(object):
name
=
inf_var_name
+
"@sharding"
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
type
=
'cast'
,
...
...
@@ -156,9 +157,14 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce communication should not overlap with calc
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
...
@@ -167,8 +173,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
...
...
@@ -177,11 +185,12 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@
staticmethod
def
sync_amp_check_nan_inf
(
block
,
ring_id
):
def
sync_amp_check_nan_inf
(
block
,
ring_id
s
):
update_loss_scaling_op_idx
=
-
1
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
...
...
@@ -189,10 +198,14 @@ class FP16Utils(object):
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
break
# not use amp
if
update_loss_scaling_op_idx
==
-
1
:
return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_int32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
...
...
@@ -212,8 +225,13 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
...
@@ -222,8 +240,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_global
},
...
...
@@ -232,4 +252,5 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_global
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
56b7ebbc
...
...
@@ -25,7 +25,7 @@ class GradientClipHelper(object):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
prune_gradient_clip
(
self
,
block
,
shard
,
pure_dp_degree
=
1
):
def
prune_gradient_clip
(
self
,
block
,
shard
,
ring_ids
):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
...
...
@@ -82,33 +82,23 @@ class GradientClipHelper(object):
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset
=
1
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce should not overlap with calc and should be scheduled in calc stream
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
+
idx_offset
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'ring_id'
:
self
.
mp_
ring_id
,
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
# global norm should only be sum within each model parallelism word size when use global group
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
idx_offset
+=
1
# the grad sum here should take the all and only param in the current shard
to_check_param
=
set
(
reversed_x_paramname
)
...
...
@@ -126,20 +116,25 @@ class GradientClipHelper(object):
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_id
,
pure_dp_degree
=
1
):
def
sync_global_norm
(
self
,
block
,
ring_id
s
):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
# FIXME(wangxi): mp should prune duplicated param_grads
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
if
op
.
type
==
"sum"
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
...
...
@@ -149,20 +144,4 @@ class GradientClipHelper(object):
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
# global norm should only be sum within each model parallelism word size
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
return
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
56b7ebbc
...
...
@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if
self
.
sharding_degree
<=
1
:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
self
.
global_ring_id
)
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
# clip
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
main_block
,
self
.
global_ring_id
,
self
.
dp_degree
)
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
)
# step6: loss div dp_degree
global_dp_degree
=
self
.
sharding_degree
*
self
.
dp_degree
...
...
@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank
,
ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
...
...
@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
...
...
@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
...
...
@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
...
...
@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
...
...
@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
self
.
pp_rank_
,
self
.
pp_rank
)
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
False
,
sync
=
False
)
if
core
.
is_compiled_with_npu
():
self
.
_init_npu_pipeline_comm
(
startup_block
)
return
...
...
@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
def
_init_comm
(
self
):
...
...
@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
persistable
=
False
)
# global ring
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
global_endpoints
,
self
.
global_rank
,
self
.
global_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# mp ring
if
self
.
mp_degree
>
1
:
self
.
_collective_helper
.
_init_communicator
(
...
...
@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
mp_rank
,
self
.
mp_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding ring
if
self
.
sharding_degree
>
1
:
...
...
@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
sharding_rank
,
self
.
sharding_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# pp ring
if
self
.
pp_degree
>
1
:
...
...
@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
dp_rank
,
self
.
dp_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
startup_block
.
_sync_with_cpp
()
...
...
@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
self
.
global_ring_id
)
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
if
self
.
mp_degree
*
self
.
pp_degree
==
1
:
# separate the sharding-hybrid senario to keep the accuracy
gradientclip_helper
=
GradientClipHelper
(
self
.
sharding_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
1
)
else
:
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
self
.
dp_degree
)
block
,
self
.
_shard
,
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
# build prog deps
reduced_grads
=
[]
...
...
@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp
if
self
.
pp_degree
>
1
:
self
.
pp_ring_id
=
20
self
.
pp_pair_ring_id
=
20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self
.
pp_ring_id
=
4
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_degree
*
self
.
mp_degree
)
%
self
.
pp_degree
# (NOTE): Already adjust for (outter-pure) dp
...
...
@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_first_stage_idx
+
pp_stage_offset
*
i
])
assert
self
.
current_endpoint
in
self
.
pp_group_endpoints
else
:
self
.
pp_degree
=
1
self
.
pp_ring_id
=
-
1
self
.
pp_degree
=
1
self
.
pp_pair_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_id
=
-
1
self
.
pp_group_endpoints
=
[]
...
...
@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
outputs
=
{
'Out'
:
params
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
# sync within global group
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding gradient merge
def
create_persistable_gradients_and_insert_merge_ops
(
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
56b7ebbc
...
...
@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
is
True
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
...
...
@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
])
class
TestFleet
MetaOptimizer_V1
(
TestFleetMetaOptimizer
):
class
TestFleet
ShardingHybridOptimizer
(
TestFleetMetaOptimizer
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"3"
os
.
environ
[
...
...
@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
self
.
sharding_ring_id
=
1
self
.
dp_ring_id
=
2
self
.
global_ring_id
=
3
self
.
pp_ring_id
=
20
self
.
pp_
pair_
ring_id
=
20
def
test_sharding_with_mp
(
self
):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
...
...
@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
...
@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
...
@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
...
@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
...
@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
...
@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
...
@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
])
self
.
assertEqual
(
main_prog_op_types
,
[
...
...
@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
"c_comm_init"
]
self
.
assertIn
(
self
.
sharding_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_
pair_
ring_id
,
created_ring_ids
)
# check correctness of pp group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
...
@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
...
@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
'c_allreduce_sum'
:
assert
'FusedOutput'
in
op
.
input_arg_names
[
0
]
def
test_hybrid_with_mp_pp_amp_gclip
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
pp_net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'amp'
)
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"sharding_degree"
:
1
,
"mp_degree"
:
2
,
"pp_degree"
:
2
,
"dp_degree"
:
1
,
}
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
"schedule_mode"
:
"1F1B"
,
"micro_batch_size"
:
2
,
"accumulate_steps"
:
4
,
}
clip
=
paddle
.
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
grad_clip
=
clip
)
train_prog
=
train_prog
.
_pipeline_opt
[
'section_program'
]
startup_prog
=
startup_prog
.
_pipeline_opt
[
'startup_program'
]
startup_prog_ops
=
startup_prog
.
global_block
().
ops
main_prog_ops
=
train_prog
.
global_block
().
ops
# check program
startup_prog_op_types
=
[
op
.
type
for
op
in
startup_prog_ops
]
main_prog_op_types
=
[
op
.
type
for
op
in
main_prog_ops
]
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
])
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_send'
,
main_prog_op_types
)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_max'
),
2
)
# global gradient clip, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_sum'
),
2
)
# should has ring id for pp
created_ring_ids
=
[
op
.
desc
.
attr
(
"ring_id"
)
for
op
in
startup_prog_ops
if
op
.
type
==
"c_comm_init"
]
self
.
assertIn
(
self
.
mp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_pair_ring_id
,
created_ring_ids
)
# check correctness of pp group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_0"
:
mp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
mp_group_waiting_ports
,
[
'127.0.0.1:36003'
])
# check correctness of sharding group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_1"
:
pp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
pp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录