提交 34b401fc 编写于 作者: X Xin Pan

clean up a global graph attr.

上级 8ac2242b
......@@ -90,5 +90,4 @@ REGISTER_PASS(multi_devices_check_pass,
paddle::framework::details::SSAGraghBuilderWithChecker)
.RequireGraphAttr(paddle::framework::details::kGraphVars)
.RequireGraphAttr(paddle::framework::details::kGraphDepVars)
.RequireGraphAttr(paddle::framework::details::kGraphOps)
.RequireGraphAttr(paddle::framework::details::kShardedVarDevice);
.RequireGraphAttr(paddle::framework::details::kGraphOps);
......@@ -34,6 +34,7 @@
namespace paddle {
namespace framework {
namespace details {
namespace {
void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
......@@ -303,7 +304,6 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps);
result.Set(kShardedVarDevice, new ShardedVarDevice);
// find send/recv vars so that we can place the distributed training
// related op in the place 0
......@@ -317,11 +317,13 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
bool is_forwarding = true;
bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) {
if (boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node);
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place.");
if (node->Op()->Type() == "recv") {
......@@ -337,7 +339,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} else if (boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node);
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set[op_dev_id].emplace(origin_param_name);
......@@ -356,12 +358,11 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block.
is_forwarding = false;
} else {
int op_dev_id = GetOpDeviceID(result, node);
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) {
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(n->Name(), op_dev_id);
sharded_var_device.emplace(n->Name(), op_dev_id);
}
} else {
// This op runs on all devices, and its output may have parameter's
......@@ -398,8 +399,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
CreateReduceOp(&result, g_name, cur_device_id);
graph->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(g_name, cur_device_id);
sharded_var_device.emplace(g_name, cur_device_id);
if (!is_dist_train) {
bcast_var_name_set[cur_device_id].emplace(p_name);
}
......@@ -617,8 +617,9 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
}
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
......@@ -631,15 +632,15 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(const ir::Graph &graph,
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1]);
int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id;
}
int MultiDevSSAGraphBuilder::GetVarDeviceID(const ir::Graph &graph,
const std::string &varname) const {
auto &sharded_var_device = graph.Get<ShardedVarDevice>(kShardedVarDevice);
int MultiDevSSAGraphBuilder::GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname);
return got == sharded_var_device.end() ? -1 : got->second;
}
......@@ -709,8 +710,9 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result,
return var;
}
int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1;
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
......@@ -725,23 +727,22 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
}
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else if (node->Op()->Type() == "concat") {
op_dev_id = GetVarDeviceID(*result, input_var_names[0]);
op_dev_id =
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else {
LOG(ERROR) << "got unexpected dist op: " << node->Op()->Type();
......@@ -774,12 +775,14 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) {
}
// Create RPC related op handles that connects its in ops and out ops.
int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
ir::Node *node) const {
int MultiDevSSAGraphBuilder::CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const {
int op_dev_id = -1;
if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe.
op_dev_id = GetVarDeviceID(*result, node->inputs[0]->Name());
op_dev_id =
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by
......@@ -797,11 +800,9 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
VLOG(10) << "send grad " << input_var_names[0] << " origin "
<< send_param_grad[1] << " place: " << op_dev_id;
for (auto &varname : input_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(send_param_grad[1], op_dev_id);
sharded_var_device->emplace(send_param_grad[1], op_dev_id);
}
} else if (node->Op()->Type() == "recv") {
std::vector<std::string> output_var_names;
......@@ -811,7 +812,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) {
op_dev_id = GetVarDeviceID(*result, recv_param_grad[1]);
op_dev_id =
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id;
......@@ -819,8 +821,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
op_dev_id = GetAppropriateDeviceID(output_var_names);
}
for (auto &varname : output_var_names) {
result->Get<ShardedVarDevice>(kShardedVarDevice)
.emplace(varname, op_dev_id);
sharded_var_device->emplace(varname, op_dev_id);
}
} else {
// send_barrier, fetch_barrier will run on place 0;
......@@ -847,7 +848,8 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result,
for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = GetVarDeviceID(*result, output->Name());
outvar_dev_id =
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1);
}
p = places_[outvar_dev_id];
......
......@@ -44,12 +44,18 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
mutable platform::NCCLContextMap *nccl_ctxs_;
#endif
int GetVarDeviceID(const ir::Graph &graph, const std::string &varname) const;
int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const;
int CreateRPCOp(ir::Graph *result, ir::Node *node) const;
int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const;
int CreateRPCOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
int CreateDistTrainOp(
ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const;
......@@ -69,7 +75,9 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const;
int GetOpDeviceID(const ir::Graph &graph, ir::Node *node) const;
int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
......
......@@ -49,9 +49,6 @@ const char kGraphDepVars[] = "dep_vars";
// unordered.
typedef std::vector<std::unique_ptr<OpHandleBase>> GraphOps;
const char kGraphOps[] = "ops";
typedef std::unordered_map<std::string, int> ShardedVarDevice;
const char kShardedVarDevice[] = "sharded_var_device";
} // namespace details
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册