提交 8b170ffa 编写于 作者: S sandyhouse

update

上级 af17a6ee
......@@ -28,6 +28,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/heter_service.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -656,6 +657,7 @@ class SectionWorker : public DeviceWorker {
void SetMicrobatchNum(int num) { num_microbatches_ = num; }
void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; }
void SetPipelineStage(int stage) { pipeline_stage_ = stage; }
void SetScheduleMode(int mode) { schedule_mode_ = mode; }
void SetMicrobatchScopes(const std::vector<Scope*>& scope) {
microbatch_scopes_ = scope;
}
......@@ -663,6 +665,15 @@ class SectionWorker : public DeviceWorker {
void SetSkipVars(const std::vector<std::string>& skip_vars) {
skip_vars_ = skip_vars;
}
void RunBackward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunForward(
int micro_id, std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
void RunUpdate(
std::unique_ptr<GarbageCollector>&,
std::unordered_map<const OperatorBase*, std::vector<std::string>>&);
protected:
int section_id_;
......@@ -670,6 +681,7 @@ class SectionWorker : public DeviceWorker {
int num_microbatches_;
int num_pipeline_stages_;
int pipeline_stage_;
int schedule_mode_; // 0 for GPipe and 1 for deepspeed
std::vector<Scope*> microbatch_scopes_;
std::vector<std::string> skip_vars_;
const Scope* minibatch_scope_;
......
......@@ -36,6 +36,7 @@ message ShardingConfig {
optional int32 parallelism = 5 [ default = 1 ];
optional bool use_pipeline = 6 [ default = false ];
optional int32 acc_steps = 7 [ default = 1 ];
optional int32 schedule_mode = 8 [ default = 0 ];
}
message AMPConfig {
......
......@@ -27,6 +27,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
const auto& section_params = trainer_desc.section_param();
const auto num_pipeline_stages_ = section_params.num_pipeline_stages();
const auto pipeline_stage_ = section_params.pipeline_stage();
const auto schedule_mode_ = section_params.schedule_mode();
num_microbatches_ = section_params.num_microbatches();
VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_;
trainer_desc_ = trainer_desc;
......@@ -44,6 +45,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
this_worker->SetMicrobatchNum(num_microbatches_);
this_worker->SetPipelineStageNum(num_pipeline_stages_);
this_worker->SetPipelineStage(pipeline_stage_);
this_worker->SetScheduleMode(schedule_mode_);
}
void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) {
......
......@@ -22,15 +22,79 @@ class TrainerDesc;
uint64_t SectionWorker::batch_id_(0);
void SectionWorker::Initialize(const TrainerDesc& desc) {
void SectionWorker::Initialize(const TrainerDesc &desc) {
dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_);
program_.reset(
new ProgramDesc(desc.section_param().section_config().program_desc()));
for (auto& op_desc : program_->Block(0).AllOps()) {
for (auto &op_desc : program_->Block(0).AllOps()) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
}
void SectionWorker::RunForward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunBackward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
}
}
}
void SectionWorker::RunUpdate(
std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
}
}
}
void SectionWorker::TrainFiles() {
VLOG(5) << "begin section_worker TrainFiles";
......@@ -48,168 +112,49 @@ void SectionWorker::TrainFiles() {
#endif
}
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
if (startup_steps > num_microbatches_) {
startup_steps = num_microbatches_;
}
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
VLOG(3) << "to run forward batch:" << fw_step;
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((fw_step == 0 && run_first_mbatch) || (fw_step != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< fw_step;
op->Run(*microbatch_scopes_[fw_step], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[fw_step], op.get(),
unused_vars_, gc.get());
}
}
if (schedule_mode_ == 0) {
// Gpipe scheduler which runs all forwards first, then backwards, then
// update
// step1: run forward
for (int i = 0; i < num_microbatches_; ++i) {
RunForward(i, gc, unused_vars_);
}
fw_step += 1;
}
// 1f1b phase
while (fw_step < num_microbatches_) {
VLOG(3) << "to run forward batch:" << fw_step;
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
if ((fw_step == 0 && run_first_mbatch) || (fw_step != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< fw_step;
op->Run(*microbatch_scopes_[fw_step], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[fw_step], op.get(),
unused_vars_, gc.get());
}
}
// step2: run backward
for (int i = 0; i < num_microbatches_; ++i) {
RunBackward(i, gc, unused_vars_);
}
fw_step += 1;
VLOG(3) << "to run backward batch:" << bw_step;
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< bw_step;
op->Run(*microbatch_scopes_[bw_step], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[bw_step], op.get(),
unused_vars_, gc.get());
}
}
// step2: run update
RunUpdate(gc, unused_vars_);
} else {
// 1F1B scheduler
auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1;
VLOG(3) << "startup_steps:" << startup_steps
<< ", num_stages: " << num_pipeline_stages_
<< ", stage:" << pipeline_stage_;
if (startup_steps > num_microbatches_) {
startup_steps = num_microbatches_;
}
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
VLOG(3) << "to run backward batch:" << bw_step;
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< bw_step;
op->Run(*microbatch_scopes_[bw_step], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[bw_step], op.get(),
unused_vars_, gc.get());
}
}
int fw_step = 0;
int bw_step = 0;
// startup phase
while (fw_step < startup_steps) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
}
bw_step += 1;
}
// for (int i = 0; i < num_microbatches_; ++i) {
// for (auto& op : ops_) {
// int op_role = op->Attr<int>(std::string("op_role"));
// // We run op with op_role = kLRSched only for the first microbatch
// // to avoid increasing the @LR_DECAY_STEP@ multiple times.
// bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward)
// ||
// op_role == (static_cast<int>(OpRole::kForward)
// |
// static_cast<int>(OpRole::kLoss)) ||
// op_role == static_cast<int>(OpRole::kLRSched);
// bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
// op_role == (static_cast<int>(OpRole::kForward) |
// static_cast<int>(OpRole::kLoss));
// if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) {
// VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch
// "
// << i;
// op->Run(*microbatch_scopes_[i], place_);
// if (gc) {
// DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
// gc.get());
// }
// }
// }
// cudaDeviceSynchronize();
// }
// // backward pass
// for (int i = 0; i < num_microbatches_; ++i) {
// for (auto& op : ops_) {
// int op_role = op->Attr<int>(std::string("op_role"));
// if (op_role == static_cast<int>(OpRole::kBackward) ||
// op_role == (static_cast<int>(OpRole::kBackward) |
// static_cast<int>(OpRole::kLoss))) {
// VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch
// "
// << i;
// op->Run(*microbatch_scopes_[i], place_);
// if (gc) {
// DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_,
// gc.get());
// }
// }
// }
// cudaDeviceSynchronize();
// }
// update pass
for (auto& op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
// for (int i = 0; i < num_microbatches_; ++i) {
// DeleteUnusedTensors(*microbatch_scopes_[i],
// op.get(), unused_vars_, gc.get());
//}
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
// 1f1b phase
while (fw_step < num_microbatches_) {
RunForward(fw_step, gc, unused_vars_);
fw_step += 1;
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
// backward phase
while (bw_step < num_microbatches_) {
RunBackward(bw_step, gc, unused_vars_);
bw_step += 1;
}
RunUpdate(gc, unused_vars_);
}
dev_ctx_->Wait();
++batch_id_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册