Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a501a7b0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a501a7b0
编写于
3月 22, 2021
作者:
L
lilong12
提交者:
GitHub
3月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[3D-parallel] add 1f1b scheduler for pipeline (#31566)
* add 1f1b scheduler for pp, test=develop
上级
ed7956a8
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
193 addition
and
73 deletion
+193
-73
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+18
-2
paddle/fluid/framework/distributed_strategy.proto
paddle/fluid/framework/distributed_strategy.proto
+1
-0
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+9
-1
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+112
-61
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+3
-0
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
...e/distributed/fleet/meta_optimizers/pipeline_optimizer.py
+8
-1
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+12
-0
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+5
-0
python/paddle/fluid/tests/unittests/pipeline_mnist.py
python/paddle/fluid/tests/unittests/pipeline_mnist.py
+16
-7
python/paddle/fluid/tests/unittests/pipeline_mnist_one_device.py
...paddle/fluid/tests/unittests/pipeline_mnist_one_device.py
+4
-0
python/paddle/fluid/tests/unittests/test_pipeline.py
python/paddle/fluid/tests/unittests/test_pipeline.py
+5
-1
未找到文件。
paddle/fluid/framework/device_worker.h
浏览文件 @
a501a7b0
...
...
@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
@@ -454,7 +455,7 @@ class HeterBoxWorker : public HogwildWorker {
virtual
void
CacheProgram
(
const
ProgramDesc
&
main_program
)
{
new
(
&
program_
)
ProgramDesc
(
main_program
);
}
v
irtual
v
oid
ProduceTasks
()
override
;
void
ProduceTasks
()
override
;
virtual
void
SetStream
(
const
gpuStream_t
stream
)
{
copy_stream_
=
stream
;
}
virtual
void
SetEvent
(
const
gpuEvent_t
event
)
{
event_
=
event
;
}
virtual
void
TrainFilesWithProfiler
()
{}
...
...
@@ -555,7 +556,7 @@ class PSGPUWorker : public HogwildWorker {
virtual
void
CacheProgram
(
const
ProgramDesc
&
main_program
)
{
new
(
&
program_
)
ProgramDesc
(
main_program
);
}
v
irtual
v
oid
ProduceTasks
()
override
;
void
ProduceTasks
()
override
;
virtual
void
SetStream
(
const
gpuStream_t
stream
)
{
copy_stream_
=
stream
;
}
virtual
void
SetEvent
(
const
gpuEvent_t
event
)
{
event_
=
event
;
}
void
ResetStat
();
...
...
@@ -659,6 +660,9 @@ class SectionWorker : public DeviceWorker {
void
SetDeviceIndex
(
int
tid
)
override
{}
void
SetThreadIndex
(
int
thread_id
)
{
thread_id_
=
thread_id
;
}
void
SetMicrobatchNum
(
int
num
)
{
num_microbatches_
=
num
;
}
void
SetPipelineStageNum
(
int
num
)
{
num_pipeline_stages_
=
num
;
}
void
SetPipelineStage
(
int
stage
)
{
pipeline_stage_
=
stage
;
}
void
SetScheduleMode
(
int
mode
)
{
schedule_mode_
=
mode
;
}
void
SetMicrobatchScopes
(
const
std
::
vector
<
Scope
*>&
scope
)
{
microbatch_scopes_
=
scope
;
}
...
...
@@ -666,11 +670,23 @@ class SectionWorker : public DeviceWorker {
void
SetSkipVars
(
const
std
::
vector
<
std
::
string
>&
skip_vars
)
{
skip_vars_
=
skip_vars
;
}
void
RunBackward
(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>&
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
);
void
RunForward
(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>&
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
);
void
RunUpdate
(
std
::
unique_ptr
<
GarbageCollector
>&
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>&
);
protected:
int
section_id_
;
int
thread_id_
;
int
num_microbatches_
;
int
num_pipeline_stages_
;
int
pipeline_stage_
;
int
schedule_mode_
;
// 0 for F-then-B and 1 for 1F1B
std
::
vector
<
Scope
*>
microbatch_scopes_
;
std
::
vector
<
std
::
string
>
skip_vars_
;
const
Scope
*
minibatch_scope_
;
...
...
paddle/fluid/framework/distributed_strategy.proto
浏览文件 @
a501a7b0
...
...
@@ -120,6 +120,7 @@ message AsyncConfig {
message
PipelineConfig
{
optional
int32
micro_batch_size
=
1
[
default
=
1
];
optional
int32
accumulate_steps
=
2
[
default
=
1
];
optional
string
schedule_mode
=
3
[
default
=
'1F1B'
];
}
message
DistributedStrategy
{
...
...
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
a501a7b0
...
...
@@ -24,6 +24,9 @@ namespace framework {
void
PipelineTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
dataset
)
{
const
auto
&
section_params
=
trainer_desc
.
section_param
();
const
int
num_pipeline_stages_
=
section_params
.
num_pipeline_stages
();
const
int
pipeline_stage_
=
section_params
.
pipeline_stage
();
const
int
schedule_mode_
=
section_params
.
schedule_mode
();
num_microbatches_
=
section_params
.
num_microbatches
();
VLOG
(
3
)
<<
"Number of microbatches per minibatch: "
<<
num_microbatches_
;
trainer_desc_
=
trainer_desc
;
...
...
@@ -39,6 +42,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker
->
SetPlace
(
place_
);
this_worker
->
Initialize
(
trainer_desc
);
this_worker
->
SetMicrobatchNum
(
num_microbatches_
);
this_worker
->
SetPipelineStageNum
(
num_pipeline_stages_
);
this_worker
->
SetPipelineStage
(
pipeline_stage_
);
this_worker
->
SetScheduleMode
(
schedule_mode_
);
}
void
PipelineTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
...
...
@@ -75,7 +81,9 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for
(
auto
&
var
:
global_block
.
AllVars
())
{
bool
is_param_grad
=
false
;
size_t
pos
=
0
;
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
// A magic suffix to indicate the merged gradient
std
::
string
magicSuffix
=
std
::
string
(
kGradVarSuffix
)
+
"@MERGED"
;
if
((
pos
=
var
->
Name
().
find
(
magicSuffix
))
!=
std
::
string
::
npos
)
{
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
a501a7b0
...
...
@@ -22,34 +22,20 @@ class TrainerDesc;
uint64_t
SectionWorker
::
batch_id_
(
0
);
void
SectionWorker
::
Initialize
(
const
TrainerDesc
&
desc
)
{
void
SectionWorker
::
Initialize
(
const
TrainerDesc
&
desc
)
{
dev_ctx_
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
);
program_
.
reset
(
new
ProgramDesc
(
desc
.
section_param
().
section_config
().
program_desc
()));
for
(
auto
&
op_desc
:
program_
->
Block
(
0
).
AllOps
())
{
for
(
auto
&
op_desc
:
program_
->
Block
(
0
).
AllOps
())
{
ops_
.
push_back
(
OpRegistry
::
CreateOp
(
*
op_desc
));
}
}
void
SectionWorker
::
TrainFiles
()
{
VLOG
(
5
)
<<
"begin section_worker TrainFiles"
;
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
>
gc
;
auto
unused_vars_
=
GetUnusedVars
(
program_
->
Block
(
0
),
ops_
,
skip_vars_
);
if
(
max_memory_size
>=
0
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
),
max_memory_size
));
}
}
#endif
}
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
void
SectionWorker
::
RunForward
(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
...
...
@@ -60,57 +46,122 @@ void SectionWorker::TrainFiles() {
bool
run_others
=
op_role
==
static_cast
<
int
>
(
OpRole
::
kForward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kForward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
));
if
((
i
==
0
&&
run_first_mbatch
)
||
(
i
!=
0
&&
run_others
))
{
if
((
micro_id
==
0
&&
run_first_mbatch
)
||
(
micro_id
!=
0
&&
run_others
))
{
VLOG
(
3
)
<<
"Forward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
i
;
op
->
Run
(
*
microbatch_scopes_
[
i
],
place_
);
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
.
get
()
,
unused_vars_
,
gc
.
get
());
}
}
}
#ifdef PADDLE_WITH_RCCL
hipDeviceSynchronize
();
#else
cudaDeviceSynchronize
();
#endif
}
}
// backward pass
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
for
(
auto
&
op
:
ops_
)
{
void
SectionWorker
::
RunBackward
(
int
micro_id
,
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kBackward
)
||
op_role
==
(
static_cast
<
int
>
(
OpRole
::
kBackward
)
|
static_cast
<
int
>
(
OpRole
::
kLoss
)))
{
VLOG
(
3
)
<<
"Backward: running op "
<<
op
->
Type
()
<<
" for micro-batch "
<<
i
;
op
->
Run
(
*
microbatch_scopes_
[
i
],
place_
);
<<
micro_id
;
op
->
Run
(
*
microbatch_scopes_
[
micro_id
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
i
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
DeleteUnusedTensors
(
*
microbatch_scopes_
[
micro_id
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
#ifdef PADDLE_WITH_RCCL
hipDeviceSynchronize
();
#else
cudaDeviceSynchronize
();
#endif
}
}
// update pass
for
(
auto
&
op
:
ops_
)
{
void
SectionWorker
::
RunUpdate
(
std
::
unique_ptr
<
GarbageCollector
>
&
gc
,
std
::
unordered_map
<
const
OperatorBase
*
,
std
::
vector
<
std
::
string
>>
&
unused_vars_
)
{
for
(
auto
&
op
:
ops_
)
{
int
op_role
=
op
->
Attr
<
int
>
(
std
::
string
(
"op_role"
));
if
(
op_role
==
static_cast
<
int
>
(
OpRole
::
kOptimize
))
{
VLOG
(
3
)
<<
"Update: running op "
<<
op
->
Type
();
op
->
Run
(
*
microbatch_scopes_
[
0
],
place_
);
op
->
Run
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
place_
);
if
(
gc
)
{
DeleteUnusedTensors
(
*
microbatch_scopes_
[
0
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
DeleteUnusedTensors
(
*
microbatch_scopes_
[
num_microbatches_
-
1
],
op
.
get
(),
unused_vars_
,
gc
.
get
());
}
}
}
}
void
SectionWorker
::
TrainFiles
()
{
VLOG
(
5
)
<<
"begin section_worker TrainFiles"
;
int64_t
max_memory_size
=
GetEagerDeletionThreshold
();
std
::
unique_ptr
<
GarbageCollector
>
gc
;
auto
unused_vars_
=
GetUnusedVars
(
program_
->
Block
(
0
),
ops_
,
skip_vars_
);
if
(
max_memory_size
>=
0
)
{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
platform
::
is_gpu_place
(
place_
))
{
if
(
IsFastEagerDeletionModeEnabled
())
{
gc
.
reset
(
new
UnsafeFastGPUGarbageCollector
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
),
max_memory_size
));
}
}
#endif
}
if
(
schedule_mode_
==
0
)
{
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
RunForward
(
i
,
gc
,
unused_vars_
);
}
// step2: run backward
for
(
int
i
=
0
;
i
<
num_microbatches_
;
++
i
)
{
RunBackward
(
i
,
gc
,
unused_vars_
);
}
// step3: run update
RunUpdate
(
gc
,
unused_vars_
);
}
else
{
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto
startup_steps
=
num_pipeline_stages_
-
pipeline_stage_
-
1
;
VLOG
(
3
)
<<
"startup_steps:"
<<
startup_steps
<<
", num_stages: "
<<
num_pipeline_stages_
<<
", stage:"
<<
pipeline_stage_
;
PADDLE_ENFORCE_GT
(
num_microbatches_
,
startup_steps
,
platform
::
errors
::
InvalidArgument
(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d)."
,
num_microbatches_
,
startup_steps
));
int
fw_step
=
0
;
int
bw_step
=
0
;
// startup phase
while
(
fw_step
<
startup_steps
)
{
RunForward
(
fw_step
,
gc
,
unused_vars_
);
fw_step
+=
1
;
}
// 1f1b phase
while
(
fw_step
<
num_microbatches_
)
{
RunForward
(
fw_step
,
gc
,
unused_vars_
);
fw_step
+=
1
;
RunBackward
(
bw_step
,
gc
,
unused_vars_
);
bw_step
+=
1
;
}
// backward phase
while
(
bw_step
<
num_microbatches_
)
{
RunBackward
(
bw_step
,
gc
,
unused_vars_
);
bw_step
+=
1
;
}
RunUpdate
(
gc
,
unused_vars_
);
}
dev_ctx_
->
Wait
();
++
batch_id_
;
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
a501a7b0
...
...
@@ -93,6 +93,9 @@ message SectionWorkerParameter {
optional
int32
start_cpu_core_id
=
4
[
default
=
1
];
repeated
string
param_need_sync
=
5
;
optional
int32
num_microbatches
=
6
;
optional
int32
num_pipeline_stages
=
7
[
default
=
1
];
optional
int32
pipeline_stage
=
8
[
default
=
1
];
optional
int32
schedule_mode
=
9
[
default
=
0
];
}
message
SectionConfig
{
...
...
python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py
浏览文件 @
a501a7b0
...
...
@@ -138,7 +138,10 @@ class PipelineOptimizer(MetaOptimizerBase):
super
(
PipelineOptimizer
,
self
).
__init__
(
optimizer
)
self
.
inner_opt
=
optimizer
# we do not allow meta optimizer to be inner optimizer currently
self
.
meta_optimizers_white_list
=
[]
self
.
meta_optimizers_white_list
=
[
"RecomputeOptimizer"
,
"AMPOptimizer"
,
]
self
.
meta_optimizers_black_list
=
[
"GraphExecutionOptimizer"
,
]
def
_set_basic_info
(
self
,
loss
,
role_maker
,
user_defined_optimizer
,
...
...
@@ -149,6 +152,8 @@ class PipelineOptimizer(MetaOptimizerBase):
'micro_batch_size'
]
self
.
num_microbatches
=
user_defined_strategy
.
pipeline_configs
[
'accumulate_steps'
]
self
.
schedule_mode
=
user_defined_strategy
.
pipeline_configs
[
'schedule_mode'
]
def
_can_apply
(
self
):
if
not
self
.
role_maker
.
_is_collective
:
...
...
@@ -167,6 +172,7 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy
.
pipeline_configs
=
{
"micro_batch_size"
:
1
,
"accumulate_steps"
:
1
,
"schedule_mode"
:
"1F1B"
,
}
def
minimize_impl
(
self
,
...
...
@@ -192,6 +198,7 @@ class PipelineOptimizer(MetaOptimizerBase):
loss
.
block
.
program
.
_pipeline_opt
[
'local_rank'
]
=
self
.
rank
loss
.
block
.
program
.
_pipeline_opt
[
'micro_batch_size'
]
=
self
.
micro_batch_size
loss
.
block
.
program
.
_pipeline_opt
[
'schedule_mode'
]
=
self
.
schedule_mode
optimize_ops
,
params_grads
,
prog_list
=
self
.
wrapped_opt
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
assert
prog_list
...
...
python/paddle/fluid/device_worker.py
浏览文件 @
a501a7b0
...
...
@@ -413,6 +413,18 @@ class Section(DeviceWorker):
section_param
=
trainer_desc
.
section_param
section_param
.
num_microbatches
=
pipeline_opt
[
"num_microbatches"
]
section_param
.
start_cpu_core_id
=
pipeline_opt
[
"start_cpu_core_id"
]
section_param
.
pipeline_stage
=
pipeline_opt
[
"pipeline_stage"
]
section_param
.
num_pipeline_stages
=
pipeline_opt
[
"num_pipeline_stages"
]
schedule_mode_str
=
pipeline_opt
[
"schedule_mode"
]
# F-then-B scheduler which runs Forward phase for all microbatches,
# then runs Backward phase for all microbatches.
# 1F1B scheduler, which runs forward phase and backward phase altertively
# after startup phase.
assert
schedule_mode_str
in
[
"F-then-B"
,
"1F1B"
],
(
"The schedule mode "
"for pipeline must be one of F-then-B or 1F1B"
)
schedule_mode
=
0
if
schedule_mode_str
==
"F-then-B"
else
1
section_param
.
schedule_mode
=
schedule_mode
cfg
=
section_param
.
section_config
program
=
pipeline_opt
[
"section_program"
]
cfg
.
program_desc
.
ParseFromString
(
program
[
"program"
].
_get_desc
()
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a501a7b0
...
...
@@ -4273,6 +4273,7 @@ class PipelineOptimizer(object):
grad_name
=
self
.
_append_grad_suffix
(
param_name
)
if
not
main_block
.
has_var
(
grad_name
):
continue
grad_var
=
main_block
.
vars
[
grad_name
]
grad_var
.
persistable
=
True
main_block
.
_insert_op
(
index
=
0
,
type
=
'fill_constant'
,
...
...
@@ -4517,6 +4518,7 @@ class PipelineOptimizer(object):
"You must use pipeline with fleet"
local_rank
=
main_program
.
_pipeline_opt
[
'local_rank'
]
%
len
(
device_specs
)
self
.
schedule_mode
=
main_program
.
_pipeline_opt
[
'schedule_mode'
]
place_list
=
[]
for
dev_spec
in
device_specs
:
...
...
@@ -4543,6 +4545,9 @@ class PipelineOptimizer(object):
main_program
.
_pipeline_opt
=
{
"trainer"
:
"PipelineTrainer"
,
"device_worker"
:
"Section"
,
"pipeline_stage"
:
local_rank
,
"num_pipeline_stages"
:
len
(
device_specs
),
"schedule_mode"
:
self
.
schedule_mode
,
"inner_parallelism"
:
len
(
device_specs
),
"section_program"
:
program_list
[
local_rank
],
"place"
:
place_list
[
local_rank
],
...
...
python/paddle/fluid/tests/unittests/pipeline_mnist.py
浏览文件 @
a501a7b0
...
...
@@ -110,22 +110,31 @@ class TestDistMnist2x2(TestDistRunnerBase):
lr_val
=
fluid
.
layers
.
piecewise_decay
(
boundaries
=
bd
,
values
=
lr
)
opt
=
fluid
.
optimizer
.
Momentum
(
learning_rate
=
lr_val
,
momentum
=
0.9
)
acc_steps
=
2
# accumulated steps for pipeline
if
dist_strategy
:
# Reader
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
)
if
dist_strategy
:
fleet
.
init
(
is_collective
=
True
)
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
'micro_batch_size'
:
batch_size
,
}
strategy
.
pipeline_configs
=
{
'micro_batch_size'
:
batch_size
,
'schedule_mode'
:
'1F1B'
,
'accumulate_steps'
:
acc_steps
}
dist_opt
=
fleet
.
distributed_optimizer
(
optimizer
=
opt
,
strategy
=
strategy
)
dist_opt
.
minimize
(
avg_cost
)
else
:
opt
.
minimize
(
avg_cost
)
# Reader
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
*
acc_steps
)
test_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
batch_size
*
acc_steps
)
if
dist_strategy
:
return
inference_program
,
avg_cost
,
train_reader
,
test_reader
,
batch_acc
,
predict
,
data_loader
...
...
python/paddle/fluid/tests/unittests/pipeline_mnist_one_device.py
浏览文件 @
a501a7b0
...
...
@@ -122,6 +122,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
if
dist_strategy
:
strategy
=
fleet
.
DistributedStrategy
()
strategy
.
pipeline
=
True
strategy
.
pipeline_configs
=
{
'schedule_mode'
:
'F-then-B'
,
'micro_batch_size'
:
batch_size
}
dist_opt
=
fleet
.
distributed_optimizer
(
optimizer
=
opt
,
strategy
=
strategy
)
dist_opt
.
minimize
(
avg_cost
)
...
...
python/paddle/fluid/tests/unittests/test_pipeline.py
浏览文件 @
a501a7b0
...
...
@@ -34,9 +34,13 @@ class TestPipeline(TestDistBase):
def
test_dist_train
(
self
):
import
paddle.fluid
as
fluid
if
fluid
.
core
.
is_compiled_with_cuda
():
# TODO (sandyhouse) fix the delta value.
# Now pipeline only gets the loss value of the last
# microbatch, so it is not consistable with the
# non-pipeline one.
self
.
check_with_place
(
"pipeline_mnist.py"
,
delta
=
1e
-5
,
delta
=
1e
0
,
check_error_log
=
True
,
log_name
=
flag_name
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录