From e9abc66910a9ee613c60c6ccfcba86f3eed8d429 Mon Sep 17 00:00:00 2001
From: Yancey1989 <yancey1989@gmail.com>
Date: Tue, 22 May 2018 16:48:40 +0800
Subject: [PATCH] fix pe

---
 .../details/computation_op_handle.cc          |  2 +
 .../details/multi_devices_graph_builder.cc    | 84 +++++++++++++------
 .../details/multi_devices_graph_builder.h     | 14 +++-
 paddle/fluid/operators/detail/grpc_client.cc  |  6 --
 .../fluid/transpiler/distribute_transpiler.py | 10 +++
 5 files changed, 82 insertions(+), 34 deletions(-)

diff --git a/paddle/fluid/framework/details/computation_op_handle.cc b/paddle/fluid/framework/details/computation_op_handle.cc
index df05bb0633..f6e1208a01 100644
--- a/paddle/fluid/framework/details/computation_op_handle.cc
+++ b/paddle/fluid/framework/details/computation_op_handle.cc
@@ -29,7 +29,9 @@ void ComputationOpHandle::RunImpl() {
   WaitInputVarGenerated(place_);
 
   this->RunAndRecordEvent([this] {
+    VLOG(3) << "begin run op type is " << op_->Type();
     op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
+    VLOG(3) << "end run op type is " << op_->Type();
   });
 }
 
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
index 50998fb8e0..fb5b8608b3 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc
@@ -12,7 +12,6 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
-#include <fstream>
 #include <utility>
 #include "paddle/fluid/framework/details/broadcast_op_handle.h"
 #include "paddle/fluid/framework/details/computation_op_handle.h"
@@ -79,9 +78,39 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
     CreateOpOutput(result, op_handle, each_var_name, p, place_id);
   }
 }
-bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
-                                            OpDesc *send_op) const {
-  if (send_op == nullptr) {
+
+std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
+    const ProgramDesc &program) const {
+  std::vector<std::string> send_vars;
+  for (auto *op : program.Block(0).AllOps()) {
+    if (op->Type() == "send_vars" || op->Type() == "send") {
+      auto op_vars = op->InputArgumentNames();
+      send_vars.reserve(send_vars.size() +
+                        std::distance(op_vars.begin(), op_vars.end()));
+      send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
+    }
+  }
+  return send_vars;
+}
+
+std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
+    const ProgramDesc &program) const {
+  std::vector<std::string> recv_vars;
+  for (auto *op : program.Block(0).AllOps()) {
+    if (op->Type() == "recv" || op->Type() == "send") {
+      auto op_vars = op->OutputArgumentNames();
+      recv_vars.reserve(recv_vars.size() +
+                        std::distance(op_vars.begin(), op_vars.end()));
+      recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
+    }
+  }
+  return recv_vars;
+}
+
+bool MultiDevSSAGraphBuilder::IsDistTrainOp(
+    const OpDesc &op, const std::vector<std::string> &send_vars,
+    const std::vector<std::string> &recv_vars) const {
+  if (send_vars.size() == 0 || recv_vars.size() == 0) {
     return false;
   }
 
@@ -89,21 +118,23 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
    * Check any of opvars contains `.block` and in sendvars
    */
   auto checker = [](const std::vector<std::string> &opvars,
-                    const std::vector<std::string> &sendvars) -> bool {
+                    const std::vector<std::string> &rpc_vars) -> bool {
     for (auto &var : opvars) {
       if (var.find(".block") != std::string::npos &&
-          std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) {
+          std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
         return true;
       }
     }
     return false;
   };
 
-  if (op.Type() == "split" || op.Type() == "split_byref") {
-    return checker(op.OutputArgumentNames(), send_op->InputArgumentNames());
+  if (op.Type() == "split" || op.Type() == "split_byref" ||
+      op.Type() == "split_selected_rows") {
+    return checker(op.OutputArgumentNames(), send_vars);
   } else if (op.Type() == "concat") {
-    return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
+    return checker(op.InputArgumentNames(), recv_vars);
   }
+
   return false;
 }
 
@@ -132,8 +163,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
       std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
       places_.size());
 
-  // Find "send" op first for split is in front of send.
-  OpDesc *send_op = GetSendOpDesc(program);
+  // find send/recv vars so that we can place the distributed training
+  // realted op in the place 0
+  auto send_vars = FindDistTrainSendVars(program);
+  auto recv_vars = FindDistTrainRecvVars(program);
 
   size_t cur_device_id = 0;
   std::vector<std::unordered_set<std::string>> var_name_on_devices;
@@ -147,8 +180,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
       // append rpc op if program is distributed trainer main program.
       // always use the first device
       CreateRPCOp(&result, *op);
-    } else if (IsDistTrainOp(*op, send_op)) {
-      CreateComputationalOps(&result, *op, 1);
+    } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
+      // CreateComputationalOps(&result, *op, 1);
+      CreateComputationalOp(&result, *op, 0);
     } else if (IsScaleLossOp(*op)) {
       // user can customize loss@grad if not use_default_grad_scale_
       if (strategy_.gradient_scale_ !=
@@ -213,9 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
   AddOutputToLeafOps(&result);
 
   if (VLOG_IS_ON(10)) {
-    std::string filename = "/tmp/graph";
-    std::ofstream fout(filename);
-    PrintGraphviz(*graph, fout);
+    std::ostringstream sout;
+    PrintGraphviz(*graph, sout);
+    VLOG(10) << sout.str();
   }
 
   return std::unique_ptr<SSAGraph>(graph);
@@ -274,6 +308,7 @@ OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
   }
   return nullptr;
 }
+
 void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
     SSAGraph *result, const std::string &og) const {
 #ifdef PADDLE_WITH_CUDA
@@ -396,14 +431,14 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
   return var;
 }
 
-void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result,
-                                        std::string op_name) const {
+void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
+                                        const std::string &prev_op_name) const {
   for (auto &prev_op : result->ops_) {
-    if (prev_op->Name() == op_name) {
+    if (prev_op->Name() == prev_op_name) {
       auto *dep_var = new DummyVarHandle();
       prev_op->AddOutput(dep_var);
       result->dep_vars_.emplace(dep_var);
-      result->ops_.back().get()->AddInput(dep_var);
+      op->AddInput(dep_var);
     }
   }
 }
@@ -412,14 +447,14 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
                                           const OpDesc &op) const {
   auto &p = places_[0];
   auto *s = local_scopes_[0];
-  VLOG(3) << "create rpc op: " << op.Type();
   result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
+
   if (op.Type() == "send_barrier") {
-    ConnectOp(result, "send_vars");
+    ConnectOp(result, result->ops_.back().get(), "send_vars");
   } else if (op.Type() == "recv") {
-    ConnectOp(result, "send_barrier");
+    ConnectOp(result, result->ops_.back().get(), "send_barrier");
   } else if (op.Type() == "fetch_barrier") {
-    ConnectOp(result, "recv");
+    ConnectOp(result, result->ops_.back().get(), "recv");
   } else if (op.Type() == "send" || op.Type() == "send_vars") {
     // do nothing
   } else {
@@ -429,7 +464,6 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
   }
 
   // FIXME(wuyi): send op always copy from GPU 0
-  // result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
   // Create inputs for output on original place and no ssa output
   // is created for send op.
   CreateOpHandleIOs(result, op, 0);
diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.h b/paddle/fluid/framework/details/multi_devices_graph_builder.h
index 45713b0c4f..1d0021c954 100644
--- a/paddle/fluid/framework/details/multi_devices_graph_builder.h
+++ b/paddle/fluid/framework/details/multi_devices_graph_builder.h
@@ -64,17 +64,25 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
 
   bool IsScaleLossOp(const OpDesc &op) const;
 
-  void CreateSendOp(SSAGraph *result, const OpDesc &op) const;
   void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
 
   /**
    * Is this operator as the end-point operator before/after send operator.
    */
-  bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const;
+  bool IsDistTrainOp(const OpDesc &op,
+                     const std::vector<std::string> &send_vars,
+                     const std::vector<std::string> &recv_vars) const;
+
+  std::vector<std::string> FindDistTrainSendVars(
+      const ProgramDesc &program) const;
+
+  std::vector<std::string> FindDistTrainRecvVars(
+      const ProgramDesc &program) const;
 
   bool IsRPCOp(const OpDesc &op) const;
 
-  void ConnectOp(SSAGraph *result, std::string op_name) const;
+  void ConnectOp(SSAGraph *result, OpHandleBase *op,
+                 const std::string &prev_op_name) const;
 
   void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
                               size_t num_places) const;
diff --git a/paddle/fluid/operators/detail/grpc_client.cc b/paddle/fluid/operators/detail/grpc_client.cc
index ca0518d4dc..a758205938 100644
--- a/paddle/fluid/operators/detail/grpc_client.cc
+++ b/paddle/fluid/operators/detail/grpc_client.cc
@@ -245,17 +245,11 @@ bool RPCClient::Proceed() {
 }
 std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
                                                      const std::string& key) {
-  VLOG(3) << "this addr: " << this;
   std::unique_lock<std::mutex> lock(mutex_);
   auto it = channels_.find(key);
   if (it != channels_.end()) {
-    VLOG(3) << "find ep: " << ep;
     return it->second;
   }
-  VLOG(3) << "can not find ep: " << ep;
-  for (auto it = channels_.begin(); it != channels_.end(); ++it) {
-    VLOG(3) << "ep: " << it->first;
-  }
 
   grpc::ChannelArguments args;
   args.SetCompressionAlgorithm(GRPC_COMPRESS_NONE);
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index 806cc2fcc1..cf7775e8ed 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -373,6 +373,16 @@ class DistributeTranspiler:
         for i, ep in enumerate(eplist):
             self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
             self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
+        # step4: Concat the parameters splits together after recv.
+        for varname, splited_var in param_var_mapping.iteritems():
+            if len(splited_var) <= 1:
+                continue
+            orig_param = program.global_block().vars[varname]
+            program.global_block().append_op(
+                type="concat",
+                inputs={"X": splited_var},
+                outputs={"Out": [orig_param]},
+                attrs={"axis": 0})
 
         # TODO(Yancey1989): check dist lookup table
         if self.has_distributed_lookup_table:
-- 
GitLab