Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
56b7ebbc
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
56b7ebbc
编写于
8月 03, 2021
作者:
W
WangXi
提交者:
GitHub
8月 03, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[hybrid] remove the using of global ring in hybrid parallel (#34525)
上级
9b6c7eb9
变更
15
显示空白变更内容
内联
并排
Showing
15 changed file
with
329 addition
and
247 deletion
+329
-247
paddle/fluid/imperative/bkcl_context.cc
paddle/fluid/imperative/bkcl_context.cc
+2
-2
paddle/fluid/imperative/nccl_context.cc
paddle/fluid/imperative/nccl_context.cc
+2
-2
paddle/fluid/operators/collective/c_comm_init_op.cc
paddle/fluid/operators/collective/c_comm_init_op.cc
+40
-44
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
+5
-4
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
+5
-4
paddle/fluid/platform/collective_helper.cc
paddle/fluid/platform/collective_helper.cc
+4
-4
paddle/fluid/platform/collective_helper.h
paddle/fluid/platform/collective_helper.h
+4
-4
paddle/fluid/platform/gen_comm_id_helper.cc
paddle/fluid/platform/gen_comm_id_helper.cc
+55
-21
paddle/fluid/platform/gen_comm_id_helper.h
paddle/fluid/platform/gen_comm_id_helper.h
+3
-3
python/paddle/distributed/fleet/meta_optimizers/common.py
python/paddle/distributed/fleet/meta_optimizers/common.py
+9
-15
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
...distributed/fleet/meta_optimizers/sharding/fp16_helper.py
+46
-25
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
...ed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
+24
-45
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+31
-52
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
...uid/tests/unittests/test_fleet_sharding_meta_optimizer.py
+94
-18
未找到文件。
paddle/fluid/imperative/bkcl_context.cc
浏览文件 @
56b7ebbc
...
@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
...
@@ -92,7 +92,7 @@ void BKCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
&
bkcl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
ring_id
);
}
}
...
@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
...
@@ -116,7 +116,7 @@ void BKCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" xpu id: "
<<
xpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign bkcl_comm in XPUDeviceContext within ring_id
// it will assign bkcl_comm in XPUDeviceContext within ring_id
platform
::
BKCLCommContext
::
Instance
().
Create
BKCL
Comm
(
platform
::
BKCLCommContext
::
Instance
().
CreateComm
(
&
bkcl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
&
bkcl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
xpu_id
,
ring_id
);
}
}
...
...
paddle/fluid/imperative/nccl_context.cc
浏览文件 @
56b7ebbc
...
@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
...
@@ -75,7 +75,7 @@ void NCCLParallelContext::Init() {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
&
nccl_ids
[
ring_id
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
ring_id
);
...
@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
...
@@ -108,7 +108,7 @@ void NCCLParallelContext::InitWithRingID(int ring_id) {
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" local rank: "
<<
strategy_
.
local_rank_
<<
" gpu id: "
<<
gpu_id
<<
" ring id: "
<<
ring_id
;
<<
" ring id: "
<<
ring_id
;
// it will assign nccl_comm in CUDADeviceContext within ring_id
// it will assign nccl_comm in CUDADeviceContext within ring_id
platform
::
NCCLCommContext
::
Instance
().
Create
NCCL
Comm
(
platform
::
NCCLCommContext
::
Instance
().
CreateComm
(
&
nccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
&
nccl_ids
[
0
],
strategy_
.
nranks_
,
strategy_
.
local_rank_
,
gpu_id
,
ring_id
);
compute_events_
.
emplace_back
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
compute_events_
.
emplace_back
(
platform
::
CudaEventResourcePool
::
Instance
().
New
(
...
...
paddle/fluid/operators/collective/c_comm_init_op.cc
浏览文件 @
56b7ebbc
...
@@ -24,15 +24,16 @@ limitations under the License. */
...
@@ -24,15 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
Scope
;
class
Scope
;
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
...
@@ -46,56 +47,51 @@ class CCommInitOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
// TODO(wangxi): Put this in the unified header file
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
using
UniqueId
=
ncclUniqueId
;
using
Place
=
platform
::
CUDAPlace
;
using
CommContext
=
platform
::
NCCLCommContext
;
#elif defined(PADDLE_WITH_XPU_BKCL)
using
UniqueId
=
BKCLUniqueId
;
using
Place
=
platform
::
XPUPlace
;
using
CommContext
=
platform
::
BKCLCommContext
;
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU or XPU."
));
#endif
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
)
||
is_xpu_place
(
place
),
true
,
PADDLE_ENFORCE_EQ
(
is_gpu_place
(
place
)
||
is_xpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
"CCommInitOp can run on gpu or xpu place only."
));
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_XPU_BKCL)
auto
var
=
scope
.
FindVar
(
Input
(
"X"
));
auto
var
=
scope
.
FindVar
(
Input
(
"X"
));
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
var
,
platform
::
errors
::
InvalidArgument
(
"Input con not be empty."
));
if
(
is_gpu_place
(
place
))
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclUniqueId
*
nccl_id
=
var
->
GetMutable
<
ncclUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
UniqueId
*
comm_id
=
var
->
GetMutable
<
UniqueId
>
();
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
device_id
=
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
}
platform
::
NCCLCommContext
::
Instance
().
CreateNCCLComm
(
nccl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with GPU."
));
#endif
}
else
if
(
is_xpu_place
(
place
))
{
#if defined(PADDLE_WITH_XPU_BKCL)
BKCLUniqueId
*
bkcl_id
=
var
->
GetMutable
<
BKCLUniqueId
>
();
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
nranks
=
Attr
<
int
>
(
"nranks"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rank_id
=
Attr
<
int
>
(
"rank"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
int
rid
=
Attr
<
int
>
(
"ring_id"
);
#if defined(PADDLE_WITH_XPU_BKCL)
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
rid
,
0
,
rid
,
0
,
platform
::
errors
::
OutOfRange
(
platform
::
errors
::
OutOfRange
(
"Ring id must equal 0 in multi Kunlun cards training, but got %d"
,
"Ring id must equal 0 in multi Kunlun cards training, but got %d"
,
rid
));
rid
));
int
device_id
=
BOOST_GET_CONST
(
platform
::
XPUPlace
,
place
).
device
;
#endif
int
device_id
=
BOOST_GET_CONST
(
Place
,
place
).
device
;
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
if
(
Attr
<
int
>
(
"device_id"
)
>=
0
)
{
device_id
=
Attr
<
int
>
(
"device_id"
);
device_id
=
Attr
<
int
>
(
"device_id"
);
}
}
platform
::
BKCLCommContext
::
Instance
().
CreateBKCLComm
(
CommContext
::
Instance
().
CreateComm
(
comm_id
,
nranks
,
rank_id
,
device_id
,
bkcl_id
,
nranks
,
rank_id
,
device_id
,
rid
);
rid
);
#else
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"PaddlePaddle should be compiled with XPU."
));
#endif
#endif
}
else
{
PADDLE_THROW
(
platform
::
errors
::
PreconditionNotMet
(
"CCommInitOp can run on gpu or xpu place only."
));
}
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_bkcl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
...
@@ -62,7 +62,7 @@ class CGenBKCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
...
@@ -75,14 +75,13 @@ class CGenBKCLIdOp : public framework::OperatorBase {
GenBKCLID
(
&
bkcl_ids
);
GenBKCLID
(
&
bkcl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
bkcl_ids
,
ring_id
);
}
else
{
}
else
{
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
std
::
string
endpoint
=
Attr
<
std
::
string
>
(
"endpoint"
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
);
platform
::
RecvBroadCastCommID
(
endpoint
,
&
bkcl_ids
,
ring_id
);
}
}
CopyBKCLIDToVar
(
bkcl_ids
,
func
,
scope
);
CopyBKCLIDToVar
(
bkcl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -108,6 +107,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_hccl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
...
@@ -63,7 +63,7 @@ class CGenHCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
...
@@ -79,13 +79,12 @@ class CGenHCCLIdOp : public framework::OperatorBase {
GenHCCLID
(
&
hccl_ids
);
GenHCCLID
(
&
hccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
hccl_ids
,
ring_id
);
}
else
{
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
hccl_ids
,
ring_id
);
}
}
CopyHCCLIDToVar
(
hccl_ids
,
func
,
scope
);
CopyHCCLIDToVar
(
hccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -128,6 +127,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/operators/collective/c_gen_nccl_id_op.cc
浏览文件 @
56b7ebbc
...
@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -60,7 +60,7 @@ class CGenNCCLIdOp : public framework::OperatorBase {
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
dev_place
)
const
override
{
const
platform
::
Place
&
dev_place
)
const
override
{
int
rank
=
Attr
<
int
>
(
"rank"
);
int
rank
=
Attr
<
int
>
(
"rank"
);
framework
::
Scope
&
local_scope
=
scope
.
NewScope
(
);
int
ring_id
=
Attr
<
int
>
(
"ring_id"
);
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
std
::
function
<
std
::
string
(
size_t
)
>
func
=
[
&
](
size_t
i
)
->
std
::
string
{
return
Output
(
"Out"
);
return
Output
(
"Out"
);
...
@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
...
@@ -76,13 +76,12 @@ class CGenNCCLIdOp : public framework::OperatorBase {
GenNCCLID
(
&
nccl_ids
);
GenNCCLID
(
&
nccl_ids
);
std
::
vector
<
std
::
string
>
endpoint_list
=
std
::
vector
<
std
::
string
>
endpoint_list
=
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
Attr
<
std
::
vector
<
std
::
string
>>
(
"other_endpoints"
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
);
platform
::
SendBroadCastCommID
(
endpoint_list
,
&
nccl_ids
,
ring_id
);
}
else
{
}
else
{
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
);
platform
::
RecvBroadCastCommID
(
server_fd
,
endpoint
,
&
nccl_ids
,
ring_id
);
}
}
CopyNCCLIDToVar
(
nccl_ids
,
func
,
scope
);
CopyNCCLIDToVar
(
nccl_ids
,
func
,
scope
);
scope
.
DeleteScope
(
&
local_scope
);
}
}
};
};
...
@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
...
@@ -123,6 +122,8 @@ For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the ser
"(int default 0) "
"(int default 0) "
"The rank of the trainer in distributed training."
)
"The rank of the trainer in distributed training."
)
.
SetDefault
(
0
);
.
SetDefault
(
0
);
AddAttr
<
int
>
(
"ring_id"
,
"(int default 0) user specified ring id"
)
.
SetDefault
(
0
);
}
}
};
};
...
...
paddle/fluid/platform/collective_helper.cc
浏览文件 @
56b7ebbc
...
@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
...
@@ -72,7 +72,7 @@ class NCCLCommImpl : public NCCLComm {
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_event_
;
};
};
NCCLComm
*
NCCLCommContext
::
Create
NCCL
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
NCCLComm
*
NCCLCommContext
::
CreateComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
nccl_id
,
PADDLE_ENFORCE_NOT_NULL
(
nccl_id
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
...
@@ -225,7 +225,7 @@ class BKCLCommImpl : public BKCLComm {
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
std
::
unique_ptr
<
XPUDeviceContext
>
dev_ctx_
;
};
};
BKCLComm
*
BKCLCommContext
::
Create
BKCL
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
BKCLComm
*
BKCLCommContext
::
CreateComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
ring_id
)
{
int
rank
,
int
dev_id
,
int
ring_id
)
{
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
PADDLE_ENFORCE_NOT_NULL
(
bkcl_id
,
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/platform/collective_helper.h
浏览文件 @
56b7ebbc
...
@@ -72,8 +72,8 @@ class NCCLCommContext {
...
@@ -72,8 +72,8 @@ class NCCLCommContext {
return
comm_ctx
;
return
comm_ctx
;
}
}
NCCLComm
*
Create
NCCLComm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
NCCLComm
*
Create
Comm
(
ncclUniqueId
*
nccl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
dev_id
,
int
ring_id
=
0
);
int
ring_id
=
0
);
void
CreateAllNCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
void
CreateAllNCCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
@@ -274,8 +274,8 @@ class BKCLCommContext {
...
@@ -274,8 +274,8 @@ class BKCLCommContext {
return
comm_ctx
;
return
comm_ctx
;
}
}
BKCLComm
*
Create
BKCLComm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
BKCLComm
*
Create
Comm
(
BKCLUniqueId
*
bkcl_id
,
int
nranks
,
int
rank
,
int
dev_id
,
int
dev_id
,
int
ring_id
=
0
);
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
void
CreateAllBKCLComms
(
const
std
::
vector
<
int
>&
dev_ids
,
int
ring_id
=
0
);
...
...
paddle/fluid/platform/gen_comm_id_helper.cc
浏览文件 @
56b7ebbc
...
@@ -42,7 +42,10 @@ namespace platform {
...
@@ -42,7 +42,10 @@ namespace platform {
std
::
once_flag
SocketServer
::
init_flag_
;
std
::
once_flag
SocketServer
::
init_flag_
;
constexpr
char
COMM_HEAD
[]
=
"_pd_gen_comm_id_"
;
struct
CommHead
{
int
version
=
1
;
// unused for now
int
ring_id
=
0
;
};
// Check system calls, such as socket, bind.
// Check system calls, such as socket, bind.
#define CHECK_SYS_CALL(call, name) \
#define CHECK_SYS_CALL(call, name) \
...
@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
...
@@ -188,11 +191,15 @@ int CreateListenSocket(const std::string& ep) {
void
CloseSocket
(
int
fd
)
{
CHECK_SYS_CALL
(
close
(
fd
),
"close"
);
}
void
CloseSocket
(
int
fd
)
{
CHECK_SYS_CALL
(
close
(
fd
),
"close"
);
}
static
int
SocketAccept
(
int
server_fd
,
const
char
*
head
)
{
static
int
SocketAccept
(
int
server_fd
,
const
CommHead
head
)
{
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
struct
sockaddr_in
client_addr
;
struct
sockaddr_in
client_addr
;
socklen_t
addr_length
=
sizeof
(
client_addr
);
socklen_t
addr_length
=
sizeof
(
client_addr
);
char
buffer
[
1024
]
=
{
0
};
char
buffer
[
1024
]
=
{
0
};
int
conn
=
-
1
;
int
conn
=
-
1
;
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
while
(
true
)
{
while
(
true
)
{
CHECK_SYS_CALL_VAL
(
CHECK_SYS_CALL_VAL
(
...
@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
...
@@ -200,8 +207,10 @@ static int SocketAccept(int server_fd, const char* head) {
&
addr_length
),
&
addr_length
),
"accept"
,
conn
);
"accept"
,
conn
);
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
strlen
(
head
));
int
ret_val
=
SocketRecv
(
conn
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
strncmp
(
buffer
,
head
,
strlen
(
head
))
==
0
)
{
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// send a message to the sender, indicating that the link is correct
CHECK_SYS_CALL
(
SocketSend
(
conn
,
phead
,
sizeof
(
head
)),
"send"
);
break
;
// accept client
break
;
// accept client
}
else
{
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
...
@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
...
@@ -211,7 +220,7 @@ static int SocketAccept(int server_fd, const char* head) {
return
conn
;
return
conn
;
}
}
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
char
*
head
)
{
static
int
ConnectAddr
(
const
std
::
string
&
ep
,
const
CommHead
head
)
{
auto
addr
=
paddle
::
string
::
Split
(
ep
,
':'
);
auto
addr
=
paddle
::
string
::
Split
(
ep
,
':'
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
addr
.
size
(),
2UL
,
addr
.
size
(),
2UL
,
...
@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -220,9 +229,6 @@ static int ConnectAddr(const std::string& ep, const char* head) {
std
::
string
host
=
addr
[
0
];
std
::
string
host
=
addr
[
0
];
int
port
=
std
::
stoi
(
addr
[
1
]);
int
port
=
std
::
stoi
(
addr
[
1
]);
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
struct
sockaddr_in
server_addr
;
struct
sockaddr_in
server_addr
;
memset
(
&
server_addr
,
0
,
sizeof
(
server_addr
));
memset
(
&
server_addr
,
0
,
sizeof
(
server_addr
));
server_addr
.
sin_family
=
AF_INET
;
server_addr
.
sin_family
=
AF_INET
;
...
@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -245,10 +251,18 @@ static int ConnectAddr(const std::string& ep, const char* head) {
platform
::
errors
::
Unavailable
(
"Open address %s failed: %s"
,
platform
::
errors
::
Unavailable
(
"Open address %s failed: %s"
,
ep
,
strerror
(
errno
)));
ep
,
strerror
(
errno
)));
static_assert
(
sizeof
(
CommHead
)
<=
1024
,
"sizeof(CommHead) must <= buffer size"
);
char
buffer
[
1024
]
=
{
0
};
const
char
*
phead
=
reinterpret_cast
<
const
char
*>
(
&
head
);
// TODO(wangxi) Set from env, default 900s=15min
// TODO(wangxi) Set from env, default 900s=15min
int
timeout
=
900
*
1000
;
int
timeout
=
900
*
1000
;
int
try_times
=
0
;
int
try_times
=
0
;
int
total_time
=
0
;
int
total_time
=
0
;
int
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
while
(
true
)
{
while
(
true
)
{
int
ret_val
=
-
1
;
int
ret_val
=
-
1
;
RETRY_SYS_CALL_VAL
(
RETRY_SYS_CALL_VAL
(
...
@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
...
@@ -260,8 +274,19 @@ static int ConnectAddr(const std::string& ep, const char* head) {
continue
;
continue
;
}
}
CHECK_SYS_CALL
(
SocketSend
(
sock
,
head
,
strlen
(
head
)),
"send"
);
CHECK_SYS_CALL
(
SocketSend
(
sock
,
phead
,
sizeof
(
head
)),
"send"
);
break
;
ret_val
=
SocketRecv
(
sock
,
buffer
,
sizeof
(
head
));
if
(
ret_val
>
0
&&
memcmp
(
buffer
,
phead
,
sizeof
(
head
))
==
0
)
{
// recv same message from recver, indicating that the link is correct
break
;
// accept client
}
else
{
VLOG
(
3
)
<<
"socket read failed with ret_val="
<<
ret_val
;
CloseSocket
(
sock
);
}
sock
=
-
1
;
CHECK_SYS_CALL_VAL
(
socket
(
AF_INET
,
SOCK_STREAM
,
0
),
"socket"
,
sock
);
// unmatched link, retry after 80ms
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
80
));
}
}
return
sock
;
return
sock
;
}
}
...
@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
...
@@ -295,12 +320,15 @@ static void SendCommID(int conn, CommUniqueId* nccl_id) {
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
CommHead
head
;
head
.
ring_id
=
ring_id
;
// connect with server
// connect with server
std
::
vector
<
int
>
connects
;
std
::
vector
<
int
>
connects
;
for
(
auto
server
:
servers
)
{
for
(
auto
server
:
servers
)
{
VLOG
(
3
)
<<
"connecting endpoint: "
<<
server
;
VLOG
(
3
)
<<
"connecting endpoint: "
<<
server
;
int
conn
=
ConnectAddr
(
server
,
COMM_HEAD
);
int
conn
=
ConnectAddr
(
server
,
head
);
connects
.
push_back
(
conn
);
connects
.
push_back
(
conn
);
}
}
VLOG
(
3
)
<<
"connecting completed..."
;
VLOG
(
3
)
<<
"connecting completed..."
;
...
@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
...
@@ -322,16 +350,18 @@ void SendBroadCastCommID(std::vector<std::string> servers,
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
int
server
=
CreateListenSocket
(
endpoint
);
int
server
=
CreateListenSocket
(
endpoint
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
);
RecvBroadCastCommID
(
server
,
endpoint
,
nccl_ids
,
ring_id
);
CloseSocket
(
server
);
CloseSocket
(
server
);
}
}
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
)
{
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
)
{
int
client
=
SocketAccept
(
server_fd
,
COMM_HEAD
);
CommHead
head
;
head
.
ring_id
=
ring_id
;
int
client
=
SocketAccept
(
server_fd
,
head
);
for
(
size_t
i
=
0
;
i
<
nccl_ids
->
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nccl_ids
->
size
();
++
i
)
{
VLOG
(
3
)
<<
"trainer: "
<<
endpoint
VLOG
(
3
)
<<
"trainer: "
<<
endpoint
...
@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
...
@@ -362,9 +392,13 @@ SocketServer& SocketServer::GetInstance(const std::string& end_point) {
/// template instantiation
/// template instantiation
#define INSTANT_TEMPLATE(Type) \
#define INSTANT_TEMPLATE(Type) \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
template void SendBroadCastCommID<Type>(std::vector<std::string> servers, \
std::vector<Type> * nccl_ids); \
std::vector<Type> * nccl_ids, \
template void RecvBroadCastCommID<Type>(std::string endpoint, \
int ring_id = 0); \
std::vector<Type> * nccl_ids);
template void RecvBroadCastCommID<Type>( \
std::string endpoint, std::vector<Type> * nccl_ids, int ring_id = 0); \
template void RecvBroadCastCommID<Type>(int server_fd, std::string endpoint, \
std::vector<Type>* nccl_ids, \
int ring_id = 0);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
INSTANT_TEMPLATE
(
ncclUniqueId
)
INSTANT_TEMPLATE
(
ncclUniqueId
)
...
...
paddle/fluid/platform/gen_comm_id_helper.h
浏览文件 @
56b7ebbc
...
@@ -31,16 +31,16 @@ void CloseSocket(int fd);
...
@@ -31,16 +31,16 @@ void CloseSocket(int fd);
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
void
SendBroadCastCommID
(
std
::
vector
<
std
::
string
>
servers
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
// recv nccl id from socket
// recv nccl id from socket
template
<
typename
CommUniqueId
>
template
<
typename
CommUniqueId
>
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
void
RecvBroadCastCommID
(
int
server_fd
,
std
::
string
endpoint
,
std
::
vector
<
CommUniqueId
>*
nccl_ids
);
std
::
vector
<
CommUniqueId
>*
nccl_ids
,
int
ring_id
=
0
);
class
SocketServer
{
class
SocketServer
{
public:
public:
...
...
python/paddle/distributed/fleet/meta_optimizers/common.py
浏览文件 @
56b7ebbc
...
@@ -126,11 +126,11 @@ class CollectiveHelper(object):
...
@@ -126,11 +126,11 @@ class CollectiveHelper(object):
_add_sync_by_allreduce
(
block
)
_add_sync_by_allreduce
(
block
)
return
return
if
core
.
is_compiled_with_cuda
():
comm_id_var
=
block
.
create_var
(
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'nccl
_id'
),
name
=
unique_name
.
generate
(
'comm
_id'
),
persistable
=
True
,
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
type
=
core
.
VarDesc
.
VarType
.
RAW
)
if
core
.
is_compiled_with_cuda
():
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_nccl_id'
,
type
=
'c_gen_nccl_id'
,
inputs
=
{},
inputs
=
{},
...
@@ -139,6 +139,7 @@ class CollectiveHelper(object):
...
@@ -139,6 +139,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
...
@@ -152,10 +153,6 @@ class CollectiveHelper(object):
...
@@ -152,10 +153,6 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
elif
core
.
is_compiled_with_xpu
():
elif
core
.
is_compiled_with_xpu
():
comm_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'bkcl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_bkcl_id'
,
type
=
'c_gen_bkcl_id'
,
inputs
=
{},
inputs
=
{},
...
@@ -164,6 +161,7 @@ class CollectiveHelper(object):
...
@@ -164,6 +161,7 @@ class CollectiveHelper(object):
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
...
@@ -177,24 +175,20 @@ class CollectiveHelper(object):
...
@@ -177,24 +175,20 @@ class CollectiveHelper(object):
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
elif
core
.
is_compiled_with_npu
():
elif
core
.
is_compiled_with_npu
():
hccl_id_var
=
block
.
create_var
(
name
=
unique_name
.
generate
(
'hccl_id'
),
persistable
=
True
,
type
=
core
.
VarDesc
.
VarType
.
RAW
)
endpoint_to_index_map
=
{
e
:
idx
for
idx
,
e
in
enumerate
(
endpoints
)}
block
.
append_op
(
block
.
append_op
(
type
=
'c_gen_hccl_id'
,
type
=
'c_gen_hccl_id'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
'Out'
:
hccl
_id_var
},
outputs
=
{
'Out'
:
comm
_id_var
},
attrs
=
{
attrs
=
{
'rank'
:
rank
,
'rank'
:
rank
,
'endpoint'
:
current_endpoint
,
'endpoint'
:
current_endpoint
,
'other_endpoints'
:
other_endpoints
,
'other_endpoints'
:
other_endpoints
,
'ring_id'
:
ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
OP_ROLE_KEY
:
OpRole
.
Forward
})
})
block
.
append_op
(
block
.
append_op
(
type
=
'c_comm_init_hccl'
,
type
=
'c_comm_init_hccl'
,
inputs
=
{
'X'
:
hccl
_id_var
},
inputs
=
{
'X'
:
comm
_id_var
},
outputs
=
{},
outputs
=
{},
attrs
=
{
attrs
=
{
'rank'
:
rank
,
'rank'
:
rank
,
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
浏览文件 @
56b7ebbc
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
...
@@ -73,7 +73,7 @@ class FP16Utils(object):
return
inserted_op_num
return
inserted_op_num
@
staticmethod
@
staticmethod
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
):
def
prune_fp16
(
block
,
shard
,
reduced_grads_to_param
,
ring_id
s
):
"""
"""
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding
2. revise amp inifine grad checking for sharding
...
@@ -146,6 +146,7 @@ class FP16Utils(object):
...
@@ -146,6 +146,7 @@ class FP16Utils(object):
name
=
inf_var_name
+
"@sharding"
,
name
=
inf_var_name
+
"@sharding"
,
shape
=
inf_var
.
shape
,
shape
=
inf_var
.
shape
,
dtype
=
inf_var
.
dtype
)
dtype
=
inf_var
.
dtype
)
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
...
@@ -156,9 +157,14 @@ class FP16Utils(object):
...
@@ -156,9 +157,14 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce communication should not overlap with calc
# this allreduce communication should not overlap with calc
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
@@ -167,8 +173,10 @@ class FP16Utils(object):
...
@@ -167,8 +173,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_sharding
},
outputs
=
{
'Out'
:
inf_var_sharding
},
...
@@ -177,11 +185,12 @@ class FP16Utils(object):
...
@@ -177,11 +185,12 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_sharding
.
dtype
,
"out_dtype"
:
inf_var_sharding
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@
staticmethod
@
staticmethod
def
sync_amp_check_nan_inf
(
block
,
ring_id
):
def
sync_amp_check_nan_inf
(
block
,
ring_id
s
):
update_loss_scaling_op_idx
=
-
1
update_loss_scaling_op_idx
=
-
1
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
...
@@ -189,10 +198,14 @@ class FP16Utils(object):
...
@@ -189,10 +198,14 @@ class FP16Utils(object):
update_loss_scaling_op_idx
=
idx
update_loss_scaling_op_idx
=
idx
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
inf_var_name
=
op
.
desc
.
input
(
'FoundInfinite'
)[
0
]
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
op
.
_rename_input
(
inf_var_name
,
inf_var_name
+
"@GLOBAL_WORLD"
)
break
# not use amp
# not use amp
if
update_loss_scaling_op_idx
==
-
1
:
if
update_loss_scaling_op_idx
==
-
1
:
return
return
# 0. inf_var_int32 = cast(inf_var)
# 1. inf_var_int32 = allreduce_max(inf_var_int32)
# 3. inf_var = cast(inf_var_int32)
inf_var
=
block
.
var
(
inf_var_name
)
inf_var
=
block
.
var
(
inf_var_name
)
inf_var_int32
=
block
.
create_var
(
inf_var_int32
=
block
.
create_var
(
name
=
inf_var_name
+
"@cast_int32"
,
name
=
inf_var_name
+
"@cast_int32"
,
...
@@ -212,8 +225,13 @@ class FP16Utils(object):
...
@@ -212,8 +225,13 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_int32
.
dtype
,
"out_dtype"
:
inf_var_int32
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
# allreduce(mp)->allreduce(pp)
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
1
,
update_loss_scaling_op_idx
,
type
=
'c_allreduce_max'
,
type
=
'c_allreduce_max'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_int32
},
...
@@ -222,8 +240,10 @@ class FP16Utils(object):
...
@@ -222,8 +240,10 @@ class FP16Utils(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
update_loss_scaling_op_idx
+
2
,
update_loss_scaling_op_idx
,
type
=
'cast'
,
type
=
'cast'
,
inputs
=
{
'X'
:
inf_var_int32
},
inputs
=
{
'X'
:
inf_var_int32
},
outputs
=
{
'Out'
:
inf_var_global
},
outputs
=
{
'Out'
:
inf_var_global
},
...
@@ -232,4 +252,5 @@ class FP16Utils(object):
...
@@ -232,4 +252,5 @@ class FP16Utils(object):
"out_dtype"
:
inf_var_global
.
dtype
,
"out_dtype"
:
inf_var_global
.
dtype
,
OP_ROLE_KEY
:
OpRole
.
Optimize
OP_ROLE_KEY
:
OpRole
.
Optimize
})
})
update_loss_scaling_op_idx
+=
1
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
浏览文件 @
56b7ebbc
...
@@ -25,7 +25,7 @@ class GradientClipHelper(object):
...
@@ -25,7 +25,7 @@ class GradientClipHelper(object):
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
return
op
.
desc
.
has_attr
(
"op_namescope"
)
\
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
and
op
.
desc
.
attr
(
"op_namescope"
).
startswith
(
"/gradient_clip"
)
def
prune_gradient_clip
(
self
,
block
,
shard
,
pure_dp_degree
=
1
):
def
prune_gradient_clip
(
self
,
block
,
shard
,
ring_ids
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
...
@@ -82,33 +82,23 @@ class GradientClipHelper(object):
...
@@ -82,33 +82,23 @@ class GradientClipHelper(object):
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
assert
(
len
(
op
.
desc
.
output_arg_names
())
==
1
)
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
# allreduce(mp)->allreduce(sharding)->allreduce(pp)
idx_offset
=
1
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
# this allreduce should not overlap with calc and should be scheduled in calc stream
# this allreduce should not overlap with calc and should be scheduled in calc stream
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
+
idx_offset
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
attrs
=
{
'ring_id'
:
self
.
mp_
ring_id
,
'ring_id'
:
ring_id
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
idx_offset
+=
1
# global norm should only be sum within each model parallelism word size when use global group
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
# the grad sum here should take the all and only param in the current shard
# the grad sum here should take the all and only param in the current shard
to_check_param
=
set
(
reversed_x_paramname
)
to_check_param
=
set
(
reversed_x_paramname
)
...
@@ -126,20 +116,25 @@ class GradientClipHelper(object):
...
@@ -126,20 +116,25 @@ class GradientClipHelper(object):
return
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def
sync_global_norm
(
self
,
block
,
ring_id
,
pure_dp_degree
=
1
):
def
sync_global_norm
(
self
,
block
,
ring_id
s
):
"""
"""
prune gradient_clip related ops for params that not belong to cur shard
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
keep: sum, sqrt, elementwise_max, elementwise_div
"""
"""
# FIXME(wangxi): mp should prune duplicated param_grads
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
self
.
_is_gradient_clip_op
(
op
):
if
not
self
.
_is_gradient_clip_op
(
op
):
continue
continue
if
op
.
type
==
"sum"
:
if
op
.
type
==
"sum"
:
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
sum_res
=
op
.
desc
.
output_arg_names
()[
0
]
for
ring_id
in
ring_ids
:
if
ring_id
==
-
1
:
continue
idx
=
idx
+
1
block
.
_insert_op_without_sync
(
block
.
_insert_op_without_sync
(
idx
+
1
,
idx
,
type
=
'c_allreduce_sum'
,
type
=
'c_allreduce_sum'
,
inputs
=
{
'X'
:
sum_res
},
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
...
@@ -149,20 +144,4 @@ class GradientClipHelper(object):
...
@@ -149,20 +144,4 @@ class GradientClipHelper(object):
'use_calc_stream'
:
True
,
'use_calc_stream'
:
True
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
})
# global norm should only be sum within each model parallelism word size
if
pure_dp_degree
>
1
:
block
.
_insert_op_without_sync
(
idx
+
2
,
type
=
'scale'
,
inputs
=
{
'X'
:
sum_res
},
outputs
=
{
'Out'
:
sum_res
},
attrs
=
{
'scale'
:
1.0
/
float
(
pure_dp_degree
),
'op_namescope'
:
"/gradient_clip_model_parallelism"
,
'bias'
:
0.0
,
'bias_after_scale'
:
False
,
OP_ROLE_KEY
:
OpRole
.
Optimize
})
return
return
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
56b7ebbc
...
@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -328,13 +328,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# if not use sharding, adapt amp/clip, for remain parallelism.
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
# cast --> amp --> clip --> opt
if
self
.
sharding_degree
<=
1
:
if
self
.
sharding_degree
<=
1
:
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var
# amp
# amp
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
self
.
global_ring_id
)
FP16Utils
.
sync_amp_check_nan_inf
(
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
])
# clip
# clip
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
=
GradientClipHelper
(
None
)
gradientclip_helper
.
sync_global_norm
(
gradientclip_helper
.
sync_global_norm
(
main_block
,
self
.
global_ring_id
,
self
.
dp_degree
)
main_block
,
[
self
.
mp_ring_id
,
self
.
pp_ring_id
]
)
# step6: loss div dp_degree
# step6: loss div dp_degree
global_dp_degree
=
self
.
sharding_degree
*
self
.
dp_degree
global_dp_degree
=
self
.
sharding_degree
*
self
.
dp_degree
...
@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -392,7 +396,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_rank
,
pp_rank
,
ring_id
,
ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
def
_init_npu_pipeline_comm
(
self
,
startup_block
):
...
@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -426,8 +429,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
pair
=
send_to_next_pair
if
even
else
recv_from_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
logger
.
info
(
"pair0(even->odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
ring_id
))
...
@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -436,8 +437,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
pair
=
recv_from_next_pair
if
even
else
send_to_prev_pair
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
ring_id
=
self
.
pp_ring_map
[
pair
[
0
]
*
1000
+
pair
[
1
]]
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
logger
.
info
(
"pair1(even<-odd): pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
ring_id
))
...
@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -450,8 +449,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
max_ring_id
+
1
)
# 3->0 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
logger
.
info
(
"pair2(odd->even): pp pair:{}, ring_id: {}"
.
format
(
...
@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -463,8 +460,6 @@ class ShardingOptimizer(MetaOptimizerBase):
pair
[
0
]
*
1000
+
pair
[
1
],
pair
[
0
]
*
1000
+
pair
[
1
],
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
max_ring_id
+
2
)
# 0->3 not in pp_ring_map
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
if
self
.
pp_rank
!=
0
and
self
.
pp_rank
!=
self
.
pp_degree
-
1
:
my_pair
.
remove
(
pair
)
my_pair
.
remove
(
pair
)
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
logger
.
info
(
"pair3(odd<-even): pp pair:{}, ring_id: {}"
.
format
(
...
@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -478,6 +473,15 @@ class ShardingOptimizer(MetaOptimizerBase):
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
assert
self
.
pp_rank_
==
self
.
pp_rank
,
"pp rank for pp opt [{}], pp rank for sharding opt [{}]"
.
format
(
self
.
pp_rank_
,
self
.
pp_rank
)
self
.
pp_rank_
,
self
.
pp_rank
)
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
pp_group_endpoints
,
self
.
pp_rank
,
self
.
pp_ring_id
,
False
,
sync
=
False
)
if
core
.
is_compiled_with_npu
():
if
core
.
is_compiled_with_npu
():
self
.
_init_npu_pipeline_comm
(
startup_block
)
self
.
_init_npu_pipeline_comm
(
startup_block
)
return
return
...
@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -489,8 +493,6 @@ class ShardingOptimizer(MetaOptimizerBase):
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
logger
.
info
(
"pp pair:{}, ring_id: {}"
.
format
(
pair
,
ring_id
))
if
self
.
pp_rank
in
pair
:
if
self
.
pp_rank
in
pair
:
self
.
_init_pair_comm
(
pair
,
ring_id
)
self
.
_init_pair_comm
(
pair
,
ring_id
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
def
_init_comm
(
self
):
def
_init_comm
(
self
):
...
@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -505,19 +507,6 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
dtype
=
core
.
VarDesc
.
VarType
.
INT32
,
persistable
=
False
)
persistable
=
False
)
# global ring
self
.
_collective_helper
.
_init_communicator
(
self
.
_startup_program
,
self
.
current_endpoint
,
self
.
global_endpoints
,
self
.
global_rank
,
self
.
global_ring_id
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# mp ring
# mp ring
if
self
.
mp_degree
>
1
:
if
self
.
mp_degree
>
1
:
self
.
_collective_helper
.
_init_communicator
(
self
.
_collective_helper
.
_init_communicator
(
...
@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -527,10 +516,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
mp_rank
,
self
.
mp_rank
,
self
.
mp_ring_id
,
self
.
mp_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding ring
# sharding ring
if
self
.
sharding_degree
>
1
:
if
self
.
sharding_degree
>
1
:
...
@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -541,10 +527,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
sharding_rank
,
self
.
sharding_rank
,
self
.
sharding_ring_id
,
self
.
sharding_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# pp ring
# pp ring
if
self
.
pp_degree
>
1
:
if
self
.
pp_degree
>
1
:
...
@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -559,10 +542,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
dp_rank
,
self
.
dp_rank
,
self
.
dp_ring_id
,
self
.
dp_ring_id
,
False
,
False
,
global_ring_id
=
self
.
global_ring_id
,
sync
=
False
)
sync
=
False
)
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
startup_block
.
_sync_with_cpp
()
startup_block
.
_sync_with_cpp
()
...
@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -736,21 +716,20 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
"""
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
=
WeightDecayHelper
()
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
weightdecay_helper
.
prune_weight_decay
(
block
,
self
.
_shard
)
# FIXME(wangxi): mp should prune duplicated param_grads
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# group. and each Data Parallelism group should have its own sync of FoundInfinite
# amp could use global group for sync
# amp could use global group for sync
FP16Utils
.
prune_fp16
(
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
FP16Utils
.
prune_fp16
(
self
.
global_ring_id
)
block
,
self
.
_shard
,
self
.
_reduced_grads_to_param
,
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
# clipbyglobalnorm should only use the Model paramllelism group (mp-sharding-pp)
if
self
.
mp_degree
*
self
.
pp_degree
==
1
:
gradientclip_helper
=
GradientClipHelper
(
None
)
# separate the sharding-hybrid senario to keep the accuracy
gradientclip_helper
=
GradientClipHelper
(
self
.
sharding_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
1
)
block
,
self
.
_shard
,
else
:
[
self
.
mp_ring_id
,
self
.
sharding_ring_id
,
self
.
pp_ring_id
])
gradientclip_helper
=
GradientClipHelper
(
self
.
global_ring_id
)
gradientclip_helper
.
prune_gradient_clip
(
block
,
self
.
_shard
,
pure_dp_degree
=
self
.
dp_degree
)
# build prog deps
# build prog deps
reduced_grads
=
[]
reduced_grads
=
[]
...
@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1143,7 +1122,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# pp
# pp
if
self
.
pp_degree
>
1
:
if
self
.
pp_degree
>
1
:
self
.
pp_ring_id
=
20
self
.
pp_pair_ring_id
=
20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self
.
pp_ring_id
=
4
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_degree
*
self
.
pp_rank
=
self
.
global_rank
//
(
self
.
sharding_degree
*
self
.
mp_degree
)
%
self
.
pp_degree
self
.
mp_degree
)
%
self
.
pp_degree
# (NOTE): Already adjust for (outter-pure) dp
# (NOTE): Already adjust for (outter-pure) dp
...
@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1159,8 +1140,9 @@ class ShardingOptimizer(MetaOptimizerBase):
pp_first_stage_idx
+
pp_stage_offset
*
i
])
pp_first_stage_idx
+
pp_stage_offset
*
i
])
assert
self
.
current_endpoint
in
self
.
pp_group_endpoints
assert
self
.
current_endpoint
in
self
.
pp_group_endpoints
else
:
else
:
self
.
pp_degree
=
1
self
.
pp_ring_id
=
-
1
self
.
pp_ring_id
=
-
1
self
.
pp_degree
=
1
self
.
pp_pair_ring_id
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_rank
=
-
1
self
.
pp_group_id
=
-
1
self
.
pp_group_id
=
-
1
self
.
pp_group_endpoints
=
[]
self
.
pp_group_endpoints
=
[]
...
@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -1256,9 +1238,6 @@ class ShardingOptimizer(MetaOptimizerBase):
outputs
=
{
'Out'
:
params
},
outputs
=
{
'Out'
:
params
},
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
attrs
=
{
'ring_id'
:
self
.
dp_ring_id
,
OP_ROLE_KEY
:
OpRole
.
Forward
})
OP_ROLE_KEY
:
OpRole
.
Forward
})
# sync within global group
append_naive_sync
(
startup_block
,
self
.
startup_prog_sync_var
,
self
.
global_ring_id
)
# sharding gradient merge
# sharding gradient merge
def
create_persistable_gradients_and_insert_merge_ops
(
def
create_persistable_gradients_and_insert_merge_ops
(
...
...
python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py
浏览文件 @
56b7ebbc
...
@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -34,7 +34,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
set_strategy
(
strategy
,
'sharding'
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
)
parameters
=
[
parameters
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
==
True
x
.
name
for
x
in
train_prog
.
list_vars
()
if
x
.
persistable
is
True
]
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
ops
=
[
op
.
type
for
op
in
avg_cost
.
block
.
ops
]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
vars
=
[
x
.
name
for
x
in
train_prog
.
list_vars
()]
...
@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
...
@@ -292,7 +292,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
])
])
class
TestFleet
MetaOptimizer_V1
(
TestFleetMetaOptimizer
):
class
TestFleet
ShardingHybridOptimizer
(
TestFleetMetaOptimizer
):
def
setUp
(
self
):
def
setUp
(
self
):
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"3"
os
.
environ
[
"PADDLE_TRAINER_ID"
]
=
"3"
os
.
environ
[
os
.
environ
[
...
@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -303,7 +303,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
self
.
sharding_ring_id
=
1
self
.
sharding_ring_id
=
1
self
.
dp_ring_id
=
2
self
.
dp_ring_id
=
2
self
.
global_ring_id
=
3
self
.
global_ring_id
=
3
self
.
pp_ring_id
=
20
self
.
pp_
pair_
ring_id
=
20
def
test_sharding_with_mp
(
self
):
def
test_sharding_with_mp
(
self
):
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
# NOTE(JZ-LIANG) MP parallelism need user to build model with MP API
...
@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -336,7 +336,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -345,7 +345,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -381,7 +381,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -390,7 +390,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -450,7 +450,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -459,7 +459,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -530,12 +530,8 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'fill_constant'
,
'c_allreduce_sum'
,
'c_sync_calc_stream'
])
])
self
.
assertEqual
(
main_prog_op_types
,
[
self
.
assertEqual
(
main_prog_op_types
,
[
...
@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -566,13 +562,13 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
"c_comm_init"
if
op
.
type
==
"c_comm_init"
]
]
self
.
assertIn
(
self
.
sharding_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
sharding_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_
pair_
ring_id
,
created_ring_ids
)
# check correctness of pp group
# check correctness of pp group
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_1
"
:
0
]
==
"
comm_id_0
"
:
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
sharding_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
self
.
assertEqual
(
sharding_group_waiting_ports
,
[
'127.0.0.1:36003'
])
...
@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -581,7 +577,7 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
sharding_group_waiting_port
=
None
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"
nccl_id_2
"
:
0
]
==
"
comm_id_1
"
:
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
dp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
self
.
assertEqual
(
dp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
...
@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
...
@@ -616,6 +612,86 @@ class TestFleetMetaOptimizer_V1(TestFleetMetaOptimizer):
if
op
.
type
==
'c_allreduce_sum'
:
if
op
.
type
==
'c_allreduce_sum'
:
assert
'FusedOutput'
in
op
.
input_arg_names
[
0
]
assert
'FusedOutput'
in
op
.
input_arg_names
[
0
]
def
test_hybrid_with_mp_pp_amp_gclip
(
self
):
train_prog
,
startup_prog
=
paddle
.
fluid
.
Program
(),
paddle
.
fluid
.
Program
(
)
avg_cost
,
strategy
=
self
.
pp_net
(
train_prog
,
startup_prog
)
self
.
set_strategy
(
strategy
,
'amp'
)
strategy
.
sharding
=
True
strategy
.
sharding_configs
=
{
"sharding_degree"
:
1
,
"mp_degree"
:
2
,
"pp_degree"
:
2
,
"dp_degree"
:
1
,
}
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
"schedule_mode"
:
"1F1B"
,
"micro_batch_size"
:
2
,
"accumulate_steps"
:
4
,
}
clip
=
paddle
.
fluid
.
clip
.
GradientClipByGlobalNorm
(
clip_norm
=
1.0
)
self
.
optimizer
(
avg_cost
,
strategy
,
train_prog
,
startup_prog
,
grad_clip
=
clip
)
train_prog
=
train_prog
.
_pipeline_opt
[
'section_program'
]
startup_prog
=
startup_prog
.
_pipeline_opt
[
'startup_program'
]
startup_prog_ops
=
startup_prog
.
global_block
().
ops
main_prog_ops
=
train_prog
.
global_block
().
ops
# check program
startup_prog_op_types
=
[
op
.
type
for
op
in
startup_prog_ops
]
main_prog_op_types
=
[
op
.
type
for
op
in
main_prog_ops
]
# ring: mp, pp_group, pp_pair, pp_pair
self
.
assertEqual
(
startup_prog_op_types
,
[
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'uniform_random'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'fill_constant'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
,
'c_gen_nccl_id'
,
'c_comm_init'
])
# pp + mp, partial send recv
self
.
assertIn
(
'partial_recv'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_allgather'
,
main_prog_op_types
)
self
.
assertIn
(
'partial_send'
,
main_prog_op_types
)
# amp check_finite_and_unscale, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_max'
),
2
)
# global gradient clip, allreduce(mp)->allreduce(pp)
self
.
assertEqual
(
main_prog_op_types
.
count
(
'c_allreduce_sum'
),
2
)
# should has ring id for pp
created_ring_ids
=
[
op
.
desc
.
attr
(
"ring_id"
)
for
op
in
startup_prog_ops
if
op
.
type
==
"c_comm_init"
]
self
.
assertIn
(
self
.
mp_ring_id
,
created_ring_ids
)
self
.
assertIn
(
self
.
pp_pair_ring_id
,
created_ring_ids
)
# check correctness of pp group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_0"
:
mp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
mp_group_waiting_ports
,
[
'127.0.0.1:36003'
])
# check correctness of sharding group
sharding_group_waiting_port
=
None
for
op
in
startup_prog_ops
:
if
op
.
type
==
"c_gen_nccl_id"
and
op
.
desc
.
output_arg_names
()[
0
]
==
"comm_id_1"
:
pp_group_waiting_ports
=
op
.
desc
.
attr
(
"other_endpoints"
)
self
.
assertEqual
(
pp_group_waiting_ports
,
[
'127.0.0.1:36002'
])
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录