diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 10d39e779336e2001d66a55ac6d01ee768ddd4ff..3413467b149539bcff42d78a9a6fe315d6558bb4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -78,6 +78,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, } } +bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, + OpDesc *send_op) const { + if (send_op == nullptr) { + return false; + } + + auto checker = [&](const std::vector opvars, + const std::vector sendvars) -> bool { + bool is_dist_train_op = false; + for (auto &var : opvars) { + if (var.find(".block") != std::string::npos && + std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + is_dist_train_op = true; + break; + } + } + return is_dist_train_op; + }; + + if (op.Type() == "split") { + return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); + } else if (op.Type() == "concat") { + return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); @@ -89,19 +116,30 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); + // Find "send" op first for split is in front of send. + OpDesc *send_op = nullptr; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send") { + send_op = op; + break; + } + } + bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (op->Type() == "send") { // append send op if program is distributed trainer main program. // always use the first device CreateSendOp(&result, *op); + } else if (IsDistTrainOp(*op, send_op)) { + CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { if (!skip_scale_loss_) { CreateScaleLossGradOp(&result); } is_forwarding = false; } else { - CreateComputationalOps(&result, *op); + CreateComputationalOps(&result, *op, places_.size()); if (!is_forwarding) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. But there are no @@ -199,8 +237,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { } void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, - const OpDesc &op) const { - for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) { + const OpDesc &op, + size_t num_places) const { + for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index 009c31b40c279ae3e924d2d7c67933e7444ed85c..dc3da70eda2abaa1a312c25aedf94fa7e427c78a 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -65,7 +65,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateSendOp(SSAGraph *result, const OpDesc &op) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const; + bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + size_t num_places) const; void CreateScaleLossGradOp(SSAGraph *result) const;