未验证 提交 a501a7b0 编写于 作者: L lilong12 提交者: GitHub

[3D-parallel] add 1f1b scheduler for pipeline (#31566)

* add 1f1b scheduler for pp, test=develop
上级 ed7956a8
......@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -454,7 +455,7 @@ class HeterBoxWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual void ProduceTasks() override;
void ProduceTasks() override;
virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
virtual void TrainFilesWithProfiler() {}
......@@ -555,7 +556,7 @@ class PSGPUWorker : public HogwildWorker {
virtual void CacheProgram(const ProgramDesc& main_program) {
new (&program_) ProgramDesc(main_program);
}
virtual void ProduceTasks() override;
void ProduceTasks() override;
virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
void ResetStat();
......@@ -659,6 +660,9 @@ class SectionWorker : public DeviceWorker {
void SetDeviceIndex(int tid) override {}
void SetThreadIndex(int thread_id) { thread_id_ = thread_id; }
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
void SetScheduleMode(int mode) { schedule_mode_ = mode; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope;
}
......@@ -666,11 +670,23 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars;
}
void RunBackward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunForward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
protected:
int section_id_;
int thread_id_;
int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
int schedule_mode_; // 0 for F-then-B and 1 for 1F1B
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;
......
......@@ -120,6 +120,7 @@ message AsyncConfig {
message PipelineConfig {
optional int32 micro_batch_size = 1 [ default = 1 ];
optional int32 accumulate_steps = 2 [ default = 1 ];
optional string schedule_mode = 3 [ default = '1F1B' ];
}
message DistributedStrategy {
......
......@@ -24,6 +24,9 @@ namespace framework {
void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
const auto& section_params = trainer_desc.section_param();
const int num_pipeline_stages_ = section_params.num_pipeline_stages();
const int pipeline_stage_ = section_params.pipeline_stage();
const int schedule_mode_ = section_params.schedule_mode();
num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
......@@ -39,6 +42,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
}
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......@@ -75,7 +81,9 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for (auto& var : global_block.AllVars()) {
bool is_param_grad = false;
size_t pos = 0;
if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) {
// A magic suffix to indicate the merged gradient
std::string magicSuffix = std::string(kGradVarSuffix) + "@MERGED";
if ((pos = var->Name().find(magicSuffix)) != std::string::npos) {
auto prefix_name = var->Name().substr(0, pos);
if (param_map.find(prefix_name) != param_map.end()) {
is_param_grad = true;
......
......@@ -22,15 +22,79 @@ class TrainerDesc;
uint64_t SectionWorker::batch_id_(0);
void SectionWorker::Initialize(const TrainerDesc& desc) {
void SectionWorker::Initialize(const TrainerDesc &desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset(
new ProgramDesc(desc.section_param().section_config().program_desc()));
for (auto& op_desc : program_->Block(0).AllOps()) {
for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
}
void SectionWorker::RunForward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunBackward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunUpdate(
std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
}
}
}
void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
......@@ -48,69 +112,56 @@ void SectionWorker::TrainFiles() {
#endif
}
for (int i = 0; i < num_microbatches_; ++i) {
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< i;
op->Run(*microbatch_scopes_[i], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
gc.get());
}
}
if (schedule_mode_ == 0) {
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
#ifdef PADDLE_WITH_RCCL
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
}
// backward pass
for (int i = 0; i < num_microbatches_; ++i) {
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< i;
op->Run(*microbatch_scopes_[i], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
gc.get());
}
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step3: run update
RunUpdate(gc, unused_vars_);
} else {
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
PADDLE_ENFORCE_GT(
num_microbatches_, startup_steps,
platform::errors::InvalidArgument(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d).",
num_microbatches_, startup_steps));
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
#ifdef PADDLE_WITH_RCCL
hipDeviceSynchronize();
#else
cudaDeviceSynchronize();
#endif
}
// update pass
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[0], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_,
gc.get());
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
}
dev_ctx_->Wait();
++batch_id_;
......
......@@ -93,6 +93,9 @@ message SectionWorkerParameter {
optional int32 start_cpu_core_id = 4 [ default = 1 ];
repeated string param_need_sync = 5;
optional int32 num_microbatches = 6;
optional int32 num_pipeline_stages = 7 [ default = 1 ];
optional int32 pipeline_stage = 8 [ default = 1 ];
optional int32 schedule_mode = 9 [ default = 0 ];
}
message SectionConfig {
......
......@@ -138,7 +138,10 @@ class PipelineOptimizer(MetaOptimizerBase):
super(PipelineOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = []
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
......@@ -149,6 +152,8 @@ class PipelineOptimizer(MetaOptimizerBase):
'micro_batch_size']
self.num_microbatches = user_defined_strategy.pipeline_configs[
'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode']
def _can_apply(self):
if not self.role_maker._is_collective:
......@@ -167,6 +172,7 @@ class PipelineOptimizer(MetaOptimizerBase):
dist_strategy.pipeline_configs = {
"micro_batch_size": 1,
"accumulate_steps": 1,
"schedule_mode": "1F1B",
}
def minimize_impl(self,
......@@ -192,6 +198,7 @@ class PipelineOptimizer(MetaOptimizerBase):
loss.block.program._pipeline_opt['local_rank'] = self.rank
loss.block.program._pipeline_opt[
'micro_batch_size'] = self.micro_batch_size
loss.block.program._pipeline_opt['schedule_mode'] = self.schedule_mode
optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
assert prog_list
......
......@@ -413,6 +413,18 @@ class Section(DeviceWorker):
section_param = trainer_desc.section_param
section_param.num_microbatches = pipeline_opt["num_microbatches"]
section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"]
section_param.pipeline_stage = pipeline_opt["pipeline_stage"]
section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"]
schedule_mode_str = pipeline_opt["schedule_mode"]
# F-then-B scheduler which runs Forward phase for all microbatches,
# then runs Backward phase for all microbatches.
# 1F1B scheduler, which runs forward phase and backward phase altertively
# after startup phase.
assert schedule_mode_str in ["F-then-B", "1F1B"], (
"The schedule mode "
"for pipeline must be one of F-then-B or 1F1B")
schedule_mode = 0 if schedule_mode_str == "F-then-B" else 1
section_param.schedule_mode = schedule_mode
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
......
......@@ -4273,6 +4273,7 @@ class PipelineOptimizer(object):
grad_name = self._append_grad_suffix(param_name)
if not main_block.has_var(grad_name): continue
grad_var = main_block.vars[grad_name]
grad_var.persistable = True
main_block._insert_op(
index=0,
type='fill_constant',
......@@ -4517,6 +4518,7 @@ class PipelineOptimizer(object):
"You must use pipeline with fleet"
local_rank = main_program._pipeline_opt['local_rank'] % len(
device_specs)
self.schedule_mode = main_program._pipeline_opt['schedule_mode']
place_list = []
for dev_spec in device_specs:
......@@ -4543,6 +4545,9 @@ class PipelineOptimizer(object):
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
"pipeline_stage": local_rank,
"num_pipeline_stages": len(device_specs),
"schedule_mode": self.schedule_mode,
"inner_parallelism": len(device_specs),
"section_program": program_list[local_rank],
"place": place_list[local_rank],
......
......@@ -110,22 +110,31 @@ class TestDistMnist2x2(TestDistRunnerBase):
lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
opt = fluid.optimizer.Momentum(learning_rate=lr_val, momentum=0.9)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
acc_steps = 2 # accumulated steps for pipeline
if dist_strategy:
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.pipeline = True
strategy.pipeline_configs = {'micro_batch_size': batch_size, }
strategy.pipeline_configs = {
'micro_batch_size': batch_size,
'schedule_mode': '1F1B',
'accumulate_steps': acc_steps
}
dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost)
else:
opt.minimize(avg_cost)
# Reader
train_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size * acc_steps)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size * acc_steps)
if dist_strategy:
return inference_program, avg_cost, train_reader, test_reader, batch_acc, predict, data_loader
......
......@@ -122,6 +122,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
if dist_strategy:
strategy = fleet.DistributedStrategy()
strategy.pipeline = True
strategy.pipeline_configs = {
'schedule_mode': 'F-then-B',
'micro_batch_size': batch_size
}
dist_opt = fleet.distributed_optimizer(
optimizer=opt, strategy=strategy)
dist_opt.minimize(avg_cost)
......
......@@ -34,9 +34,13 @@ class TestPipeline(TestDistBase):
def test_dist_train(self):
import paddle.fluid as fluid
if fluid.core.is_compiled_with_cuda():
# TODO (sandyhouse) fix the delta value.
# Now pipeline only gets the loss value of the last
# microbatch, so it is not consistable with the
# non-pipeline one.
self.check_with_place(
"pipeline_mnist.py",
delta=1e-5,
delta=1e0,
check_error_log=True,
log_name=flag_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册