From 86776a1f96db7459c381ba5faf7989f71621e2c7 Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Thu, 23 Aug 2018 13:47:16 +0800 Subject: [PATCH] Resovle multi gpu async deps (#12828) * dist transpiler add control dependency var between send and recv * fix async deps * follow comments and refine * fix deps connect for rpc ops --- .../details/multi_devices_graph_pass.cc | 26 ++++++++++++++++--- paddle/fluid/framework/ir/node.cc | 2 +- paddle/fluid/framework/ir/node.h | 2 +- paddle/fluid/pybind/const_value.cc | 5 +++- python/paddle/fluid/framework.py | 6 +++++ .../fluid/transpiler/distribute_transpiler.py | 18 +++++++++++-- 6 files changed, 50 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index c5a13e7e1..bc61b0eac 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -763,6 +763,8 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, // Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { + // FIXME(typhoonzero): Cleanup this deps for both sync mode and async mode + // put them into transpiler. int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. @@ -771,26 +773,42 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by // split_byref op - // so that we can balance the variable blocks to all the pserver - // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; for (ir::Node *n : node->inputs) { input_var_names.push_back(n->Name()); } - op_dev_id = GetAppropriateDeviceID(input_var_names); + auto send_param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE_EQ(send_param_grad.size(), 2U); + op_dev_id = GetAppropriateDeviceID({send_param_grad[1]}); + VLOG(10) << "send grad " << input_var_names[0] << " origin " + << send_param_grad[1] << " place: " << op_dev_id; for (auto &varname : input_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } + result->Get(kShardedVarDevice) + .emplace(send_param_grad[1], op_dev_id); } } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); } - op_dev_id = GetAppropriateDeviceID(output_var_names); + auto recv_param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + // FIXME(typhoonzero): assume each recv op output one param + // Use the same place as send. + if (recv_param_grad.size() == 2U) { + op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); + VLOG(10) << "recv param " << recv_param_grad[0] + << " get grad place: " << recv_param_grad[1] + << " place: " << op_dev_id; + } else { + op_dev_id = GetAppropriateDeviceID(output_var_names); + } for (auto &varname : output_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); diff --git a/paddle/fluid/framework/ir/node.cc b/paddle/fluid/framework/ir/node.cc index aca77da8d..65c45c7d2 100644 --- a/paddle/fluid/framework/ir/node.cc +++ b/paddle/fluid/framework/ir/node.cc @@ -17,7 +17,7 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -const char Node::kControlDepVarName[] = "__control_var"; +constexpr char Node::kControlDepVarName[]; } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9c0765ab8..a6667de0a 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -27,7 +27,7 @@ namespace ir { class Node { public: enum class Type { kOperation, kVariable }; - static const char kControlDepVarName[]; + static constexpr char kControlDepVarName[] = "__control_var"; explicit Node(const std::string& name, Type type) : name_(name), var_desc_(nullptr), op_desc_(nullptr), type_(type) {} diff --git a/paddle/fluid/pybind/const_value.cc b/paddle/fluid/pybind/const_value.cc index 76aa7d201..e4415ed15 100644 --- a/paddle/fluid/pybind/const_value.cc +++ b/paddle/fluid/pybind/const_value.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/const_value.h" -#include +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" namespace paddle { @@ -24,6 +25,8 @@ void BindConstValue(pybind11::module* m) { m->def("kTempVarName", [] { return framework::kTempVarName; }); m->def("kGradVarSuffix", [] { return framework::kGradVarSuffix; }); m->def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; }); + m->def("kControlDepVarName", + [] { return framework::ir::Node::kControlDepVarName; }); auto op_proto_and_checker_maker = m->def_submodule("op_proto_and_checker_maker"); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 2377ac5f9..b05fe9571 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -49,6 +49,12 @@ EMPTY_VAR_NAME = core.kEmptyVarName() TEMP_VAR_NAME = core.kTempVarName() GRAD_VAR_SUFFIX = core.kGradVarSuffix() ZERO_VAR_SUFFIX = core.kZeroVarSuffix() +CONTROL_DEP_VAR_PREFIX = core.kControlDepVarName() + + +def generate_control_dev_var_name(): + import random + return CONTROL_DEP_VAR_PREFIX + "@" + str(random.random()) def grad_var_name(var_name): diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 08fcc6971..8a083422c 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -211,8 +211,10 @@ class DistributeTranspiler(object): ps_dispatcher = self.config.split_method(self.pserver_endpoints) self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.param_name_to_grad_name = dict() + self.grad_name_to_param_name = dict() for param_var, grad_var in self.params_grads: self.param_name_to_grad_name[param_var.name] = grad_var.name + self.grad_name_to_param_name[grad_var.name] = param_var.name # step 1: split and create vars, then put splited vars in dicts for later use. self._init_splited_vars() @@ -254,8 +256,10 @@ class DistributeTranspiler(object): AssertionError("Can not insert the send op by original " "variable name :", splited_grad_varname) - dummy_output = program.global_block().create_var() + dummy_output = program.global_block().create_var( + name=framework.generate_control_dev_var_name()) grad_name_to_send_dummy_out[grad_varname] = dummy_output + program.global_block()._insert_op( index=index + 1, type="send", @@ -264,6 +268,8 @@ class DistributeTranspiler(object): attrs={ "epmap": eplist, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: + [self.grad_name_to_param_name[grad_varname], grad_varname], "sync_mode": not self.sync_mode, }) for _, var in enumerate(splited_vars): @@ -305,6 +311,10 @@ class DistributeTranspiler(object): attrs={ "epmap": eps, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [ + param_varname, + self.param_name_to_grad_name[param_varname] + ], "sync_mode": not self.sync_mode }) @@ -934,7 +944,11 @@ class DistributeTranspiler(object): attrs={ "sync_mode": True, "epmap": pserver_endpoints, - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, + OP_ROLE_VAR_ATTR_NAME: [ + self.grad_name_to_param_name[table_grad_name], + table_grad_name + ] }) break -- GitLab