未验证 提交 5dda0ef6 编写于 作者: G Ghost Screaming 提交者: GitHub

Add fused_feed_forward pass (#50423)

* Add fused_feed_forward pass for semi-automatic static graph training.

* Add fused_feedforward property in parallel_executor.cc

* Polish code.

* Polish fused feed_forward pass code. Support use_dropout1 and
use_dropout2 option.

* Support model parallel in fused_feedforward pass.
上级 02296977
......@@ -383,6 +383,7 @@ set(IR_PASS_DEPS
fix_op_run_order_pass
fuse_gemm_epilogue_pass
fused_attention_pass
fused_feedforward_pass
delete_dropout_op_pass)
if(WITH_CINN)
......
......@@ -210,6 +210,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_sgd_op_pass");
AppendPass("fuse_momentum_op_pass");
}
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.fused_feedforward_, "fused_feedforward_pass");
#endif
}
void SetCollectiveContext() const {
......@@ -529,6 +532,9 @@ USE_PASS(fused_attention_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass);
#endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fused_feedforward_pass);
#endif
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......
......@@ -131,6 +131,8 @@ struct BuildStrategy {
bool fuse_gemm_epilogue_{false};
// Fused multi head attention
bool fused_attention_{false};
// Fused feed forward
bool fused_feedforward_{false};
// mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means
......@@ -264,6 +266,7 @@ inline std::ostream &operator<<(std::ostream &os,
os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl;
os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl;
os << "fused_attention_: " << strategy.fused_attention_ << std::endl;
os << "fused_feedforward_: " << strategy.fused_feedforward_ << std::endl;
os << "mkldnn_enabled_op_types_: ";
for (auto str : strategy.mkldnn_enabled_op_types_) {
os << str << ", ";
......
......@@ -126,6 +126,7 @@ message BuildStrategy {
optional bool fuse_gemm_epilogue = 16 [ default = false ];
optional string debug_graphviz_path = 17;
optional bool fused_attention = 18 [ default = false];
optional bool fused_feedforward = 19 [ default = false];
}
message ExecutionStrategy {
......
......@@ -264,6 +264,10 @@ cc_library(
fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc
DEPS pass graph_pattern_detector)
cc_library(
fused_feedforward_pass
SRCS fused_feedforward_pass.cc
DEPS pass graph_pattern_detector)
set(GLOB_PASS_LIB
${INFER_IR_PASSES}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fused_feedforward_pass.h"
#include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
void FusedFeedForwardPass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init(scope_name, graph);
for (auto use_mp : std::vector<bool>({true, false})) {
for (auto pre_layer_norm : std::vector<bool>({true, false})) {
for (auto add_residual : std::vector<bool>({true, false})) {
for (auto use_dropout_1 : std::vector<bool>({true, false})) {
for (auto use_dropout_2 : std::vector<bool>({true, false})) {
// pre_layer_norm and add_residual can't both be false!
if (!pre_layer_norm && !add_residual) continue;
// use_dropout_1 and use_dropout_2 can't both be false!
if (!use_dropout_1 && !use_dropout_2) continue;
Cache dropout_nodes_map;
graph = FusedFeedForwardFwd(graph,
use_mp,
pre_layer_norm,
add_residual,
use_dropout_1,
use_dropout_2,
&dropout_nodes_map);
graph = FusedFeedForwardBwd(graph,
use_mp,
pre_layer_norm,
add_residual,
use_dropout_1,
use_dropout_2,
&dropout_nodes_map);
}
}
}
}
}
}
ir::Graph *FusedFeedForwardPass::FusedFeedForwardFwd(
ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
const std::string scope_name("fused_feed_forward_fwd_pattern");
GraphPatternDetector gpd;
auto *x = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(scope_name, "x"))
->AsInput();
if (pre_layer_norm) {
x->assert_is_op_input("layer_norm", "X");
} else {
x->assert_is_op_input("matmul_v2", "X");
}
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
// 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
// -> layer_norm (post_layer_norm)
// other cases: may delete mp, residual_add, dropout1, dropout2 operators
patterns::FusedFeedForwardFwd fused_feedforward_pattern(gpd.mutable_pattern(),
scope_name);
std::unordered_set<std::string> act_types = {"gelu", "relu"};
VLOG(4) << "Fused Feedforward forward pass."
<< " pre_layer_norm: " << pre_layer_norm
<< ", add_residual: " << add_residual
<< ", use_dropout_1: " << use_dropout_1
<< ", use_dropout_2: " << use_dropout_2;
fused_feedforward_pattern(x,
act_types,
use_mp,
pre_layer_norm,
add_residual,
use_dropout_1,
use_dropout_2);
int found_fused_feedforward_fwd_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle feed_forward forward fusion";
// LayerNorm
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_op, layer_norm_op, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_out, layer_norm_out, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_variance, layer_norm_variance, fused_feedforward_pattern);
// Linear1
GET_IR_NODE_FROM_SUBGRAPH(
matmul_op_1, matmul_op_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_1, matmul_w_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_out_1, matmul_out_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_1, ele_add_op_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_1, ele_add_bias_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_out_1, ele_add_out_1, fused_feedforward_pattern);
// Activation
GET_IR_NODE_FROM_SUBGRAPH(act_op, act_op, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, fused_feedforward_pattern);
// Linear2
GET_IR_NODE_FROM_SUBGRAPH(
matmul_op_2, matmul_op_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_2, matmul_w_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_out_2, matmul_out_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_2, ele_add_op_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_2, ele_add_bias_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_out_2, ele_add_out_2, fused_feedforward_pattern);
if (use_dropout_1 && use_dropout_2) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_1, dropout_op_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_2, dropout_op_2, fused_feedforward_pattern);
if (PADDLE_GET_CONST(bool, dropout_op_1->Op()->GetAttr("is_test")) !=
PADDLE_GET_CONST(bool, dropout_op_2->Op()->GetAttr("is_test"))) {
LOG(WARNING) << "Dropout 1 and dropout 2 attribute is_test set "
"different values. "
<< "Skip fused_feedforward pattern replacement.";
return;
}
}
OpDesc fused_feedforward_op_desc(layer_norm_op->Op()->Block());
fused_feedforward_op_desc.SetType("fused_feedforward");
fused_feedforward_op_desc.SetInput("X", {subgraph.at(x)->Name()});
fused_feedforward_op_desc.SetInput("Linear1Weight", {matmul_w_1->Name()});
fused_feedforward_op_desc.SetInput("Linear1Bias", {ele_add_bias_1->Name()});
fused_feedforward_op_desc.SetInput("Linear2Weight", {matmul_w_2->Name()});
fused_feedforward_op_desc.SetInput("Linear2Bias", {ele_add_bias_2->Name()});
if (pre_layer_norm) {
fused_feedforward_op_desc.SetInput("Ln1Scale",
{layer_norm_scale->Name()});
fused_feedforward_op_desc.SetInput("Ln1Bias", {layer_norm_bias->Name()});
fused_feedforward_op_desc.SetOutput("Ln1Mean", {layer_norm_mean->Name()});
fused_feedforward_op_desc.SetOutput("Ln1Variance",
{layer_norm_variance->Name()});
fused_feedforward_op_desc.SetOutput("Ln1Out", {layer_norm_out->Name()});
fused_feedforward_op_desc.SetAttr(
"ln1_epsilon", layer_norm_op->Op()->GetAttr("epsilon"));
if (!add_residual) {
if (use_dropout_2) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_out_2, dropout_out_2, fused_feedforward_pattern);
fused_feedforward_op_desc.SetOutput("Out", {dropout_out_2->Name()});
} else {
fused_feedforward_op_desc.SetOutput("Out", {ele_add_out_2->Name()});
}
} else {
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_out_3, ele_add_out_3, fused_feedforward_pattern);
fused_feedforward_op_desc.SetOutput("Out", {ele_add_out_3->Name()});
}
} else {
fused_feedforward_op_desc.SetInput("Ln2Scale",
{layer_norm_scale->Name()});
fused_feedforward_op_desc.SetInput("Ln2Bias", {layer_norm_bias->Name()});
fused_feedforward_op_desc.SetOutput("Ln2Mean", {layer_norm_mean->Name()});
fused_feedforward_op_desc.SetOutput("Ln2Variance",
{layer_norm_variance->Name()});
fused_feedforward_op_desc.SetAttr(
"ln2_epsilon", layer_norm_op->Op()->GetAttr("epsilon"));
fused_feedforward_op_desc.SetOutput("Out", {layer_norm_out->Name()});
}
bool is_test = false;
DropoutNode record;
if (use_dropout_1) {
// Dropout1
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_1, dropout_op_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_mask_1, dropout_mask_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_out_1, dropout_out_1, fused_feedforward_pattern);
record.dropout_mask_node_1 = dropout_mask_1;
record.dropout_out_node_1 = dropout_out_1;
fused_feedforward_op_desc.SetOutput("Dropout1Mask",
{dropout_mask_1->Name()});
fused_feedforward_op_desc.SetOutput("Dropout1Out",
{dropout_out_1->Name()});
fused_feedforward_op_desc.SetAttr(
"dropout1_rate", dropout_op_1->Op()->GetAttr("dropout_prob"));
fused_feedforward_op_desc.SetAttr(
"dropout1_implementation",
dropout_op_1->Op()->GetAttr("dropout_implementation"));
is_test = PADDLE_GET_CONST(bool, dropout_op_1->Op()->GetAttr("is_test"));
} else {
fused_feedforward_op_desc.SetAttr("dropout1_rate", 0.0f);
VarDesc dropout_out_desc_1(
patterns::PDNodeName(scope_name, "dropout_out_1"));
dropout_out_desc_1.SetShape(ele_add_out_1->Var()->GetShape());
dropout_out_desc_1.SetDataType(ele_add_out_1->Var()->GetDataType());
dropout_out_desc_1.SetLoDLevel(ele_add_out_1->Var()->GetLoDLevel());
dropout_out_desc_1.SetStopGradient(static_cast<bool>(true));
record.dropout_out_node_1 = g->CreateVarNode(&dropout_out_desc_1);
fused_feedforward_op_desc.SetOutput("Dropout1Out",
{record.dropout_out_node_1->Name()});
VarDesc dropout_mask_desc_1(
patterns::PDNodeName(scope_name, "dropout_mask_1"));
dropout_mask_desc_1.SetShape(ele_add_out_1->Var()->GetShape());
dropout_mask_desc_1.SetDataType(proto::VarType::UINT8);
dropout_mask_desc_1.SetLoDLevel(ele_add_out_1->Var()->GetLoDLevel());
dropout_mask_desc_1.SetStopGradient(static_cast<bool>(true));
// Tranfer to backward operator.
record.dropout_mask_node_1 = g->CreateVarNode(&dropout_mask_desc_1);
fused_feedforward_op_desc.SetOutput("Dropout1Mask",
{record.dropout_mask_node_1->Name()});
}
if (use_dropout_2) {
// Dropout2
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_2, dropout_op_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_mask_2, dropout_mask_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
dropout_out_2, dropout_out_2, fused_feedforward_pattern);
record.dropout_out_node_2 = dropout_out_2;
record.dropout_mask_node_2 = dropout_mask_2;
fused_feedforward_op_desc.SetOutput("Dropout2Out",
{dropout_out_2->Name()});
fused_feedforward_op_desc.SetOutput("Dropout2Mask",
{dropout_mask_2->Name()});
fused_feedforward_op_desc.SetAttr(
"dropout2_rate", dropout_op_2->Op()->GetAttr("dropout_prob"));
fused_feedforward_op_desc.SetAttr(
"dropout2_implementation",
dropout_op_2->Op()->GetAttr("dropout_implementation"));
is_test = PADDLE_GET_CONST(bool, dropout_op_2->Op()->GetAttr("is_test"));
} else {
fused_feedforward_op_desc.SetAttr("dropout2_rate", 0.0f);
VarDesc dropout_out_desc_2(
patterns::PDNodeName(scope_name, "dropout_out_2"));
dropout_out_desc_2.SetShape(ele_add_out_2->Var()->GetShape());
dropout_out_desc_2.SetDataType(ele_add_out_2->Var()->GetDataType());
dropout_out_desc_2.SetLoDLevel(ele_add_out_2->Var()->GetLoDLevel());
dropout_out_desc_2.SetStopGradient(static_cast<bool>(true));
record.dropout_out_node_2 = g->CreateVarNode(&dropout_out_desc_2);
fused_feedforward_op_desc.SetOutput("Dropout2Out",
{record.dropout_out_node_2->Name()});
VarDesc dropout_mask_desc_2(
patterns::PDNodeName(scope_name, "dropout_mask_2"));
dropout_mask_desc_2.SetShape(ele_add_out_2->Var()->GetShape());
dropout_mask_desc_2.SetDataType(proto::VarType::UINT8);
dropout_mask_desc_2.SetLoDLevel(ele_add_out_2->Var()->GetLoDLevel());
dropout_mask_desc_2.SetStopGradient(static_cast<bool>(true));
// Transmit to backward operator.
record.dropout_mask_node_2 = g->CreateVarNode(&dropout_mask_desc_2);
fused_feedforward_op_desc.SetOutput("Dropout2Mask",
{record.dropout_mask_node_2->Name()});
}
// Transmit to backward operator.
dropout_nodes_map->insert(std::make_pair(matmul_w_1, record));
fused_feedforward_op_desc.SetOutput("Linear1Out", {ele_add_out_1->Name()});
fused_feedforward_op_desc.SetAttr("pre_layer_norm", pre_layer_norm);
fused_feedforward_op_desc.SetAttr("act_method", act_op->Op()->Type());
if (!use_dropout_1 && !use_dropout_2) {
is_test = true;
}
fused_feedforward_op_desc.SetAttr("is_test", is_test);
// These attributes set default value
fused_feedforward_op_desc.SetAttr("dropout1_fix_seed", false);
fused_feedforward_op_desc.SetAttr("dropout2_fix_seed", false);
fused_feedforward_op_desc.SetAttr("dropout1_seed", 0);
fused_feedforward_op_desc.SetAttr("dropout2_seed", 0);
fused_feedforward_op_desc.SetAttr("add_residual", add_residual);
int ring_id = -1;
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(
c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern);
ring_id =
PADDLE_GET_CONST(int, c_allreduce_sum_op->Op()->GetAttr("ring_id"));
}
fused_feedforward_op_desc.SetAttr("ring_id", ring_id);
auto fused_feedforward_node = g->CreateOpNode(&fused_feedforward_op_desc);
IR_NODE_LINK_TO(subgraph.at(x), fused_feedforward_node);
IR_NODE_LINK_TO(matmul_w_1, fused_feedforward_node);
IR_NODE_LINK_TO(ele_add_bias_1, fused_feedforward_node);
IR_NODE_LINK_TO(matmul_w_2, fused_feedforward_node);
IR_NODE_LINK_TO(ele_add_bias_2, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_feedforward_node);
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_out);
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_mean);
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_variance);
IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_mask_node_1);
IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_out_node_1);
IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_mask_node_2);
IR_NODE_LINK_TO(fused_feedforward_node, record.dropout_out_node_2);
IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_1);
if (!pre_layer_norm) {
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_out);
} else {
if (add_residual) {
// Residual Add, dispensable
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_out_3, ele_add_out_3, fused_feedforward_pattern);
IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_3);
} else {
if (!use_dropout_2) {
IR_NODE_LINK_TO(fused_feedforward_node, ele_add_out_2);
}
}
}
std::unordered_set<const Node *> nodes_to_remove = {layer_norm_op,
matmul_op_1,
ele_add_op_1,
act_op,
matmul_op_2,
ele_add_op_2};
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(
c_identity_op, c_identity_op, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern);
nodes_to_remove.insert(c_identity_op);
nodes_to_remove.insert(c_allreduce_sum_op);
}
if (use_dropout_1) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_1, dropout_op_1, fused_feedforward_pattern);
nodes_to_remove.insert(dropout_op_1);
}
if (use_dropout_2) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_2, dropout_op_2, fused_feedforward_pattern);
nodes_to_remove.insert(dropout_op_2);
}
if (add_residual) {
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_3, ele_add_op_3, fused_feedforward_pattern);
nodes_to_remove.insert(ele_add_op_3);
}
GraphSafeRemoveNodes(g, nodes_to_remove);
found_fused_feedforward_fwd_count++;
VLOG(4) << "After remove nodes.";
};
gpd(graph, handler);
AddStatis(found_fused_feedforward_fwd_count);
return graph;
}
ir::Graph *FusedFeedForwardPass::FusedFeedForwardBwd(
ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
const std::string scope_name("fused_feed_forward_bwd_pattern");
// 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
// activation_grad -> linear1_grad -> layer_norm_grad
// 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
// dropout1_grad -> activation_grad -> linear1_grad
// other cases: may delete mp, residual_add_grad, dropout1_grad, dropout2_grad
// operators
GraphPatternDetector gpd;
auto *x_grad = gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(scope_name, "x_grad"))
->AsInput();
patterns::FusedFeedForwardBwd fused_feedforward_pattern(gpd.mutable_pattern(),
scope_name);
std::unordered_set<std::string> act_grad_types = {"gelu_grad", "relu_grad"};
fused_feedforward_pattern(x_grad,
act_grad_types,
use_mp,
pre_layer_norm,
add_residual,
use_dropout_1,
use_dropout_2);
VLOG(4) << "Fused Feedforward backward pass."
<< " pre_layer_norm: " << pre_layer_norm
<< ", add_residual: " << add_residual
<< ", use_dropout_1: " << use_dropout_1
<< ", use_dropout_2: " << use_dropout_2;
int found_fused_feedforward_bwd_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle feed_forward backward fusion";
// LayerNorm Grad
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_op_grad, layer_norm_op_grad, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_in, layer_norm_in, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_mean, layer_norm_mean, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_variance, layer_norm_variance, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_scale, layer_norm_scale, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias, layer_norm_bias, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_in_grad, layer_norm_in_grad, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_scale_grad,
layer_norm_scale_grad,
fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
layer_norm_bias_grad, layer_norm_bias_grad, fused_feedforward_pattern);
// Linear Grad 1
GET_IR_NODE_FROM_SUBGRAPH(
matmul_op_grad_1, matmul_op_grad_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_in_1, matmul_in_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_1, matmul_w_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_in_grad_1, matmul_in_grad_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_grad_1, matmul_w_grad_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_grad_1, ele_add_op_grad_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_in_1, ele_add_in_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_1, ele_add_bias_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_in_grad_1, ele_add_in_grad_1, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_grad_1, ele_add_bias_grad_1, fused_feedforward_pattern);
// Activation Grad
GET_IR_NODE_FROM_SUBGRAPH(
act_op_grad, act_op_grad, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(act_in, act_in, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
act_in_grad, act_in_grad, fused_feedforward_pattern);
// Linear Grad 2
GET_IR_NODE_FROM_SUBGRAPH(
matmul_op_grad_2, matmul_op_grad_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_in_2, matmul_in_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_2, matmul_w_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_in_grad_2, matmul_in_grad_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
matmul_w_grad_2, matmul_w_grad_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_grad_2, ele_add_op_grad_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_in_2, ele_add_in_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_2, ele_add_bias_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_in_grad_2, ele_add_in_grad_2, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_bias_grad_2, ele_add_bias_grad_2, fused_feedforward_pattern);
auto record = (*dropout_nodes_map)[matmul_w_1];
if (use_dropout_1) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern);
if (PADDLE_GET_CONST(bool, dropout_op_grad_1->Op()->GetAttr("is_test"))) {
LOG(WARNING) << "Dropout_grad 1 attribute is_test should be set false."
<< " Skip fused_feedforward_grad pattern replacement";
return;
}
} else {
if (record.dropout_mask_node_1 == nullptr ||
record.dropout_out_node_1 == nullptr) {
LOG(WARNING)
<< "Dropout_grad 1 has no mask/out input from forward pass."
<< " Skip fused_feedforward_grad pattern replacement";
return;
}
}
if (use_dropout_2) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern);
if (PADDLE_GET_CONST(bool, dropout_op_grad_2->Op()->GetAttr("is_test"))) {
LOG(WARNING) << "Dropout_grad 2 attribute is_test should be set false."
<< " Skip fused_feedforward_grad pattern replacement";
return;
}
} else {
if (record.dropout_mask_node_2 == nullptr) {
LOG(WARNING) << "Dropout_grad 2 has no mask input from forward pass."
<< " Skip fused_feedforward_grad pattern replacement";
return;
}
}
OpDesc fused_feedforward_op_desc(layer_norm_op_grad->Op()->Block());
fused_feedforward_op_desc.SetType("fused_feedforward_grad");
fused_feedforward_op_desc.SetInput(framework::GradVarName("Out"),
{subgraph.at(x_grad)->Name()});
fused_feedforward_op_desc.SetInput(
"X", {pre_layer_norm ? layer_norm_in->Name() : matmul_in_1->Name()});
fused_feedforward_op_desc.SetInput("Linear1Weight", {matmul_w_1->Name()});
fused_feedforward_op_desc.SetInput("Linear1Bias", {ele_add_bias_1->Name()});
fused_feedforward_op_desc.SetInput("Linear2Weight", {matmul_w_2->Name()});
fused_feedforward_op_desc.SetInput("Linear2Bias", {ele_add_bias_2->Name()});
fused_feedforward_op_desc.SetInput("Linear1Out", {act_in->Name()});
fused_feedforward_op_desc.SetInput("Dropout1Out",
{record.dropout_out_node_1->Name()});
fused_feedforward_op_desc.SetInput("Dropout1Mask",
{record.dropout_mask_node_1->Name()});
fused_feedforward_op_desc.SetInput("Dropout2Mask",
{record.dropout_mask_node_2->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Linear1Weight"),
{matmul_w_grad_1->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Linear1Bias"),
{ele_add_bias_grad_1->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Linear2Weight"),
{matmul_w_grad_2->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Linear2Bias"),
{ele_add_bias_grad_2->Name()});
fused_feedforward_op_desc.SetAttr("pre_layer_norm", pre_layer_norm);
fused_feedforward_op_desc.SetAttr(
"ln1_epsilon", layer_norm_op_grad->Op()->GetAttr("epsilon"));
fused_feedforward_op_desc.SetAttr(
"ln2_epsilon", layer_norm_op_grad->Op()->GetAttr("epsilon"));
fused_feedforward_op_desc.SetAttr("act_method",
act_op_grad->Op()->Type().substr(0, 4));
fused_feedforward_op_desc.SetAttr("add_residual", add_residual);
// These attributes set default value
fused_feedforward_op_desc.SetAttr("is_test", false);
fused_feedforward_op_desc.SetAttr("dropout1_fix_seed", false);
fused_feedforward_op_desc.SetAttr("dropout2_fix_seed", false);
fused_feedforward_op_desc.SetAttr("dropout1_seed", 0);
fused_feedforward_op_desc.SetAttr("dropout2_seed", 0);
int ring_id = -1;
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(
c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern);
ring_id =
PADDLE_GET_CONST(int, c_allreduce_sum_op->Op()->GetAttr("ring_id"));
}
fused_feedforward_op_desc.SetAttr("ring_id", ring_id);
if (pre_layer_norm) {
fused_feedforward_op_desc.SetInput("Ln1Scale",
{layer_norm_scale->Name()});
fused_feedforward_op_desc.SetInput("Ln1Bias", {layer_norm_bias->Name()});
fused_feedforward_op_desc.SetInput("Ln1Out", {matmul_in_1->Name()});
fused_feedforward_op_desc.SetInput("Ln1Mean", {layer_norm_mean->Name()});
fused_feedforward_op_desc.SetInput("Ln1Variance",
{layer_norm_variance->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Ln1Scale"),
{layer_norm_scale_grad->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Ln1Bias"),
{layer_norm_bias_grad->Name()});
} else {
fused_feedforward_op_desc.SetInput("Ln2Scale",
{layer_norm_scale->Name()});
fused_feedforward_op_desc.SetInput("Ln2Bias", {layer_norm_bias->Name()});
fused_feedforward_op_desc.SetInput("Ln2Mean", {layer_norm_mean->Name()});
fused_feedforward_op_desc.SetInput("Ln2Variance",
{layer_norm_variance->Name()});
// Special
fused_feedforward_op_desc.SetInput("Dropout2Out",
{record.dropout_out_node_2->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Ln2Scale"),
{layer_norm_scale_grad->Name()});
fused_feedforward_op_desc.SetOutput(GradVarName("Ln2Bias"),
{layer_norm_bias_grad->Name()});
}
if (use_dropout_1) {
// Dropout Grad 1
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern);
fused_feedforward_op_desc.SetAttr(
"dropout1_rate", dropout_op_grad_1->Op()->GetAttr("dropout_prob"));
fused_feedforward_op_desc.SetAttr(
"dropout1_implementation",
dropout_op_grad_1->Op()->GetAttr("dropout_implementation"));
} else {
fused_feedforward_op_desc.SetAttr("dropout1_rate", 0.0f);
fused_feedforward_op_desc.SetAttr(
"dropout1_implementation",
static_cast<std::string>("upscale_in_train"));
}
if (use_dropout_2) {
// Dropout Grad 2
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern);
fused_feedforward_op_desc.SetAttr(
"dropout2_rate", dropout_op_grad_2->Op()->GetAttr("dropout_prob"));
fused_feedforward_op_desc.SetAttr(
"dropout2_implementation",
dropout_op_grad_2->Op()->GetAttr("dropout_implementation"));
} else {
fused_feedforward_op_desc.SetAttr("dropout2_rate", 0.0f);
fused_feedforward_op_desc.SetAttr(
"dropout2_implementation",
static_cast<std::string>("upscale_in_train"));
}
if (add_residual) {
GET_IR_NODE_FROM_SUBGRAPH(sum_out, sum_out, fused_feedforward_pattern);
fused_feedforward_op_desc.SetOutput(GradVarName("X"), {sum_out->Name()});
} else {
if (pre_layer_norm) {
fused_feedforward_op_desc.SetOutput(GradVarName("X"),
{layer_norm_in_grad->Name()});
} else {
fused_feedforward_op_desc.SetOutput(GradVarName("X"),
{matmul_in_grad_1->Name()});
}
}
auto fused_feedforward_node = g->CreateOpNode(&fused_feedforward_op_desc);
IR_NODE_LINK_TO(subgraph.at(x_grad), fused_feedforward_node);
IR_NODE_LINK_TO(matmul_w_1, fused_feedforward_node);
IR_NODE_LINK_TO(ele_add_bias_1, fused_feedforward_node);
IR_NODE_LINK_TO(matmul_w_2, fused_feedforward_node);
IR_NODE_LINK_TO(ele_add_bias_2, fused_feedforward_node);
IR_NODE_LINK_TO(record.dropout_mask_node_1, fused_feedforward_node);
IR_NODE_LINK_TO(record.dropout_mask_node_2, fused_feedforward_node);
IR_NODE_LINK_TO(act_in, fused_feedforward_node);
IR_NODE_LINK_TO(record.dropout_out_node_1, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_scale, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_bias, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_mean, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_variance, fused_feedforward_node);
IR_NODE_LINK_TO(layer_norm_in, fused_feedforward_node);
if (pre_layer_norm) {
IR_NODE_LINK_TO(matmul_in_1, fused_feedforward_node);
} else {
IR_NODE_LINK_TO(record.dropout_out_node_2, fused_feedforward_node);
}
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_scale_grad);
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_bias_grad);
IR_NODE_LINK_TO(fused_feedforward_node, matmul_w_grad_1);
IR_NODE_LINK_TO(fused_feedforward_node, ele_add_bias_grad_1);
IR_NODE_LINK_TO(fused_feedforward_node, matmul_w_grad_2);
IR_NODE_LINK_TO(fused_feedforward_node, ele_add_bias_grad_2);
if (add_residual) {
GET_IR_NODE_FROM_SUBGRAPH(sum_out, sum_out, fused_feedforward_pattern);
IR_NODE_LINK_TO(fused_feedforward_node, sum_out);
} else {
if (pre_layer_norm) {
IR_NODE_LINK_TO(fused_feedforward_node, layer_norm_in_grad);
} else {
IR_NODE_LINK_TO(fused_feedforward_node, matmul_in_grad_1);
}
}
std::unordered_set<const Node *> nodes_to_remove = {layer_norm_op_grad,
matmul_op_grad_1,
ele_add_op_grad_1,
act_op_grad,
matmul_op_grad_2,
ele_add_op_grad_2};
if (use_mp) {
GET_IR_NODE_FROM_SUBGRAPH(
c_identity_op, c_identity_op, fused_feedforward_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
c_allreduce_sum_op, c_allreduce_sum_op, fused_feedforward_pattern);
nodes_to_remove.insert(c_identity_op);
nodes_to_remove.insert(c_allreduce_sum_op);
}
if (use_dropout_1) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_1, dropout_op_grad_1, fused_feedforward_pattern);
nodes_to_remove.insert(dropout_op_grad_1);
}
if (use_dropout_2) {
GET_IR_NODE_FROM_SUBGRAPH(
dropout_op_grad_2, dropout_op_grad_2, fused_feedforward_pattern);
nodes_to_remove.insert(dropout_op_grad_2);
}
if (add_residual) {
GET_IR_NODE_FROM_SUBGRAPH(
ele_add_op_grad_3, ele_add_op_grad_3, fused_feedforward_pattern);
// Sum for gradient addition
GET_IR_NODE_FROM_SUBGRAPH(sum_op, sum_op, fused_feedforward_pattern);
nodes_to_remove.insert(ele_add_op_grad_3);
nodes_to_remove.insert(sum_op);
}
GraphSafeRemoveNodes(g, nodes_to_remove);
found_fused_feedforward_bwd_count++;
};
gpd(graph, handler);
AddStatis(found_fused_feedforward_bwd_count);
return graph;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_feedforward_pass,
paddle::framework::ir::FusedFeedForwardPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the FeedForward in attention
* Forward:
* 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
* -> residual_add (pre_layer_norm)
* 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
* -> layer_norm (pose_layer_norm)
* other cases: may delete mp, residual_add, dropout1, dropout2 operators
* Backward:
* 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
* activation_grad -> linear1_grad -> layer_norm_grad (pre_layer_norm)
* 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
* dropout1_grad -> activation_grad -> linear1_grad (pose_layer_norm)
* other cases: may delete mp, residual_add_grad, dropout1_grad, dropout2_grad
* operators
*/
class Graph;
class Node;
class FusedFeedForwardPass : public FusePassBase {
public:
virtual ~FusedFeedForwardPass() {}
protected:
// Used for pattern created variable node transfer
// between corresponding forward operator and backward operator.
struct DropoutNode {
Node *dropout_out_node_1;
Node *dropout_mask_node_1;
Node *dropout_out_node_2;
Node *dropout_mask_node_2;
DropoutNode()
: dropout_out_node_1(nullptr),
dropout_mask_node_1(nullptr),
dropout_out_node_2(nullptr),
dropout_mask_node_2(nullptr) {}
};
typedef std::unordered_map<Node *, DropoutNode> Cache;
const std::string scope_name{"fused_feedforward"};
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FusedFeedForwardFwd(ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const;
ir::Graph *FusedFeedForwardBwd(ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -113,7 +113,8 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph &graph) {
if (node.Name().rfind("__control_var") == 0) continue;
for (const auto &pdnode : pattern_.nodes()) {
if (pdnode->Tell(&node)) {
VLOG(4) << "Node " << node.Name() << " marked as " << pdnode->name();
VLOG(4) << "Node " << node.Name() << "(" << node.id() << ")"
<< " marked as " << pdnode->name();
pdnodes2nodes_[pdnode.get()].insert(&node);
}
}
......@@ -231,7 +232,8 @@ GraphPatternDetector::DetectPatterns() {
// source -> target
for (Node *source : pdnodes2nodes_[edge.first]) {
for (Node *target : pdnodes2nodes_[edge.second]) {
VLOG(8) << "check " << source->id() << " -- " << target->id();
VLOG(8) << "check " << source->Name() << "(" << source->id() << ")"
<< " -- " << target->Name() << "(" << target->id() << ")";
// TODO(Superjomn) add some prune strategies.
for (const auto &group : pre_groups) {
if (IsNodesLink(source, target)) {
......@@ -251,7 +253,9 @@ GraphPatternDetector::DetectPatterns() {
VLOG(3) << "step " << step << " get records: " << cur_groups.size();
for (auto &group : cur_groups) {
for (auto &item : group.roles) {
VLOG(4) << "node " << item.second->id() << " as " << item.first->name();
VLOG(4) << "node " << item.second->Name() << "(" << item.second->id()
<< ")"
<< " as " << item.first->name();
}
VLOG(4) << "=========================================================";
}
......@@ -4011,6 +4015,443 @@ PDNode *patterns::MergeLayernormPattern::operator()(PDNode *in) {
return layernorm_40_out;
}
PDNode *patterns::FusedFeedForwardFwd::operator()(
paddle::framework::ir::PDNode *x_var,
std::unordered_set<std::string> act_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2) {
// Possible patterns
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
// 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
// -> layer_norm (pOST_layer_norm)
// other cases: may delete residual_add, dropout1, dropout2 operators
// intermediate input, and final pattern output
PDNode *out_var = x_var;
// LayerNorm
auto *layer_norm_op =
pattern->NewNode(layer_norm_op_repr())->assert_is_op("layer_norm");
auto *layer_norm_bias_var = pattern->NewNode(layer_norm_bias_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Bias");
auto *layer_norm_scale_var = pattern->NewNode(layer_norm_scale_repr())
->AsInput()
->assert_is_op_input("layer_norm", "Scale");
auto *layer_norm_out_var = pattern->NewNode(layer_norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
auto *layer_norm_mean_var = pattern->NewNode(layer_norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
if (pre_layer_norm) {
out_var->assert_is_op_input("layer_norm", "X");
layer_norm_op
->LinksFrom({out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
out_var = layer_norm_out_var;
}
// Model parallel, do nothing in forward.
if (use_mp) {
out_var->assert_is_op_input("c_identity", "X");
auto *c_identity_op =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
auto *c_identity_out_var = pattern->NewNode(c_identity_out_repr())
->assert_is_op_output("c_identity", "Out");
c_identity_op->LinksFrom({out_var}).LinksTo({c_identity_out_var});
out_var = c_identity_out_var;
}
// Linear1
out_var->assert_is_op_input("matmul_v2", "X");
auto *matmul_op_1 =
pattern->NewNode(matmul_op_1_repr())->assert_is_op("matmul_v2");
auto *matmul_w_var_1 = pattern->NewNode(matmul_w_1_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "Y");
auto *matmul_out_var_1 = pattern->NewNode(matmul_out_1_repr())
->assert_is_op_output("matmul_v2", "Out");
matmul_op_1->LinksFrom({out_var, matmul_w_var_1}).LinksTo({matmul_out_var_1});
out_var = matmul_out_var_1;
out_var->assert_is_op_input("elementwise_add", "X");
auto *ele_add_op_1 =
pattern->NewNode(ele_add_op_1_repr())->assert_is_op("elementwise_add");
auto *ele_add_bias_var_1 = pattern->NewNode(ele_add_bias_1_repr())
->AsInput()
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add_out_var_1 = pattern->NewNode(ele_add_out_1_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_add_op_1->LinksFrom({out_var, ele_add_bias_var_1})
.LinksTo({ele_add_out_var_1});
out_var = ele_add_out_var_1;
// Activation
out_var->assert_is_ops_input(act_types);
auto *act_op = pattern->NewNode(act_op_repr())->assert_is_ops(act_types);
auto *act_out_var =
pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out");
act_op->LinksFrom({out_var}).LinksTo({act_out_var});
out_var = act_out_var;
// Dropout1
if (use_dropout_1) {
out_var->assert_is_op_input("dropout", "X");
auto *dropout_op_1 =
pattern->NewNode(dropout_op_1_repr())->assert_is_op("dropout");
auto *dropout_mask_var_1 = pattern->NewNode(dropout_mask_1_repr())
->assert_is_op_output("dropout", "Mask");
auto *dropout_out_var_1 = pattern->NewNode(dropout_out_1_repr())
->assert_is_op_output("dropout", "Out");
dropout_op_1->LinksFrom({out_var}).LinksTo(
{dropout_mask_var_1, dropout_out_var_1});
out_var = dropout_out_var_1;
}
// Linear2
out_var->assert_is_op_input("matmul_v2", "X");
auto *matmul_op_2 =
pattern->NewNode(matmul_op_2_repr())->assert_is_op("matmul_v2");
auto *matmul_w_var_2 =
pattern->NewNode(matmul_w_2_repr())->assert_is_op_input("matmul_v2", "Y");
auto *matmul_out_var_2 = pattern->NewNode(matmul_out_2_repr())
->assert_is_op_output("matmul_v2", "Out");
matmul_op_2->LinksFrom({out_var, matmul_w_var_2}).LinksTo({matmul_out_var_2});
out_var = matmul_out_var_2;
// Model parallel, do nothing in forward.
if (use_mp) {
out_var->assert_is_op_input("c_allreduce_sum", "X");
auto *c_allreduce_sum_op = pattern->NewNode(c_allreduce_sum_op_repr())
->assert_is_op("c_allreduce_sum");
auto *c_allreduce_sum_out_var =
pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum", "Out");
c_allreduce_sum_op->LinksFrom({out_var}).LinksTo({c_allreduce_sum_out_var});
out_var = c_allreduce_sum_out_var;
}
out_var->assert_is_op_input("elementwise_add", "X");
auto *ele_add_op_2 =
pattern->NewNode(ele_add_op_2_repr())->assert_is_op("elementwise_add");
auto *ele_add_bias_var_2 = pattern->NewNode(ele_add_bias_2_repr())
->assert_is_op_input("elementwise_add", "Y");
auto *ele_add_out_var_2 = pattern->NewNode(ele_add_out_2_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_add_op_2->LinksFrom({out_var, ele_add_bias_var_2})
.LinksTo({ele_add_out_var_2});
out_var = ele_add_out_var_2;
// Dropout 2
if (use_dropout_2) {
out_var->assert_is_op_input("dropout", "X");
auto *dropout_op_2 =
pattern->NewNode(dropout_op_2_repr())->assert_is_op("dropout");
auto *dropout_mask_var_2 = pattern->NewNode(dropout_mask_2_repr())
->assert_is_op_output("dropout", "Mask");
auto *dropout_out_var_2 = pattern->NewNode(dropout_out_2_repr())
->assert_is_op_output("dropout", "Out");
dropout_op_2->LinksFrom({out_var}).LinksTo(
{dropout_mask_var_2, dropout_out_var_2});
out_var = dropout_out_var_2;
}
// Residual Add
if (add_residual) {
out_var->assert_is_op_input("elementwise_add", "X");
x_var->assert_is_op_input("elementwise_add", "Y");
auto *ele_add_op_3 =
pattern->NewNode(ele_add_op_3_repr())->assert_is_op("elementwise_add");
auto *ele_add_out_var_3 =
pattern->NewNode(ele_add_out_3_repr())
->assert_is_op_output("elementwise_add", "Out");
ele_add_op_3->LinksFrom({out_var, x_var}).LinksTo({ele_add_out_var_3});
out_var = ele_add_out_var_3;
}
// Post LayerNorm
if (!pre_layer_norm) {
out_var->assert_is_op_input("layer_norm", "X");
layer_norm_op
->LinksFrom({out_var, layer_norm_bias_var, layer_norm_scale_var})
.LinksTo(
{layer_norm_out_var, layer_norm_mean_var, layer_norm_variance_var});
out_var = layer_norm_out_var;
}
return out_var;
}
PDNode *patterns::FusedFeedForwardBwd::operator()(
paddle::framework::ir::PDNode *x_grad,
std::unordered_set<std::string> act_grad_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2) {
// Possible patterns
// 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
// activation_grad -> linear1_grad -> layer_norm_grad
// 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
// dropout1_grad -> activation_grad -> linear1_grad
// other cases: may delete residual_add_grad, dropout1_grad, dropout2_grad
// operators
// intermediate input_grad, and final pattern ouput_grad
PDNode *out_grad = x_grad;
// LayerNorm: in["Mean", "Variance", "Scale", "Bias", "Y@GRAD"],
// out["X@GRAD", "Scale@GRAD", "Bias@GRAD"]
auto *layer_norm_op_grad = pattern->NewNode(layer_norm_op_grad_repr())
->assert_is_op("layer_norm_grad");
auto *layer_norm_in_var = pattern->NewNode(layer_norm_in_repr())
->assert_is_op_input("layer_norm_grad", "X");
auto *layer_norm_mean_var =
pattern->NewNode(layer_norm_mean_repr())
->assert_is_op_input("layer_norm_grad", "Mean");
auto *layer_norm_variance_var =
pattern->NewNode(layer_norm_variance_repr())
->assert_is_op_input("layer_norm_grad", "Variance");
auto *layer_norm_scale_var =
pattern->NewNode(layer_norm_scale_repr())
->assert_is_op_input("layer_norm_grad", "Scale");
auto *layer_norm_bias_var =
pattern->NewNode(layer_norm_bias_repr())
->assert_is_op_input("layer_norm_grad", "Bias");
auto *layer_norm_in_grad =
pattern->NewNode(layer_norm_in_grad_repr())
->assert_is_op_output("layer_norm_grad", GradVarName("X"));
auto *layer_norm_scale_grad =
pattern->NewNode(layer_norm_scale_grad_repr())
->assert_is_op_output("layer_norm_grad", GradVarName("Scale"));
auto *layer_norm_bias_grad =
pattern->NewNode(layer_norm_bias_grad_repr())
->assert_is_op_output("layer_norm_grad", GradVarName("Bias"));
// post_layer_norm
if (!pre_layer_norm) {
out_grad->assert_is_op_input("layer_norm_grad", GradVarName("Y"));
layer_norm_op_grad
->LinksFrom({out_grad,
layer_norm_in_var,
layer_norm_mean_var,
layer_norm_variance_var,
layer_norm_scale_var,
layer_norm_bias_var})
.LinksTo(
{layer_norm_in_grad, layer_norm_scale_grad, layer_norm_bias_grad});
out_grad = layer_norm_in_grad;
}
// partial input_grad of residual_add
PDNode *tmp = nullptr;
auto *matmul_in_var_1 = pattern->NewNode(matmul_in_1_repr())
->assert_is_op_input("matmul_v2_grad", "X");
if (add_residual) {
// Residual Add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"]
out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out"));
auto *ele_add_op_grad_3 = pattern->NewNode(ele_add_op_grad_3_repr())
->assert_is_op("elementwise_add_grad");
auto *ele_add_in_var_3 =
pattern->NewNode(ele_add_in_3_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto *ele_add_in_grad_3 =
pattern->NewNode(ele_add_in_grad_3_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("X"));
auto *ele_add_bias_grad_3 =
pattern->NewNode(ele_add_bias_grad_3_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"));
tmp = ele_add_bias_grad_3;
if (pre_layer_norm) {
ele_add_op_grad_3
->LinksFrom({out_grad, ele_add_in_var_3, layer_norm_in_var})
.LinksTo({ele_add_in_grad_3, ele_add_bias_grad_3});
} else {
ele_add_op_grad_3
->LinksFrom({out_grad, ele_add_in_var_3, matmul_in_var_1})
.LinksTo({ele_add_in_grad_3, ele_add_bias_grad_3});
}
out_grad = ele_add_in_grad_3;
}
// Dropout 2: in["Out@GRAD", "Mask"], out["X@GRAD"]
if (use_dropout_2) {
out_grad->assert_is_op_input("dropout_grad", GradVarName("Out"));
auto *dropout_op_grad_2 = pattern->NewNode(dropout_op_grad_2_repr())
->assert_is_op("dropout_grad");
auto *dropout_mask_grad_2 =
pattern->NewNode(dropout_mask_2_repr())
->assert_is_op_input("dropout_grad", "Mask");
auto *dropout_in_grad_2 =
pattern->NewNode(dropout_in_grad_2_repr())
->assert_is_op_output("dropout_grad", GradVarName("X"));
dropout_op_grad_2->LinksFrom({out_grad, dropout_mask_grad_2})
.LinksTo({dropout_in_grad_2});
out_grad = dropout_in_grad_2;
}
// Linear 2:
// elementwise_add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"]
out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out"));
auto *ele_add_op_grad_2 = pattern->NewNode(ele_add_op_grad_2_repr())
->assert_is_op("elementwise_add_grad");
auto *ele_add_in_var_2 =
pattern->NewNode(ele_add_in_2_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto *ele_add_bias_var_2 =
pattern->NewNode(ele_add_bias_2_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
auto *ele_add_in_grad_2 =
pattern->NewNode(ele_add_in_grad_2_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("X"));
auto *ele_add_bias_grad_2 =
pattern->NewNode(ele_add_bias_grad_2_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"));
ele_add_op_grad_2->LinksFrom({out_grad, ele_add_in_var_2, ele_add_bias_var_2})
.LinksTo({ele_add_in_grad_2, ele_add_bias_grad_2});
out_grad = ele_add_in_grad_2;
// Model parallel, do nothing in backward.
if (use_mp) {
out_grad->assert_is_op_input("c_identity", "X");
auto *c_identity_op =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
auto *c_identity_out_grad = pattern->NewNode(c_identity_out_repr())
->assert_is_op_output("c_identity", "Out");
c_identity_op->LinksFrom({out_grad}).LinksTo({c_identity_out_grad});
out_grad = c_identity_out_grad;
}
// matmul_v2: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"]
out_grad->assert_is_op_input("matmul_v2_grad", GradVarName("Out"));
auto *matmul_op_grad_2 =
pattern->NewNode(matmul_op_grad_2_repr())->assert_is_op("matmul_v2_grad");
auto *matmul_in_var_2 = pattern->NewNode(matmul_in_2_repr())
->assert_is_op_input("matmul_v2_grad", "X");
auto *matmul_w_var_2 = pattern->NewNode(matmul_w_2_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto *matmul_in_grad_2 =
pattern->NewNode(matmul_in_grad_2_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("X"));
auto *matmul_w_grad_2 =
pattern->NewNode(matmul_w_grad_2_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("Y"));
matmul_op_grad_2->LinksFrom({out_grad, matmul_in_var_2, matmul_w_var_2})
.LinksTo({matmul_in_grad_2, matmul_w_grad_2});
out_grad = matmul_in_grad_2;
// Dropout 1: in["Out@GRAD", "Mask"], out["X@GRAD"]
if (use_dropout_1) {
out_grad->assert_is_op_input("dropout_grad", GradVarName("Out"));
auto *dropout_op_grad_1 = pattern->NewNode(dropout_op_grad_1_repr())
->assert_is_op("dropout_grad");
auto *dropout_mask_var_1 = pattern->NewNode(dropout_mask_1_repr())
->assert_is_op_input("dropout_grad", "Mask");
auto *dropout_in_grad_1 =
pattern->NewNode(dropout_in_grad_1_repr())
->assert_is_op_output("dropout_grad", GradVarName("X"));
dropout_op_grad_1->LinksFrom({out_grad, dropout_mask_var_1})
.LinksTo({dropout_in_grad_1});
out_grad = dropout_in_grad_1;
}
// Activation: in["Out", "Out@GRAD"], out["X@GRAD"]
out_grad->assert_is_ops_input(act_grad_types, GradVarName("Out"));
auto *act_op_grad =
pattern->NewNode(act_op_grad_repr())->assert_is_ops(act_grad_types);
auto *act_in_var =
pattern->NewNode(act_in_repr())->assert_is_ops_input(act_grad_types, "X");
auto *act_in_grad =
pattern->NewNode(act_in_grad_repr())
->assert_is_ops_output(act_grad_types, GradVarName("X"));
act_op_grad->LinksFrom({out_grad, act_in_var}).LinksTo({act_in_grad});
out_grad = act_in_grad;
// Linear 1:
// elementwise_add: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"]
out_grad->assert_is_op_input("elementwise_add_grad", GradVarName("Out"));
auto *ele_add_op_grad_1 = pattern->NewNode(ele_add_op_grad_1_repr())
->assert_is_op("elementwise_add_grad");
auto *ele_add_in_var_1 =
pattern->NewNode(ele_add_in_1_repr())
->assert_is_op_input("elementwise_add_grad", "X");
auto *ele_add_bias_var_1 =
pattern->NewNode(ele_add_bias_1_repr())
->assert_is_op_input("elementwise_add_grad", "Y");
auto *ele_add_in_grad_1 =
pattern->NewNode(ele_add_in_grad_1_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("X"));
auto *ele_add_bias_grad_1 =
pattern->NewNode(ele_add_bias_grad_1_repr())
->assert_is_op_output("elementwise_add_grad", GradVarName("Y"));
ele_add_op_grad_1->LinksFrom({out_grad, ele_add_in_var_1, ele_add_bias_var_1})
.LinksTo({ele_add_in_grad_1, ele_add_bias_grad_1});
out_grad = ele_add_in_grad_1;
// matmul_v2: in["Out@GRAD", "X", "Y"], out["X@GRAD", "Y@GRAD"]
out_grad->assert_is_op_input("matmul_v2_grad", GradVarName("Out"));
auto *matmul_op_grad_1 =
pattern->NewNode(matmul_op_grad_1_repr())->assert_is_op("matmul_v2_grad");
// auto *matmul_in_var_1 = pattern->NewNode(matmul_in_1_repr())
// ->assert_is_op_input("matmul_v2_grad",
// "X");
auto *matmul_w_var_1 = pattern->NewNode(matmul_w_1_repr())
->assert_is_op_input("matmul_v2_grad", "Y");
auto *matmul_in_grad_1 =
pattern->NewNode(matmul_in_grad_1_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("X"));
auto *matmul_w_grad_1 =
pattern->NewNode(matmul_w_grad_1_repr())
->assert_is_op_output("matmul_v2_grad", GradVarName("Y"));
matmul_op_grad_1->LinksFrom({out_grad, matmul_in_var_1, matmul_w_var_1})
.LinksTo({matmul_in_grad_1, matmul_w_grad_1});
out_grad = matmul_in_grad_1;
// Model parallel, all_reduce in backward.
if (use_mp) {
out_grad->assert_is_op_input("c_allreduce_sum", "X");
auto *c_allreduce_sum_op = pattern->NewNode(c_allreduce_sum_op_repr())
->assert_is_op("c_allreduce_sum");
auto *c_allreduce_sum_out_grad =
pattern->NewNode(c_allreduce_sum_out_repr())
->assert_is_op_output("c_allreduce_sum", "Out");
c_allreduce_sum_op->LinksFrom({out_grad})
.LinksTo({c_allreduce_sum_out_grad});
out_grad = c_allreduce_sum_out_grad;
}
// pre LayerNorm
if (pre_layer_norm) {
out_grad->assert_is_op_input("layer_norm_grad", GradVarName("Y"));
layer_norm_op_grad
->LinksFrom({out_grad,
layer_norm_in_var,
layer_norm_mean_var,
layer_norm_variance_var,
layer_norm_scale_var,
layer_norm_bias_var})
.LinksTo(
{layer_norm_in_grad, layer_norm_scale_grad, layer_norm_bias_grad});
out_grad = layer_norm_in_grad;
}
// sum for final gradient
if (add_residual) {
auto *sum_op = pattern->NewNode(sum_op_repr())->assert_is_op("sum");
auto *sum_out =
pattern->NewNode(sum_out_repr())->assert_is_op_output("sum", "Out");
sum_op->LinksFrom({tmp, out_grad}).LinksTo({sum_out});
out_grad = sum_out;
}
return out_grad;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -2156,6 +2156,133 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out);
};
// The following patterns are used to fuse feedforward in forward
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
// 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
// -> layer_norm (pOST_layer_norm)
// other cases: may delete residual_add, dropout1, dropout2 operators
struct FusedFeedForwardFwd : public PatternBase {
FusedFeedForwardFwd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_feedforward_fwd") {}
PDNode* operator()(PDNode* x,
std::unordered_set<std::string> act_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2);
#ifndef FEEDFORWARD_LINEAR_DROPOUT_NODE
#define FEEDFORWARD_LINEAR_DROPOUT_NODE(suffix__) \
PATTERN_DECL_NODE(matmul_op_##suffix__); \
PATTERN_DECL_NODE(matmul_w_##suffix__); \
PATTERN_DECL_NODE(matmul_out_##suffix__); \
PATTERN_DECL_NODE(ele_add_op_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_##suffix__); \
PATTERN_DECL_NODE(ele_add_out_##suffix__); \
PATTERN_DECL_NODE(dropout_op_##suffix__); \
PATTERN_DECL_NODE(dropout_out_##suffix__); \
PATTERN_DECL_NODE(dropout_mask_##suffix__);
// LayerNorm: layer_norm
PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
// Mode parallelism
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(c_allreduce_sum_op);
PATTERN_DECL_NODE(c_allreduce_sum_out);
// Linear 1 and Dropout 1: matmul_v2 + elementwise_add + dropout
FEEDFORWARD_LINEAR_DROPOUT_NODE(1);
// Activation Grad: gelu or relu
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
// Linear 2 and Dropout 2: matmul_v2 + elementwise_add + dropout
FEEDFORWARD_LINEAR_DROPOUT_NODE(2);
// ResidualAdd: elementwise_add
PATTERN_DECL_NODE(ele_add_op_3);
PATTERN_DECL_NODE(ele_add_out_3);
#undef FEEDFORWARD_LINEAR_DROPOUT_NODE
#endif
};
// The following patterns are used to fuse feedforward in backward
// 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
// activation_grad -> linear1_grad -> layer_norm_grad
// 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
// dropout1_grad -> activation_grad -> linear1_grad
// other cases: may delete residual_add_grad, dropout1_grad, dropout2_grad
// operators
struct FusedFeedForwardBwd : public PatternBase {
FusedFeedForwardBwd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_feedforward_bwd") {}
PDNode* operator()(PDNode* x,
std::unordered_set<std::string> act_grad_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2);
#ifndef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE
#define FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(suffix__) \
PATTERN_DECL_NODE(matmul_op_grad_##suffix__); \
PATTERN_DECL_NODE(matmul_in_##suffix__); \
PATTERN_DECL_NODE(matmul_w_##suffix__); \
PATTERN_DECL_NODE(matmul_in_grad_##suffix__); \
PATTERN_DECL_NODE(matmul_w_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_op_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_in_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_##suffix__); \
PATTERN_DECL_NODE(ele_add_in_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_grad_##suffix__); \
PATTERN_DECL_NODE(dropout_op_grad_##suffix__); \
PATTERN_DECL_NODE(dropout_mask_##suffix__); \
PATTERN_DECL_NODE(dropout_in_grad_##suffix__);
// LayerNorm Grad: layer_norm_grad
PATTERN_DECL_NODE(layer_norm_op_grad);
PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_in_grad);
PATTERN_DECL_NODE(layer_norm_scale_grad);
PATTERN_DECL_NODE(layer_norm_bias_grad);
// Mode parallelism
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(c_allreduce_sum_op);
PATTERN_DECL_NODE(c_allreduce_sum_out);
// Linear 1 and Dropout 1: matmul_v2_grad + elementwise_add_grad +
// dropout_grad
FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(1);
// Activation Grad: gelu_grad or relu_add
PATTERN_DECL_NODE(act_op_grad);
PATTERN_DECL_NODE(act_in);
PATTERN_DECL_NODE(act_in_grad);
// Linear 2 and Dropout 2: matmul_v2_grad + elementwise_add_grad +
// dropout_grad
FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(2);
// Residual Add: elementwise_add
PATTERN_DECL_NODE(ele_add_op_grad_3);
PATTERN_DECL_NODE(ele_add_in_3);
PATTERN_DECL_NODE(ele_add_bias_3);
PATTERN_DECL_NODE(ele_add_in_grad_3);
PATTERN_DECL_NODE(ele_add_bias_grad_3);
PATTERN_DECL_NODE(sum_op);
PATTERN_DECL_NODE(sum_out);
#undef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE
#endif
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -723,6 +723,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True
)DOC")
.def_property(
"fused_feedforward",
[](const BuildStrategy &self) { return self.fused_feedforward_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fused_feedforward_ = b;
},
R"DOC((bool, optional): fused_feedforward indicate whether
to fuse the whole feed_forward part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_feedforward = True
)DOC")
.def_property(
"fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
......
......@@ -84,6 +84,19 @@ class FusedAttentionPass(CPPPassWrapper):
return PassType.FUSION_OPT
@register_pass("fused_feedforward")
class FusedFeedforwardPass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fused_feedforward_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_gemm_epilogue")
class FuseGemmEpiloguePass(CPPPassWrapper):
def __init__(self):
......
......@@ -253,6 +253,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [
"fuse_bn_add_act",
"fuse_bn_act",
"fused_attention",
"fused_feedforward",
"fuse_gemm_epilogue",
"fuse_optimizer",
]
......
......@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
......
......@@ -89,6 +89,7 @@ class TestFusedPassBaseList(unittest.TestCase):
[
"fuse_bn_act",
"fused_attention",
"fused_feedforward",
"fuse_optimizer",
"fuse_gemm_epilogue",
"fuse_bn_add_act",
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class FeedForward(nn.Layer):
def __init__(
self,
in_features,
hidden_features,
out_features,
drop_prob=0.1,
act_layer=nn.GELU,
pre_layer_norm=True,
add_residual=True,
use_dropout_1=True,
use_dropout_2=True,
):
super(FeedForward, self).__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.in_features = out_features
self.pre_layer_norm = pre_layer_norm
self.add_residual = add_residual
self.use_dropout_1 = use_dropout_1
self.use_dropout_2 = use_dropout_2
self.fc1 = nn.Linear(in_features, in_features)
self.fc2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc3 = nn.Linear(hidden_features, out_features)
self.drop1 = nn.Dropout(drop_prob)
self.drop2 = nn.Dropout(drop_prob)
self.norm = nn.LayerNorm(in_features, epsilon=1e-5)
self.fc4 = nn.Linear(out_features, out_features)
def forward(self, x):
x = self.fc1(x)
residual = x
if self.pre_layer_norm:
x = self.norm(x)
x = self.fc2(x)
x = self.act(x)
if self.use_dropout_1:
x = self.drop1(x)
x = self.fc3(x)
if self.use_dropout_2:
x = self.drop2(x)
if self.add_residual:
x += residual
if not self.pre_layer_norm:
x = self.norm(x)
x = self.fc4(x)
return x
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedFeedforwadPass(unittest.TestCase):
def setUp(self):
self.pre_layer_norm = True
self.add_residual = True
self.use_dropout_1 = True
self.use_dropout_2 = True
def get_value(self, use_pass=False):
batch_size = 2
in_features = 768
hidden_features = 3072
out_features = 768
act_layer = nn.GELU
pre_layer_norm = self.pre_layer_norm
add_residual = self.add_residual
use_dropout_1 = self.use_dropout_1
use_dropout_2 = self.use_dropout_2
np.random.seed(1234)
x_data = np.random.rand(batch_size, in_features, in_features).astype(
'float32'
)
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[2, in_features, in_features],
dtype='float32',
)
feed_forward = FeedForward(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
drop_prob=1e-10,
act_layer=act_layer,
pre_layer_norm=pre_layer_norm,
add_residual=add_residual,
use_dropout_1=use_dropout_1,
use_dropout_2=use_dropout_2,
)
out = feed_forward(data)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
if use_pass:
pass_manager = PassManager([new_pass("fused_feedforward")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert 'fused_feedforward' in [op.type for op in ops]
assert 'fused_feedforward_grad' in [op.type for op in ops]
exe = paddle.static.Executor(paddle.CUDAPlace(0))
exe.run(startup_prog)
for i in range(2):
ret_loss = exe.run(
main_prog, feed={"x": x_data}, fetch_list=[loss.name]
)
return ret_loss
def test_pass(self):
for pre_layer_norm in [True, False]:
for add_residual in [True, False]:
for use_dropout_1 in [True, False]:
for use_dropout_2 in [True, False]:
if not pre_layer_norm and not add_residual:
continue
if not use_dropout_1 and not use_dropout_2:
continue
self.pre_layer_norm = pre_layer_norm
self.add_residual = add_residual
self.use_dropout_1 = use_dropout_1
self.use_dropout_2 = use_dropout_2
ret_loss = self.get_value()
ret_loss_fused = self.get_value(use_pass=True)
assert np.allclose(ret_loss, ret_loss_fused)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册