未验证 提交 1e3245a8 编写于 作者: R RichardWooSJTU 提交者: GitHub

Fuse multi transformer layer pass (#47541)

* add fuse_multi_transformer_layer_pass
上级 3addd568
...@@ -107,6 +107,7 @@ pass_library(skip_layernorm_fuse_pass base) ...@@ -107,6 +107,7 @@ pass_library(skip_layernorm_fuse_pass base)
pass_library(multihead_matmul_fuse_pass inference) pass_library(multihead_matmul_fuse_pass inference)
pass_library(fused_multi_transformer_encoder_pass inference) pass_library(fused_multi_transformer_encoder_pass inference)
pass_library(fused_multi_transformer_decoder_pass inference) pass_library(fused_multi_transformer_decoder_pass inference)
pass_library(fuse_multi_transformer_layer_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(yolo_box_fuse_pass inference) pass_library(yolo_box_fuse_pass inference)
...@@ -326,6 +327,10 @@ cc_test( ...@@ -326,6 +327,10 @@ cc_test(
test_fused_multi_transformer_decoder_pass test_fused_multi_transformer_decoder_pass
SRCS fused_multi_transformer_decoder_pass_tester.cc SRCS fused_multi_transformer_decoder_pass_tester.cc
DEPS fused_multi_transformer_decoder_pass) DEPS fused_multi_transformer_decoder_pass)
cc_test(
test_fuse_multi_transformer_layer_pass
SRCS fuse_multi_transformer_layer_pass_tester.cc
DEPS fuse_multi_transformer_layer_pass)
cc_test( cc_test(
test_conv_bn_fuse_pass_cc test_conv_bn_fuse_pass_cc
SRCS conv_bn_fuse_pass_tester.cc SRCS conv_bn_fuse_pass_tester.cc
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h"
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
std::unordered_map<std::string, std::string>
MultiTransformerLayerPattern::operator()(bool enable_int8,
int num_fused_op,
bool is_decoder) {
std::string fused_multi_transformer_name =
enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer";
std::unordered_map<std::string, std::string> node_reprs;
// x0 and src_mask is unqiue input of subgraph
auto* x0 = pattern->NewNode(x0_repr());
x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput();
auto* src_mask = pattern->NewNode(src_mask_repr());
src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask")
->AsInput();
for (int i = 0; i < num_fused_op; ++i) {
auto fuse_op_repr =
PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i));
node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr;
auto* fused_multi_transformer =
pattern->NewNode(fuse_op_repr)
->assert_is_op(fused_multi_transformer_name);
auto out_repr =
PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i));
node_reprs["out_" + std::to_string(i)] = out_repr;
auto* out = pattern->NewNode(out_repr)->assert_is_op_output(
fused_multi_transformer_name, "Out");
if (is_decoder) {
auto shape_repr =
PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i));
node_reprs["shape_" + std::to_string(i)] = shape_repr;
auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape");
auto shape_out_repr =
PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i));
node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr;
auto* shape_out =
pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out");
shape->LinksFrom({src_mask}).LinksTo({shape_out});
auto slice_repr =
PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i));
node_reprs["slice_" + std::to_string(i)] = slice_repr;
auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice");
auto slice_out_repr =
PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i));
node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr;
auto* slice_out =
pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out");
slice->LinksFrom({shape_out}).LinksTo({slice_out});
fused_multi_transformer->LinksFrom({x0, src_mask, slice_out})
.LinksTo({out});
} else {
auto cache_kv_repr =
PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i));
node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr;
auto* cache_kv = pattern->NewNode(cache_kv_repr);
cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV");
cache_kv->AsInput();
auto fill_const_op_repr =
PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i));
node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr;
auto fill_const_op = pattern->NewNode(fill_const_op_repr)
->assert_is_op("fill_constant_batch_size_like");
fused_multi_transformer->LinksFrom({x0, src_mask, cache_kv})
.LinksTo({out});
fill_const_op->LinksFrom({x0}).LinksTo({cache_kv});
}
x0 = out;
}
x0->AsOutput();
return node_reprs;
}
} // namespace patterns
inline void MergeInput(OpDesc* op,
const std::vector<VariableNameMap>& input_name_maps,
const std::string& input_name) {
std::vector<std::string> tmp = input_name_maps[0].at(input_name);
for (size_t i = 1; i < input_name_maps.size(); ++i) {
tmp.insert(tmp.end(),
input_name_maps[i].at(input_name).begin(),
input_name_maps[i].at(input_name).end());
}
op->SetInput(input_name, tmp);
}
template <typename T>
inline void MergeAttrs(const std::vector<OpDesc*>& ops,
const std::string& attr_name) {
std::vector<T> res;
for (size_t i = 0; i < ops.size(); ++i) {
auto scale_vec =
PADDLE_GET_CONST(std::vector<T>, ops[i]->GetAttr(attr_name));
res.insert(res.end(), scale_vec.begin(), scale_vec.end());
}
ops[0]->SetAttr(attr_name, res);
}
int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// TODO(wufeisheng): Get enable_int8 attr from graph after
// fused_multi_transformer pass with int8 merged
bool enable_int8 = false;
int num_fuse_op = 0;
bool is_decoder = false;
if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) {
num_fuse_op = graph->Get<int>(kFusedMultiTransformerEncoderFusionCount);
is_decoder = false;
} else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) {
num_fuse_op = graph->Get<int>(kFusedMultiTransformerDecoderFusionCount);
is_decoder = true;
}
if (num_fuse_op == 0) {
VLOG(4) << "fuse_multi_transformer_layer_pass will be skipped "
"cause num_fuse_op is not been set or set to 0";
return 0;
}
if (!is_decoder) {
VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern";
} else {
VLOG(4) << "fuse_multi_transformer_layer_pass will match decoder pattern";
}
patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern,
name_scope);
auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder);
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
///////////////////
//// Get nodes ////
///////////////////
GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern);
GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern);
std::vector<Node*> fuse_op_nodes;
std::vector<Node*> out_nodes;
std::vector<std::string> unused_node_prefixes = {
"shape_", "shape_out_", "slice_", "slice_out_"};
std::vector<Node*> unused_nodes;
std::vector<OpDesc*> fuse_op_descs;
std::vector<VariableNameMap> fuse_op_input_var_name_maps;
std::vector<VariableNameMap> fuse_op_output_var_name_maps;
for (int i = 0; i < num_fuse_op; ++i) {
PDNode* fuse_op_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["fuse_op_" + std::to_string(i)]);
Node* fuse_op_node = subgraph.at(fuse_op_pdnode);
fuse_op_nodes.push_back(fuse_op_node);
fuse_op_descs.push_back(fuse_op_node->Op());
fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs());
fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs());
PDNode* out_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["out_" + std::to_string(i)]);
out_nodes.push_back(subgraph.at(out_pdnode));
// fill_const op use x0 as input
if (!is_decoder && i != 0) {
PDNode* fill_op_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs["fill_op_" + std::to_string(i)]);
Node* fill_op_node = subgraph.at(fill_op_pdnode);
fill_op_node->Op()->SetInput("Input", {x0->Name()});
IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node);
IR_NODE_LINK_TO(x0, fill_op_node);
} else if (is_decoder && i != 0) {
for (const auto& unused_node_prefix : unused_node_prefixes) {
PDNode* unused_pdnode =
multi_layer_pattern.PatternBase::pattern->RetrieveNode(
node_reprs[unused_node_prefix + std::to_string(i)]);
Node* unused_node = subgraph.at(unused_pdnode);
unused_nodes.push_back(unused_node);
}
}
}
///////////////
//// Merge ////
///////////////
// Merge inputs
std::vector<std::string> inputs_names = {"CacheKV",
"FFN1Bias",
"FFN1Weight",
"FFN2Bias",
"FFN2Weight",
"FFNLnBias",
"FFNLnScale",
"LnBias",
"LnScale",
"OutLinearBias",
"OutLinearW",
"QKVBias",
"QKVW"};
for (const auto& input_name : inputs_names) {
MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name);
}
// Merge outputs
fuse_op_descs[0]->SetOutput(
"Out", fuse_op_output_var_name_maps[num_fuse_op - 1]["Out"]);
auto& merged_cache_kv_out_names =
fuse_op_output_var_name_maps[0]["CacheKVOut"];
for (int i = 1; i < num_fuse_op; ++i) {
const auto& out_var_names = fuse_op_output_var_name_maps[i]["CacheKVOut"];
merged_cache_kv_out_names.insert(merged_cache_kv_out_names.end(),
out_var_names.begin(),
out_var_names.end());
}
fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names);
////////////////
//// ReLink ////
////////////////
// Before relink, out nodes (0 -> num_layer-1) should be removed
std::unordered_set<const Node*> marked_out_nodes(out_nodes.begin(),
out_nodes.end() - 1);
GraphSafeRemoveNodes(graph, marked_out_nodes);
// Relink all input nodes of fused_multi_transformer ops to the first op
auto& merged_inputs = fuse_op_nodes[0]->inputs;
for (int i = 1; i < num_fuse_op; ++i) {
merged_inputs.insert(merged_inputs.end(),
fuse_op_nodes[i]->inputs.begin(),
fuse_op_nodes[i]->inputs.end());
}
// Relink fuse op -> out
IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]);
IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]);
/////////////////////////////
//// Delete unused nodes ////
/////////////////////////////
// Delete fused_multi_transformer op expect for the first one
std::unordered_set<const Node*> marked_fuse_op_nodes(
fuse_op_nodes.begin() + 1, fuse_op_nodes.end());
if (is_decoder) {
marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end());
}
GraphSafeRemoveNodes(graph, marked_fuse_op_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal("During the fuse_multi_transformer_layer pass, "
"The scope should not be null."));
int fusion_count = BuildFusion(graph, name_scope_, scope);
AddStatis(fusion_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fuse_multi_transformer_layer_pass,
paddle::framework::ir::FuseMultiTransformerLayerPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct MultiTransformerLayerPattern : public PatternBase {
MultiTransformerLayerPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "fuse_multi_transformer_layer") {}
std::unordered_map<std::string, std::string> operator()(
bool enable_int8, int num_fused_op = 1, bool is_decoder = false);
PATTERN_DECL_NODE(src_mask);
PATTERN_DECL_NODE(x0);
};
} // namespace patterns
class FuseMultiTransformerLayerPass : public FusePassBase {
public:
FuseMultiTransformerLayerPass() {}
virtual ~FuseMultiTransformerLayerPass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"fuse_multi_transformer_layer"};
private:
int BuildFusion(Graph* graph,
const std::string& name_scope,
Scope* scope) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_version_registry.h"
#define DEF_INPUT_DATA \
Layers layers; \
int num_layers = 3; \
auto* x = layers.data("x", {1, 128, 1024}); \
auto* src_mask = layers.data("src_mask", {1, 16, 128, 128}); \
auto* ln_scale = layers.data("ln_scale", {1024}, true); \
auto* ln_bias = layers.data("ln_bias", {1024}, true); \
auto* ffn_ln_scale = layers.data("ffn_ln_scale", {1024}, true); \
auto* ffn_ln_bias = layers.data("ffn_ln_bias", {1024}, true); \
auto* qkv_w = layers.data("qkv_w", {3, 16, 64, 1024}, true); \
auto* out_linear_w = layers.data("out_linear_w", {1024, 1024}, true); \
auto* ffn1_w = layers.data("ffn1_w", {1024, 4096}, true); \
auto* ffn2_w = layers.data("ffn2_w", {4096, 1024}, true); \
auto* qkv_bias = layers.data("qkv_bias", {3072}, true); \
auto* out_linear_bias = layers.data("out_linear_bias", {1024}, true); \
auto* ffn1_bias = layers.data("ffn1_bias", {4096}, true); \
auto* ffn2_bias = layers.data("ffn2_bias", {1024}, true);
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "ln_scale", {1024});
AddVarToScope(param_scope, "ln_bias", {1024});
AddVarToScope(param_scope, "ffn_ln_scale", {1024});
AddVarToScope(param_scope, "ffn_ln_bias", {1024});
AddVarToScope(param_scope, "qkv_w", {3, 16, 64, 1024});
AddVarToScope(param_scope, "out_linear_w", {1024, 1024});
AddVarToScope(param_scope, "ffn1_w", {1024, 4096});
AddVarToScope(param_scope, "ffn2_w", {4096, 1024});
AddVarToScope(param_scope, "qkv_bias", {3072});
AddVarToScope(param_scope, "out_linear_bias", {1024});
AddVarToScope(param_scope, "ffn1_bias", {4096});
AddVarToScope(param_scope, "ffn2_bias", {1024});
return param_scope;
}
TEST(FuseMultiTransformerLayerPass, encoder_fp) {
DEF_INPUT_DATA
// Layers
for (int i = 0; i < num_layers; ++i) {
auto* cache_kv = layers.fill_constant_batch_size_like(
x,
static_cast<int>(proto::VarType::FP32),
0,
1,
{2, -1, 16, 1024, 64},
0);
auto* out = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12);
x = out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(num_layers));
auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fuse_multi_transformer_layer_pass failed";
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
TEST(FuseMultiTransformerLayerPass, decoder_fp) {
DEF_INPUT_DATA
x = layers.data("x", {1, 1, 1024});
auto* cache_kv = layers.data("cache_kv", {2, 1, 16, 1024, 64}, true);
src_mask = layers.data("src_mask", {1, 16, 1, 129});
// Layers
for (int i = 0; i < num_layers; ++i) {
auto* shape_out = layers.shape(src_mask);
auto* time_stamp = layers.slice(shape_out, {0}, {3}, {4});
auto* out = layers.fused_multi_transformer(x,
cache_kv,
src_mask,
qkv_w,
qkv_bias,
out_linear_w,
out_linear_bias,
ffn1_w,
ffn1_bias,
ffn2_w,
ffn2_bias,
ln_scale,
ln_bias,
ffn_ln_scale,
ffn_ln_bias,
0.1,
1e-12,
time_stamp);
x = out;
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto param_scope = CreateParamScope();
AddVarToScope(param_scope, "cache_kv", {2, 1, 16, 1024, 64});
graph->Set("__param_scope__", param_scope);
graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(num_layers));
auto pass = PassRegistry::Instance().Get("fuse_multi_transformer_layer_pass");
if (pass.get() == nullptr)
LOG(INFO) << "get fuse_multi_transformer_layer_pass failed";
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = GetNumOpNodes(graph, "fused_multi_transformer");
PADDLE_ENFORCE_EQ(
num_nodes_after,
1,
platform::errors::InvalidArgument(
"After the fuse_multi_transformer_layer_pass, "
"The node num in graph should be 1, but the result is %d",
num_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fuse_multi_transformer_layer_pass);
...@@ -1565,6 +1565,7 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const { ...@@ -1565,6 +1565,7 @@ void FusedMultiTransformerDecoderPass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderPass, new bool(true)); graph->Set(kFusedMultiTransformerDecoderPass, new bool(true));
graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
...@@ -2178,6 +2179,7 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const { ...@@ -2178,6 +2179,7 @@ void FusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
...@@ -2833,6 +2835,7 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl( ...@@ -2833,6 +2835,7 @@ void MultiDevicesFusedMultiTransformerDecoderFuseQKVPass::ApplyImpl(
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true)); graph->Set(kFusedMultiTransformerDecoderFuseQKVPass, new bool(true));
graph->Set(kFusedMultiTransformerDecoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -1728,6 +1728,7 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const { ...@@ -1728,6 +1728,7 @@ void FusedMultiTransformerEncoderPass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderPass, new bool(true));
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
...@@ -2380,6 +2381,7 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const { ...@@ -2380,6 +2381,7 @@ void FusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(Graph* graph) const {
int fusion_count = BuildFusion(graph, name_scope_, scope); int fusion_count = BuildFusion(graph, name_scope_, scope);
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true)); graph->Set(kFusedMultiTransformerEncoderFuseQKVPass, new bool(true));
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
...@@ -3076,6 +3078,7 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl( ...@@ -3076,6 +3078,7 @@ void MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::ApplyImpl(
if (fusion_count > 0) { if (fusion_count > 0) {
graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass, graph->Set(kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass,
new bool(true)); new bool(true));
graph->Set(kFusedMultiTransformerEncoderFusionCount, new int(fusion_count));
} }
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -46,7 +46,7 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -46,7 +46,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fused_multi_transformer_decoder_fuse_qkv_pass", "fused_multi_transformer_decoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass",
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
}; "fuse_multi_transformer_layer_pass"};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
......
...@@ -59,6 +59,10 @@ constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] = ...@@ -59,6 +59,10 @@ constexpr char kMultiDevicesFusedMultiTransformerEncoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag"; "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass_flag";
constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] = constexpr char kMultiDevicesFusedMultiTransformerDecoderFuseQKVPass[] =
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag"; "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass_flag";
constexpr char kFusedMultiTransformerEncoderFusionCount[] =
"fused_multi_transformer_encoder_fusion_count";
constexpr char kFusedMultiTransformerDecoderFusionCount[] =
"fused_multi_transformer_decoder_fusion_count";
constexpr char kPrelnEmbEltwiseLayernormPass[] = constexpr char kPrelnEmbEltwiseLayernormPass[] =
"preln_embedding_eltwise_layernorm_fuse_pass_flag"; "preln_embedding_eltwise_layernorm_fuse_pass_flag";
......
...@@ -528,6 +528,119 @@ struct Layers { ...@@ -528,6 +528,119 @@ struct Layers {
return out; return out;
} }
VarDesc* shape(VarDesc* input) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("shape");
op->SetInput("Input", {input->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* slice(VarDesc* input,
std::vector<int> axes,
std::vector<int> starts,
std::vector<int> ends) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("slice");
op->SetInput("Input", {input->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("axes", axes);
op->SetAttr("starts", starts);
op->SetAttr("ends", ends);
return out;
}
VarDesc* fill_constant_batch_size_like(VarDesc* x,
int dtype,
int input_dim_idx,
int output_dim_idx,
std::vector<int> shape,
float value) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("fill_constant_batch_size_like");
op->SetInput("Input", {x->Name()});
op->SetAttr("dtype", dtype);
op->SetAttr("input_dim_idx", input_dim_idx);
op->SetAttr("output_dim_idx", output_dim_idx);
op->SetAttr("shape", shape);
op->SetAttr("value", value);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* fused_multi_transformer(VarDesc* x,
VarDesc* cache_kv,
VarDesc* src_mask,
VarDesc* qkv_w,
VarDesc* qkv_bias,
VarDesc* out_linear_w,
VarDesc* out_linear_bias,
VarDesc* ffn1_w,
VarDesc* ffn1_bias,
VarDesc* ffn2_w,
VarDesc* ffn2_bias,
VarDesc* ln_scale,
VarDesc* ln_bias,
VarDesc* ffn_ln_scale,
VarDesc* ffn_ln_bias,
float epsilon,
float dropout_rate,
VarDesc* time_stamp = nullptr,
VarDesc* qkv_out_scale = nullptr,
VarDesc* out_linear_out_scale = nullptr,
VarDesc* ffn1_out_scale = nullptr,
VarDesc* ffn2_out_scale = nullptr,
std::vector<float> qkv_in_scale = {},
std::vector<float> out_linear_in_scale = {},
std::vector<float> ffn1_in_scale = {},
std::vector<float> ffn2_in_scale = {}) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
std::string op_type = qkv_out_scale ? "fused_multi_transformer_int8"
: "fused_multi_transformer";
op->SetType(op_type);
op->SetInput("X", {x->Name()});
op->SetInput("CacheKV", {cache_kv->Name()});
op->SetInput("SrcMask", {src_mask->Name()});
op->SetInput("QKVW", {qkv_w->Name()});
op->SetInput("QKVBias", {qkv_bias->Name()});
op->SetInput("OutLinearW", {out_linear_w->Name()});
op->SetInput("OutLinearBias", {out_linear_bias->Name()});
op->SetInput("FFN1Weight", {ffn1_w->Name()});
op->SetInput("FFN1Bias", {ffn1_bias->Name()});
op->SetInput("FFN2Weight", {ffn2_w->Name()});
op->SetInput("FFN2Bias", {ffn2_bias->Name()});
op->SetInput("LnScale", {ln_scale->Name()});
op->SetInput("LnBias", {ln_bias->Name()});
op->SetInput("FFNLnScale", {ffn_ln_scale->Name()});
op->SetInput("FFNLnBias", {ffn_ln_bias->Name()});
op->SetAttr("pre_layer_norm", true);
op->SetAttr("is_test", true);
op->SetAttr("dropout_implementation", "upscale_in_train");
op->SetAttr("dropout_rate", dropout_rate);
op->SetAttr("epsilon", epsilon);
op->SetOutput("Out", {out->Name()});
if (time_stamp) {
op->SetInput("TimeStep", {time_stamp->Name()});
}
if (qkv_out_scale) {
op->SetInput("QKVOutScale", {qkv_out_scale->Name()});
op->SetInput("OutLinearOutScale", {out_linear_out_scale->Name()});
op->SetInput("FFN1OutScale", {ffn1_out_scale->Name()});
op->SetInput("FFN2OutScale", {ffn2_out_scale->Name()});
op->SetAttr("qkv_in_scale", qkv_in_scale);
op->SetAttr("out_linear_in_scale", out_linear_in_scale);
op->SetAttr("ffn1_in_scale", ffn1_in_scale);
op->SetAttr("ffn2_in_scale", ffn2_in_scale);
}
return out;
}
void backward(std::vector<VarDesc*> targets) { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program, // This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program. // but is constructed differently as the actual program.
......
...@@ -212,6 +212,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -212,6 +212,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"fused_multi_transformer_decoder_fuse_qkv_pass", // "fused_multi_transformer_decoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_encoder_fuse_qkv_pass", //
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", // "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", //
"fuse_multi_transformer_layer_pass", //
"gpu_cpu_squeeze2_matmul_fuse_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
"gpu_cpu_flatten2_matmul_fuse_pass", // "gpu_cpu_flatten2_matmul_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册