提交 79989c90 编写于 作者: Y Yu Yang

Add SSA builder

上级 64d7a302
...@@ -43,14 +43,20 @@ struct SSAGraph { ...@@ -43,14 +43,20 @@ struct SSAGraph {
std::vector<std::unique_ptr<OpHandleBase>> ops_; std::vector<std::unique_ptr<OpHandleBase>> ops_;
}; };
/** class SSAGraphBuilder {
public:
virtual ~SSAGraphBuilder() {}
virtual void Build(const ProgramDesc &program, SSAGraph *graph) const = 0;
protected:
/**
* We only handle write after read(WAR), since it should not have a write * We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need * after write in program. If there are write after write operators, we need
* prune them. * prune them.
* *
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/ */
static void PolishGraphToSupportDataHazards(SSAGraph *graph) { static void PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) { for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) { if (name_pair.second.size() <= 1) {
...@@ -83,9 +89,9 @@ static void PolishGraphToSupportDataHazards(SSAGraph *graph) { ...@@ -83,9 +89,9 @@ static void PolishGraphToSupportDataHazards(SSAGraph *graph) {
} }
} }
} }
} }
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, const platform::Place &place,
size_t place_offset) { size_t place_offset) {
...@@ -103,11 +109,12 @@ static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph, ...@@ -103,11 +109,12 @@ static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
var = &var_holder.rbegin()->second; var = &var_holder.rbegin()->second;
} }
return var; return var;
} }
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) { const platform::Place &place,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size(); size_t version = vars.size();
auto &var = vars[version]; auto &var = vars[version];
...@@ -115,7 +122,132 @@ static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -115,7 +122,132 @@ static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
var.name_ = each_var_name; var.name_ = each_var_name;
var.place_ = place; var.place_ = place;
op_handle->AddOutput(&var); op_handle->AddOutput(&var);
} }
};
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
public:
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
const std::string &loss_var_name,
const std::unordered_set<std::string> &params,
const std::vector<Scope *> &local_scopes,
platform::NCCLContextMap *nccl_ctxs)
: loss_var_name_(loss_var_name),
places_(places),
local_scopes_(local_scopes),
nccl_ctxs_(nccl_ctxs) {
for (auto &p : params) {
grad_names_.insert(GradVarName(p));
}
}
void Build(const ProgramDesc &program, SSAGraph *graph) const override {
SSAGraph &result = *graph;
result.vars_.resize(places_.size());
bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) {
bool change_forward = false;
if (!is_forwarding) {
// FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 &&
op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
continue; // Drop fill 1. for backward coeff;
}
}
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto *s = local_scopes_[i];
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = result.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&result, op_handle, each_var_name, p, i);
}
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
// Insert ScaleCost OpHandle
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
nccl_ctxs_->DevCtx(p));
result.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p,
i);
change_forward = true;
}
}
}
if (change_forward) {
is_forwarding = false;
}
if (!is_forwarding) {
auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) {
if (grad_names_.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
result.ops_.emplace_back(
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
auto *op_handle = result.ops_.back().get();
for (size_t i = 0; i < places_.size(); ++i) {
auto &p = places_[i];
auto &vars = result.vars_[i][og];
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->AddOutput(&var);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards(&result);
}
private:
std::string loss_var_name_;
const std::vector<platform::Place> &places_;
const std::vector<Scope *> &local_scopes_;
platform::NCCLContextMap *nccl_ctxs_;
std::unordered_set<std::string> grad_names_;
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
...@@ -123,9 +255,7 @@ class ParallelExecutorPrivate { ...@@ -123,9 +255,7 @@ class ParallelExecutorPrivate {
const std::vector<platform::Place> &places) const std::vector<platform::Place> &places)
: places_(places), : places_(places),
fetch_dev_ctxs_(places), fetch_dev_ctxs_(places),
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) { pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {}
graph_.vars_.resize(places.size());
}
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
platform::DeviceContextPool fetch_dev_ctxs_; platform::DeviceContextPool fetch_dev_ctxs_;
...@@ -199,7 +329,10 @@ ParallelExecutor::ParallelExecutor( ...@@ -199,7 +329,10 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
ConstructDependencyGraph(params, main_program, loss_var_name); MultiDevSSAGraphBuilder builder(member_->places_, loss_var_name, params,
member_->local_scopes_,
member_->nccl_ctxs_.get());
builder.Build(main_program, &member_->graph_);
// Step 3. Create vars in each scope; // Step 3. Create vars in each scope;
for (auto *scope : member_->local_scopes_) { for (auto *scope : member_->local_scopes_) {
...@@ -213,110 +346,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -213,110 +346,6 @@ ParallelExecutor::ParallelExecutor(
} }
} }
void ParallelExecutor::ConstructDependencyGraph(
const std::unordered_set<std::string> &params,
const ProgramDesc &main_program, const std::string &loss_var_name) const {
std::unordered_set<std::string> grads;
for (auto &each_param : params) {
grads.insert(each_param + "@GRAD");
}
bool is_forwarding = true;
for (auto *op : main_program.Block(0).AllOps()) {
bool change_forward = false;
if (!is_forwarding) {
// FIXME(yy): Do not hard code like this
if (op->OutputArgumentNames().size() == 1 &&
op->OutputArgumentNames()[0] == loss_var_name + "@GRAD") {
continue; // Drop fill 1. for backward coeff;
}
}
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &p = member_->places_[i];
auto *s = member_->local_scopes_[i];
member_->graph_.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = member_->graph_.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var =
CreateOrGetLatestVarHandle(&member_->graph_, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
CreateOpOutput(&member_->graph_, op_handle, each_var_name, p, i);
}
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle
op_handle =
new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s,
p, member_->nccl_ctxs_->DevCtx(p));
member_->graph_.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
// VarHandle *loss = GetVarHandle(loss_var_name, place);
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
CreateOpOutput(&member_->graph_, op_handle, loss_var_name + "@GRAD",
p, i);
change_forward = true;
}
}
}
if (change_forward) {
is_forwarding = false;
}
if (!is_forwarding) {
auto var_names = op->OutputArgumentNames();
for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
member_->graph_.ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->local_scopes_, member_->places_, *member_->nccl_ctxs_));
auto *op_handle = member_->graph_.ops_.back().get();
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &p = member_->places_[i];
auto &vars = member_->graph_.vars_[i][og];
if (vars.empty()) { // This device has no data. continue.
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->AddInput(prev_grad);
auto &var = vars[vars.size()];
var.place_ = p;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->AddOutput(&var);
}
}
}
}
}
/*
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards(&member_->graph_);
}
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
const ProgramDesc &startup_program) const { const ProgramDesc &startup_program) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
......
...@@ -47,10 +47,6 @@ class ParallelExecutor { ...@@ -47,10 +47,6 @@ class ParallelExecutor {
void BCastParamsToGPUs(const ProgramDesc& startup_program) const; void BCastParamsToGPUs(const ProgramDesc& startup_program) const;
void ConstructDependencyGraph(const std::unordered_set<std::string>& params,
const ProgramDesc& main_program,
const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册