提交 9612c7e5 编写于 作者: Y Yu Yang

Add comments and polish code

上级 76c4ae85
...@@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -58,23 +58,20 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const OpDesc &op, const OpDesc &op,
const platform::Place &p, size_t place_id) const {
const size_t &i) const { auto p = places_[place_id];
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
op_handle->SetDeviceContext(p, op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p)); platform::DeviceContextPool::Instance().Get(p));
auto var_names = op.InputArgumentNames(); for (auto &each_var_name : op.InputArgumentNames()) {
VarHandle *var =
for (auto &each_var_name : var_names) { CreateOrGetLatestVarHandle(result, each_var_name, p, place_id);
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
op_handle->AddInput(var); op_handle->AddInput(var);
} }
var_names = op.OutputArgumentNames(); for (auto &each_var_name : op.OutputArgumentNames()) {
CreateOpOutput(result, op_handle, each_var_name, p, place_id);
for (auto &each_var_name : var_names) {
CreateOpOutput(result, op_handle, each_var_name, p, i);
} }
} }
...@@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -84,17 +81,18 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
return false; return false;
} }
auto checker = [&](const std::vector<std::string> opvars, /**
const std::vector<std::string> sendvars) -> bool { * Check any of opvars contains `.block` and in sendvars
bool is_dist_train_op = false; */
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool {
for (auto &var : opvars) { for (auto &var : opvars) {
if (var.find(".block") != std::string::npos && if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
is_dist_train_op = true; return true;
break;
} }
} }
return is_dist_train_op; return false;
}; };
if (op.Type() == "split") { if (op.Type() == "split") {
...@@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -117,13 +115,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
places_.size()); places_.size());
// Find "send" op first for split is in front of send. // Find "send" op first for split is in front of send.
OpDesc *send_op = nullptr; OpDesc *send_op = GetSendOpDesc(program);
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
send_op = op;
break;
}
}
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
...@@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -134,6 +126,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_op)) {
CreateComputationalOps(&result, *op, 1); CreateComputationalOps(&result, *op, 1);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if skip_scale_loss_
if (!skip_scale_loss_) { if (!skip_scale_loss_) {
CreateScaleLossGradOp(&result); CreateScaleLossGradOp(&result);
} }
...@@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -142,10 +135,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
CreateComputationalOps(&result, *op, places_.size()); CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding) { if (!is_forwarding) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. But there are no // broadcast, and each gradient is only broadcast once.
// other cases, for example, we need to adjust the gradient according to
// the input when we get the gradient, which is not considered at
// present.
for (auto &og : op->OutputArgumentNames()) { for (auto &og : op->OutputArgumentNames()) {
if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) {
InsertNCCLAllReduceOp(&result, og); InsertNCCLAllReduceOp(&result, og);
...@@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -175,6 +165,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
} }
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, ...@@ -243,7 +243,7 @@ void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result,
auto p = places_[scope_idx]; auto p = places_[scope_idx];
auto s = local_scopes_[scope_idx]; auto s = local_scopes_[scope_idx];
result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); result->ops_.emplace_back(new ComputationOpHandle(op, s, p));
CreateOpHandleIOs(result, op, p, scope_idx); CreateOpHandleIOs(result, op, scope_idx);
} }
} }
...@@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, ...@@ -255,7 +255,7 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
result->ops_.emplace_back(new SendOpHandle(op, s, p)); result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output // Create inputs for output on original place and no ssa output
// is created for send op. // is created for send op.
CreateOpHandleIOs(result, op, p, 0); CreateOpHandleIOs(result, op, 0);
} }
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
......
...@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -48,7 +48,7 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
private: private:
void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, void CreateOpHandleIOs(SSAGraph *result, const OpDesc &op,
const platform::Place &p, const size_t &i) const; size_t place_id) const;
private: private:
std::string loss_var_name_; std::string loss_var_name_;
...@@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -65,6 +65,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
/**
* Is this operator as the end-point operator before/after send operator.
*/
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
...@@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -77,6 +80,12 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
std::unordered_set<std::string> *og_has_been_broadcast) const; std::unordered_set<std::string> *og_has_been_broadcast) const;
void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const; void InsertNCCLAllReduceOp(SSAGraph *result, const std::string &og) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -25,12 +25,22 @@ namespace paddle { ...@@ -25,12 +25,22 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
// A SSA graph used by parallel executor.
struct SSAGraph { struct SSAGraph {
// all variable in each devices.
// The outside vector is the device vector. Each element of this vector is a
// map from variable name to variables. The variables, who have the same name,
// will have a different version. The offset in the
// `std::vector<std::unique_ptr<VarHandle>>` is the version of varaibles.
std::vector< std::vector<
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>> std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>
vars_; vars_;
// aux variables to represent dependency. Useful to resolve data hazard. // aux variables to represent dependency. Useful to resolve data hazard.
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_; std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
// all operators. NOTE that even we use a vector here, the operators is
// unordered.
std::vector<std::unique_ptr<OpHandleBase>> ops_; std::vector<std::unique_ptr<OpHandleBase>> ops_;
}; };
......
...@@ -48,6 +48,8 @@ class SSAGraphBuilder { ...@@ -48,6 +48,8 @@ class SSAGraphBuilder {
const platform::Place &place, const platform::Place &place,
size_t place_offset); size_t place_offset);
// Add an output variable (each_var_name, place, place_offset) to op_handle,
// which belongs to graph
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
const std::string &each_var_name, const std::string &each_var_name,
const platform::Place &place, size_t place_offset); const platform::Place &place, size_t place_offset);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册