未验证 提交 5dda0ef6 编写于 作者: G Ghost Screaming 提交者: GitHub

Add fused_feed_forward pass (#50423)

* Add fused_feed_forward pass for semi-automatic static graph training.

* Add fused_feedforward property in parallel_executor.cc

* Polish code.

* Polish fused feed_forward pass code. Support use_dropout1 and
use_dropout2 option.

* Support model parallel in fused_feedforward pass.
上级 02296977
...@@ -383,6 +383,7 @@ set(IR_PASS_DEPS ...@@ -383,6 +383,7 @@ set(IR_PASS_DEPS
fix_op_run_order_pass fix_op_run_order_pass
fuse_gemm_epilogue_pass fuse_gemm_epilogue_pass
fused_attention_pass fused_attention_pass
fused_feedforward_pass
delete_dropout_op_pass) delete_dropout_op_pass)
if(WITH_CINN) if(WITH_CINN)
......
...@@ -210,6 +210,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -210,6 +210,9 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPass("fuse_sgd_op_pass"); AppendPass("fuse_sgd_op_pass");
AppendPass("fuse_momentum_op_pass"); AppendPass("fuse_momentum_op_pass");
} }
#ifdef PADDLE_WITH_CUDA
AppendPassWithCheck(strategy_.fused_feedforward_, "fused_feedforward_pass");
#endif
} }
void SetCollectiveContext() const { void SetCollectiveContext() const {
...@@ -529,6 +532,9 @@ USE_PASS(fused_attention_pass); ...@@ -529,6 +532,9 @@ USE_PASS(fused_attention_pass);
#ifdef PADDLE_WITH_CINN #ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
#endif #endif
#ifdef PADDLE_WITH_CUDA
USE_PASS(fused_feedforward_pass);
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
......
...@@ -131,6 +131,8 @@ struct BuildStrategy { ...@@ -131,6 +131,8 @@ struct BuildStrategy {
bool fuse_gemm_epilogue_{false}; bool fuse_gemm_epilogue_{false};
// Fused multi head attention // Fused multi head attention
bool fused_attention_{false}; bool fused_attention_{false};
// Fused feed forward
bool fused_feedforward_{false};
// mkldnn_enabled_op_types specify the operator type list to // mkldnn_enabled_op_types specify the operator type list to
// use MKLDNN acceleration. It is null in default, means // use MKLDNN acceleration. It is null in default, means
...@@ -264,6 +266,7 @@ inline std::ostream &operator<<(std::ostream &os, ...@@ -264,6 +266,7 @@ inline std::ostream &operator<<(std::ostream &os,
os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl;
os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl;
os << "fused_attention_: " << strategy.fused_attention_ << std::endl; os << "fused_attention_: " << strategy.fused_attention_ << std::endl;
os << "fused_feedforward_: " << strategy.fused_feedforward_ << std::endl;
os << "mkldnn_enabled_op_types_: "; os << "mkldnn_enabled_op_types_: ";
for (auto str : strategy.mkldnn_enabled_op_types_) { for (auto str : strategy.mkldnn_enabled_op_types_) {
os << str << ", "; os << str << ", ";
......
...@@ -126,6 +126,7 @@ message BuildStrategy { ...@@ -126,6 +126,7 @@ message BuildStrategy {
optional bool fuse_gemm_epilogue = 16 [ default = false ]; optional bool fuse_gemm_epilogue = 16 [ default = false ];
optional string debug_graphviz_path = 17; optional string debug_graphviz_path = 17;
optional bool fused_attention = 18 [ default = false]; optional bool fused_attention = 18 [ default = false];
optional bool fused_feedforward = 19 [ default = false];
} }
message ExecutionStrategy { message ExecutionStrategy {
......
...@@ -264,6 +264,10 @@ cc_library( ...@@ -264,6 +264,10 @@ cc_library(
fuse_relu_depthwise_conv_pass fuse_relu_depthwise_conv_pass
SRCS fuse_relu_depthwise_conv_pass.cc SRCS fuse_relu_depthwise_conv_pass.cc
DEPS pass graph_pattern_detector) DEPS pass graph_pattern_detector)
cc_library(
fused_feedforward_pass
SRCS fused_feedforward_pass.cc
DEPS pass graph_pattern_detector)
set(GLOB_PASS_LIB set(GLOB_PASS_LIB
${INFER_IR_PASSES} ${INFER_IR_PASSES}
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex>
#include <string>
#include <unordered_set>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the FeedForward in attention
* Forward:
* 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
* -> residual_add (pre_layer_norm)
* 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
* -> layer_norm (pose_layer_norm)
* other cases: may delete mp, residual_add, dropout1, dropout2 operators
* Backward:
* 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
* activation_grad -> linear1_grad -> layer_norm_grad (pre_layer_norm)
* 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
* dropout1_grad -> activation_grad -> linear1_grad (pose_layer_norm)
* other cases: may delete mp, residual_add_grad, dropout1_grad, dropout2_grad
* operators
*/
class Graph;
class Node;
class FusedFeedForwardPass : public FusePassBase {
public:
virtual ~FusedFeedForwardPass() {}
protected:
// Used for pattern created variable node transfer
// between corresponding forward operator and backward operator.
struct DropoutNode {
Node *dropout_out_node_1;
Node *dropout_mask_node_1;
Node *dropout_out_node_2;
Node *dropout_mask_node_2;
DropoutNode()
: dropout_out_node_1(nullptr),
dropout_mask_node_1(nullptr),
dropout_out_node_2(nullptr),
dropout_mask_node_2(nullptr) {}
};
typedef std::unordered_map<Node *, DropoutNode> Cache;
const std::string scope_name{"fused_feedforward"};
void ApplyImpl(ir::Graph *graph) const override;
ir::Graph *FusedFeedForwardFwd(ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const;
ir::Graph *FusedFeedForwardBwd(ir::Graph *graph,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2,
Cache *dropout_nodes_map) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2156,6 +2156,133 @@ struct AddSupportInt8 : public PatternBase { ...@@ -2156,6 +2156,133 @@ struct AddSupportInt8 : public PatternBase {
PATTERN_DECL_NODE(quant_out); PATTERN_DECL_NODE(quant_out);
}; };
// The following patterns are used to fuse feedforward in forward
// 1. layer_norm -> linear1 -> activation -> dropout1 -> linear2 -> dropout2
// -> residual_add (pre_layer_norm)
// 2. linear1 -> activation -> dropout1 -> linear2 -> dropout2 -> residual_add
// -> layer_norm (pOST_layer_norm)
// other cases: may delete residual_add, dropout1, dropout2 operators
struct FusedFeedForwardFwd : public PatternBase {
FusedFeedForwardFwd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_feedforward_fwd") {}
PDNode* operator()(PDNode* x,
std::unordered_set<std::string> act_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2);
#ifndef FEEDFORWARD_LINEAR_DROPOUT_NODE
#define FEEDFORWARD_LINEAR_DROPOUT_NODE(suffix__) \
PATTERN_DECL_NODE(matmul_op_##suffix__); \
PATTERN_DECL_NODE(matmul_w_##suffix__); \
PATTERN_DECL_NODE(matmul_out_##suffix__); \
PATTERN_DECL_NODE(ele_add_op_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_##suffix__); \
PATTERN_DECL_NODE(ele_add_out_##suffix__); \
PATTERN_DECL_NODE(dropout_op_##suffix__); \
PATTERN_DECL_NODE(dropout_out_##suffix__); \
PATTERN_DECL_NODE(dropout_mask_##suffix__);
// LayerNorm: layer_norm
PATTERN_DECL_NODE(layer_norm_op);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
// Mode parallelism
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(c_allreduce_sum_op);
PATTERN_DECL_NODE(c_allreduce_sum_out);
// Linear 1 and Dropout 1: matmul_v2 + elementwise_add + dropout
FEEDFORWARD_LINEAR_DROPOUT_NODE(1);
// Activation Grad: gelu or relu
PATTERN_DECL_NODE(act_op);
PATTERN_DECL_NODE(act_out);
// Linear 2 and Dropout 2: matmul_v2 + elementwise_add + dropout
FEEDFORWARD_LINEAR_DROPOUT_NODE(2);
// ResidualAdd: elementwise_add
PATTERN_DECL_NODE(ele_add_op_3);
PATTERN_DECL_NODE(ele_add_out_3);
#undef FEEDFORWARD_LINEAR_DROPOUT_NODE
#endif
};
// The following patterns are used to fuse feedforward in backward
// 1. residual_add_grad -> dropout2_grad -> linear2_grad -> dropout1_grad ->
// activation_grad -> linear1_grad -> layer_norm_grad
// 2. layer_norm_grad -> residual_add_grad -> dropout2_grad -> linear2_grad ->
// dropout1_grad -> activation_grad -> linear1_grad
// other cases: may delete residual_add_grad, dropout1_grad, dropout2_grad
// operators
struct FusedFeedForwardBwd : public PatternBase {
FusedFeedForwardBwd(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fused_feedforward_bwd") {}
PDNode* operator()(PDNode* x,
std::unordered_set<std::string> act_grad_types,
bool use_mp,
bool pre_layer_norm,
bool add_residual,
bool use_dropout_1,
bool use_dropout_2);
#ifndef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE
#define FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(suffix__) \
PATTERN_DECL_NODE(matmul_op_grad_##suffix__); \
PATTERN_DECL_NODE(matmul_in_##suffix__); \
PATTERN_DECL_NODE(matmul_w_##suffix__); \
PATTERN_DECL_NODE(matmul_in_grad_##suffix__); \
PATTERN_DECL_NODE(matmul_w_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_op_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_in_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_##suffix__); \
PATTERN_DECL_NODE(ele_add_in_grad_##suffix__); \
PATTERN_DECL_NODE(ele_add_bias_grad_##suffix__); \
PATTERN_DECL_NODE(dropout_op_grad_##suffix__); \
PATTERN_DECL_NODE(dropout_mask_##suffix__); \
PATTERN_DECL_NODE(dropout_in_grad_##suffix__);
// LayerNorm Grad: layer_norm_grad
PATTERN_DECL_NODE(layer_norm_op_grad);
PATTERN_DECL_NODE(layer_norm_in);
PATTERN_DECL_NODE(layer_norm_mean);
PATTERN_DECL_NODE(layer_norm_variance);
PATTERN_DECL_NODE(layer_norm_scale);
PATTERN_DECL_NODE(layer_norm_bias);
PATTERN_DECL_NODE(layer_norm_in_grad);
PATTERN_DECL_NODE(layer_norm_scale_grad);
PATTERN_DECL_NODE(layer_norm_bias_grad);
// Mode parallelism
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_out);
PATTERN_DECL_NODE(c_allreduce_sum_op);
PATTERN_DECL_NODE(c_allreduce_sum_out);
// Linear 1 and Dropout 1: matmul_v2_grad + elementwise_add_grad +
// dropout_grad
FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(1);
// Activation Grad: gelu_grad or relu_add
PATTERN_DECL_NODE(act_op_grad);
PATTERN_DECL_NODE(act_in);
PATTERN_DECL_NODE(act_in_grad);
// Linear 2 and Dropout 2: matmul_v2_grad + elementwise_add_grad +
// dropout_grad
FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE(2);
// Residual Add: elementwise_add
PATTERN_DECL_NODE(ele_add_op_grad_3);
PATTERN_DECL_NODE(ele_add_in_3);
PATTERN_DECL_NODE(ele_add_bias_3);
PATTERN_DECL_NODE(ele_add_in_grad_3);
PATTERN_DECL_NODE(ele_add_bias_grad_3);
PATTERN_DECL_NODE(sum_op);
PATTERN_DECL_NODE(sum_out);
#undef FEEDFORWARD_LINEAR_DROPOUT_GRAD_NODE
#endif
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
...@@ -723,6 +723,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT ...@@ -723,6 +723,32 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
build_strategy = static.BuildStrategy() build_strategy = static.BuildStrategy()
build_strategy.fused_attention = True build_strategy.fused_attention = True
)DOC") )DOC")
.def_property(
"fused_feedforward",
[](const BuildStrategy &self) { return self.fused_feedforward_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, cannot be "
"configured again."));
self.fused_feedforward_ = b;
},
R"DOC((bool, optional): fused_feedforward indicate whether
to fuse the whole feed_forward part with one op,
it may make the execution faster. Default is False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.fused_feedforward = True
)DOC")
.def_property( .def_property(
"fuse_bn_act_ops", "fuse_bn_act_ops",
[](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; },
......
...@@ -84,6 +84,19 @@ class FusedAttentionPass(CPPPassWrapper): ...@@ -84,6 +84,19 @@ class FusedAttentionPass(CPPPassWrapper):
return PassType.FUSION_OPT return PassType.FUSION_OPT
@register_pass("fused_feedforward")
class FusedFeedforwardPass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fused_feedforward_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_gemm_epilogue") @register_pass("fuse_gemm_epilogue")
class FuseGemmEpiloguePass(CPPPassWrapper): class FuseGemmEpiloguePass(CPPPassWrapper):
def __init__(self): def __init__(self):
......
...@@ -253,6 +253,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [ ...@@ -253,6 +253,7 @@ PassBase._PASS_PROCESS_ORDER_LIST = [
"fuse_bn_add_act", "fuse_bn_add_act",
"fuse_bn_act", "fuse_bn_act",
"fused_attention", "fused_attention",
"fused_feedforward",
"fuse_gemm_epilogue", "fuse_gemm_epilogue",
"fuse_optimizer", "fuse_optimizer",
] ]
......
...@@ -76,6 +76,7 @@ if(NOT WITH_GPU) ...@@ -76,6 +76,7 @@ if(NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op)
list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api) list(REMOVE_ITEM TEST_OPS test_fused_bias_dropout_residual_layer_norm_op_api)
list(REMOVE_ITEM TEST_OPS test_fused_attention_pass) list(REMOVE_ITEM TEST_OPS test_fused_attention_pass)
list(REMOVE_ITEM TEST_OPS test_fused_feedforward_pass)
endif() endif()
list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op) list(REMOVE_ITEM TEST_OPS test_fused_ec_moe_op)
......
...@@ -89,6 +89,7 @@ class TestFusedPassBaseList(unittest.TestCase): ...@@ -89,6 +89,7 @@ class TestFusedPassBaseList(unittest.TestCase):
[ [
"fuse_bn_act", "fuse_bn_act",
"fused_attention", "fused_attention",
"fused_feedforward",
"fuse_optimizer", "fuse_optimizer",
"fuse_gemm_epilogue", "fuse_gemm_epilogue",
"fuse_bn_add_act", "fuse_bn_add_act",
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import paddle.nn as nn
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
class FeedForward(nn.Layer):
def __init__(
self,
in_features,
hidden_features,
out_features,
drop_prob=0.1,
act_layer=nn.GELU,
pre_layer_norm=True,
add_residual=True,
use_dropout_1=True,
use_dropout_2=True,
):
super(FeedForward, self).__init__()
self.in_features = in_features
self.hidden_features = hidden_features
self.in_features = out_features
self.pre_layer_norm = pre_layer_norm
self.add_residual = add_residual
self.use_dropout_1 = use_dropout_1
self.use_dropout_2 = use_dropout_2
self.fc1 = nn.Linear(in_features, in_features)
self.fc2 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc3 = nn.Linear(hidden_features, out_features)
self.drop1 = nn.Dropout(drop_prob)
self.drop2 = nn.Dropout(drop_prob)
self.norm = nn.LayerNorm(in_features, epsilon=1e-5)
self.fc4 = nn.Linear(out_features, out_features)
def forward(self, x):
x = self.fc1(x)
residual = x
if self.pre_layer_norm:
x = self.norm(x)
x = self.fc2(x)
x = self.act(x)
if self.use_dropout_1:
x = self.drop1(x)
x = self.fc3(x)
if self.use_dropout_2:
x = self.drop2(x)
if self.add_residual:
x += residual
if not self.pre_layer_norm:
x = self.norm(x)
x = self.fc4(x)
return x
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestFusedFeedforwadPass(unittest.TestCase):
def setUp(self):
self.pre_layer_norm = True
self.add_residual = True
self.use_dropout_1 = True
self.use_dropout_2 = True
def get_value(self, use_pass=False):
batch_size = 2
in_features = 768
hidden_features = 3072
out_features = 768
act_layer = nn.GELU
pre_layer_norm = self.pre_layer_norm
add_residual = self.add_residual
use_dropout_1 = self.use_dropout_1
use_dropout_2 = self.use_dropout_2
np.random.seed(1234)
x_data = np.random.rand(batch_size, in_features, in_features).astype(
'float32'
)
main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234
with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
name="x",
shape=[2, in_features, in_features],
dtype='float32',
)
feed_forward = FeedForward(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
drop_prob=1e-10,
act_layer=act_layer,
pre_layer_norm=pre_layer_norm,
add_residual=add_residual,
use_dropout_1=use_dropout_1,
use_dropout_2=use_dropout_2,
)
out = feed_forward(data)
loss = paddle.mean(out)
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)
if use_pass:
pass_manager = PassManager([new_pass("fused_feedforward")])
pass_manager.apply([main_prog], [startup_prog])
ops = main_prog.global_block().ops
assert 'fused_feedforward' in [op.type for op in ops]
assert 'fused_feedforward_grad' in [op.type for op in ops]
exe = paddle.static.Executor(paddle.CUDAPlace(0))
exe.run(startup_prog)
for i in range(2):
ret_loss = exe.run(
main_prog, feed={"x": x_data}, fetch_list=[loss.name]
)
return ret_loss
def test_pass(self):
for pre_layer_norm in [True, False]:
for add_residual in [True, False]:
for use_dropout_1 in [True, False]:
for use_dropout_2 in [True, False]:
if not pre_layer_norm and not add_residual:
continue
if not use_dropout_1 and not use_dropout_2:
continue
self.pre_layer_norm = pre_layer_norm
self.add_residual = add_residual
self.use_dropout_1 = use_dropout_1
self.use_dropout_2 = use_dropout_2
ret_loss = self.get_value()
ret_loss_fused = self.get_value(use_pass=True)
assert np.allclose(ret_loss, ret_loss_fused)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册