提交 84367cf8 编写于 作者: Q Qiao Longfei

support async mode in dist mode parallel executor

上级 e72637dd
...@@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -167,6 +167,10 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
bool is_forwarding = true; bool is_forwarding = true;
bool insert_collection_ops = NeedCollectiveOps(); bool insert_collection_ops = NeedCollectiveOps();
if (strategy_.async_mode_) {
// async mode did not need to merge gradient
insert_collection_ops = false;
}
for (ir::Node *node : sorted_ops) { for (ir::Node *node : sorted_ops) {
if (DealWithSpecialOp(&result, node)) { if (DealWithSpecialOp(&result, node)) {
...@@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -192,8 +196,22 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
static_cast<bool>(boost::get<int>(node->Op()->GetAttr( static_cast<bool>(boost::get<int>(node->Op()->GetAttr(
OpProtoAndCheckerMaker::OpRoleAttrName())) & OpProtoAndCheckerMaker::OpRoleAttrName())) &
static_cast<int>(OpRole::kBackward)); static_cast<int>(OpRole::kBackward));
// optimize op is already processed in DealWithSpecialOp,
// here we only consider backward op
if (!is_bk_op) continue; if (!is_bk_op) continue;
/*
* the op that will generate the gradient of on parameter will have
one attr op_role_var
* to record the parameter and gradient, like:
attrs {
name: "op_role_var"
type: STRINGS
strings: "fc_1.b_0"
strings: "fc_1.b_0@GRAD"
}
*/
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
auto backward_vars = auto backward_vars =
...@@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl( ...@@ -204,7 +222,7 @@ std::unique_ptr<ir::Graph> MultiDevSSAGraphBuilderBase::ApplyImpl(
for (size_t i = 0; i < backward_vars.size(); i += 2) { for (size_t i = 0; i < backward_vars.size(); i += 2) {
auto &p_name = backward_vars[i]; auto &p_name = backward_vars[i];
auto &g_name = backward_vars[i + 1]; auto &g_name = backward_vars[i + 1];
VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; VLOG(3) << "Bcast " << g_name << " for parameter " << p_name;
InsertCollectiveOp(&result, p_name, g_name); InsertCollectiveOp(&result, p_name, g_name);
} }
...@@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( ...@@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp(
void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result,
ir::Node *node, ir::Node *node,
int dev_id) const { size_t dev_id) const {
result->Get<GraphOps>(kGraphOps).emplace_back( result->Get<GraphOps>(kGraphOps).emplace_back(
new ComputationOpHandle(result->CreateOpNode(node->Op()), new ComputationOpHandle(result->CreateOpNode(node->Op()),
local_scopes_[dev_id], places_[dev_id], dev_id)); local_scopes_[dev_id], places_[dev_id], dev_id));
...@@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( ...@@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps(
} }
} }
VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(
const std::string &og, ir::Graph *result, const std::string &og, size_t dst_dev_id) const {
int dst_dev_id) const {
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle( result->Get<GraphOps>(kGraphOps).emplace_back(new ReduceOpHandle(
result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), result->CreateEmptyNode("reduce", ir::Node::Type::kOperation),
...@@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
ir::Node *node) const { ir::Node *node) const {
bool insert_op = false; bool insert_op = false;
if (OpHaveRole(*node, OpRole::kRPC)) { if (OpHaveRole(*node, OpRole::kRPC)) {
// in async_mode, each graph will send it's own gradient.
if (strategy_.async_mode_ && node->Op()->Type() == "send") {
return false;
}
int op_dev_id = CreateRPCOp(result, node); int op_dev_id = CreateRPCOp(result, node);
PADDLE_ENFORCE(op_dev_id != -1, PADDLE_ENFORCE(op_dev_id != -1,
"Can not schedule the RPC operator to the right place."); "Can not schedule the RPC operator to the right place.");
...@@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else if (OpHaveRole(*node, OpRole::kDist)) { } else if (OpHaveRole(*node, OpRole::kDist)) {
int op_dev_id = CreateDistTrainOp(result, node); int op_dev_id = CreateDistTrainOp(result, node);
if (node->Op()->Type() == "concat") { if (node->Op()->Type() == "concat") {
// the input(block of parameter) of concat is on different device,
// the output(parameter) will on one device.
auto origin_param_name = node->Op()->OutputArgumentNames()[0]; auto origin_param_name = node->Op()->OutputArgumentNames()[0];
bcast_var_name_set_[op_dev_id].emplace(origin_param_name); bcast_var_name_set_[op_dev_id].emplace(origin_param_name);
} }
...@@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ...@@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
} else { } else {
int op_dev_id = GetOpDeviceID(node); int op_dev_id = GetOpDeviceID(node);
if (op_dev_id != -1) { // This op only runs on one specific device. if (op_dev_id != -1) { // This op only runs on one specific device.
// optimize op will be processed here.
CreateComputationalOp(result, node, op_dev_id); CreateComputationalOp(result, node, op_dev_id);
for (ir::Node *n : node->outputs) { for (ir::Node *n : node->outputs) {
sharded_var_device_.emplace(n->Name(), op_dev_id); sharded_var_device_.emplace(n->Name(), op_dev_id);
...@@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, ...@@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result,
void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
const std::string &p_name, const std::string &p_name,
const std::string &g_name) const { const std::string &g_name) const {
// collective gradient to each device
size_t cur_device_id = 0; size_t cur_device_id = 0;
switch (strategy_.reduce_) { switch (strategy_.reduce_) {
case BuildStrategy::ReduceStrategy::kReduce: case BuildStrategy::ReduceStrategy::kReduce:
......
...@@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { ...@@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass {
proto::VarType::Type dtype) const; proto::VarType::Type dtype) const;
VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og,
int dst_dev_id) const; size_t dst_dev_id) const;
void CreateComputationalOp(ir::Graph *result, ir::Node *node, void CreateComputationalOp(ir::Graph *result, ir::Node *node,
int dev_id) const; size_t dev_id) const;
bool IsSparseGradient(const std::string &og) const; bool IsSparseGradient(const std::string &og) const;
...@@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { ...@@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
protected: protected:
virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, void InsertCollectiveOp(ir::Graph *result, const std::string &p_name,
const std::string &g_name) const {} const std::string &g_name) const override {}
bool NeedCollectiveOps() const override { return false; } bool NeedCollectiveOps() const override { return false; }
virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override {
return false; return false;
} }
virtual void InsertPostprocessOps(ir::Graph *result) const {} void InsertPostprocessOps(ir::Graph *result) const override {}
}; };
class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册