未验证 提交 a74208c1 编写于 作者: W WangXi 提交者: GitHub

[hybrid parallel] Optimize pipeline memory (#34230)

上级 056b8741
......@@ -548,6 +548,7 @@ class SectionWorker : public DeviceWorker {
~SectionWorker() override {}
void Initialize(const TrainerDesc& desc) override;
void PrepareUnusedVar();
void BindingDataFeedMemory() override {}
void CreateDeviceResource(const ProgramDesc& main_prog) override{};
......@@ -581,7 +582,8 @@ class SectionWorker : public DeviceWorker {
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void PrepareUnusedVar();
void RunFThenB(std::unique_ptr<GarbageCollector>&);
void Run1F1B(std::unique_ptr<GarbageCollector>&);
protected:
int section_id_;
......@@ -591,9 +593,12 @@ class SectionWorker : public DeviceWorker {
int pipeline_stage_;
int schedule_mode_; // 0 for F-then-B and 1 for 1F1B
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;
// skip&backward vars are only used in 1F1B
std::vector<std::string> skip_vars_;
std::vector<std::string> backward_send_vars_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::shared_ptr<framework::ProgramDesc> program_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
......
......@@ -146,18 +146,9 @@ GetUnusedVars(const BlockDesc &block,
return result;
}
void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
const std::unordered_map<const OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc) {
auto iter = delete_vars_map.find(op);
if (iter == delete_vars_map.end()) {
return;
}
auto &delete_vars = iter->second;
void DeleteUnusedTensors(const Scope &scope,
const std::vector<std::string> &delete_vars,
GarbageCollector *gc) {
std::deque<std::shared_ptr<memory::Allocation>> garbages;
for (auto &var_name : delete_vars) {
......@@ -189,6 +180,20 @@ void DeleteUnusedTensors(
}
}
void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
const std::unordered_map<const OperatorBase *, std::vector<std::string>>
&delete_vars_map,
GarbageCollector *gc) {
auto iter = delete_vars_map.find(op);
if (iter == delete_vars_map.end()) {
return;
}
auto &delete_vars = iter->second;
DeleteUnusedTensors(scope, delete_vars, gc);
}
static std::vector<std::unique_ptr<OperatorBase>> CreateOpsFromBlock(
const BlockDesc &block) {
std::vector<std::unique_ptr<OperatorBase>> ops;
......
......@@ -36,6 +36,11 @@ GetUnusedVars(const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_vars);
// Collect unused tensors
void DeleteUnusedTensors(const Scope &scope,
const std::vector<std::string> &delete_vars,
GarbageCollector *gc);
// Collect unused tensors after op runs
void DeleteUnusedTensors(
const Scope &scope, const OperatorBase *op,
......
......@@ -45,11 +45,11 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
auto this_worker =
std::dynamic_pointer_cast<paddle::framework::SectionWorker>(worker_);
this_worker->SetPlace(place_);
this_worker->Initialize(trainer_desc);
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
this_worker->Initialize(trainer_desc);
}
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......
......@@ -30,6 +30,36 @@ void SectionWorker::Initialize(const TrainerDesc &desc) {
for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
// if not 1F1B scheduler
if (schedule_mode_ != 1) return;
bool is_first_stage = (pipeline_stage_ == 0);
int BACKWARD = static_cast<int>(OpRole::kBackward);
for (auto &op : ops_) {
int op_role = op->Attr<int>("op_role");
auto op_type = op->Type();
// pipeline backward send op
if (op_role != BACKWARD) continue;
if (op_type != "send_v2" && op_type != "partial_send") continue;
auto var_name = op->InputVars()[0];
VLOG(3) << "Pipeline backward send var " << var_name;
PADDLE_ENFORCE_NE(is_first_stage, true,
platform::errors::PreconditionNotMet(
"The first pipeline stage must do not have a "
"backward send var, please check var %s",
var_name));
backward_send_vars_.push_back(var_name);
skip_vars_.push_back(var_name);
}
}
void SectionWorker::PrepareUnusedVar() {
VLOG(5) << "begin prepare the unsed vars";
unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
}
void SectionWorker::RunForward(
......@@ -96,9 +126,79 @@ void SectionWorker::RunUpdate(
}
}
void SectionWorker::PrepareUnusedVar() {
VLOG(5) << "begin prepare the unsed vars";
unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_);
void SectionWorker::RunFThenB(std::unique_ptr<GarbageCollector> &gc) {
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step3: run update
RunUpdate(gc, unused_vars_);
}
void SectionWorker::Run1F1B(std::unique_ptr<GarbageCollector> &gc) {
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
PADDLE_ENFORCE_GT(
num_microbatches_, startup_steps,
platform::errors::InvalidArgument(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d).",
num_microbatches_, startup_steps));
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
// delete backward send var at step=(bw_step - 2)
if (gc && bw_step >= 2) {
DeleteUnusedTensors(*microbatch_scopes_[bw_step - 2], backward_send_vars_,
gc.get());
}
RunBackward(bw_step, gc, unused_vars_);
fw_step += 1;
bw_step += 1;
}
int reserve_bw_send_step = bw_step - 2;
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
if (gc) {
// NOTE(wangxi): program must add sync backward send comm at update
// delete backward send var
for (int i = reserve_bw_send_step; i < num_microbatches_; ++i) {
DeleteUnusedTensors(*microbatch_scopes_[i], backward_send_vars_,
gc.get());
}
}
}
void SectionWorker::TrainFiles() {
......@@ -132,56 +232,11 @@ void SectionWorker::TrainFiles() {
} // max_memory_size >= 0
if (schedule_mode_ == 0) {
// F-then-B scheduler which runs Forward phase for all microbatches,
// then runs Backward phase for all microbatches.
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
// step3: run update
RunUpdate(gc, unused_vars_);
RunFThenB(gc);
} else {
// 1F1B scheduler, which runs forward phase and backward phase altertively
// after startup phase. For a stage, the number of microbatches for
// startup is num_pipeline_stages_ - pipeline_stage_ - 1, where
// num_pipeline_stages_ is the total number of pipeline stages and
// pipeline_stage_ is the pipeline stage of the current device.
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
PADDLE_ENFORCE_GT(
num_microbatches_, startup_steps,
platform::errors::InvalidArgument(
"To use pipeline with 1F1B scheduler, please make sure number of "
"microbatches (%d) is than startup steps (%d).",
num_microbatches_, startup_steps));
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
Run1F1B(gc);
}
dev_ctx_->Wait();
++batch_id_;
}
......
......@@ -5277,6 +5277,7 @@ class PipelineOptimizer(object):
backward_recv_index = index
break
# last pipeline stage
if backward_recv_index is None: return
offset = 0
......
......@@ -1258,7 +1258,7 @@ class TestDistBase(unittest.TestCase):
"fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10," \
"alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,executor=10,operator=10," \
"sparse_all_reduce_op_handle=10,gen_nccl_id_op=10,gen_nccl_id_op_help=10,nccl_helper=10,grpc_client=10," \
"grpc_server=10,request_handler_impl=10"
"grpc_server=10,request_handler_impl=10,section_worker=10"
required_envs["GLOG_logtostderr"] = "1"
required_envs.update(need_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册