未验证 提交 0a144ca1 编写于 作者: L Leo Chen 提交者: GitHub

convert grad_merge_all_reduce in graph to program (#46353)

上级 173b39bb
...@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle { ...@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
std::string Name() const override; std::string Name() const override;
std::string GradMergeCondName() { return grad_merge_cond_name_; }
protected: protected:
void RunImpl() override; void RunImpl() override;
......
...@@ -19,7 +19,8 @@ cc_library( ...@@ -19,7 +19,8 @@ cc_library(
cc_library( cc_library(
graph_helper graph_helper
SRCS graph_helper.cc SRCS graph_helper.cc
DEPS graph program_utils scale_loss_grad_op_handle) DEPS graph program_utils scale_loss_grad_op_handle
grad_merge_all_reduce_op_handle)
cc_library( cc_library(
pass pass
SRCS pass.cc SRCS pass.cc
......
...@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
all_persistable = false; all_persistable = false;
} }
} }
VLOG(4) << "all_persistable:" << all_persistable;
if (all_persistable) { VLOG(4) << "any_persistable:" << all_persistable;
// All grads are persistable, only need to be executed once at the // NOTE. In scope_buffered_ssa_graph_executor, after each execution of
// beginning. // DropScope(), non persistable vars will be Erase or Clear. So
result->Get<details::ProgramDescs>(details::kStartupProgramDescs) // coalesce_tensor op needs to be executed again after the execution
.emplace_back(); // of DropScope().
ProgramDesc &program_desc =
result->Get<details::ProgramDescs>(details::kStartupProgramDescs) // we can make fused_output persistable, so the memeory is not cleared
.back(); // and coalesce_tensor op do nothing if the inputs are already continue.
auto *global_block = program_desc.MutableBlock(0);
AppendAllocSpaceForVarsOp(params_name, result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
grads_name, ProgramDesc &program_desc =
fused_var_name, result->Get<details::ProgramDescs>(details::kProgramDescs).back();
dtype, auto *global_block = program_desc.MutableBlock(0);
all_persistable, AppendAllocSpaceForVarsOp(params_name,
global_block); grads_name,
} else { fused_var_name,
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of dtype,
// DropScope(), non persistable vars will be Erase or Clear. So any_persistable,
// coalesce_tensor op needs to be executed again after the execution global_block);
// of DropScope().
result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
ProgramDesc &program_desc =
result->Get<details::ProgramDescs>(details::kProgramDescs).back();
auto *global_block = program_desc.MutableBlock(0);
AppendAllocSpaceForVarsOp(params_name,
grads_name,
fused_var_name,
dtype,
any_persistable,
global_block);
}
} }
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name, void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
...@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass { ...@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
const proto::VarType::Type &dtype, const proto::VarType::Type &dtype,
bool persistable, bool persistable,
BlockDesc *global_block) const { BlockDesc *global_block) const {
auto fused_out_var = global_block->Var(fused_var_name);
fused_out_var->SetPersistable(persistable);
auto op_desc = global_block->AppendOp(); auto op_desc = global_block->AppendOp();
op_desc->SetType("coalesce_tensor"); op_desc->SetType("coalesce_tensor");
op_desc->SetInput("Input", params_name); op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name); op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name}); op_desc->SetOutput("FusedOutput", {fused_var_name});
op_desc->SetAttr("dtype", static_cast<int>(dtype)); op_desc->SetAttr("dtype", static_cast<int>(dtype));
op_desc->SetAttr("persist_output", persistable); op_desc->SetAttr("persist_output", persistable);
} }
}; };
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <queue> #include <queue>
#include <stack> #include <stack>
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
...@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node, ...@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetType("c_allreduce_sum"); desc2.SetType("c_allreduce_sum");
if (node.IsWrappedBy<details::OpHandleBase>()) { if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander = details::OpHandleBase &op_handler =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>(); const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
// set inputs // set inputs
auto in_var_handles = op_hander.Inputs(); auto in_var_handles = op_handler.Inputs();
std::vector<std::string> in_names; std::vector<std::string> in_names;
for (const auto &in : in_var_handles) { for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) { if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
...@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node, ...@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetInput("X", {name}); desc2.SetInput("X", {name});
// set outputs // set outputs
auto out_var_handles = op_hander.Outputs(); auto out_var_handles = op_handler.Outputs();
std::vector<std::string> out_names; std::vector<std::string> out_names;
for (const auto &out : out_var_handles) { for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) { if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
...@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node, ...@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetOutput("Out", {name}); desc2.SetOutput("Out", {name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId( int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_hander)->GetComm()); dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm());
desc2.SetAttr("ring_id", ring_id); desc2.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true); desc2.SetAttr("use_calc_stream", true);
// handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
auto cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)
->GradMergeCondName();
desc2.SetInput("Cond", {cond_name});
}
} }
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
...@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph, ...@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
VLOG(8) << "Merge main programs"; VLOG(8) << "Merge main programs";
MergePrograms(program, program_descs, /*append=*/false); MergePrograms(program, program_descs, /*append=*/false);
} }
// handle startup program
} }
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies( static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
......
...@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst, ...@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
if (dst_block->FindVar(src_new_var->Name())) continue; if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name()); auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var; *dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name(); VLOG(10) << "Create new variable " << dst_new_var->Name()
<< ", persistable:" << dst_new_var->Persistable();
} }
}; };
VisitAllElements(srcs, create_var_visitor, reverse); VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) { auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
......
...@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel { ...@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (var_name == "Cond") {
return expected_kernel_type;
} else {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}
}; };
template <ReduceType red_type, typename T> template <ReduceType red_type, typename T>
...@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> { ...@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO) #if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
...@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
...@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> { ...@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
...@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T> ...@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> { class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
...@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> { ...@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
auto in = ctx.Input<framework::Tensor>("X"); auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out"); auto out = ctx.Output<framework::Tensor>("Out");
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
cnclDataType_t dtype = cnclDataType_t dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype())); platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype()));
...@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us ...@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
)DOC", )DOC",
GetName(), GetName(),
GetName())); GetName()));
ExtraMake();
} }
protected: protected:
virtual std::string GetName() const = 0; virtual std::string GetName() const = 0;
virtual void ExtraMake() {}
}; };
} // namespace operators } // namespace operators
......
...@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
class CAllReduceSumOpMaker : public CAllReduceOpMaker { class CAllReduceSumOpMaker : public CAllReduceOpMaker {
protected: protected:
void ExtraMake() override {
AddInput("Cond", "(Tensor), whether to do all reduce or not.")
.AsDispensable();
}
std::string GetName() const override { return "Sum"; } std::string GetName() const override { return "Sum"; }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册