Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
56b7ebbc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56b7ebbc
编写于
8月 03, 2021
作者:
W
WangXi
提交者:
GitHub
8月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] remove the using of global ring in hybrid parallel (#34525)
上级
9b6c7eb9
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
329 addition
and
247 deletion
+329
-247
paddle/fluid/imperative/bkcl_context.cc
paddle/fluid/imperative/bkcl_context.cc
+2
-2
paddle/fluid/imperative/nccl_context.cc
paddle/fluid/imperative/nccl_context.cc
+2
-2
paddle/fluid/operators/collective/c_comm_init_op.cc
paddle/fluid/operators/collective/c_comm_init_op.cc
+40
-44
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+5
-4
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+4
-4
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+4
-4
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+55
-21
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+3
-3
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+9
-15
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+46
-25
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+24
-45
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+31
-52
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+94
-18
未找到文件。
paddle/fluid/imperative/bkcl_context.cc
浏览文件 @
56b7ebbc
...
@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
...
@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
ring_id
);
}
}
...
@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
...
@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
&
bkcl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
}
}
...
...
paddle/fluid/imperative/nccl_context.cc
浏览文件 @
56b7ebbc
...
@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
...
@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
&
nccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
ring_id
);
...
@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
...
@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
&
nccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
compute_events_
.
emplace_back
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
compute_events_
.
emplace_back
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
...
...
paddle/fluid/operators/collective/c_comm_init_op.cc
浏览文件 @
56b7ebbc
...
@@ -24,15 +24,16 @@ limitations under the License. */
...
@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Scope
;
class
Scope
;
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
...
@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using
UniqueId
=
ncclUniqueId
;
using
Place
=
platform
::
CUDAPlace
;
using
CommContext
=
platform
::
NCCLCommContext
;
#elif defined(PADDLE_WITH_XPU_BKCL)
using
UniqueId
=
BKCLUniqueId
;
using
Place
=
platform
::
XPUPlace
;
using
CommContext
=
platform
::
BKCLCommContext
;
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU or XPU."
));
#endif
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
)
||
is_xpu_place
(
place
),
true
,
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
)
||
is_xpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
"CCommInitOp can run on gpu or xpu place only."
));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto
var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
if
(
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId
*
nccl_id
=
var
->
GetMutable
<
ncclUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
UniqueId
*
comm_id
=
var
->
GetMutable
<
UniqueId
>
();
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
}
platform
::
NCCLCommContext
::
Instance
().
CreateNCCLComm
(
nccl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU."
));
#endif
}
else
if
(
is_xpu_place
(
place
))
{
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId
*
bkcl_id
=
var
->
GetMutable
<
BKCLUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
rid
,
0
,
rid
,
0
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"Ring id must equal 0 in multi Kunlun cards training, but got %d"
,
"Ring id must equal 0 in multi Kunlun cards training, but got %d"
,
rid
));
rid
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place
).
device
;
#endif
int
device_id
=
BOOST_GET_CONST
(
Place
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
device_id
=
Attr
<
int
>
(
"device_id"
);
}
}
platform
::
BKCLCommContext
::
Instance
().
CreateBKCLComm
(
CommContext
::
Instance
().
CreateComm
(
comm_id
,
nranks
,
rank_id
,
device_id
,
bkcl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with XPU."
));
#endif
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
}
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
...
@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
...
@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID
(
&
bkcl_ids
);
GenBKCLID
(
&
bkcl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
,
ring_id
);
}
else
{
}
else
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
,
ring_id
);
}
}
CopyBKCLIDToVar
(
bkcl_ids
,
func
,
scope
);
CopyBKCLIDToVar
(
bkcl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
...
@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
...
@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID
(
&
hccl_ids
);
GenHCCLID
(
&
hccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
,
ring_id
);
}
else
{
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
,
ring_id
);
}
}
CopyHCCLIDToVar
(
hccl_ids
,
func
,
scope
);
CopyHCCLIDToVar
(
hccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID
(
&
nccl_ids
);
GenNCCLID
(
&
nccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
,
ring_id
);
}
else
{
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
,
ring_id
);
}
}
CopyNCCLIDToVar
(
nccl_ids
,
func
,
scope
);
CopyNCCLIDToVar
(
nccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/platform/collective_helper.cc
浏览文件 @
56b7ebbc
...
@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
...
@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
};
};
NCCLComm
*
NCCLCommContext
::
Create
NCCL
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
NCCLComm
*
NCCLCommContext
::
CreateComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
nccl_id
,
PADDLE_ENFORCE_NOT_NULL
(
nccl_id
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
...
@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
};
};
BKCLComm
*
BKCLCommContext
::
Create
BKCL
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
BKCLComm
*
BKCLCommContext
::
CreateComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
56b7ebbc
...
@@ -72,8 +72,8 @@ class NCCLCommContext {
...
@@ -72,8 +72,8 @@ class NCCLCommContext {
return
comm_ctx
;
return
comm_ctx
;
}
}
NCCLComm
*
Create
NCCLComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
NCCLComm
*
Create
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
dev_id
,
int
ring_id
=
0
);
int
ring_id
=
0
);
void
CreateAllNCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
void
CreateAllNCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
@@ -274,8 +274,8 @@ class BKCLCommContext {
...
@@ -274,8 +274,8 @@ class BKCLCommContext {
return
comm_ctx
;
return
comm_ctx
;
}
}
BKCLComm
*
Create
BKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
BKCLComm
*
Create
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
dev_id
,
int
ring_id
=
0
);
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
...
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
56b7ebbc
...
@@ -42,7 +42,10 @@ namespace platform {
...
@@ -42,7 +42,10 @@ namespace platform {
std
::
once_flag
SocketServer
::
init_flag_
;
std
::
once_flag
SocketServer
::
init_flag_
;
constexpr
char
COMM_HEAD
[]
=
"_pd_gen_comm_id_"
;
struct
CommHead
{
int
version
=
1
;
// unused for now
int
ring_id
=
0
;
};
// Check system calls, such as socket, bind.
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
#define CHECK_SYS_CALL(call, name) \
...
@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
...
@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
void
CloseSocket
(
int
fd
)
{
CHECK_SYS_CALL
(
close
(
fd
),
"close"
);
}
void
CloseSocket
(
int
fd
)
{
CHECK_SYS_CALL
(
close
(
fd
),
"close"
);
}
static
int
SocketAccept
(
int
server_fd
,
const
char
*
head
)
{
static
int
SocketAccept
(
int
server_fd
,
const
CommHead
head
)
{
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
struct
sockaddr_in
client_addr
;
struct
sockaddr_in
client_addr
;
socklen_t
addr_length
=
sizeof
(
client_addr
);
socklen_t
addr_length
=
sizeof
(
client_addr
);
char
buffer
[
1024
]
=
{
0
};
char
buffer
[
1024
]
=
{
0
};
int
conn
=
-
1
;
int
conn
=
-
1
;
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
while
(
true
)
{
while
(
true
)
{
CHECK_SYS_CALL_VAL
(
CHECK_SYS_CALL_VAL
(
...
@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
...
@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
&
addr_length
),
&
addr_length
),
"accept"
,
conn
);
"accept"
,
conn
);
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
strlen
(
head
));
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
strncmp
(
buffer
,
head
,
strlen
(
head
))
==
0
)
{
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// send a message to the sender, indicating that the link is correct
CHECK_SYS_CALL
(
SocketSend
(
conn
,
phead
,
sizeof
(
head
)),
"send"
);
break
;
// accept client
break
;
// accept client
}
else
{
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
...
@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
...
@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
return
conn
;
return
conn
;
}
}
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
char
*
head
)
{
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
CommHead
head
)
{
auto
addr
=
paddle
::
string
::
Split
(
ep
,
':'
);
auto
addr
=
paddle
::
string
::
Split
(
ep
,
':'
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
addr
.
size
(),
2UL
,
addr
.
size
(),
2UL
,
...
@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
std
::
string
host
=
addr
[
0
];
std
::
string
host
=
addr
[
0
];
int
port
=
std
::
stoi
(
addr
[
1
]);
int
port
=
std
::
stoi
(
addr
[
1
]);
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
struct
sockaddr_in
server_addr
;
struct
sockaddr_in
server_addr
;
memset
(
&
server_addr
,
0
,
sizeof
(
server_addr
));
memset
(
&
server_addr
,
0
,
sizeof
(
server_addr
));
server_addr
.
sin_family
=
AF_INET
;
server_addr
.
sin_family
=
AF_INET
;
...
@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
platform
::
errors
::
Unavailable
(
"Open address %s failed: %s"
,
platform
::
errors
::
Unavailable
(
"Open address %s failed: %s"
,
ep
,
strerror
(
errno
)));
ep
,
strerror
(
errno
)));
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
char
buffer
[
1024
]
=
{
0
};
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
// TODO(wangxi) Set from env, default 900s=15min
// TODO(wangxi) Set from env, default 900s=15min
int
timeout
=
900
*
1000
;
int
timeout
=
900
*
1000
;
int
try_times
=
0
;
int
try_times
=
0
;
int
total_time
=
0
;
int
total_time
=
0
;
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
while
(
true
)
{
while
(
true
)
{
int
ret_val
=
-
1
;
int
ret_val
=
-
1
;
RETRY_SYS_CALL_VAL
(
RETRY_SYS_CALL_VAL
(
...
@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
continue
;
continue
;
}
}
CHECK_SYS_CALL
(
SocketSend
(
sock
,
head
,
strlen
(
head
)),
"send"
);
CHECK_SYS_CALL
(
SocketSend
(
sock
,
phead
,
sizeof
(
head
)),
"send"
);
break
;
ret_val
=
SocketRecv
(
sock
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// recv same message from recver, indicating that the link is correct
break
;
// accept client
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
CloseSocket
(
sock
);
}
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
// unmatched link, retry after 80ms
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
80
));
}
}
return
sock
;
return
sock
;
}
}
...
@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
...
@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
CommHead
head
;
head
.
ring_id
=
ring_id
;
// connect with server
// connect with server
std
::
vector
<
int
>
connects
;
std
::
vector
<
int
>
connects
;
for
(
auto
server
:
servers
)
{
for
(
auto
server
:
servers
)
{
VLOG
(
3
)
<<
"connecting endpoint: "
<<
server
;
VLOG
(
3
)
<<
"connecting endpoint: "
<<
server
;
int
conn
=
ConnectAddr
(
server
,
COMM_HEAD
);
int
conn
=
ConnectAddr
(
server
,
head
);
connects
.
push_back
(
conn
);
connects
.
push_back
(
conn
);
}
}
VLOG
(
3
)
<<
"connecting completed..."
;
VLOG
(
3
)
<<
"connecting completed..."
;
...
@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
...
@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
int
server
=
CreateListenSocket
(
endpoint
);
int
server
=
CreateListenSocket
(
endpoint
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
,
ring_id
);
CloseSocket
(
server
);
CloseSocket
(
server
);
}
}
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
int
client
=
SocketAccept
(
server_fd
,
COMM_HEAD
);
CommHead
head
;
head
.
ring_id
=
ring_id
;
int
client
=
SocketAccept
(
server_fd
,
head
);
for
(
size_t
i
=
0
;
i
<
nccl_ids
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nccl_ids
->
size
();
++
i
)
{
VLOG
(
3
)
<<
"trainer: "
<<
endpoint
VLOG
(
3
)
<<
"trainer: "
<<
endpoint
...
@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
...
@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
/// template instantiation
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
std::vector<Type> * nccl_ids, \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
int ring_id = 0); \
std::vector<Type> * nccl_ids);
template void RecvBroadCastCommID<Type>( \
std::string endpoint, std::vector<Type> * nccl_ids, int ring_id = 0); \
template void RecvBroadCastCommID<Type>(int server_fd, std::string endpoint, \
std::vector<Type>* nccl_ids, \
int ring_id = 0);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE
(
ncclUniqueId
)
INSTANT_TEMPLATE
(
ncclUniqueId
)
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
56b7ebbc
...
@@ -31,16 +31,16 @@ void CloseSocket(int fd);
...
@@ -31,16 +31,16 @@ void CloseSocket(int fd);
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
// recv nccl id from socket
// recv nccl id from socket
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
class
SocketServer
{
class
SocketServer
{
public:
public:
...
...
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
56b7ebbc
...
@@ -126,11 +126,11 @@ class CollectiveHelper(object):
...
@@ -126,11 +126,11 @@ class CollectiveHelper(object):
_add_sync_by_allreduce
(
block
)
_add_sync_by_allreduce
(
block
)
return
return
if
core
.
is_compiled_with_cuda
():
comm_id_var
=
block
.
create_var
(
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'nccl
_id'
),
name
=
unique_name
.
generate
(
'comm
_id'
),
persistable
=
True
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
core
.
is_compiled_with_cuda
():
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
type
=
'c_gen_nccl_id'
,
inputs
=
{},
inputs
=
{},
...
@@ -139,6 +139,7 @@ class CollectiveHelper(object):
...
@@ -139,6 +139,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
...
@@ -152,10 +153,6 @@ class CollectiveHelper(object):
...
@@ -152,10 +153,6 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
elif
core
.
is_compiled_with_xpu
():
elif
core
.
is_compiled_with_xpu
():
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'bkcl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_bkcl_id'
,
type
=
'c_gen_bkcl_id'
,
inputs
=
{},
inputs
=
{},
...
@@ -164,6 +161,7 @@ class CollectiveHelper(object):
...
@@ -164,6 +161,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
...
@@ -177,24 +175,20 @@ class CollectiveHelper(object):
...
@@ -177,24 +175,20 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
elif
core
.
is_compiled_with_npu
():
elif
core
.
is_compiled_with_npu
():
hccl_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'hccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
endpoint_to_index_map
=
{
e
:
idx
for
idx
,
e
in
enumerate
(
endpoints
)}
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_hccl_id'
,
type
=
'c_gen_hccl_id'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
'Out'
:
hccl
_id_var
},
outputs
=
{
'Out'
:
comm
_id_var
},
attrs
=
{
attrs
=
{
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
type
=
'c_comm_init_hccl'
,
type
=
'c_comm_init_hccl'
,
inputs
=
{
'X'
:
hccl
_id_var
},
inputs
=
{
'X'
:
comm
_id_var
},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
'rank'
:
rank
,
'rank'
:
rank
,
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
56b7ebbc
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
return
inserted_op_num
return
inserted_op_num
@
staticmethod
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
):
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
s
):
"""
"""
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
2. revise amp inifine grad checking for sharding
...
@@ -146,6 +146,7 @@ class FP16Utils(object):
...
@@ -146,6 +146,7 @@ class FP16Utils(object):
name
=
inf_var_name
+
"@sharding"
,
name
=
inf_var_name
+
"@sharding"
,
shape
=
inf_var
.
shape
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
...
@@ -156,9 +157,14 @@ class FP16Utils(object):
...
@@ -156,9 +157,14 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce communication should not overlap with calc
# this allreduce communication should not overlap with calc
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
@@ -167,8 +173,10 @@ class FP16Utils(object):
...
@@ -167,8 +173,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
outputs
=
{
'Out'
:
inf_var_sharding
},
...
@@ -177,11 +185,12 @@ class FP16Utils(object):
...
@@ -177,11 +185,12 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_sharding
.
dtype
,
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@
staticmethod
@
staticmethod
def
sync_amp_check_nan_inf
(
block
,
ring_id
):
def
sync_amp_check_nan_inf
(
block
,
ring_id
s
):
update_loss_scaling_op_idx
=
-
1
update_loss_scaling_op_idx
=
-
1
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
...
@@ -189,10 +198,14 @@ class FP16Utils(object):
...
@@ -189,10 +198,14 @@ class FP16Utils(object):
update_loss_scaling_op_idx
=
idx
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
break
# not use amp
# not use amp
if
update_loss_scaling_op_idx
==
-
1
:
if
update_loss_scaling_op_idx
==
-
1
:
return
return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var
=
block
.
var
(
inf_var_name
)
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_int32
=
block
.
create_var
(
inf_var_int32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
name
=
inf_var_name
+
"@cast_int32"
,
...
@@ -212,8 +225,13 @@ class FP16Utils(object):
...
@@ -212,8 +225,13 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
@@ -222,8 +240,10 @@ class FP16Utils(object):
...
@@ -222,8 +240,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_global
},
outputs
=
{
'Out'
:
inf_var_global
},
...
@@ -232,4 +252,5 @@ class FP16Utils(object):
...
@@ -232,4 +252,5 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_global
.
dtype
,
"out_dtype"
:
inf_var_global
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
56b7ebbc
...
@@ -25,7 +25,7 @@ class GradientClipHelper(object):
...
@@ -25,7 +25,7 @@ class GradientClipHelper(object):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
prune_gradient_clip
(
self
,
block
,
shard
,
pure_dp_degree
=
1
):
def
prune_gradient_clip
(
self
,
block
,
shard
,
ring_ids
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
...
@@ -82,33 +82,23 @@ class GradientClipHelper(object):
...
@@ -82,33 +82,23 @@ class GradientClipHelper(object):
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset
=
1
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce should not overlap with calc and should be scheduled in calc stream
# this allreduce should not overlap with calc and should be scheduled in calc stream
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
+
idx_offset
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
attrs
=
{
'ring_id'
:
self
.
mp_
ring_id
,
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
idx_offset
+=
1
# global norm should only be sum within each model parallelism word size when use global group
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
# the grad sum here should take the all and only param in the current shard
# the grad sum here should take the all and only param in the current shard
to_check_param
=
set
(
reversed_x_paramname
)
to_check_param
=
set
(
reversed_x_paramname
)
...
@@ -126,20 +116,25 @@ class GradientClipHelper(object):
...
@@ -126,20 +116,25 @@ class GradientClipHelper(object):
return
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_id
,
pure_dp_degree
=
1
):
def
sync_global_norm
(
self
,
block
,
ring_id
s
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
keep: sum, sqrt, elementwise_max, elementwise_div
"""
"""
# FIXME(wangxi): mp should prune duplicated param_grads
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
continue
if
op
.
type
==
"sum"
:
if
op
.
type
==
"sum"
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
...
@@ -149,20 +144,4 @@ class GradientClipHelper(object):
...
@@ -149,20 +144,4 @@ class GradientClipHelper(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
# global norm should only be sum within each model parallelism word size
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
return
return
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
56b7ebbc
...
@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# if not use sharding, adapt amp/clip, for remain parallelism.
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
# cast --> amp --> clip --> opt
if
self
.
sharding_degree
<=
1
:
if
self
.
sharding_degree
<=
1
:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp
# amp
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
self
.
global_ring_id
)
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
# clip
# clip
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
gradientclip_helper
.
sync_global_norm
(
main_block
,
self
.
global_ring_id
,
self
.
dp_degree
)
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
)
# step6: loss div dp_degree
# step6: loss div dp_degree
global_dp_degree
=
self
.
sharding_degree
*
self
.
dp_degree
global_dp_degree
=
self
.
sharding_degree
*
self
.
dp_degree
...
@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank
,
pp_rank
,
ring_id
,
ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
...
@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
ring_id
))
...
@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
ring_id
))
...
@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
...
@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
...
@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
self
.
pp_rank_
,
self
.
pp_rank
)
self
.
pp_rank_
,
self
.
pp_rank
)
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
False
,
sync
=
False
)
if
core
.
is_compiled_with_npu
():
if
core
.
is_compiled_with_npu
():
self
.
_init_npu_pipeline_comm
(
startup_block
)
self
.
_init_npu_pipeline_comm
(
startup_block
)
return
return
...
@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
if
self
.
pp_rank
in
pair
:
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
def
_init_comm
(
self
):
def
_init_comm
(
self
):
...
@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
persistable
=
False
)
persistable
=
False
)
# global ring
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
global_endpoints
,
self
.
global_rank
,
self
.
global_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# mp ring
# mp ring
if
self
.
mp_degree
>
1
:
if
self
.
mp_degree
>
1
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_collective_helper
.
_init_communicator
(
...
@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
mp_rank
,
self
.
mp_rank
,
self
.
mp_ring_id
,
self
.
mp_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding ring
# sharding ring
if
self
.
sharding_degree
>
1
:
if
self
.
sharding_degree
>
1
:
...
@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
sharding_rank
,
self
.
sharding_rank
,
self
.
sharding_ring_id
,
self
.
sharding_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# pp ring
# pp ring
if
self
.
pp_degree
>
1
:
if
self
.
pp_degree
>
1
:
...
@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
dp_rank
,
self
.
dp_rank
,
self
.
dp_ring_id
,
self
.
dp_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
startup_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
# amp could use global group for sync
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
FP16Utils
.
prune_fp16
(
self
.
global_ring_id
)
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
if
self
.
mp_degree
*
self
.
pp_degree
==
1
:
gradientclip_helper
=
GradientClipHelper
(
None
)
# separate the sharding-hybrid senario to keep the accuracy
gradientclip_helper
=
GradientClipHelper
(
self
.
sharding_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
1
)
block
,
self
.
_shard
,
else
:
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
self
.
dp_degree
)
# build prog deps
# build prog deps
reduced_grads
=
[]
reduced_grads
=
[]
...
@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp
# pp
if
self
.
pp_degree
>
1
:
if
self
.
pp_degree
>
1
:
self
.
pp_ring_id
=
20
self
.
pp_pair_ring_id
=
20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self
.
pp_ring_id
=
4
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_degree
*
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_degree
*
self
.
mp_degree
)
%
self
.
pp_degree
self
.
mp_degree
)
%
self
.
pp_degree
# (NOTE): Already adjust for (outter-pure) dp
# (NOTE): Already adjust for (outter-pure) dp
...
@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_first_stage_idx
+
pp_stage_offset
*
i
])
pp_first_stage_idx
+
pp_stage_offset
*
i
])
assert
self
.
current_endpoint
in
self
.
pp_group_endpoints
assert
self
.
current_endpoint
in
self
.
pp_group_endpoints
else
:
else
:
self
.
pp_degree
=
1
self
.
pp_ring_id
=
-
1
self
.
pp_ring_id
=
-
1
self
.
pp_degree
=
1
self
.
pp_pair_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_id
=
-
1
self
.
pp_group_id
=
-
1
self
.
pp_group_endpoints
=
[]
self
.
pp_group_endpoints
=
[]
...
@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
outputs
=
{
'Out'
:
params
},
outputs
=
{
'Out'
:
params
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
OP_ROLE_KEY
:
OpRole
.
Forward
})
# sync within global group
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding gradient merge
# sharding gradient merge
def
create_persistable_gradients_and_insert_merge_ops
(
def
create_persistable_gradients_and_insert_merge_ops
(
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
56b7ebbc
...
@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
parameters
=
[
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
is
True
]
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
...
@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
])
])
class
TestFleet
MetaOptimizer_V1
(
TestFleetMetaOptimizer
):
class
TestFleet
ShardingHybridOptimizer
(
TestFleetMetaOptimizer
):
def
setUp
(
self
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"3"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"3"
os
.
environ
[
os
.
environ
[
...
@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
self
.
sharding_ring_id
=
1
self
.
sharding_ring_id
=
1
self
.
dp_ring_id
=
2
self
.
dp_ring_id
=
2
self
.
global_ring_id
=
3
self
.
global_ring_id
=
3
self
.
pp_ring_id
=
20
self
.
pp_
pair_
ring_id
=
20
def
test_sharding_with_mp
(
self
):
def
test_sharding_with_mp
(
self
):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
...
@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
"c_comm_init"
if
op
.
type
==
"c_comm_init"
]
]
self
.
assertIn
(
self
.
sharding_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
sharding_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_
pair_
ring_id
,
created_ring_ids
)
# check correctness of pp group
# check correctness of pp group
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
'c_allreduce_sum'
:
if
op
.
type
==
'c_allreduce_sum'
:
assert
'FusedOutput'
in
op
.
input_arg_names
[
0
]
assert
'FusedOutput'
in
op
.
input_arg_names
[
0
]
def
test_hybrid_with_mp_pp_amp_gclip
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
pp_net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'amp'
)
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"sharding_degree"
:
1
,
"mp_degree"
:
2
,
"pp_degree"
:
2
,
"dp_degree"
:
1
,
}
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
"schedule_mode"
:
"1F1B"
,
"micro_batch_size"
:
2
,
"accumulate_steps"
:
4
,
}
clip
=
paddle
.
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
grad_clip
=
clip
)
train_prog
=
train_prog
.
_pipeline_opt
[
'section_program'
]
startup_prog
=
startup_prog
.
_pipeline_opt
[
'startup_program'
]
startup_prog_ops
=
startup_prog
.
global_block
().
ops
main_prog_ops
=
train_prog
.
global_block
().
ops
# check program
startup_prog_op_types
=
[
op
.
type
for
op
in
startup_prog_ops
]
main_prog_op_types
=
[
op
.
type
for
op
in
main_prog_ops
]
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
])
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_send'
,
main_prog_op_types
)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_max'
),
2
)
# global gradient clip, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_sum'
),
2
)
# should has ring id for pp
created_ring_ids
=
[
op
.
desc
.
attr
(
"ring_id"
)
for
op
in
startup_prog_ops
if
op
.
type
==
"c_comm_init"
]
self
.
assertIn
(
self
.
mp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_pair_ring_id
,
created_ring_ids
)
# check correctness of pp group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_0"
:
mp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
mp_group_waiting_ports
,
[
'127.0.0.1:36003'
])
# check correctness of sharding group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_1"
:
pp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
pp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录