diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 03f5f2e73a8d3e1cd1816f47a92ccfdb9bba4850..4b2fc826f7412cf1f1d844cabcaceff33e305f72 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -42,6 +42,12 @@ namespace { typedef std::vector GraphOps; const char kGraphOps[] = "ops"; +bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) { + return boost::get( + node.Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(role); +} + void PolishGraphToSupportDataHazards(ir::Graph *graph) { for (auto &var_map : graph->Get(kGraphVars)) { for (auto &name_pair : var_map) { @@ -150,6 +156,7 @@ void MultiDevSSAGraphBuilder::Init() const { grad_names_.insert(GradVarName(p)); } balance_vars_.resize(places_.size(), 0); + if (strategy_.enable_data_balance_ && places_.size() == 1) { LOG(WARNING) << "It is no need to enable data balance when there is only " "one place. enable_data_balance is set to False."; @@ -157,145 +164,16 @@ void MultiDevSSAGraphBuilder::Init() const { } } -void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, - ir::Node *node, - size_t place_id) const { - auto p = places_[place_id]; - auto *op_handle = result->Get(kGraphOps).back(); - op_handle->SetDeviceContext(p, - platform::DeviceContextPool::Instance().Get(p)); - - for (ir::Node *input : node->inputs) { - VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); - op_handle->AddInput(var); - } - - for (ir::Node *output : node->outputs) { - ir::Node *new_node = nullptr; - if (output->Var()) { - new_node = result->CreateVarNode(output->Var()); - } else { - new_node = - result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); - } - CreateOpOutput(result, op_handle, new_node, p, place_id); - } -} - -std::vector MultiDevSSAGraphBuilder::FindDistTrainSendVars( - const std::vector &nodes) const { - std::vector send_vars; - // since parameters are all in block 0, - // it's enough to only scan send ops in block 0 - for (auto &node : nodes) { - OpDesc *op = node->Op(); - // TODO(Yancey1989): use a graceful method to find send op, - // instead of the the hard code string - if (op->Type() == "send") { - auto op_vars = op->InputArgumentNames(); - send_vars.reserve(send_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end()); - } - } - return send_vars; -} - -std::vector MultiDevSSAGraphBuilder::FindDistTrainRecvVars( - const std::vector &nodes) const { - std::vector recv_vars; - for (auto &node : nodes) { - OpDesc *op = node->Op(); - // TODO(Yancey1989): use a graceful method to find recv op, - // instead of the hard code string - if (op->Type() == "recv") { - auto op_vars = op->OutputArgumentNames(); - recv_vars.reserve(recv_vars.size() + - std::distance(op_vars.begin(), op_vars.end())); - recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end()); - } - } - return recv_vars; -} - -size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( - const std::vector &var_names) const { - int64_t numel_sum = 0; - for (auto var_name : var_names) { - if (all_vars_.find(var_name) == all_vars_.end()) continue; - auto var_desc = all_vars_.at(var_name); - PADDLE_ENFORCE_NOT_NULL(var_desc); - auto dim = framework::make_ddim(var_desc->GetShape()); - int64_t numel = framework::product(dim); - PADDLE_ENFORCE_GT(numel, 0); - numel_sum += numel; - } - - auto smallest = - std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); - size_t dev_id = - static_cast(std::distance(std::begin(balance_vars_), smallest)); - balance_vars_[dev_id] += numel_sum; - return dev_id; -} - -// Topology sort the graph nodes from inputs to outputs. -// Since SSAGraphBuilder depends on forward/backward nodes to assign devices -// to parameter/gradients before optimizer ops, topo sort is insufficient. ( -// some optimizer ops might not depend on any nodes), we manually move all -// optimizer nodes after last backward nodes. -// However, the assumption by SSAGraphBuilder should be relaxed in the future. -std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { - std::vector ret = ir::TopologySortOperations(graph); - size_t last_backward = 0; - for (size_t i = 0; i < ret.size(); ++i) { - if (boost::get( - ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kBackward)) { - last_backward = i; - } - } - - std::vector optimize_ops; - std::vector sorted_ret; - for (size_t i = 0; i < ret.size(); ++i) { - if (i < last_backward) { - if (static_cast(boost::get(ret[i]->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) & - static_cast(OpRole::kOptimize))) { - optimize_ops.push_back(ret[i]); - } else { - sorted_ret.push_back(ret[i]); - } - } else if (i == last_backward) { - sorted_ret.push_back(ret[i]); - // Verify that no operations before optimize ops depends on optimize ops. - std::unordered_set optimize_set(optimize_ops.begin(), - optimize_ops.end()); - for (ir::Node *n : sorted_ret) { - for (ir::Node *in : n->inputs) { - for (ir::Node *pre_n : in->inputs) { - PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(), - "optimize operations cannot be depended by forward " - "or backward node %s -> %s", - pre_n->Name(), n->Name()); - } - } - } - sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(), - optimize_ops.end()); - } else { - sorted_ret.push_back(ret[i]); - } - } - return sorted_ret; -} - std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr graph) const { Init(); // Give the topology sort order and rebuild the graph structure. - std::vector sorted_ops = SortOpsAndDelayOptimizeOp(*graph); + std::vector sorted_ops = ir::TopologySortOperations(*graph); + + if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) { + sorted_ops = SortForReduceMode(sorted_ops); + } + auto nodes = graph->ReleaseNodes(); ir::Graph &result = *graph; @@ -304,31 +182,22 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( all_vars_.emplace(node->Name(), node->Var()); } } - std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphOps, new GraphOps); - // find send/recv vars so that we can place the distributed training - // related op in the place 0 - auto send_vars = FindDistTrainSendVars(sorted_ops); - auto recv_vars = FindDistTrainRecvVars(sorted_ops); - std::vector> bcast_var_name_set; bcast_var_name_set.resize(places_.size()); - size_t cur_device_id = 0; bool is_forwarding = true; bool is_dist_train = false; std::unordered_map sharded_var_device; for (ir::Node *node : sorted_ops) { - if (boost::get( - node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kRPC)) { + if (OpHaveRole(*node, OpRole::kRPC)) { int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device); PADDLE_ENFORCE(op_dev_id != -1, "Can not schedule the RPC operator to the right place."); @@ -342,9 +211,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } } is_dist_train = true; - } else if (boost::get(node->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kDist)) { + } else if (OpHaveRole(*node, OpRole::kDist)) { int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device); if (node->Op()->Type() == "concat") { auto origin_param_name = node->Op()->OutputArgumentNames()[0]; @@ -364,7 +231,7 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( // the block. is_forwarding = false; } else { - int op_dev_id = GetOpDeviceID(result, node, sharded_var_device); + int op_dev_id = GetOpDeviceID(node, sharded_var_device); if (op_dev_id != -1) { // This op only runs on one specific device. CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { @@ -384,47 +251,48 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } if (!is_forwarding && places_.size() > 1) { + bool is_bk_op = + static_cast(boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward)); + if (!is_bk_op) continue; // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. - if (static_cast(boost::get(node->Op()->GetAttr( - OpProtoAndCheckerMaker::OpRoleAttrName())) & - static_cast(OpRole::kBackward))) { - try { - auto backward_vars = boost::get>( - node->Op()->GetNullableAttr( - OpProtoAndCheckerMaker::OpRoleVarAttrName())); - - PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); - - for (size_t i = 0; i < backward_vars.size(); i += 2) { - auto &p_name = backward_vars[i]; - auto &g_name = backward_vars[i + 1]; - VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; - - switch (strategy_.reduce_) { - case BuildStrategy::ReduceStrategy::kReduce: - cur_device_id = GetAppropriateDeviceID({g_name}); - CreateReduceOp(&result, g_name, cur_device_id); - sharded_var_device.emplace(g_name, cur_device_id); - if (!is_dist_train) { - bcast_var_name_set[cur_device_id].emplace(p_name); - } - break; - case BuildStrategy::ReduceStrategy::kAllReduce: - if (IsSparseGradient(g_name)) { - CreateReduceOp(&result, g_name, 0); - CreateBroadcastOp(&result, g_name, 0); - } else { - InsertAllReduceOp(&result, g_name); - } - break; - default: - LOG(FATAL) << "Unknown reduce strategy "; - break; - } + try { + auto backward_vars = boost::get>( + node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + for (size_t i = 0; i < backward_vars.size(); i += 2) { + auto &p_name = backward_vars[i]; + auto &g_name = backward_vars[i + 1]; + VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; + size_t cur_device_id = -1; + switch (strategy_.reduce_) { + case BuildStrategy::ReduceStrategy::kReduce: + cur_device_id = GetAppropriateDeviceID({g_name}); + CreateReduceOp(&result, g_name, cur_device_id); + sharded_var_device.emplace(g_name, cur_device_id); + if (!is_dist_train) { + bcast_var_name_set[cur_device_id].emplace(p_name); + } + break; + case BuildStrategy::ReduceStrategy::kAllReduce: + if (IsSparseGradient(g_name)) { + CreateReduceOp(&result, g_name, 0); + CreateBroadcastOp(&result, g_name, 0); + } else { + InsertAllReduceOp(&result, g_name); + } + break; + default: + LOG(FATAL) << "Unknown reduce strategy "; + break; } - } catch (boost::bad_get e) { } + } catch (boost::bad_get e) { } } } @@ -468,12 +336,108 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( return graph; } -bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { - PADDLE_ENFORCE(all_vars_.count(og) != 0); - if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { - return true; +std::vector MultiDevSSAGraphBuilder::SortForReduceMode( + const std::vector &topo_ops) const { + std::unordered_map sharded_var_device; + std::vector sorted_ops; + std::unordered_map> delayed_op; + sorted_ops.reserve(topo_ops.size()); + + auto insert_delayed_op = [&](const std::string &var_name, int dev_id) { + sharded_var_device.emplace(var_name, dev_id); + if (delayed_op.count(var_name)) { + auto &ops = delayed_op.at(var_name); + sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end()); + delayed_op.at(var_name).clear(); + } + }; + + for (ir::Node *node : topo_ops) { + int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op); + if (op_dev_id > -1) { + // This op only runs on one specific device. + sorted_ops.emplace_back(node); + for (ir::Node *n : node->outputs) { + insert_delayed_op(n->Name(), op_dev_id); + } + } else if (op_dev_id == -1) { + // This op runs on all devices, and its output may have parameter's + // gradients. + sorted_ops.emplace_back(node); + bool is_bk_op = + static_cast(boost::get(node->Op()->GetAttr( + OpProtoAndCheckerMaker::OpRoleAttrName())) & + static_cast(OpRole::kBackward)); + if (!is_bk_op) continue; + // Currently, we assume that once gradient is generated, it can be + // broadcast, and each gradient is only broadcast once. + std::vector backward_vars; + try { + backward_vars = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + } catch (boost::bad_get e) { + } + PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); + + for (size_t i = 0; i < backward_vars.size(); i += 2) { + auto &g_name = backward_vars[i + 1]; + size_t cur_device_id = GetAppropriateDeviceID({g_name}); + insert_delayed_op(g_name, static_cast(cur_device_id)); + } + } else if (op_dev_id == -2) { + // The Op on which the Op depends has not yet been generated. + } } - return false; + + PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size()); + return sorted_ops; +} + +void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result, + ir::Node *node, + size_t place_id) const { + auto p = places_[place_id]; + auto *op_handle = result->Get(kGraphOps).back(); + op_handle->SetDeviceContext(p, + platform::DeviceContextPool::Instance().Get(p)); + + for (ir::Node *input : node->inputs) { + VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id); + op_handle->AddInput(var); + } + + for (ir::Node *output : node->outputs) { + ir::Node *new_node = nullptr; + if (output->Var()) { + new_node = result->CreateVarNode(output->Var()); + } else { + new_node = + result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable); + } + CreateOpOutput(result, op_handle, new_node, p, place_id); + } +} + +size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID( + const std::vector &var_names) const { + int64_t numel_sum = 0; + for (auto var_name : var_names) { + if (all_vars_.find(var_name) == all_vars_.end()) continue; + auto var_desc = all_vars_.at(var_name); + PADDLE_ENFORCE_NOT_NULL(var_desc); + auto dim = framework::make_ddim(var_desc->GetShape()); + int64_t numel = framework::product(dim); + PADDLE_ENFORCE_GT(numel, 0); + numel_sum += numel; + } + + auto smallest = + std::min_element(std::begin(balance_vars_), std::end(balance_vars_)); + size_t dev_id = + static_cast(std::distance(std::begin(balance_vars_), smallest)); + balance_vars_[dev_id] += numel_sum; + return dev_id; } void MultiDevSSAGraphBuilder::SetCommunicationContext( @@ -624,28 +588,52 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( } int MultiDevSSAGraphBuilder::GetOpDeviceID( - const ir::Graph &graph, ir::Node *node, + ir::Node *node, + const std::unordered_map &sharded_var_device, + std::unordered_map> *delay_ops) const { + if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { + return -1; + } + + if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { + return -1; + } + + auto param_grad = boost::get>( + node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); + + PADDLE_ENFORCE_EQ(param_grad.size(), 2U); + int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device); + + if (dev_id == -1) { + (*delay_ops)[param_grad[1]].push_back(node); + return -2; + } + return dev_id; +} + +int MultiDevSSAGraphBuilder::GetOpDeviceID( + ir::Node *node, const std::unordered_map &sharded_var_device) const { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { return -1; } - int op_role = boost::get( - node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); - if (op_role != static_cast(framework::OpRole::kOptimize)) { + + if (!OpHaveRole(*node, framework::OpRole::kOptimize)) { return -1; } auto param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); PADDLE_ENFORCE_EQ(param_grad.size(), 2U); - int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device); + int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device); PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } int MultiDevSSAGraphBuilder::GetVarDeviceID( - const ir::Graph &graph, const std::string &varname, + const std::string &varname, const std::unordered_map &sharded_var_device) const { auto got = sharded_var_device.find(varname); if (got == sharded_var_device.end()) { @@ -739,8 +727,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp( node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_ids") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = - GetVarDeviceID(*result, input_var_names[0], *sharded_var_device); + op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { op_dev_id = GetAppropriateDeviceID(input_var_names); for (auto &varname : input_var_names) { @@ -751,8 +738,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp( sharded_var_device->emplace(varname, op_dev_id); } } else if (node->Op()->Type() == "concat") { - op_dev_id = - GetVarDeviceID(*result, input_var_names[0], *sharded_var_device); + op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device); for (auto &varname : output_var_names) { sharded_var_device->emplace(varname, op_dev_id); } @@ -793,8 +779,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. - op_dev_id = - GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device); + op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device); PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), "This hack no longer holds, please fix."); // the variable name which contains .block means it was splited by @@ -824,8 +809,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( auto recv_param_grad = boost::get>( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); if (recv_param_grad.size() == 2U) { - op_dev_id = - GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device); + op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device); VLOG(10) << "recv param " << recv_param_grad[0] << " get grad place: " << recv_param_grad[1] << " place: " << op_dev_id; @@ -860,8 +844,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( for (ir::Node *output : node->outputs) { int outvar_dev_id = op_dev_id; if (node->Op()->Type() == "fetch_barrier") { - outvar_dev_id = - GetVarDeviceID(*result, output->Name(), *sharded_var_device); + outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name()); } p = places_[outvar_dev_id]; @@ -878,6 +861,14 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( return op_dev_id; } +bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { + PADDLE_ENFORCE(all_vars_.count(og) != 0); + if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { + return true; + } + return false; +} + bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { return boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index 8e462aec7dc7ce45cad592b89de0b6edde8c9146..17a418d3fd3df24b89e268f3815a7841812ce9a7 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass { #endif int GetVarDeviceID( - const ir::Graph &graph, const std::string &varname, + const std::string &varname, const std::unordered_map &sharded_var_device) const; bool IsScaleLossOp(ir::Node *node) const; @@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ir::Graph *result, ir::Node *node, std::unordered_map *sharded_var_device) const; - std::vector FindDistTrainSendVars( - const std::vector &nodes) const; - - std::vector FindDistTrainRecvVars( - const std::vector &nodes) const; - void CreateComputationalOps(ir::Graph *result, ir::Node *node, size_t num_places) const; @@ -76,7 +70,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass { int dev_id) const; int GetOpDeviceID( - const ir::Graph &graph, ir::Node *node, + ir::Node *node, const std::unordered_map &sharded_var_device) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; @@ -99,6 +93,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass { void SetCommunicationContext(OpHandleBase *op_handle, const platform::Place &p) const; + std::vector SortForReduceMode( + const std::vector &) const; + + int GetOpDeviceID( + ir::Node *node, + const std::unordered_map &shared_var_device, + std::unordered_map> *delay_ops) + const; + mutable std::string loss_var_name_; mutable std::vector places_; mutable std::vector local_scopes_; diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index fc91564bbaecf7b1725908fc1eb8b1e4d2e20d32..9323d079824d5a468cf8c911df4a4d2530eca4a9 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -23,67 +23,8 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -namespace { - -void CheckProgram(const ProgramDesc &program) { -#define _INT(role) static_cast(role) - - std::map visit; - for (OpDesc *op : program.Block(0).AllOps()) { - // For backward compatibility, some program doesn't have role added. - if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue; - int role_id = - boost::get(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())); - visit[role_id] = true; - switch (role_id) { - case _INT(OpRole::kForward): - if (visit.find(_INT(OpRole::kBackward)) != visit.end()) { - LOG(ERROR) - << "Cannot add backward operator before forward operator %s." - << op->Type(); - } - break; - case _INT(OpRole::kBackward): - case _INT(OpRole::kBackward) | _INT(OpRole::kLoss): - PADDLE_ENFORCE( - visit.find(_INT(OpRole::kOptimize)) == visit.end(), - "Cannot add backward operator %s after optimize operator.", - op->Type()); - break; - case _INT(OpRole::kForward) | _INT(OpRole::kLoss): - PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) | - _INT(OpRole::kLoss)) == visit.end(), - "Cannot add backward|loss operator before " - "forward|loss operator %s.", - op->Type()); - PADDLE_ENFORCE( - visit.find(_INT(OpRole::kOptimize)) == visit.end(), - "Cannot add forward|loss operator %s after optimize operator.", - op->Type()); - break; - case _INT(OpRole::kOptimize): - case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched): - PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(), - "Optimize operators %s must follow backward operator.", - op->Type()); - break; - case _INT(OpRole::kLRSched): - case _INT(OpRole::kDist): - case _INT(OpRole::kRPC): - case _INT(OpRole::kNotSpecified): - break; - default: - LOG(FATAL) << "Unknown operator role. Don't add new role because " - "you don't know what you are doing."; - } - } - -#undef _INT -} -} // namespace Graph::Graph(const ProgramDesc &program) : program_(program) { - CheckProgram(program_); auto var_nodes = InitFromProgram(program_); ResolveHazard(var_nodes); } diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index b98408ee7726768a108772329b8dc95c2df3c891..c11055bb35c633c3a8cf720e07876ec0abcb50c9 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -215,6 +215,7 @@ void ParallelExecutor::BCastParamsToDevices( if (paddle::platform::is_gpu_place(main_tensor.place())) { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) std::vector buffers; + buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); for (size_t i = 0; i < member_->places_.size(); ++i) { @@ -248,9 +249,7 @@ void ParallelExecutor::BCastParamsToDevices( #endif } else { platform::CPUPlace cpu; - for (size_t i = 0; i < member_->places_.size(); ++i) { - if (i == 0) continue; - + for (size_t i = 1; i < member_->places_.size(); ++i) { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable();