// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "paddle/fluid/framework/details/all_reduce_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/data_balance_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/scope.h" namespace paddle { namespace framework { namespace details { static const char kLossVarName[] = "loss_var_name"; static const char kPlaces[] = "places"; static const char kParams[] = "params"; static const char kLocalScopes[] = "local_scopes"; static const char kStrategy[] = "strategy"; void MultiDevSSAGraphBuilder::Init() const { loss_var_name_ = Get(kLossVarName); places_ = Get>(kPlaces); local_scopes_ = Get>(kLocalScopes); strategy_ = Get(kStrategy); #ifdef PADDLE_WITH_CUDA nccl_ctxs_ = &Get("nccl_ctxs"); #endif for (auto &p : Get>(kParams)) { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); if (strategy_.enable_data_balance_ && places_.size() == 1) { LOG(WARNING) << "It is no need to enable data balance when there is only " "one place. enable_data_balance is set to False."; strategy_.enable_data_balance_ = false; } } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, ir::Node *node, size_t place_id) const { auto p = places_[place_id]; auto *op_handle = result->Get(kGraphOps).back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); for (ir::Node *input : node->inputs) { VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); op_handle->AddInput(var); } for (ir::Node *output : node->outputs) { ir::Node *new_node = nullptr; if (output->Var()) { new_node = result->CreateVarNode(output->Var()); } else { new_node = result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); } CreateOpOutput(result, op_handle, new_node, p, place_id); } } std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( const std::vector &nodes) const { std::vector send_vars; // since parameters are all in block 0, // it's enough to only scan send ops in block 0 for (auto &node : nodes) { OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find send op, // instead of the the hard code string if (op->Type() == "send") { auto op_vars = op->InputArgumentNames(); send_vars.reserve(send_vars.size() + std::distance(op_vars.begin(), op_vars.end())); send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); } } return send_vars; } std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( const std::vector &nodes) const { std::vector recv_vars; for (auto &node : nodes) { OpDesc *op = node->Op(); // TODO(Yancey1989): use a graceful method to find recv op, // instead of the hard code string if (op->Type() == "recv") { auto op_vars = op->OutputArgumentNames(); recv_vars.reserve(recv_vars.size() + std::distance(op_vars.begin(), op_vars.end())); recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); } } return recv_vars; } bool MultiDevSSAGraphBuilder::IsDistTrainOp( ir::Node *node, const std::vector &send_vars, const std::vector &recv_vars) const { if (send_vars.size() == 0 || recv_vars.size() == 0) { return false; } /** * Check any of opvars contains `.block` and in sendvars */ auto checker = [](const std::vector &opvars, const std::vector &rpc_vars) -> bool { for (auto &var : opvars) { // a variable name with the suffix `.block` means it's a splited // variable by (DistributeTranspiler) // [python/paddle/fluid/transpiler/distribute_transpiler.py] if (var.find(".block") != std::string::npos && std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { return true; } } return false; }; std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { output_var_names.push_back(output->Name()); } return checker(output_var_names, send_vars) || checker(input_var_names, recv_vars); } size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( const std::vector &var_names) const { int64_t numel_sum = 0; for (auto var_name : var_names) { if (all_vars_.find(var_name) == all_vars_.end()) continue; auto var_desc = all_vars_.at(var_name); PADDLE_ENFORCE_NOT_NULL(var_desc); auto dim = framework::make_ddim(var_desc->GetShape()); int64_t numel = framework::product(dim); PADDLE_ENFORCE_GT(numel, 0); numel_sum += numel; } auto smallest = std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); size_t dev_id = static_cast(std::distance(std::begin(balance_vars_), smallest)); balance_vars_[dev_id] += numel_sum; return dev_id; } // Topology sort the graph nodes from inputs to outputs. // Since SSAGraphBuilder depends on forward/backward nodes to assign devices // to parameter/gradients before optimizer ops, topo sort is insufficient. ( // some optimizer ops might not depend on any nodes), we manually move all // optimizer nodes after last backward nodes. // However, the assumption by SSAGraphBuilder should be relaxed in the future. std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { std::vector ret = ir::TopologySortOperations(graph); size_t last_backward = 0; for (size_t i = 0; i < ret.size(); ++i) { if (boost::get( ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kBackward)) { last_backward = i; } } std::vector optimize_ops; std::vector sorted_ret; for (size_t i = 0; i < ret.size(); ++i) { if (i < last_backward) { if (boost::get(ret[i]->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kOptimize)) { optimize_ops.push_back(ret[i]); } else { sorted_ret.push_back(ret[i]); } } else if (i == last_backward) { sorted_ret.push_back(ret[i]); // Verify that no operations before optimize ops depends on optimize ops. std::unordered_set optimize_set(optimize_ops.begin(), optimize_ops.end()); for (ir::Node *n : sorted_ret) { for (ir::Node *in : n->inputs) { for (ir::Node *pre_n : in->inputs) { PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), "optimize operations cannot be depended by forward " "or backward node %s -> %s", pre_n->Name(), n->Name()); } } } sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(), optimize_ops.end()); } else { sorted_ret.push_back(ret[i]); } } return sorted_ret; } std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { Init(); // Give the topology sort order and rebuild the graph structure. std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; for (auto &node : nodes) { if (node->NodeType() == ir::Node::Type::kVariable && node->Var()) { all_vars_.emplace(node->Name(), node->Var()); } } std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphOps, new GraphOps); result.Set(kShardedVarDevice, new ShardedVarDevice); // find send/recv vars so that we can place the distributed training // related op in the place 0 auto send_vars = FindDistTrainSendVars(sorted_ops); auto recv_vars = FindDistTrainRecvVars(sorted_ops); std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); size_t cur_device_id = 0; bool is_forwarding = true; for (ir::Node *node : sorted_ops) { if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { CreateRPCOp(&result, node); } else if (IsDistTrainOp(node, send_vars, recv_vars)) { CreateDistTrainOp(&result, node); } else if (IsScaleLossOp(node)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { // TODO(paddle-dev): Why is there no input for this op_handle? auto loss_grad_name = node->Op()->OutputArgumentNames()[0]; CreateScaleLossGradOp(&result, loss_grad_name); } // This assumes the backward generating code will ensure IsScaleLossOp // is true only for the op that scale the final scalar loss. // It also assumes backward op will always follow the forward op in // the block. is_forwarding = false; } else { int op_dev_id = GetOpDeviceID(result, node); if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { graph->Get(kShardedVarDevice) .emplace(n->Name(), op_dev_id); } } else { // This op runs on all devices, and its output may have parameter's // gradients. // TODO(paddle-dev): Why is so special about "read" op? if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { node->Op()->SetAttr("throw_eof_exp", false); CreateComputationalOps(&result, node, places_.size()); const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); } else { CreateComputationalOps(&result, node, places_.size()); } if (!is_forwarding && places_.size() > 1) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. if (static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward))) { try { auto backward_vars = boost::get>( node->Op()->GetNullableAttr( OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); for (size_t i = 0; i < backward_vars.size(); i += 2) { auto &p_name = backward_vars[i]; auto &g_name = backward_vars[i + 1]; VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: cur_device_id = GetAppropriateDeviceID({g_name}); CreateReduceOp(&result, g_name, cur_device_id); graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name); break; case BuildStrategy::ReduceStrategy::kAllReduce: if (IsSparseGradient(g_name)) { CreateReduceOp(&result, g_name, 0); CreateBroadcastOp(&result, g_name, 0); } else { InsertAllReduceOp(&result, g_name); } break; default: LOG(FATAL) << "Unknown reduce strategy "; break; } } } catch (boost::bad_get e) { } } } } } } bool use_gpu = false; #ifdef PADDLE_WITH_CUDA use_gpu = nccl_ctxs_ != nullptr; #endif if (use_gpu || strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { // Insert BCast Ops for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id]; for (auto &bcast_name : to_bcast_set) { CreateBroadcastOp(&result, bcast_name, dev_id); } } } /* Dependency graph has been constructed. However, there are still data hazards need to be handled. */ PolishGraphToSupportDataHazards(&result); /* * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); PADDLE_ENFORCE(!ir::HasCircle(result)); return graph; } bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { PADDLE_ENFORCE(all_vars_.count(og) != 0); if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { return true; } return false; } void MultiDevSSAGraphBuilder::SetCommunicationContext( OpHandleBase *op_handle, const platform::Place &p) const { #ifdef PADDLE_WITH_CUDA if (nccl_ctxs_ == nullptr) { op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); } #else op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); #endif } void MultiDevSSAGraphBuilder::CreateBroadcastOp(ir::Graph *result, const std::string &p_name, size_t src_dev_id) const { #ifdef PADDLE_WITH_CUDA auto *op_handle = new BroadcastOpHandle( result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_); #else auto *op_handle = new BroadcastOpHandle( result->CreateEmptyNode("broadcast", ir::Node::Type::kOperation), local_scopes_, places_); #endif result->Get(kGraphOps).emplace_back(op_handle); auto *in = result->Get(kGraphVars).at(src_dev_id).at(p_name).back().get(); op_handle->AddInput(in); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get(kGraphVars).at(i).at(p_name); auto *out_var = new VarHandle( result->CreateEmptyNode(p_name, ir::Node::Type::kVariable), vars.size(), i, p_name, p); vars.emplace_back(out_var); op_handle->AddOutput(out_var); } } void MultiDevSSAGraphBuilder::CreateComputationalOp(ir::Graph *result, ir::Node *node, int dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id])); CreateOpHandleIOs(result, node, dev_id); } void MultiDevSSAGraphBuilder::InsertAllReduceOp(ir::Graph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else result->Get(kGraphOps).emplace_back(new AllReduceOpHandle( result->CreateEmptyNode("allreduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } } void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ir::Graph *result, const std::vector &datas) const { #ifdef PADDLE_WITH_CUDA result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else result->Get(kGraphOps).emplace_back(new DataBalanceOpHandle( result->CreateEmptyNode("data_balance", ir::Node::Type::kOperation), local_scopes_, places_)); #endif auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); for (const std::string &d_name : datas) { auto &vars = result->Get(kGraphVars)[i][d_name]; PADDLE_ENFORCE(!vars.empty()); op_handle->AddInput(vars.back().get()); auto var = new VarHandle( result->CreateEmptyNode(d_name, ir::Node::Type::kVariable), vars.size(), i, d_name, p); vars.emplace_back(var); op_handle->AddOutput(var); } } } bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const { bool is_pg_once = grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0; if (is_pg_once) { // Insert NCCL AllReduce Op og_has_been_broadcast->insert(og); } return is_pg_once; } int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } int op_role = boost::get( node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if (op_role != static_cast(framework::OpRole::kOptimize)) { return -1; } auto param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(graph, param_grad[1]); PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const { auto &sharded_var_device = graph.Get(kShardedVarDevice); auto got = sharded_var_device.find(varname); return got == sharded_var_device.end() ? -1 : got->second; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp( ir::Graph *result, const std::string &loss_grad_name) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA auto *communication_dev_ctx = nccl_ctxs_ ? nccl_ctxs_->DevCtx(places_[i]) : platform::DeviceContextPool::Instance().Get(places_[i]); #else auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle( result->CreateEmptyNode("scale_loss_grad", ir::Node::Type::kOperation), local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); result->Get(kGraphOps).emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. // VarHandle *loss = GetVarHandle(loss_var_name, place); // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); CreateOpOutput( result, op_handle, result->CreateEmptyNode(loss_grad_name, ir::Node::Type::kVariable), places_[i], i); } } void MultiDevSSAGraphBuilder::CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), s, p)); CreateOpHandleIOs(result, node, scope_idx); } } VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, const std::string &og, int dst_dev_id) const { #ifdef PADDLE_WITH_CUDA result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_, nccl_ctxs_)); #else result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), local_scopes_, places_)); #endif auto *op_handle = result->Get(kGraphOps).back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; SetCommunicationContext(op_handle, p); auto &vars = result->Get(kGraphVars)[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); } auto &vars = result->Get(kGraphVars)[dst_dev_id][og]; auto var = new VarHandle(result->CreateEmptyNode(og, ir::Node::Type::kVariable), vars.size(), dst_dev_id, og, places_[dst_dev_id]); vars.emplace_back(var); op_handle->AddOutput(var); return var; } // Find the first occurence of `prev_op_name` and make current `op` depend // on it. void MultiDevSSAGraphBuilder::ConnectOp(ir::Graph *result, OpHandleBase *op, const std::string &prev_op_name) const { for (auto &prev_op : result->Get(kGraphOps)) { if (prev_op->Name() == prev_op_name) { auto *dep_var = new DummyVarHandle(result->CreateControlDepVar()); prev_op->AddOutput(dep_var); result->Get(kGraphDepVars).emplace(dep_var); op->AddInput(dep_var); } } } void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ir::Node *node) const { int op_dev_id = -1; std::vector input_var_names; std::vector output_var_names; for (ir::Node *input : node->inputs) { input_var_names.push_back(input->Name()); } for (ir::Node *output : node->outputs) { output_var_names.push_back(output->Name()); } if (node->Op()->Type() == "split_byref" || node->Op()->Type() == "split_selected_rows") { // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(*result, input_var_names[0]); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } for (auto &varname : output_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { op_dev_id = GetVarDeviceID(*result, input_var_names[0]); for (auto &varname : output_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else { PADDLE_ENFORCE( "the distribute training related op should be in [split_byref, " "concat]."); } PADDLE_ENFORCE(op_dev_id != -1, "can not find right place for distributed op: %s", node->Op()->Type()); CreateComputationalOp(result, node, op_dev_id); if (node->Op()->Type() == "concat") { ConnectOp(result, result->Get(kGraphOps).back().get(), "fetch_barrier"); } } // Create RPC related op handles that connects its in ops and out ops. void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by // split_byref op // so that we can balance the variable blocks to all the pserver // instances. if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce && node->inputs[0]->Name().find(".block") == std::string::npos) { std::vector input_var_names; for (ir::Node *n : node->inputs) { input_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } } else if (node->Op()->Type() == "recv") { std::vector output_var_names; for (ir::Node *n : node->outputs) { output_var_names.push_back(n->Name()); } op_dev_id = GetAppropriateDeviceID(output_var_names); for (auto &varname : output_var_names) { result->Get(kShardedVarDevice) .emplace(varname, op_dev_id); } } else { // send_barrier and fetch_barrier op can be scheduled on device 0 op_dev_id = 0; } PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s", node->Op()->Type()); result->Get(kGraphOps).emplace_back(new RPCOpHandle( result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id], node->Op()->Type(), places_[op_dev_id])); // TODO(panyx0718): This might not be needed anymore. if (node->Op()->Type() == "send_barrier") { ConnectOp(result, result->Get(kGraphOps).back().get(), "send"); } else if (node->Op()->Type() == "recv") { ConnectOp(result, result->Get(kGraphOps).back().get(), "send_barrier"); } else if (node->Op()->Type() == "fetch_barrier") { ConnectOp(result, result->Get(kGraphOps).back().get(), "recv"); } else if (node->Op()->Type() == "send") { // do nothing } else { PADDLE_THROW( "rpc op should be in [" "send, send_barrier. recv, fetch_barrier]"); } CreateOpHandleIOs(result, node, op_dev_id); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { return boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == (static_cast(OpRole::kBackward) | static_cast(OpRole::kLoss)) && !loss_var_name_.empty(); // If loss_var is empty. This is test mode } } // namespace details } // namespace framework } // namespace paddle REGISTER_PASS(multi_device_pass, paddle::framework::details::MultiDevSSAGraphBuilder) .RequirePassAttr(paddle::framework::details::kLossVarName) .RequirePassAttr(paddle::framework::details::kPlaces) .RequirePassAttr(paddle::framework::details::kParams) .RequirePassAttr(paddle::framework::details::kLocalScopes) .RequirePassAttr(paddle::framework::details::kStrategy);