未验证 提交 04fdb10a 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid performance] Optimize Pipeline Scheduler (#35680)

上级 e46ffaf2
......@@ -601,6 +601,10 @@ class SectionWorker : public DeviceWorker {
std::vector<std::string> backward_send_vars_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
std::vector<OperatorBase*> forward_and_lr_ops_;
std::vector<OperatorBase*> forward_ops_;
std::vector<OperatorBase*> backward_ops_;
std::vector<OperatorBase*> optimizer_ops_;
std::shared_ptr<framework::ProgramDesc> program_;
std::unordered_map<const OperatorBase*, std::vector<std::string>>
unused_vars_;
......
......@@ -31,6 +31,33 @@ void SectionWorker::Initialize(const TrainerDesc &desc) {
ops_.push_back(OpRegistry::CreateOp(*op_desc));
}
for (auto &op : ops_) {
// cache the op type during the init part
// reduce unnecessary op visit during running
int op_role = op->Attr<int>("op_role");
if ((op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss))) ||
(op_role == static_cast<int>(OpRole::kLRSched))) {
// forward ops and lr schedule ops, used for first micro step
forward_and_lr_ops_.push_back(op.get());
if ((op_role != static_cast<int>(OpRole::kLRSched))) {
// only forward ops, used for second and later micro steps
forward_ops_.push_back(op.get());
}
} else if ((op_role == static_cast<int>(OpRole::kBackward)) ||
(op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)))) {
backward_ops_.push_back(op.get());
} else if (op_role == static_cast<int>(OpRole::kOptimize)) {
optimizer_ops_.push_back(op.get());
} else {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"The op %s is None of LRSched, Forward, Backward or Optimize.",
op->Type()));
}
}
// if not 1F1B scheduler
if (schedule_mode_ != 1) return;
......@@ -66,25 +93,15 @@ void SectionWorker::RunForward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = (op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss))) ||
(op_role == static_cast<int>(OpRole::kLRSched));
bool run_others = (op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
std::vector<OperatorBase *> &forward_tmp =
micro_id == 0 ? forward_and_lr_ops_ : forward_ops_;
for (auto &op : forward_tmp) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_,
gc.get());
}
}
}
......@@ -93,18 +110,13 @@ void SectionWorker::RunBackward(
int micro_id, std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if ((op_role == static_cast<int>(OpRole::kBackward)) ||
(op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)))) {
for (auto &op : backward_ops_) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(),
unused_vars_, gc.get());
}
DeleteUnusedTensors(*microbatch_scopes_[micro_id], op, unused_vars_,
gc.get());
}
}
}
......@@ -113,15 +125,12 @@ void SectionWorker::RunUpdate(
std::unique_ptr<GarbageCollector> &gc,
std::unordered_map<const OperatorBase *, std::vector<std::string>>
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kOptimize)) {
for (auto &op : optimizer_ops_) {
VLOG(3) << "Update: running op " << op->Type();
op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_);
if (gc) {
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1],
op.get(), unused_vars_, gc.get());
}
DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], op,
unused_vars_, gc.get());
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册