提交 34b401fc 编写于 作者: X Xin Pan

clean up a global graph attr.

上级 8ac2242b
...@@ -90,5 +90,4 @@ REGISTER_PASS(multi_devices_check_pass, ...@@ -90,5 +90,4 @@ REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker) paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars) .RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars) .RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps) .RequireGraphAttr(paddle::framework::details::kGraphOps);
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
namespace { namespace {
void PolishGraphToSupportDataHazards(ir::Graph *graph) { void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
...@@ -303,7 +304,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -303,7 +304,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training // find send/recv vars so that we can place the distributed training
// related op in the place 0 // related op in the place 0
...@@ -317,11 +317,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -317,11 +317,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) { static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node); int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") { if (node->Op()->Type() == "recv") {
...@@ -337,7 +339,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -337,7 +339,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr( } else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) == OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) { static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node); int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name); bcast_var_name_set[op_dev_id].emplace(origin_param_name);
...@@ -356,12 +358,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -356,12 +358,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(result, node); int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(n->Name(), op_dev_id);
.emplace(n->Name(), op_dev_id);
} }
} else { } else {
// This op runs on all devices, and its output may have parameter's // This op runs on all devices, and its output may have parameter's
...@@ -398,8 +399,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -398,8 +399,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name}); cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id); CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device.emplace(g_name, cur_device_id);
.emplace(g_name, cur_device_id);
if (!is_dist_train) { if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name); bcast_var_name_set[cur_device_id].emplace(p_name);
} }
...@@ -617,8 +617,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -617,8 +617,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
} }
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node) const { const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
...@@ -631,15 +632,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph, ...@@ -631,15 +632,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]); int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph, int MultiDevSSAGraphBuilder::GetVarDeviceID(
const std::string &varname) const { const ir::Graph &graph, const std::string &varname,
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice); const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second; return got == sharded_var_device.end() ? -1 : got->second;
} }
...@@ -709,8 +710,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, ...@@ -709,8 +710,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var; return var;
} }
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
std::vector<std::string> input_var_names; std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -725,23 +727,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -725,23 +727,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") { node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]); op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type(); LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
...@@ -774,12 +775,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { ...@@ -774,12 +775,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
} }
// Create RPC related op handles that connects its in ops and out ops. // Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Node *node) const { ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name()); op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -797,11 +800,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -797,11 +800,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG(10) << "send grad " << input_var_names[0] << " origin " VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id; << send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(send_param_grad[1], op_dev_id);
.emplace(send_param_grad[1], op_dev_id);
} }
} else if (node->Op()->Type() == "recv") { } else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names; std::vector<std::string> output_var_names;
...@@ -811,7 +812,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -811,7 +812,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]); op_dev_id =
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id; << " place: " << op_dev_id;
...@@ -819,8 +821,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -819,8 +821,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id = GetAppropriateDeviceID(output_var_names); op_dev_id = GetAppropriateDeviceID(output_var_names);
} }
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice) sharded_var_device->emplace(varname, op_dev_id);
.emplace(varname, op_dev_id);
} }
} else { } else {
// send_barrier, fetch_barrier will run on place 0; // send_barrier, fetch_barrier will run on place 0;
...@@ -847,7 +848,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ...@@ -847,7 +848,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name()); outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1); PADDLE_ENFORCE_NE(outvar_dev_id, -1);
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
......
...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_; mutable platform::NCCLContextMap *nccl_ctxs_;
#endif #endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const; int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const; int CreateRPCOp(
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars( std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const; const std::vector<ir::Node *> &nodes) const;
...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const; int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
...@@ -49,9 +49,6 @@ const char kGraphDepVars[] = "dep_vars"; ...@@ -49,9 +49,6 @@ const char kGraphDepVars[] = "dep_vars";
// unordered. // unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps; typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops"; const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册