Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8ffcc7c8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8ffcc7c8
编写于
10月 14, 2021
作者:
S
ShenLiang
提交者:
GitHub
10月 14, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[HybridParallel]Rebuild code for pipeline (#36396)
* add no_sync for parameters sync * add pipeline for moe
上级
693b1aa1
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
40 addition
and
25 deletion
+40
-25
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
...ddle/distributed/fleet/meta_parallel/pipeline_parallel.py
+32
-23
python/paddle/fluid/dygraph/parallel.py
python/paddle/fluid/dygraph/parallel.py
+8
-2
未找到文件。
python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
浏览文件 @
8ffcc7c8
...
@@ -77,26 +77,15 @@ class PipelineParallel(MetaParallelBase):
...
@@ -77,26 +77,15 @@ class PipelineParallel(MetaParallelBase):
logger
.
info
(
"start broadcast dp parameters"
)
logger
.
info
(
"start broadcast dp parameters"
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
broadcast_dp_parameters
(
self
.
_layers
,
self
.
_hcg
)
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
def
forward_backward_pipeline
(
self
,
data
,
scaler
=
None
):
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
(
# use the 1f1b scheduling strategy.
'optimizer should be HybridParallelOptimizer subclass.'
)
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
assert
fluid
.
framework
.
_dygraph_tracer
().
_has_grad
,
(
'Please enable the generation of gradients.'
)
if
self
.
is_first_stage
or
self
.
is_last_stage
:
assert
data
is
not
None
,
(
"For the first and the last stage, the data must be set."
)
else
:
data
=
None
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
scaler
=
scaler
self
.
scaler
=
scaler
self
.
data
=
data
self
.
_compute_loss
=
True
self
.
_layers
.
train
()
# store data for train
self
.
data
=
data
# store total loss of entire batch
# store total loss of entire batch
self
.
total_loss
=
None
self
.
total_loss
=
None
...
@@ -104,10 +93,6 @@ class PipelineParallel(MetaParallelBase):
...
@@ -104,10 +93,6 @@ class PipelineParallel(MetaParallelBase):
# store data id for micro_batch
# store data id for micro_batch
self
.
micro_batch_id
=
0
self
.
micro_batch_id
=
0
# Next, use the 1f1b scheduling strategy.
# this strategy is inspired by:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
startup_steps
=
(
self
.
num_stages
-
self
.
stage_id
-
1
)
startup_steps
=
(
self
.
num_stages
-
self
.
stage_id
-
1
)
startup_steps
=
min
(
startup_steps
,
self
.
accumulate_steps
)
startup_steps
=
min
(
startup_steps
,
self
.
accumulate_steps
)
steady_steps
=
self
.
accumulate_steps
-
startup_steps
steady_steps
=
self
.
accumulate_steps
-
startup_steps
...
@@ -161,11 +146,35 @@ class PipelineParallel(MetaParallelBase):
...
@@ -161,11 +146,35 @@ class PipelineParallel(MetaParallelBase):
self
.
_layers
.
allreduce_shared_weight_gradients
()
self
.
_layers
.
allreduce_shared_weight_gradients
()
self
.
train_loss
=
self
.
_broadcast_final_loss
()
train_loss
=
self
.
_broadcast_final_loss
()
return
train_loss
def
train_batch
(
self
,
data
,
optimizer
,
lr_scheduler
=
None
,
scaler
=
None
):
assert
isinstance
(
optimizer
,
HybridParallelOptimizer
),
(
'optimizer should be HybridParallelOptimizer subclass.'
)
assert
fluid
.
framework
.
_dygraph_tracer
().
_has_grad
,
(
'Please enable the generation of gradients.'
)
if
self
.
is_first_stage
or
self
.
is_last_stage
:
assert
data
is
not
None
,
(
"For the first and the last stage, the data must be set."
)
else
:
data
=
None
self
.
optimizer
=
optimizer
self
.
lr_scheduler
=
lr_scheduler
self
.
_layers
.
train
()
# 1f1b for pipeline
train_loss
=
self
.
forward_backward_pipeline
(
data
,
scaler
)
# optimizer
# optimizer
self
.
_optimizer_step
()
self
.
_optimizer_step
()
return
self
.
train_loss
return
train_loss
def
eval_batch
(
self
,
data
,
compute_loss
=
False
):
def
eval_batch
(
self
,
data
,
compute_loss
=
False
):
self
.
_layers
.
eval
()
self
.
_layers
.
eval
()
...
...
python/paddle/fluid/dygraph/parallel.py
浏览文件 @
8ffcc7c8
...
@@ -354,9 +354,15 @@ def sync_params_buffers(model,
...
@@ -354,9 +354,15 @@ def sync_params_buffers(model,
if
not
isinstance
(
param
,
core
.
VarBase
):
if
not
isinstance
(
param
,
core
.
VarBase
):
raise
TypeError
(
"The data type of '%s' must be Varbase"
%
raise
TypeError
(
"The data type of '%s' must be Varbase"
%
param
.
name
)
param
.
name
)
# is_distributed param not need to sync when in mp mode
# is_distributed param not need to sync when in mp mode
if
is_model_parallel
and
isinstance
(
param
,
ParamBase
):
if
isinstance
(
param
,
ParamBase
):
if
param
.
is_distributed
:
if
is_model_parallel
and
param
.
is_distributed
:
continue
# NOTE(shenliang03): Support situations that do not require synchronization parameters,
# such as moe's expert parameters
if
getattr
(
param
,
"no_sync"
,
False
):
continue
continue
model_vars
.
append
(
param
.
detach
())
model_vars
.
append
(
param
.
detach
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录