diff --git a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h index 9cfc3ada6ac3d73b2e505b2513af4bb7a0195b99..ce01f85eaba52a5035a905b5463ffaddbe540cff 100644 --- a/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h @@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle { std::string Name() const override; + std::string GradMergeCondName() { return grad_merge_cond_name_; } + protected: void RunImpl() override; diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a58434eed61ba973fd0b996209c2e49f49a0ac1c..455417b521d9ad871ef316109fc0f6b7b85f8f3c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -19,7 +19,8 @@ cc_library( cc_library( graph_helper SRCS graph_helper.cc - DEPS graph program_utils scale_loss_grad_op_handle) + DEPS graph program_utils scale_loss_grad_op_handle + grad_merge_all_reduce_op_handle) cc_library( pass SRCS pass.cc diff --git a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc index 0cfc77728f94227ba9318a9ca2dcf827e6812422..a7a5a21b3291c4bf2e5406c92f37e1d2206865cf 100644 --- a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc +++ b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc @@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass { all_persistable = false; } } - - if (all_persistable) { - // All grads are persistable, only need to be executed once at the - // beginning. - result->Get(details::kStartupProgramDescs) - .emplace_back(); - ProgramDesc &program_desc = - result->Get(details::kStartupProgramDescs) - .back(); - auto *global_block = program_desc.MutableBlock(0); - AppendAllocSpaceForVarsOp(params_name, - grads_name, - fused_var_name, - dtype, - all_persistable, - global_block); - } else { - // NOTE. In scope_buffered_ssa_graph_executor, after each execution of - // DropScope(), non persistable vars will be Erase or Clear. So - // coalesce_tensor op needs to be executed again after the execution - // of DropScope(). - result->Get(details::kProgramDescs).emplace_back(); - ProgramDesc &program_desc = - result->Get(details::kProgramDescs).back(); - auto *global_block = program_desc.MutableBlock(0); - AppendAllocSpaceForVarsOp(params_name, - grads_name, - fused_var_name, - dtype, - any_persistable, - global_block); - } + VLOG(4) << "all_persistable:" << all_persistable; + VLOG(4) << "any_persistable:" << all_persistable; + // NOTE. In scope_buffered_ssa_graph_executor, after each execution of + // DropScope(), non persistable vars will be Erase or Clear. So + // coalesce_tensor op needs to be executed again after the execution + // of DropScope(). + + // we can make fused_output persistable, so the memeory is not cleared + // and coalesce_tensor op do nothing if the inputs are already continue. + + result->Get(details::kProgramDescs).emplace_back(); + ProgramDesc &program_desc = + result->Get(details::kProgramDescs).back(); + auto *global_block = program_desc.MutableBlock(0); + AppendAllocSpaceForVarsOp(params_name, + grads_name, + fused_var_name, + dtype, + any_persistable, + global_block); } void AppendAllocSpaceForVarsOp(const std::vector ¶ms_name, @@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass { const proto::VarType::Type &dtype, bool persistable, BlockDesc *global_block) const { + auto fused_out_var = global_block->Var(fused_var_name); + fused_out_var->SetPersistable(persistable); + auto op_desc = global_block->AppendOp(); op_desc->SetType("coalesce_tensor"); op_desc->SetInput("Input", params_name); op_desc->SetOutput("Output", grads_name); op_desc->SetOutput("FusedOutput", {fused_var_name}); op_desc->SetAttr("dtype", static_cast(dtype)); - op_desc->SetAttr("persist_output", persistable); } }; diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 51e49be5f96f469ae4e5db641592801906bb8850..41dabfbd36bc61b4e939dca3d93139a82586308a 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/ir/pass.h" @@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node, desc2.SetType("c_allreduce_sum"); if (node.IsWrappedBy()) { - details::OpHandleBase &op_hander = + details::OpHandleBase &op_handler = const_cast(&node)->Wrapper(); // set inputs - auto in_var_handles = op_hander.Inputs(); + auto in_var_handles = op_handler.Inputs(); std::vector in_names; for (const auto &in : in_var_handles) { if (dynamic_cast(in) != nullptr) { @@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node, desc2.SetInput("X", {name}); // set outputs - auto out_var_handles = op_hander.Outputs(); + auto out_var_handles = op_handler.Outputs(); std::vector out_names; for (const auto &out : out_var_handles) { if (dynamic_cast(out) != nullptr) { @@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node, desc2.SetOutput("Out", {name}); int ring_id = platform::NCCLCommContext::Instance().GetRingId( - dynamic_cast(&op_hander)->GetComm()); + dynamic_cast(&op_handler)->GetComm()); desc2.SetAttr("ring_id", ring_id); desc2.SetAttr("use_calc_stream", true); + + // handle grad merge + if (dynamic_cast(&op_handler)) { + VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"; + auto cond_name = + dynamic_cast(&op_handler) + ->GradMergeCondName(); + desc2.SetInput("Cond", {cond_name}); + } } desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), @@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph, VLOG(8) << "Merge main programs"; MergePrograms(program, program_descs, /*append=*/false); } + // handle startup program } static std::vector> GetOpDependencies( diff --git a/paddle/fluid/framework/program_utils.cc b/paddle/fluid/framework/program_utils.cc index a32350a569db0a9f46daf88d6e001b8c239c5519..197c74ccac3d928211e59d3cce259f3b4a70ea3c 100644 --- a/paddle/fluid/framework/program_utils.cc +++ b/paddle/fluid/framework/program_utils.cc @@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst, if (dst_block->FindVar(src_new_var->Name())) continue; auto *dst_new_var = dst_block->Var(src_new_var->Name()); *dst_new_var = *src_new_var; - VLOG(10) << "Create new variable " << dst_new_var->Name(); + VLOG(10) << "Create new variable " << dst_new_var->Name() + << ", persistable:" << dst_new_var->Persistable(); } }; + VisitAllElements(srcs, create_var_visitor, reverse); auto create_op_visitor = [dst, reverse](const ProgramDesc &src) { diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 299dd59d5efa7202ab70a317d7929b391dcf03fd..ac9f5858e408caf76883c8899027d915378130ac 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, + const framework::Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { + if (var_name == "Cond") { + return expected_kernel_type; + } else { + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); + } + } }; template @@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_GLOO) + auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_ASCEND_CL) + if (ctx.HasInput("Cond")) { + auto cond = ctx.Input("Cond"); + auto place = cond->place(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(place), + true, + platform::errors::PreconditionNotMet( + "The input `cond` tensor should be on cpu place")); + PADDLE_ENFORCE_EQ(cond->numel(), + 1, + platform::errors::PreconditionNotMet( + "The input `cond` should be shape [1]")); + if (!cond->data()[0]) { + VLOG(4) << "Skip all reduce Op since cond is 0"; + return; + } + } + auto in = ctx.Input("X"); auto out = ctx.Output("Out"); auto place = ctx.GetPlace(); @@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_XPU_BKCL) + if (ctx.HasInput("Cond")) { + auto cond = ctx.Input("Cond"); + auto place = cond->place(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(place), + true, + platform::errors::PreconditionNotMet( + "The input `cond` tensor should be on cpu place")); + PADDLE_ENFORCE_EQ(cond->numel(), + 1, + platform::errors::PreconditionNotMet( + "The input `cond` should be shape [1]")); + if (!cond->data()[0]) { + VLOG(4) << "Skip all reduce Op since cond is 0"; + return; + } + } + auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -364,6 +411,23 @@ template class CAllReduceOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { + if (ctx.HasInput("Cond")) { + auto cond = ctx.Input("Cond"); + auto place = cond->place(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(place), + true, + platform::errors::PreconditionNotMet( + "The input `cond` tensor should be on cpu place")); + PADDLE_ENFORCE_EQ(cond->numel(), + 1, + platform::errors::PreconditionNotMet( + "The input `cond` should be shape [1]")); + if (!cond->data()[0]) { + VLOG(4) << "Skip all reduce Op since cond is 0"; + return; + } + } + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel { auto in = ctx.Input("X"); auto out = ctx.Output("Out"); + if (ctx.HasInput("Cond")) { + auto cond = ctx.Input("Cond"); + auto place = cond->place(); + PADDLE_ENFORCE_EQ(platform::is_cpu_place(place), + true, + platform::errors::PreconditionNotMet( + "The input `cond` tensor should be on cpu place")); + PADDLE_ENFORCE_EQ(cond->numel(), + 1, + platform::errors::PreconditionNotMet( + "The input `cond` should be shape [1]")); + if (!cond->data()[0]) { + VLOG(4) << "Skip all reduce Op since cond is 0"; + return; + } + } + auto place = ctx.GetPlace(); cnclDataType_t dtype = platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype())); @@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us )DOC", GetName(), GetName())); + ExtraMake(); } protected: virtual std::string GetName() const = 0; + virtual void ExtraMake() {} }; } // namespace operators diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc index 5aac406bd6e26c861049a2cebf9ead620ce3e927..b7831972872894500c0b7b6ed60d0e19d228b9ff 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.cc @@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker { class CAllReduceSumOpMaker : public CAllReduceOpMaker { protected: + void ExtraMake() override { + AddInput("Cond", "(Tensor), whether to do all reduce or not.") + .AsDispensable(); + } std::string GetName() const override { return "Sum"; } };