未验证 提交 0a144ca1 编写于 作者: L Leo Chen 提交者: GitHub

convert grad_merge_all_reduce in graph to program (#46353)

上级 173b39bb
......@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
std::string Name() const override;
std::string GradMergeCondName() { return grad_merge_cond_name_; }
protected:
void RunImpl() override;
......
......@@ -19,7 +19,8 @@ cc_library(
cc_library(
graph_helper
SRCS graph_helper.cc
DEPS graph program_utils scale_loss_grad_op_handle)
DEPS graph program_utils scale_loss_grad_op_handle
grad_merge_all_reduce_op_handle)
cc_library(
pass
SRCS pass.cc
......
......@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
all_persistable = false;
}
}
if (all_persistable) {
// All grads are persistable, only need to be executed once at the
// beginning.
result->Get<details::ProgramDescs>(details::kStartupProgramDescs)
.emplace_back();
ProgramDesc &program_desc =
result->Get<details::ProgramDescs>(details::kStartupProgramDescs)
.back();
auto *global_block = program_desc.MutableBlock(0);
AppendAllocSpaceForVarsOp(params_name,
grads_name,
fused_var_name,
dtype,
all_persistable,
global_block);
} else {
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
// DropScope(), non persistable vars will be Erase or Clear. So
// coalesce_tensor op needs to be executed again after the execution
// of DropScope().
result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
ProgramDesc &program_desc =
result->Get<details::ProgramDescs>(details::kProgramDescs).back();
auto *global_block = program_desc.MutableBlock(0);
AppendAllocSpaceForVarsOp(params_name,
grads_name,
fused_var_name,
dtype,
any_persistable,
global_block);
}
VLOG(4) << "all_persistable:" << all_persistable;
VLOG(4) << "any_persistable:" << all_persistable;
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
// DropScope(), non persistable vars will be Erase or Clear. So
// coalesce_tensor op needs to be executed again after the execution
// of DropScope().
// we can make fused_output persistable, so the memeory is not cleared
// and coalesce_tensor op do nothing if the inputs are already continue.
result->Get<details::ProgramDescs>(details::kProgramDescs).emplace_back();
ProgramDesc &program_desc =
result->Get<details::ProgramDescs>(details::kProgramDescs).back();
auto *global_block = program_desc.MutableBlock(0);
AppendAllocSpaceForVarsOp(params_name,
grads_name,
fused_var_name,
dtype,
any_persistable,
global_block);
}
void AppendAllocSpaceForVarsOp(const std::vector<std::string> &params_name,
......@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
const proto::VarType::Type &dtype,
bool persistable,
BlockDesc *global_block) const {
auto fused_out_var = global_block->Var(fused_var_name);
fused_out_var->SetPersistable(persistable);
auto op_desc = global_block->AppendOp();
op_desc->SetType("coalesce_tensor");
op_desc->SetInput("Input", params_name);
op_desc->SetOutput("Output", grads_name);
op_desc->SetOutput("FusedOutput", {fused_var_name});
op_desc->SetAttr("dtype", static_cast<int>(dtype));
op_desc->SetAttr("persist_output", persistable);
}
};
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <queue>
#include <stack>
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h"
......@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetType("c_allreduce_sum");
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
details::OpHandleBase &op_handler =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
// set inputs
auto in_var_handles = op_hander.Inputs();
auto in_var_handles = op_handler.Inputs();
std::vector<std::string> in_names;
for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
......@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetInput("X", {name});
// set outputs
auto out_var_handles = op_hander.Outputs();
auto out_var_handles = op_handler.Outputs();
std::vector<std::string> out_names;
for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
......@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
desc2.SetOutput("Out", {name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_hander)->GetComm());
dynamic_cast<details::NCCLOpHandleBase *>(&op_handler)->GetComm());
desc2.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true);
// handle grad merge
if (dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)) {
VLOG(4) << "FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum";
auto cond_name =
dynamic_cast<details::FusedGradMergeAllReduceOpHandle *>(&op_handler)
->GradMergeCondName();
desc2.SetInput("Cond", {cond_name});
}
}
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
......@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
VLOG(8) << "Merge main programs";
MergePrograms(program, program_descs, /*append=*/false);
}
// handle startup program
}
static std::vector<std::vector<ir::Node::Dep>> GetOpDependencies(
......
......@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
if (dst_block->FindVar(src_new_var->Name())) continue;
auto *dst_new_var = dst_block->Var(src_new_var->Name());
*dst_new_var = *src_new_var;
VLOG(10) << "Create new variable " << dst_new_var->Name();
VLOG(10) << "Create new variable " << dst_new_var->Name()
<< ", persistable:" << dst_new_var->Persistable();
}
};
VisitAllElements(srcs, create_var_visitor, reverse);
auto create_op_visitor = [dst, reverse](const ProgramDesc &src) {
......
......@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const framework::Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
if (var_name == "Cond") {
return expected_kernel_type;
} else {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}
};
template <ReduceType red_type, typename T>
......@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_GLOO)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
......@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_ASCEND_CL)
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
auto place = ctx.GetPlace();
......@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_XPU_BKCL)
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
......@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
......@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
if (ctx.HasInput("Cond")) {
auto cond = ctx.Input<framework::Tensor>("Cond");
auto place = cond->place();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(place),
true,
platform::errors::PreconditionNotMet(
"The input `cond` tensor should be on cpu place"));
PADDLE_ENFORCE_EQ(cond->numel(),
1,
platform::errors::PreconditionNotMet(
"The input `cond` should be shape [1]"));
if (!cond->data<bool>()[0]) {
VLOG(4) << "Skip all reduce Op since cond is 0";
return;
}
}
auto place = ctx.GetPlace();
cnclDataType_t dtype =
platform::ToCNCLDataType(framework::TransToProtoVarType(in->dtype()));
......@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
)DOC",
GetName(),
GetName()));
ExtraMake();
}
protected:
virtual std::string GetName() const = 0;
virtual void ExtraMake() {}
};
} // namespace operators
......
......@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
class CAllReduceSumOpMaker : public CAllReduceOpMaker {
protected:
void ExtraMake() override {
AddInput("Cond", "(Tensor), whether to do all reduce or not.")
.AsDispensable();
}
std::string GetName() const override { return "Sum"; }
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册