未验证 提交 fe8495a7 编写于 作者: C chengduo 提交者: GitHub

[WIP] Refine MultiDevSSAGraph (#15040)

* refine parallel_exe
test=develop

* rename shared_var_device

* code refine

* add test_weight_decay

* remove Sort
test=develop

* Add SortForReduce
test=develop

* code refine
test=develop

* follow comment
test=develop
上级 85471533
...@@ -42,6 +42,12 @@ namespace { ...@@ -42,6 +42,12 @@ namespace {
typedef std::vector<OpHandleBase *> GraphOps; typedef std::vector<OpHandleBase *> GraphOps;
const char kGraphOps[] = "ops"; const char kGraphOps[] = "ops";
bool OpHaveRole(const ir::Node &node, const framework::OpRole &role) {
return boost::get<int>(
node.Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(role);
}
void PolishGraphToSupportDataHazards(ir::Graph *graph) { void PolishGraphToSupportDataHazards(ir::Graph *graph) {
for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) { for (auto &var_map : graph->Get<GraphVars>(kGraphVars)) {
for (auto &name_pair : var_map) { for (auto &name_pair : var_map) {
...@@ -147,6 +153,7 @@ void MultiDevSSAGraphBuilder::Init() const { ...@@ -147,6 +153,7 @@ void MultiDevSSAGraphBuilder::Init() const {
#endif #endif
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) { if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only " LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False."; "one place. enable_data_balance is set to False.";
...@@ -154,145 +161,16 @@ void MultiDevSSAGraphBuilder::Init() const { ...@@ -154,145 +161,16 @@ void MultiDevSSAGraphBuilder::Init() const {
} }
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
for (ir::Node *input : node->inputs) {
VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
op_handle->AddInput(var);
}
for (ir::Node *output : node->outputs) {
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, place_id);
}
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto &node : nodes) {
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const std::vector<ir::Node *> &nodes) const {
std::vector<std::string> recv_vars;
for (auto &node : nodes) {
OpDesc *op = node->Op();
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
}
// Topology sort the graph nodes from inputs to outputs.
// Since SSAGraphBuilder depends on forward/backward nodes to assign devices
// to parameter/gradients before optimizer ops, topo sort is insufficient. (
// some optimizer ops might not depend on any nodes), we manually move all
// optimizer nodes after last backward nodes.
// However, the assumption by SSAGraphBuilder should be relaxed in the future.
std::vector<ir::Node *> SortOpsAndDelayOptimizeOp(const ir::Graph &graph) {
std::vector<ir::Node *> ret = ir::TopologySortOperations(graph);
size_t last_backward = 0;
for (size_t i = 0; i < ret.size(); ++i) {
if (boost::get<int>(
ret[i]->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kBackward)) {
last_backward = i;
}
}
std::vector<ir::Node *> optimize_ops;
std::vector<ir::Node *> sorted_ret;
for (size_t i = 0; i < ret.size(); ++i) {
if (i < last_backward) {
if (static_cast<bool>(boost::get<int>(ret[i]->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kOptimize))) {
optimize_ops.push_back(ret[i]);
} else {
sorted_ret.push_back(ret[i]);
}
} else if (i == last_backward) {
sorted_ret.push_back(ret[i]);
// Verify that no operations before optimize ops depends on optimize ops.
std::unordered_set<ir::Node *> optimize_set(optimize_ops.begin(),
optimize_ops.end());
for (ir::Node *n : sorted_ret) {
for (ir::Node *in : n->inputs) {
for (ir::Node *pre_n : in->inputs) {
PADDLE_ENFORCE(optimize_set.find(pre_n) == optimize_set.end(),
"optimize operations cannot be depended by forward "
"or backward node %s -> %s",
pre_n->Name(), n->Name());
}
}
}
sorted_ret.insert(sorted_ret.end(), optimize_ops.begin(),
optimize_ops.end());
} else {
sorted_ret.push_back(ret[i]);
}
}
return sorted_ret;
}
std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
Init(); Init();
// Give the topology sort order and rebuild the graph structure. // Give the topology sort order and rebuild the graph structure.
std::vector<ir::Node *> sorted_ops = SortOpsAndDelayOptimizeOp(*graph); std::vector<ir::Node *> sorted_ops = ir::TopologySortOperations(*graph);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
sorted_ops = SortForReduceMode(sorted_ops);
}
auto nodes = graph->ReleaseNodes(); auto nodes = graph->ReleaseNodes();
ir::Graph &result = *graph; ir::Graph &result = *graph;
...@@ -303,31 +181,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -303,31 +181,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
all_vars_.emplace(node->Name(), node->Var()); all_vars_.emplace(node->Name(), node->Var());
} }
} }
std::unordered_set<std::string> og_has_been_broadcast;
// We cannot invoke resize. It is a bug of GCC 4.8 // We cannot invoke resize. It is a bug of GCC 4.8
result.Set(kGraphVars, new GraphVars(places_.size())); result.Set(kGraphVars, new GraphVars(places_.size()));
result.Set(kGraphDepVars, new GraphDepVars); result.Set(kGraphDepVars, new GraphDepVars);
result.Set(kGraphOps, new GraphOps); result.Set(kGraphOps, new GraphOps);
// find send/recv vars so that we can place the distributed training
// related op in the place 0
auto send_vars = FindDistTrainSendVars(sorted_ops);
auto recv_vars = FindDistTrainRecvVars(sorted_ops);
std::vector<std::unordered_set<std::string>> bcast_var_name_set; std::vector<std::unordered_set<std::string>> bcast_var_name_set;
bcast_var_name_set.resize(places_.size()); bcast_var_name_set.resize(places_.size());
size_t cur_device_id = 0;
bool is_forwarding = true; bool is_forwarding = true;
bool is_dist_train = false; bool is_dist_train = false;
std::unordered_map<std::string, int> sharded_var_device; std::unordered_map<std::string, int> sharded_var_device;
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (boost::get<int>( if (OpHaveRole(*node, OpRole::kRPC)) {
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device); int op_dev_id = CreateRPCOp(&result, node, &sharded_var_device);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
...@@ -341,9 +210,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -341,9 +210,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
} }
is_dist_train = true; is_dist_train = true;
} else if (boost::get<int>(node->Op()->GetAttr( } else if (OpHaveRole(*node, OpRole::kDist)) {
OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device); int op_dev_id = CreateDistTrainOp(&result, node, &sharded_var_device);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
...@@ -365,7 +232,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -365,7 +232,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
// the block. // the block.
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(result, node, sharded_var_device); int op_dev_id = GetOpDeviceID(node, sharded_var_device);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOp(&result, node, op_dev_id); CreateComputationalOp(&result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
...@@ -385,47 +252,48 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -385,47 +252,48 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
} }
if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) { if (!is_forwarding && (places_.size() > 1 || num_trainers > 1)) {
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
if (static_cast<bool>(boost::get<int>(node->Op()->GetAttr( try {
OpProtoAndCheckerMaker::OpRoleAttrName())) & auto backward_vars = boost::get<std::vector<std::string>>(
static_cast<int>(OpRole::kBackward))) { node->Op()->GetNullableAttr(
try { OpProtoAndCheckerMaker::OpRoleVarAttrName()));
auto backward_vars = boost::get<std::vector<std::string>>(
node->Op()->GetNullableAttr( PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
for (size_t i = 0; i < backward_vars.size(); i += 2) {
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0); auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1];
for (size_t i = 0; i < backward_vars.size(); i += 2) { VLOG(10) << "Bcast " << g_name << " for parameter " << p_name;
auto &p_name = backward_vars[i]; size_t cur_device_id = -1;
auto &g_name = backward_vars[i + 1]; switch (strategy_.reduce_) {
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; case BuildStrategy::ReduceStrategy::kReduce:
cur_device_id = GetAppropriateDeviceID({g_name});
switch (strategy_.reduce_) { CreateReduceOp(&result, g_name, cur_device_id);
case BuildStrategy::ReduceStrategy::kReduce: sharded_var_device.emplace(g_name, cur_device_id);
cur_device_id = GetAppropriateDeviceID({g_name}); if (!is_dist_train) {
CreateReduceOp(&result, g_name, cur_device_id); bcast_var_name_set[cur_device_id].emplace(p_name);
sharded_var_device.emplace(g_name, cur_device_id); }
if (!is_dist_train) { break;
bcast_var_name_set[cur_device_id].emplace(p_name); case BuildStrategy::ReduceStrategy::kAllReduce:
} if (IsSparseGradient(g_name)) {
break; CreateReduceOp(&result, g_name, 0);
case BuildStrategy::ReduceStrategy::kAllReduce: CreateBroadcastOp(&result, g_name, 0);
if (IsSparseGradient(g_name)) { } else {
CreateReduceOp(&result, g_name, 0); InsertAllReduceOp(&result, g_name);
CreateBroadcastOp(&result, g_name, 0); }
} else { break;
InsertAllReduceOp(&result, g_name); default:
} LOG(FATAL) << "Unknown reduce strategy ";
break; break;
default:
LOG(FATAL) << "Unknown reduce strategy ";
break;
}
} }
} catch (boost::bad_get e) {
} }
} catch (boost::bad_get e) {
} }
} }
} }
...@@ -469,12 +337,108 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl( ...@@ -469,12 +337,108 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilder::ApplyImpl(
return graph; return graph;
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { std::vector<ir::Node *> MultiDevSSAGraphBuilder::SortForReduceMode(
PADDLE_ENFORCE(all_vars_.count(og) != 0); const std::vector<ir::Node *> &topo_ops) const {
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) { std::unordered_map<std::string, int> sharded_var_device;
return true; std::vector<ir::Node *> sorted_ops;
std::unordered_map<std::string, std::vector<ir::Node *>> delayed_op;
sorted_ops.reserve(topo_ops.size());
auto insert_delayed_op = [&](const std::string &var_name, int dev_id) {
sharded_var_device.emplace(var_name, dev_id);
if (delayed_op.count(var_name)) {
auto &ops = delayed_op.at(var_name);
sorted_ops.insert(sorted_ops.end(), ops.begin(), ops.end());
delayed_op.at(var_name).clear();
}
};
for (ir::Node *node : topo_ops) {
int op_dev_id = GetOpDeviceID(node, sharded_var_device, &delayed_op);
if (op_dev_id > -1) {
// This op only runs on one specific device.
sorted_ops.emplace_back(node);
for (ir::Node *n : node->outputs) {
insert_delayed_op(n->Name(), op_dev_id);
}
} else if (op_dev_id == -1) {
// This op runs on all devices, and its output may have parameter's
// gradients.
sorted_ops.emplace_back(node);
bool is_bk_op =
static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward));
if (!is_bk_op) continue;
// Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once.
std::vector<std::string> backward_vars;
try {
backward_vars =
boost::get<std::vector<std::string>>(node->Op()->GetNullableAttr(
OpProtoAndCheckerMaker::OpRoleVarAttrName()));
} catch (boost::bad_get e) {
}
PADDLE_ENFORCE_EQ(backward_vars.size() % 2, 0);
for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &g_name = backward_vars[i + 1];
size_t cur_device_id = GetAppropriateDeviceID({g_name});
insert_delayed_op(g_name, static_cast<int>(cur_device_id));
}
} else if (op_dev_id == -2) {
// The Op on which the Op depends has not yet been generated.
}
} }
return false;
PADDLE_ENFORCE_EQ(sorted_ops.size(), topo_ops.size());
return sorted_ops;
}
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(ir::Graph *result,
ir::Node *node,
size_t place_id) const {
auto p = places_[place_id];
auto *op_handle = result->Get<GraphOps>(kGraphOps).back();
op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
for (ir::Node *input : node->inputs) {
VarHandle *var = CreateOrGetLatestVarHandle(result, input, p, place_id);
op_handle->AddInput(var);
}
for (ir::Node *output : node->outputs) {
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
}
CreateOpOutput(result, op_handle, new_node, p, place_id);
}
}
size_t MultiDevSSAGraphBuilder::GetAppropriateDeviceID(
const std::vector<std::string> &var_names) const {
int64_t numel_sum = 0;
for (auto var_name : var_names) {
if (all_vars_.find(var_name) == all_vars_.end()) continue;
auto var_desc = all_vars_.at(var_name);
PADDLE_ENFORCE_NOT_NULL(var_desc);
auto dim = framework::make_ddim(var_desc->GetShape());
int64_t numel = framework::product(dim);
PADDLE_ENFORCE_GT(numel, 0);
numel_sum += numel;
}
auto smallest =
std::min_element(std::begin(balance_vars_), std::end(balance_vars_));
size_t dev_id =
static_cast<size_t>(std::distance(std::begin(balance_vars_), smallest));
balance_vars_[dev_id] += numel_sum;
return dev_id;
} }
void MultiDevSSAGraphBuilder::SetCommunicationContext( void MultiDevSSAGraphBuilder::SetCommunicationContext(
...@@ -625,28 +589,52 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp( ...@@ -625,28 +589,52 @@ void MultiDevSSAGraphBuilder::InsertDataBalanceOp(
} }
int MultiDevSSAGraphBuilder::GetOpDeviceID( int MultiDevSSAGraphBuilder::GetOpDeviceID(
const ir::Graph &graph, ir::Node *node, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1;
}
if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
return -1;
}
auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
if (dev_id == -1) {
(*delay_ops)[param_grad[1]].push_back(node);
return -2;
}
return dev_id;
}
int MultiDevSSAGraphBuilder::GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const { const std::unordered_map<std::string, int> &sharded_var_device) const {
if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) { if (strategy_.reduce_ != BuildStrategy::ReduceStrategy::kReduce) {
return -1; return -1;
} }
int op_role = boost::get<int>(
node->Op()->GetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName())); if (!OpHaveRole(*node, framework::OpRole::kOptimize)) {
if (op_role != static_cast<int>(framework::OpRole::kOptimize)) {
return -1; return -1;
} }
auto param_grad = boost::get<std::vector<std::string>>( auto param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
PADDLE_ENFORCE_EQ(param_grad.size(), 2U); PADDLE_ENFORCE_EQ(param_grad.size(), 2U);
int dev_id = GetVarDeviceID(graph, param_grad[1], sharded_var_device); int dev_id = GetVarDeviceID(param_grad[1], sharded_var_device);
PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]",
node->Op()->Type(), param_grad[0], param_grad[1]); node->Op()->Type(), param_grad[0], param_grad[1]);
return dev_id; return dev_id;
} }
int MultiDevSSAGraphBuilder::GetVarDeviceID( int MultiDevSSAGraphBuilder::GetVarDeviceID(
const ir::Graph &graph, const std::string &varname, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const { const std::unordered_map<std::string, int> &sharded_var_device) const {
auto got = sharded_var_device.find(varname); auto got = sharded_var_device.find(varname);
if (got == sharded_var_device.end()) { if (got == sharded_var_device.end()) {
...@@ -740,8 +728,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp( ...@@ -740,8 +728,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
node->Op()->Type() == "split_selected_rows" || node->Op()->Type() == "split_selected_rows" ||
node->Op()->Type() == "split_ids") { node->Op()->Type() == "split_ids") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { if (strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) {
op_dev_id = GetAppropriateDeviceID(input_var_names); op_dev_id = GetAppropriateDeviceID(input_var_names);
for (auto &varname : input_var_names) { for (auto &varname : input_var_names) {
...@@ -752,8 +739,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp( ...@@ -752,8 +739,7 @@ int MultiDevSSAGraphBuilder::CreateDistTrainOp(
sharded_var_device->emplace(varname, op_dev_id); sharded_var_device->emplace(varname, op_dev_id);
} }
} else if (node->Op()->Type() == "concat") { } else if (node->Op()->Type() == "concat") {
op_dev_id = op_dev_id = GetVarDeviceID(input_var_names[0], *sharded_var_device);
GetVarDeviceID(*result, input_var_names[0], *sharded_var_device);
for (auto &varname : output_var_names) { for (auto &varname : output_var_names) {
sharded_var_device->emplace(varname, op_dev_id); sharded_var_device->emplace(varname, op_dev_id);
} }
...@@ -794,8 +780,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -794,8 +780,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
int op_dev_id = -1; int op_dev_id = -1;
if (node->Op()->Type() == "send") { if (node->Op()->Type() == "send") {
// TODO(paddle-dev): getting the first var is not safe. // TODO(paddle-dev): getting the first var is not safe.
op_dev_id = op_dev_id = GetVarDeviceID(node->inputs[0]->Name(), *sharded_var_device);
GetVarDeviceID(*result, node->inputs[0]->Name(), *sharded_var_device);
PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]), PADDLE_ENFORCE(!ir::IsControlDepVar(*node->inputs[0]),
"This hack no longer holds, please fix."); "This hack no longer holds, please fix.");
// the variable name which contains .block means it was splited by // the variable name which contains .block means it was splited by
...@@ -825,8 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -825,8 +810,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
auto recv_param_grad = boost::get<std::vector<std::string>>( auto recv_param_grad = boost::get<std::vector<std::string>>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName())); node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleVarAttrName()));
if (recv_param_grad.size() == 2U) { if (recv_param_grad.size() == 2U) {
op_dev_id = op_dev_id = GetVarDeviceID(recv_param_grad[1], *sharded_var_device);
GetVarDeviceID(*result, recv_param_grad[1], *sharded_var_device);
VLOG(10) << "recv param " << recv_param_grad[0] VLOG(10) << "recv param " << recv_param_grad[0]
<< " get grad place: " << recv_param_grad[1] << " get grad place: " << recv_param_grad[1]
<< " place: " << op_dev_id; << " place: " << op_dev_id;
...@@ -861,8 +845,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -861,8 +845,7 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
for (ir::Node *output : node->outputs) { for (ir::Node *output : node->outputs) {
int outvar_dev_id = op_dev_id; int outvar_dev_id = op_dev_id;
if (node->Op()->Type() == "fetch_barrier") { if (node->Op()->Type() == "fetch_barrier") {
outvar_dev_id = outvar_dev_id = GetVarDeviceID(output->Name(), *sharded_var_device);
GetVarDeviceID(*result, output->Name(), *sharded_var_device);
PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name()); PADDLE_ENFORCE_NE(outvar_dev_id, -1, "output name %s", output->Name());
} }
p = places_[outvar_dev_id]; p = places_[outvar_dev_id];
...@@ -879,6 +862,14 @@ int MultiDevSSAGraphBuilder::CreateRPCOp( ...@@ -879,6 +862,14 @@ int MultiDevSSAGraphBuilder::CreateRPCOp(
return op_dev_id; return op_dev_id;
} }
bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const {
PADDLE_ENFORCE(all_vars_.count(og) != 0);
if (all_vars_.at(og)->GetType() == proto::VarType::SELECTED_ROWS) {
return true;
}
return false;
}
bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const {
return boost::get<int>( return boost::get<int>(
node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
......
...@@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -45,7 +45,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
#endif #endif
int GetVarDeviceID( int GetVarDeviceID(
const ir::Graph &graph, const std::string &varname, const std::string &varname,
const std::unordered_map<std::string, int> &sharded_var_device) const; const std::unordered_map<std::string, int> &sharded_var_device) const;
bool IsScaleLossOp(ir::Node *node) const; bool IsScaleLossOp(ir::Node *node) const;
...@@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -57,12 +57,6 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
ir::Graph *result, ir::Node *node, ir::Graph *result, ir::Node *node,
std::unordered_map<std::string, int> *sharded_var_device) const; std::unordered_map<std::string, int> *sharded_var_device) const;
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) const;
std::vector<std::string> FindDistTrainRecvVars(
const std::vector<ir::Node *> &nodes) const;
void CreateComputationalOps(ir::Graph *result, ir::Node *node, void CreateComputationalOps(ir::Graph *result, ir::Node *node,
size_t num_places) const; size_t num_places) const;
...@@ -77,7 +71,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -77,7 +71,7 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
int dev_id) const; int dev_id) const;
int GetOpDeviceID( int GetOpDeviceID(
const ir::Graph &graph, ir::Node *node, ir::Node *node,
const std::unordered_map<std::string, int> &sharded_var_device) const; const std::unordered_map<std::string, int> &sharded_var_device) const;
void InsertAllReduceOp(ir::Graph *result, const std::string &og) const; void InsertAllReduceOp(ir::Graph *result, const std::string &og) const;
...@@ -100,6 +94,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass { ...@@ -100,6 +94,15 @@ class MultiDevSSAGraphBuilder : public ir::Pass {
void SetCommunicationContext(OpHandleBase *op_handle, void SetCommunicationContext(OpHandleBase *op_handle,
const platform::Place &p) const; const platform::Place &p) const;
std::vector<ir::Node *> SortForReduceMode(
const std::vector<ir::Node *> &) const;
int GetOpDeviceID(
ir::Node *node,
const std::unordered_map<std::string, int> &shared_var_device,
std::unordered_map<std::string, std::vector<ir::Node *>> *delay_ops)
const;
mutable std::string loss_var_name_; mutable std::string loss_var_name_;
mutable std::vector<platform::Place> places_; mutable std::vector<platform::Place> places_;
mutable std::vector<Scope *> local_scopes_; mutable std::vector<Scope *> local_scopes_;
......
...@@ -23,66 +23,8 @@ limitations under the License. */ ...@@ -23,66 +23,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
namespace {
void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role)
std::map<int, bool> visit;
for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id =
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR) << "Cannot add backward operator before forward operator "
<< op->Type();
}
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator %s.",
op->Type());
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.",
op->Type());
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.",
op->Type());
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
#undef _INT
}
} // namespace
Graph::Graph(const ProgramDesc &program) : program_(program) { Graph::Graph(const ProgramDesc &program) : program_(program) {
CheckProgram(program_);
auto var_nodes = InitFromProgram(program_); auto var_nodes = InitFromProgram(program_);
ResolveHazard(var_nodes); ResolveHazard(var_nodes);
} }
......
...@@ -320,6 +320,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -320,6 +320,7 @@ void ParallelExecutor::BCastParamsToDevices(
if (paddle::platform::is_gpu_place(main_tensor.place())) { if (paddle::platform::is_gpu_place(main_tensor.place())) {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
std::vector<void *> buffers; std::vector<void *> buffers;
buffers.reserve(member_->places_.size());
size_t numel = main_tensor.numel(); size_t numel = main_tensor.numel();
ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type());
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
...@@ -353,9 +354,7 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -353,9 +354,7 @@ void ParallelExecutor::BCastParamsToDevices(
#endif #endif
} else { } else {
platform::CPUPlace cpu; platform::CPUPlace cpu;
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 1; i < member_->places_.size(); ++i) {
if (i == 0) continue;
auto local_scope = member_->local_scopes_[i]; auto local_scope = member_->local_scopes_[i];
auto *t = local_scope->Var(var)->GetMutable<LoDTensor>(); auto *t = local_scope->Var(var)->GetMutable<LoDTensor>();
......
...@@ -148,7 +148,7 @@ class ParallelExecutor(object): ...@@ -148,7 +148,7 @@ class ParallelExecutor(object):
trainers_endpoints), "num_trainers == len(end_points)" trainers_endpoints), "num_trainers == len(end_points)"
build_strategy.trainers_endpoints = trainers_endpoints build_strategy.trainers_endpoints = trainers_endpoints
# step5: get persistable_vars, parameter_vars, places. persistable_vars # step6: get persistable_vars, places. persistable_vars
# need be broadcast to other local_scope. # need be broadcast to other local_scope.
persistable_vars = set([ persistable_vars = set([
cpt.to_text(v.name) for v in [ cpt.to_text(v.name) for v in [
...@@ -164,7 +164,7 @@ class ParallelExecutor(object): ...@@ -164,7 +164,7 @@ class ParallelExecutor(object):
places = list(map(place_obj, self._places)) places = list(map(place_obj, self._places))
# step6: init ParallelExecutor # step7: init ParallelExecutor
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
places, persistable_vars, main.desc, places, persistable_vars, main.desc,
cpt.to_text(loss_name) cpt.to_text(loss_name)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
from functools import partial
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
def get_places():
places = []
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
return places
@contextlib.contextmanager
def prog_scope_guard(main_prog, startup_prog):
scope = fluid.core.Scope()
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
with fluid.program_guard(main_prog, startup_prog):
yield
def bow_net(data,
label,
dict_dim,
is_sparse=False,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(
input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim])
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
cost = fluid.layers.cross_entropy(input=prediction, label=label)
avg_cost = fluid.layers.mean(x=cost)
return avg_cost
class TestWeightDecay(unittest.TestCase):
def setUp(self):
self.word_dict = paddle.dataset.imdb.word_dict()
reader = paddle.batch(
paddle.dataset.imdb.train(self.word_dict), batch_size=4)()
self.train_data = [next(reader) for _ in range(5)]
self.learning_rate = .5
def run_executor(self, place, feed_list, loss):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
main_prog = fluid.default_main_program()
loss_set = []
for data in self.train_data:
out = exe.run(main_prog,
feed=feeder.feed(data),
fetch_list=[loss.name])
print("loss %s" % (np.average(out)))
loss_set.append(np.average(out))
return loss_set
def run_parallel_exe(self,
place,
feed_list,
loss,
use_cuda=True,
use_reduce=False,
use_fast_executor=False,
use_ir_memory_optimize=False):
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=feed_list, place=place)
exe.run(fluid.default_startup_program())
exec_strategy = fluid.ExecutionStrategy()
if use_fast_executor:
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce \
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.memory_optimize = use_ir_memory_optimize
parallel_exe = fluid.ParallelExecutor(
use_cuda,
loss_name=loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy)
loss_set = []
for data in self.train_data:
out = parallel_exe.run(feed=feeder.feed(data),
fetch_list=[loss.name])
print("loss %s" % (np.average(out)))
loss_set.append(np.average(out))
return loss_set
def check_weight_decay(self,
place,
model,
use_parallel_exe=False,
use_reduce=False):
main_prog = fluid.framework.Program()
startup_prog = fluid.framework.Program()
startup_prog.random_seed = 1
with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog):
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
avg_cost = model(data, label, len(self.word_dict))
param_list = [(var, var * self.learning_rate)
for var in main_prog.block(0).all_parameters()]
optimizer = fluid.optimizer.Adagrad(
learning_rate=self.learning_rate)
optimizer.minimize(avg_cost)
for params in param_list:
updated_p = fluid.layers.elementwise_sub(
x=params[0], y=params[1])
fluid.layers.assign(input=updated_p, output=params[0])
if use_parallel_exe:
loss = self.run_parallel_exe(
place, [data, label],
loss=avg_cost,
use_cuda=True,
use_reduce=use_reduce)
else:
loss = self.run_executor(place, [data, label], loss=avg_cost)
return loss
def test_weight_decay(self):
model = partial(bow_net, is_sparse=False)
for place in get_places():
loss = self.check_weight_decay(place, model, use_parallel_exe=False)
loss2 = self.check_weight_decay(
place, model, use_parallel_exe=True, use_reduce=False)
for i in range(len(loss)):
assert np.isclose(a=loss[i], b=loss2[i], rtol=5e-5)
loss3 = self.check_weight_decay(
place, model, use_parallel_exe=True, use_reduce=True)
for i in range(len(loss)):
assert np.isclose(a=loss[i], b=loss3[i], rtol=5e-5)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册