From 1fba0c578a46e7d82844a86581f4b5fb3ababe59 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Mon, 23 Apr 2018 21:16:43 +0800 Subject: [PATCH] fix multi gpu dist train --- .../details/multi_devices_graph_builder.cc | 45 +++++++++++++++++-- .../details/multi_devices_graph_builder.h | 5 ++- 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 002952436e..39131492a4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -77,6 +77,33 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, } } +bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, + OpDesc *send_op) const { + if (send_op == nullptr) { + return false; + } + + auto checker = [&](const std::vector opvars, + const std::vector sendvars) -> bool { + bool is_dist_train_op = false; + for (auto &var : opvars) { + if (var.find(".block") != std::string::npos && + std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { + is_dist_train_op = true; + break; + } + } + return is_dist_train_op; + }; + + if (op.Type() == "split") { + return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); + } else if (op.Type() == "concat") { + return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); + } + return false; +} + std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); @@ -88,17 +115,28 @@ std::unique_ptr MultiDevSSAGraphBuilder::Build( std::unordered_map>>>( places_.size()); + // Find "send" op first for split is in front of send. + OpDesc *send_op = nullptr; + for (auto *op : program.Block(0).AllOps()) { + if (op->Type() == "send") { + send_op = op; + break; + } + } + bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (op->Type() == "send") { // append send op if program is distributed trainer main program. // always use the first device CreateSendOp(&result, *op); + } else if (IsDistTrainOp(*op, send_op)) { + CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { CreateScaleLossGradOp(&result); is_forwarding = false; } else { - CreateComputationalOps(&result, *op); + CreateComputationalOps(&result, *op, places_.size()); if (!is_forwarding) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. But there are no @@ -196,8 +234,9 @@ void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { } void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, - const OpDesc &op) const { - for (size_t scope_idx = 0; scope_idx < places_.size(); ++scope_idx) { + const OpDesc &op, + size_t num_places) const { + for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h index b5ba2dbd3c..42905a9a28 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.h +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h @@ -62,7 +62,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { void CreateSendOp(SSAGraph *result, const OpDesc &op) const; - void CreateComputationalOps(SSAGraph *result, const OpDesc &op) const; + bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; + + void CreateComputationalOps(SSAGraph *result, const OpDesc &op, + size_t num_places) const; void CreateScaleLossGradOp(SSAGraph *result) const; -- GitLab