You need to sign in or sign up before continuing.
提交 193c0a7e 编写于 作者: Y Yu Yang

Handle var hazard

上级 35744e7b
......@@ -28,42 +28,79 @@ namespace framework {
struct OpHandle;
struct VarHandle {
struct VarHandleBase {
virtual ~VarHandleBase() {}
virtual std::string DebugString() const = 0;
OpHandle *generated_op_;
std::vector<OpHandle *> pending_ops_;
};
struct VarHandle : public VarHandleBase {
std::string DebugString() const override {
std::stringstream ss;
ss << name_ << ":" << place_;
return ss.str();
}
size_t version_;
std::string name_;
platform::Place place_;
};
OpHandle *generated_op_;
std::vector<OpHandle *> pending_ops_;
struct DependencyVarHandle : public VarHandleBase {
std::string DebugString() const override { return "Deps var"; }
};
struct OpHandle {
std::vector<VarHandle *> inputs_;
std::vector<VarHandle *> outputs_;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctx_;
std::string DebugString() {
std::stringstream ss;
ss << "(";
for (auto *var : inputs_) {
ss << var->name_ << ":" << var->place_ << ", ";
ss << var->DebugString() << ", ";
}
ss << ") --> (";
for (auto *var : outputs_) {
ss << var->name_ << ":" << var->place_ << ", ";
ss << var->DebugString() << ", ";
}
ss << ")\n";
return ss.str();
}
virtual ~OpHandle() {}
virtual void Run() {}
virtual void Wait() {}
};
struct ComputationOpHandle : public OpHandle {
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
explicit ComputationOpHandle(const OpDesc &op_desc)
: op_(framework::OpRegistry::CreateOp(op_desc)) {}
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
scope_(scope),
place_(place) {}
void Run() override {
// Wait other op if necessary
auto *cur_ctx = dev_ctx_[place_];
for (auto *in : inputs_) {
if (in->generated_op_ && in->generated_op_->dev_ctx_[place_] != cur_ctx) {
in->generated_op_->Wait();
}
}
op_->Run(*scope_, place_);
}
};
struct ScaleLossGradOpHandle : public OpHandle {};
......@@ -122,12 +159,27 @@ class ParallelExecutorPrivate {
#endif
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
return const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
} else {
#ifdef PADDLE_WITH_CUDA
return GetNCCLCtx(place).ctx_.get();
#else
PADDLE_THROW("Not compiled with CUDA")
#endif
}
}
platform::Place main_place_;
std::unordered_map<platform::Place,
std::unordered_map<std::string, std::map<int, VarHandle>>,
platform::PlaceHash>
vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandle>> ops_;
ThreadPool pool_;
......@@ -170,7 +222,7 @@ ParallelExecutor::ParallelExecutor(
void ParallelExecutor::ConstructDependencyGraph(
const std::unordered_set<std::string> &params,
const ProgramDesc &main_program, const std::string &loss_var_name) const {
std::unordered_set<std::__cxx11::string> grads;
std::unordered_set<std::string> grads;
for (auto &each_param : params) {
grads.insert(each_param + "@GRAD");
}
......@@ -188,8 +240,11 @@ void ParallelExecutor::ConstructDependencyGraph(
}
for (auto &pair : member_->local_scopes_) {
member_->ops_.emplace_back(new ComputationOpHandle(*op));
member_->ops_.emplace_back(
new ComputationOpHandle(*op, pair.second, pair.first));
auto *op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(pair.first));
auto var_names = op->InputArgumentNames();
......@@ -210,8 +265,11 @@ void ParallelExecutor::ConstructDependencyGraph(
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle());
op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
auto &place = pair.first;
VarHandle *loss = GetVarHandle(loss_var_name, place);
loss->pending_ops_.emplace_back(op_handle);
......@@ -251,11 +309,54 @@ void ParallelExecutor::ConstructDependencyGraph(
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
for (auto &pair : member_->local_scopes_) {
op_handle->dev_ctx_[pair.first] =
member_->CommunicationDevCtx(pair.first);
}
}
}
}
}
}
/**
* Dependency graph has been constructed. However, there are still data
* harzaeds need to be handled.
*
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
* prune them.
*
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
*/
for (auto &place_pair : member_->vars_) {
for (auto &name_pair : place_pair.second) {
if (name_pair.second.size() <= 1) {
return;
}
auto it_new = name_pair.second.rbegin();
auto it_old = name_pair.second.rbegin();
++it_old;
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
auto *write_op = it_new->second.generated_op_;
auto &read_ops = it_old->second.pending_ops_;
for (auto *read_op : read_ops) {
// Manually add a dependency var from read_op to write_op;
auto *dep_var = new DependencyVarHandle();
dep_var->generated_op_ = read_op;
read_op->outputs_.emplace_back(dep_var);
dep_var->pending_ops_.emplace_back(write_op);
write_op->inputs_.emplace_back(dep_var);
member_->dep_vars_.emplace(dep_var);
}
}
}
}
}
void ParallelExecutor::GenerateVar(OpHandle *op_handle,
......@@ -349,7 +450,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
// Version --> VarHandle
std::unordered_map<VarHandle *, bool> pending_vars;
std::unordered_map<VarHandleBase *, bool> pending_vars;
std::unordered_map<OpHandle *, size_t> pending_ops;
for (auto &place_pair : member_->vars_) {
......@@ -361,12 +462,16 @@ std::vector<LoDTensor> ParallelExecutor::Run(
}
}
for (auto &var : member_->dep_vars_) {
pending_vars[var.get()] = var->generated_op_ == nullptr;
}
for (auto &op : member_->ops_) {
pending_ops.insert({op.get(), op->inputs_.size()});
}
while (!pending_ops.empty()) {
VarHandle *ready_var = nullptr;
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second) {
ready_var = pair.first;
......@@ -400,7 +505,7 @@ std::vector<LoDTensor> ParallelExecutor::Run(
auto op_run = [ready_buffer, op] {
// TODO(yy) Check Previous Op has same dev ctx.
LOG(INFO) << "Run " << op->DebugString();
op->Run();
for (auto *ready : ready_buffer) {
*ready = true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册