提交 64d7a302 编写于 作者: Y Yu Yang

Extract SSAGraph

上级 8dec4ad7
......@@ -37,6 +37,86 @@ using details::ScaleLossGradOpHandle;
using details::VarHandle;
using details::VarHandleBase;
struct SSAGraph {
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
};
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
static void PolishGraphToSupportDataHazards(SSAGraph *graph) {
for (auto &var_map : graph->vars_) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
graph->dep_vars_.emplace(dep_var);
}
}
}
}
}
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
const std::string &each_var_name,
const platform::Place &place,
size_t place_offset) {
auto &var_holders = graph->vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
}
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads,
......@@ -44,7 +124,7 @@ class ParallelExecutorPrivate {
: places_(places),
fetch_dev_ctxs_(places),
pool_(num_threads <= 1 ? nullptr : new ThreadPool(num_threads)) {
vars_.resize(places.size());
graph_.vars_.resize(places.size());
}
std::vector<platform::Place> places_;
......@@ -54,35 +134,13 @@ class ParallelExecutorPrivate {
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
SSAGraph graph_;
// Use a simpler thread pool, might be faster.
std::unique_ptr<ThreadPool> pool_;
std::unique_ptr<platform::EnforceNotMet> exception_;
VarHandle *GetVarHandle(const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &var_holders = vars_[place_offset];
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
auto &init_var = var_holder[0];
init_var.place_ = place;
init_var.name_ = each_var_name;
init_var.generated_op_ = nullptr;
init_var.version_ = 0;
var = &init_var;
} else {
var = &var_holder.rbegin()->second;
}
return var;
}
void RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
......@@ -113,17 +171,6 @@ class ParallelExecutorPrivate {
op_run();
}
}
void GenerateVar(OpHandleBase *op_handle, const std::string &each_var_name,
const platform::Place &place, size_t place_offset) {
auto &vars = vars_[place_offset][each_var_name];
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.name_ = each_var_name;
var.place_ = place;
op_handle->AddOutput(&var);
}
};
ParallelExecutor::ParallelExecutor(
......@@ -189,21 +236,22 @@ void ParallelExecutor::ConstructDependencyGraph(
auto &p = member_->places_[i];
auto *s = member_->local_scopes_[i];
member_->ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = member_->ops_.back().get();
member_->graph_.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
auto *op_handle = member_->graph_.ops_.back().get();
op_handle->dev_ctx_[p] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op->InputArgumentNames();
for (auto &each_var_name : var_names) {
VarHandle *var = member_->GetVarHandle(each_var_name, p, i);
VarHandle *var =
CreateOrGetLatestVarHandle(&member_->graph_, each_var_name, p, i);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
for (auto &each_var_name : var_names) {
member_->GenerateVar(op_handle, each_var_name, p, i);
CreateOpOutput(&member_->graph_, op_handle, each_var_name, p, i);
}
if (is_forwarding) {
......@@ -212,7 +260,7 @@ void ParallelExecutor::ConstructDependencyGraph(
op_handle =
new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s,
p, member_->nccl_ctxs_->DevCtx(p));
member_->ops_.emplace_back(op_handle);
member_->graph_.ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -220,7 +268,8 @@ void ParallelExecutor::ConstructDependencyGraph(
// loss->pending_ops_.emplace_back(op_handle);
// op_handle->inputs_.emplace_back(loss);
member_->GenerateVar(op_handle, loss_var_name + "@GRAD", p, i);
CreateOpOutput(&member_->graph_, op_handle, loss_var_name + "@GRAD",
p, i);
change_forward = true;
}
}
......@@ -235,13 +284,13 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &og : var_names) {
if (grads.count(og) != 0) { // is param grad
// Insert NCCL AllReduce Op
member_->ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->graph_.ops_.emplace_back(new NCCLAllReduceOpHandle(
member_->local_scopes_, member_->places_, *member_->nccl_ctxs_));
auto *op_handle = member_->ops_.back().get();
auto *op_handle = member_->graph_.ops_.back().get();
for (size_t i = 0; i < member_->places_.size(); ++i) {
auto &p = member_->places_[i];
auto &vars = member_->vars_[i][og];
auto &vars = member_->graph_.vars_[i][og];
if (vars.empty()) { // This device has no data. continue.
continue;
......@@ -265,49 +314,7 @@ void ParallelExecutor::ConstructDependencyGraph(
Dependency graph has been constructed. However, there are still data
harzaeds need to be handled.
*/
PolishGraphToSupportDataHazards();
}
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
void ParallelExecutor::PolishGraphToSupportDataHazards() const {
for (auto &var_map : member_->vars_) {
for (auto &name_pair : var_map) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
auto *ex_write_op = it_old->second.generated_op_;
if (ex_write_op == nullptr) { // Nobody write this var.
continue;
}
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
if (read_op == write_op) {
// Read Write is the same op.
continue;
}
auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
member_->dep_vars_.emplace(dep_var);
}
}
}
}
PolishGraphToSupportDataHazards(&member_->graph_);
}
void ParallelExecutor::BCastParamsToGPUs(
......@@ -365,7 +372,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::vector<DummyVarHandle> dummy_vars;
for (auto &var_map : member_->vars_) {
for (auto &var_map : member_->graph_.vars_) {
for (auto &name_pair : var_map) {
for (auto &version_pair : name_pair.second) {
pending_vars[&version_pair.second] =
......@@ -374,13 +381,13 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
}
for (auto &var : member_->dep_vars_) {
for (auto &var : member_->graph_.dep_vars_) {
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
std::vector<OpHandleBase *> to_run;
for (auto &op : member_->ops_) {
for (auto &op : member_->graph_.ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input.
to_run.emplace_back(op.get());
} else {
......@@ -391,7 +398,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
std::unordered_map<std::string, std::vector<VarHandleBase *>> fetched_vars;
for (auto &fetch_var_name : fetch_tensors) {
for (auto &var_map : member_->vars_) {
for (auto &var_map : member_->graph_.vars_) {
auto it = var_map.find(fetch_var_name);
if (it != var_map.end()) {
fetched_vars[fetch_var_name].push_back(&it->second.rbegin()->second);
......
......@@ -52,8 +52,6 @@ class ParallelExecutor {
const std::string& loss_var_name) const;
void BuildNCCLCommunicator() const;
void PolishGraphToSupportDataHazards() const;
};
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册