diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index f1347e2b0d70c6d7de3c50d3662bbd74e705391d..a2bbfc91b7374523fd6f27d3fc65c9ef090600c1 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -167,6 +167,10 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( bool is_forwarding = true; bool insert_collection_ops = NeedCollectiveOps(); + if (strategy_.async_mode_) { + // async mode did not need to merge gradient + insert_collection_ops = false; + } for (ir::Node *node : sorted_ops) { if (DealWithSpecialOp(&result, node)) { @@ -192,8 +196,22 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( static_cast(boost::get(node->Op()->GetAttr( OpProtoAndCheckerMaker::OpRoleAttrName())) & static_cast(OpRole::kBackward)); + // optimize op is already processed in DealWithSpecialOp, + // here we only consider backward op if (!is_bk_op) continue; + /* + * the op that will generate the gradient of on parameter will have + one attr op_role_var + * to record the parameter and gradient, like: + attrs { + name: "op_role_var" + type: STRINGS + strings: "fc_1.b_0" + strings: "fc_1.b_0@GRAD" + } + */ + // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. auto backward_vars = @@ -204,7 +222,7 @@ std::unique_ptr MultiDevSSAGraphBuilderBase::ApplyImpl( for (size_t i = 0; i < backward_vars.size(); i += 2) { auto &p_name = backward_vars[i]; auto &g_name = backward_vars[i + 1]; - VLOG(10) << "Bcast " << g_name << " for parameter " << p_name; + VLOG(3) << "Bcast " << g_name << " for parameter " << p_name; InsertCollectiveOp(&result, p_name, g_name); } @@ -385,7 +403,7 @@ void MultiDevSSAGraphBuilderBase::CreateFusedBroadcastOp( void MultiDevSSAGraphBuilderBase::CreateComputationalOp(ir::Graph *result, ir::Node *node, - int dev_id) const { + size_t dev_id) const { result->Get(kGraphOps).emplace_back( new ComputationOpHandle(result->CreateOpNode(node->Op()), local_scopes_[dev_id], places_[dev_id], dev_id)); @@ -454,9 +472,8 @@ void MultiDevSSAGraphBuilderBase::CreateComputationalOps( } } -VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp(ir::Graph *result, - const std::string &og, - int dst_dev_id) const { +VarHandle *MultiDevSSAGraphBuilderBase::CreateReduceOp( + ir::Graph *result, const std::string &og, size_t dst_dev_id) const { #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) result->Get(kGraphOps).emplace_back(new ReduceOpHandle( result->CreateEmptyNode("reduce", ir::Node::Type::kOperation), @@ -720,6 +737,10 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { bool insert_op = false; if (OpHaveRole(*node, OpRole::kRPC)) { + // in async_mode, each graph will send it's own gradient. + if (strategy_.async_mode_ && node->Op()->Type() == "send") { + return false; + } int op_dev_id = CreateRPCOp(result, node); PADDLE_ENFORCE(op_dev_id != -1, "Can not schedule the RPC operator to the right place."); @@ -737,6 +758,8 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, } else if (OpHaveRole(*node, OpRole::kDist)) { int op_dev_id = CreateDistTrainOp(result, node); if (node->Op()->Type() == "concat") { + // the input(block of parameter) of concat is on different device, + // the output(parameter) will on one device. auto origin_param_name = node->Op()->OutputArgumentNames()[0]; bcast_var_name_set_[op_dev_id].emplace(origin_param_name); } @@ -744,6 +767,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, } else { int op_dev_id = GetOpDeviceID(node); if (op_dev_id != -1) { // This op only runs on one specific device. + // optimize op will be processed here. CreateComputationalOp(result, node, op_dev_id); for (ir::Node *n : node->outputs) { sharded_var_device_.emplace(n->Name(), op_dev_id); @@ -905,6 +929,7 @@ int DistSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result, const std::string &p_name, const std::string &g_name) const { + // collective gradient to each device size_t cur_device_id = 0; switch (strategy_.reduce_) { case BuildStrategy::ReduceStrategy::kReduce: diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index e91397816c3a983e1a9d211ea9368d78e2bc89f5..377ba50fccf4abe2ee7c894d64b4728387efcbe4 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -68,10 +68,10 @@ class MultiDevSSAGraphBuilderBase : public ir::Pass { proto::VarType::Type dtype) const; VarHandle *CreateReduceOp(ir::Graph *result, const std::string &og, - int dst_dev_id) const; + size_t dst_dev_id) const; void CreateComputationalOp(ir::Graph *result, ir::Node *node, - int dev_id) const; + size_t dev_id) const; bool IsSparseGradient(const std::string &og) const; @@ -118,16 +118,16 @@ class AllReduceSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { class AsyncSSAGraphBuilder : public MultiDevSSAGraphBuilderBase { protected: - virtual void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, - const std::string &g_name) const {} + void InsertCollectiveOp(ir::Graph *result, const std::string &p_name, + const std::string &g_name) const override {} bool NeedCollectiveOps() const override { return false; } - virtual bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const { + bool DealWithSpecialOp(ir::Graph *result, ir::Node *node) const override { return false; } - virtual void InsertPostprocessOps(ir::Graph *result) const {} + void InsertPostprocessOps(ir::Graph *result) const override {} }; class BalanceVarSSAGraphBuilder : public MultiDevSSAGraphBuilderBase {