未验证 提交 ca552933 编写于 作者: Z zhangbo9674 提交者: GitHub

Add fuse_act_add_grad_pass (#48346)

* add fuse act add grad pass

* polish code

* refine code

* add test

* refine code
上级 e337d280
...@@ -31,6 +31,7 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const { ...@@ -31,6 +31,7 @@ void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
{ {
std::unordered_set<std::string> in_place_act_types = {"relu_grad"}; std::unordered_set<std::string> in_place_act_types = {"relu_grad"};
graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types); graph = FuseElewiseAddActInplaceGrad(graph, in_place_act_types);
graph = FuseActElewiseAddInplaceGrad(graph, in_place_act_types);
} }
// Remove the removable intermediate_out. // Remove the removable intermediate_out.
...@@ -110,7 +111,7 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd( ...@@ -110,7 +111,7 @@ ir::Graph *FuseElewiseAddActPass::FuseActElewiseAdd(
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) { Graph *g) {
VLOG(4) << "handle FuseElewiseAddAct fuse"; VLOG(4) << "handle FuseActElewiseAdd fuse";
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern); GET_IR_NODE_FROM_SUBGRAPH(ele_x, ele_x, act_elewise_add_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -220,6 +221,86 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad( ...@@ -220,6 +221,86 @@ ir::Graph *FuseElewiseAddActPass::FuseElewiseAddActInplaceGrad(
return graph; return graph;
} }
// the backward of act(ele_add(x,y))
// act_grad: in["Out", "Out@GRAD"], out["X@GRAD"]
// ele_add_grad: in["Y", "Out@GRAD"], out["X@GRAD", "Y@GRAD"]
ir::Graph *FuseElewiseAddActPass::FuseActElewiseAddInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("act_elewise_add_grad", graph);
GraphPatternDetector gpd;
auto *d_out_var =
gpd.mutable_pattern()
->NewNode("act_elewise_add_grad_inplace/d_out_var")
->AsInput()
->assert_is_ops_input({"elementwise_add_grad"}, GradVarName("Out"));
patterns::ActElewiseAddInplaceGrad act_elewise_add_grad_pattern(
gpd.mutable_pattern(), "act_elewise_add_grad_inplace");
act_elewise_add_grad_pattern(d_out_var, act_types);
int found_elewise_add_act_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle ActFuseElewiseAddGrad1 fuse";
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_grad_op, ele_add_grad_op, act_elewise_add_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
act_grad_op, act_grad_op, act_elewise_add_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
intermediate_var, intermediate_var, act_elewise_add_grad_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
d_intermediate_var, d_intermediate_var, act_elewise_add_grad_pattern);
std::string d_out_var_n = subgraph.at(d_out_var)->Name();
std::string intermediate_var_n = intermediate_var->Name();
std::string d_intermediate_var_n = d_intermediate_var->Name();
OpDesc desc;
desc.SetType("fused_elemwise_add_activation_grad");
desc.SetInput("IntermediateOut",
std::vector<std::string>({intermediate_var_n}));
desc.SetInput("X", {});
desc.SetInput("Y", ele_add_grad_op->Op()->Input("X"));
desc.SetInput("Out", {});
desc.SetInput(GradVarName("Out"), std::vector<std::string>({d_out_var_n}));
desc.SetOutput(GradVarName("X"),
act_grad_op->Op()->Output(GradVarName("X")));
desc.SetOutput(GradVarName("Y"),
ele_add_grad_op->Op()->Output(GradVarName("X")));
desc.SetOutput(GradVarName("IntermediateOut"),
std::vector<std::string>({d_intermediate_var_n}));
desc.SetAttr("save_intermediate_out", false);
desc.SetAttr("functor_list",
std::vector<std::string>({ele_add_grad_op->Op()->Type(),
act_grad_op->Op()->Type()}));
for (auto &n : {ele_add_grad_op->Op(), act_grad_op->Op()}) {
for (auto &m_ele : n->GetAttrMap()) {
desc.SetAttr(m_ele.first, m_ele.second);
}
}
auto fused_node = g->CreateOpNode(&desc);
VLOG(4) << "\n\t " << d_out_var_n << " -> " << ele_add_grad_op->Name()
<< " -> " << d_intermediate_var_n << "\n\t " << intermediate_var_n
<< " and " << d_intermediate_var_n << " -> " << act_grad_op->Name();
ReLinkNodes2(
g, d_intermediate_var, ele_add_grad_op, act_grad_op, fused_node);
found_elewise_add_act_count++;
};
gpd(graph, handler);
AddStatis(found_elewise_add_act_count);
return graph;
}
Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode( Node *FuseElewiseAddActPass::CreateFuseElewiseAddActNode(
Graph *g, Graph *g,
const Node *op_1, const Node *op_1,
...@@ -364,6 +445,52 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, ...@@ -364,6 +445,52 @@ void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
GraphSafeRemoveNodes(graph, nodes2delete); GraphSafeRemoveNodes(graph, nodes2delete);
} }
void FuseElewiseAddActPass::ReLinkNodes2(Graph *graph,
const Node *intermediate_out,
Node *op_1,
Node *op_2,
Node *fused_op) const { // delete act
for (auto &in : op_1->inputs) {
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_1, fused_op, in->outputs);
}
std::unordered_set<const Node *> nodes2delete;
for (auto &out : op_1->outputs) {
if (out->IsCtrlVar()) {
auto result_iter = std::find_if(
op_2->inputs.begin(),
op_2->inputs.end(),
[&out](const Node *node) -> bool { return node == out; });
if (result_iter == op_2->inputs.end()) {
IR_OP_VAR_LINK(fused_op, out);
} else {
nodes2delete.emplace(out);
}
} else {
IR_OP_VAR_LINK(fused_op, out);
}
}
for (auto &in : op_2->inputs) {
if (in == intermediate_out || nodes2delete.count(in)) {
continue;
}
fused_op->inputs.emplace_back(in);
in->outputs = this->ReplaceNode(op_2, fused_op, in->outputs);
}
for (auto &out : op_2->outputs) {
IR_OP_VAR_LINK(fused_op, out);
}
nodes2delete.insert(std::move(op_1));
nodes2delete.insert(std::move(op_2));
GraphSafeRemoveNodes(graph, nodes2delete);
}
std::vector<Node *> FuseElewiseAddActPass::ReplaceNode( std::vector<Node *> FuseElewiseAddActPass::ReplaceNode(
Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const { Node *cur_node, Node *new_node, const std::vector<Node *> &nodes) const {
std::vector<Node *> new_list(nodes.size()); std::vector<Node *> new_list(nodes.size());
......
...@@ -49,6 +49,9 @@ class FuseElewiseAddActPass : public FusePassBase { ...@@ -49,6 +49,9 @@ class FuseElewiseAddActPass : public FusePassBase {
ir::Graph *FuseElewiseAddActInplaceGrad( ir::Graph *FuseElewiseAddActInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const; ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
ir::Graph *FuseActElewiseAddInplaceGrad(
ir::Graph *graph, const std::unordered_set<std::string> &act_types) const;
/** /**
* Remove the removable intermediate_out. * Remove the removable intermediate_out.
* - If the intermediate_out is only used by the backward op, but the * - If the intermediate_out is only used by the backward op, but the
...@@ -69,6 +72,11 @@ class FuseElewiseAddActPass : public FusePassBase { ...@@ -69,6 +72,11 @@ class FuseElewiseAddActPass : public FusePassBase {
Node *op_1, Node *op_1,
Node *op_2, Node *op_2,
Node *fused_op) const; Node *fused_op) const;
void ReLinkNodes2(Graph *graph,
const Node *intermediate_out,
Node *op_1,
Node *op_2,
Node *fused_op) const;
Node *CreateFuseElewiseAddActNode(Graph *g, Node *CreateFuseElewiseAddActNode(Graph *g,
const Node *op_1, const Node *op_1,
const Node *op_2, const Node *op_2,
......
...@@ -91,7 +91,6 @@ void GraphPatternDetector::operator()(Graph *graph, ...@@ -91,7 +91,6 @@ void GraphPatternDetector::operator()(Graph *graph,
if (!MarkPDNodesInGraph(*graph)) { if (!MarkPDNodesInGraph(*graph)) {
return; return;
} }
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
SortSubgraphs(&subgraphs); SortSubgraphs(&subgraphs);
...@@ -99,7 +98,6 @@ void GraphPatternDetector::operator()(Graph *graph, ...@@ -99,7 +98,6 @@ void GraphPatternDetector::operator()(Graph *graph,
ValidateByNodeRole(&subgraphs); ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return; if (subgraphs.empty()) return;
int id = 0; int id = 0;
for (auto &g : subgraphs) { for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph"; VLOG(3) << "optimizing #" << id++ << " subgraph";
...@@ -1613,6 +1611,33 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -1613,6 +1611,33 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
return ele_add_grad; return ele_add_grad;
} }
PDNode *patterns::ActElewiseAddInplaceGrad::operator()(
paddle::framework::ir::PDNode *d_out_var,
std::unordered_set<std::string> act_types) {
VLOG(4) << "ActElewiseAddInplaceGrad::operator";
auto *ele_add_grad_op = pattern->NewNode(ele_add_grad_op_repr())
->assert_is_op("elementwise_add_grad");
auto *act_grad_op =
pattern->NewNode(act_grad_op_repr())->assert_is_ops(act_types);
auto *d_intermediate_out_var =
pattern->NewNode(d_intermediate_var_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"))
->assert_is_ops_input(act_types, GradVarName("Out"));
auto *intermediate_out_var =
pattern->NewNode(intermediate_var_repr())
->assert_is_op_input("elementwise_add_grad", "Y")
->assert_is_ops_input(act_types, "Out");
ele_add_grad_op->LinksFrom({d_out_var});
d_intermediate_out_var->LinksFrom({ele_add_grad_op}).LinksTo({act_grad_op});
intermediate_out_var->LinksTo({ele_add_grad_op});
intermediate_out_var->LinksTo({act_grad_op});
return act_grad_op;
}
PDNode *patterns::ElewiseAddAct::operator()( PDNode *patterns::ElewiseAddAct::operator()(
paddle::framework::ir::PDNode *ele_x_var, paddle::framework::ir::PDNode *ele_x_var,
std::unordered_set<std::string> act_types) { std::unordered_set<std::string> act_types) {
......
...@@ -928,6 +928,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase { ...@@ -928,6 +928,27 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
PATTERN_DECL_NODE(ele_y); PATTERN_DECL_NODE(ele_y);
}; };
// the backward of ele_add(act(x), y)
// the act is inplace.
// op: elementwise_add_grad + act_grad
// named nodes: elementwise_add_grad, act_grad
// ele_y, d_ele_y, d_intermeiate_out, intermediate_out, d_x
struct ActElewiseAddInplaceGrad : public PatternBase {
ActElewiseAddInplaceGrad(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "act_elewise_add_grad1") {}
// ele_add_grad: in["Y", "Out@GRAD"], out["IntermediateOut@GRAD", "Y@GRAD"]
// act_grad: in["IntermediateOut", "IntermediateOut@GRAD"], out["X@GRAD"]
PDNode* operator()(PDNode* d_out_var, std::unordered_set<std::string> acts);
// declare operator node's name
PATTERN_DECL_NODE(ele_add_grad_op);
PATTERN_DECL_NODE(act_grad_op);
// // declare variable node's name
PATTERN_DECL_NODE(intermediate_var);
PATTERN_DECL_NODE(d_intermediate_var);
};
// The following patterns are used to fuse linear and act (ReLu or GeLU) // The following patterns are used to fuse linear and act (ReLu or GeLU)
// formula: act(F.linear(x)) // formula: act(F.linear(x))
// op: matmul_v2 + elementwise_add + act // op: matmul_v2 + elementwise_add + act
......
...@@ -462,7 +462,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -462,7 +462,6 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
bool should_skip_input = bool should_skip_input =
no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0; no_buffer_ins && no_buffer_ins->count(var_name_item.first) > 0;
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i]; auto var = var_name_item.second[i];
auto var_name = new_ins[var_name_item.first].at(i); auto var_name = new_ins[var_name_item.first].at(i);
......
...@@ -664,11 +664,9 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> { ...@@ -664,11 +664,9 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
in_y, in_y,
nullptr, nullptr,
platform::errors::InvalidArgument("Input(Y) should not be nullptr.")); platform::errors::InvalidArgument("Input(Y) should not be nullptr."));
auto in_out = ctx.Input<phi::DenseTensor>("Out"); phi::DenseTensor *in_out =
PADDLE_ENFORCE_NE( const_cast<phi::DenseTensor *>(ctx.Input<phi::DenseTensor>("Out"));
in_out,
nullptr,
platform::errors::InvalidArgument("Input(Out) should not be nullptr."));
auto in_out_grad = auto in_out_grad =
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")); ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_NE(in_out_grad, PADDLE_ENFORCE_NE(in_out_grad,
...@@ -726,6 +724,23 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> { ...@@ -726,6 +724,23 @@ class FusedElemwiseActivationGradKernel : public framework::OpKernel<T> {
in_x = const_cast<phi::DenseTensor *>(in_out_grad); in_x = const_cast<phi::DenseTensor *>(in_out_grad);
} }
// Get in_Out
if (ctx.HasInput("Out")) {
PADDLE_ENFORCE_NE(
in_out,
nullptr,
platform::errors::InvalidArgument("Input(X) should not be null."));
} else {
// If functor_list contains elementwise_add, the backward doesn't use
// in_x, in_y and in_out.
PADDLE_ENFORCE_EQ(InputXCanBeAbsent(functor_list),
true,
platform::errors::InvalidArgument(
"Only when the compoundfunctor contains "
"elementwise_add_grad, the 'X' could be absent."));
in_out = const_cast<phi::DenseTensor *>(in_out_grad);
}
bool has_in_place = HasInPlaceUnary(functor_list); bool has_in_place = HasInPlaceUnary(functor_list);
if (has_in_place) { if (has_in_place) {
RunGradFunctors<DeviceContext, T, true /*InPlace*/>(ctx, RunGradFunctors<DeviceContext, T, true /*InPlace*/>(ctx,
......
...@@ -14,10 +14,11 @@ ...@@ -14,10 +14,11 @@
import os import os
import unittest import unittest
import numpy
from parallel_executor_test_base import DeviceType, TestParallelExecutorBase from parallel_executor_test_base import DeviceType, TestParallelExecutorBase
from simple_nets import fc_with_batchnorm, init_data, simple_fc_net from simple_nets import fc_with_batchnorm, init_data, simple_fc_net
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -89,8 +90,72 @@ class TestMNIST(TestParallelExecutorBase): ...@@ -89,8 +90,72 @@ class TestMNIST(TestParallelExecutorBase):
) )
if __name__ == '__main__': class TestFuseActElewiseAddInplaceGradPass(unittest.TestCase):
import paddle def build_program(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
X = fluid.data(name="X", shape=[3, 3], dtype='float32')
Y = fluid.data(name="Y", shape=[3, 3], dtype='float32')
Out1 = X * 5
Out2 = fluid.layers.relu(Out1)
prediction = fluid.layers.elementwise_add(Y, Out2, axis=1)
loss = paddle.mean(prediction)
sgd = fluid.optimizer.SGD(learning_rate=0.001)
sgd.minimize(loss)
return X, Y, loss
def check(self, place):
paddle.seed(1)
numpy.random.seed(1)
paddle.framework.random._manual_program_seed(1)
main_program = fluid.Program()
startup_program = fluid.Program()
X, Y, loss = self.build_program(main_program, startup_program)
exe = fluid.Executor(place)
x = numpy.random.random(size=(3, 3)).astype('float32')
y = numpy.random.random(size=(3, 3)).astype('float32')
label = numpy.random.random(size=(3, 3)).astype('float32')
# open fused_pass
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = True
compiled_prog_fused = paddle.static.CompiledProgram(
main_program, build_strategy=build_strategy
)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
loss_data_fused = exe.run(
compiled_prog_fused,
feed={"X": x, "Y": y},
fetch_list=[loss.name],
)
# close fused_pass
build_strategy = fluid.BuildStrategy()
build_strategy.fuse_elewise_add_act_ops = False
compiled_prog = paddle.static.CompiledProgram(
main_program, build_strategy=build_strategy
)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
loss_data = exe.run(
compiled_prog, feed={"X": x, "Y": y}, fetch_list=[loss.name]
)
self.assertEqual(loss_data_fused, loss_data)
def test_fuse_act_add_grad_pass_cpu(self):
place = fluid.CPUPlace()
self.check(place)
def test_fuse_act_add_grad_pass_cuda(self):
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
self.check(place)
if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册