未验证 提交 bfa5d6b8 编写于 作者: Z zhupengyang 提交者: GitHub

transform cachekv datalayout of fused_multi_transformer_xpu (#53144)

上级 ae426b78
......@@ -245,6 +245,8 @@ if(WITH_XPU)
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_cachekv_layout_trans_pass inference DIR
xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -528,4 +530,8 @@ if(WITH_XPU)
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
cc_test(
test_fused_multi_transformer_cachekv_layout_trans_pass
SRCS xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc
DEPS fused_multi_transformer_cachekv_layout_trans_pass)
endif()
......@@ -62,6 +62,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_cachekv_layout_trans_pass",
"one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FusedMultiTransformerFillConstantPattern : public PatternBase {
FusedMultiTransformerFillConstantPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(fill_constant);
PATTERN_DECL_NODE(fused_multi_transformer);
// declare variable node's name
PATTERN_DECL_NODE(fill_constant_out);
}; // struct FusedMultiTransformerFillConstantPattern
FusedMultiTransformerFillConstantPattern::
FusedMultiTransformerFillConstantPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fill_constant = pattern->NewNode(fill_constant_repr())
->assert_is_op("fill_constant")
->assert_has_n_inputs(5)
->assert_more([](Node* node) {
return node->Op()->GetAttrIfExists<std::string>(
"friendly_device_type") != "xpu";
});
auto* fill_constant_out = pattern->NewNode(fill_constant_out_repr())
->assert_is_op_output("fill_constant", "Out");
auto* fused_multi_transformer =
pattern->NewNode(fused_multi_transformer_repr())
->assert_is_op("fused_multi_transformer");
fill_constant->LinksTo({fill_constant_out});
fused_multi_transformer->LinksFrom({fill_constant_out});
}
struct FusedMultiTransformerGatherPattern : public PatternBase {
FusedMultiTransformerGatherPattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(fused_multi_transformer);
PATTERN_DECL_NODE(gather);
// declare variable node's name
PATTERN_DECL_NODE(gather_in);
PATTERN_DECL_NODE(gather_out);
}; // struct FusedMultiTransformerGatherPattern
FusedMultiTransformerGatherPattern::FusedMultiTransformerGatherPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* gather_in =
pattern->NewNode(gather_in_repr())->assert_is_op_input("gather", "X");
auto* gather = pattern->NewNode(gather_repr())
->assert_is_op("gather")
->assert_more([](Node* node) {
return node->Op()->GetAttrIfExists<int>("axis") == 1;
});
auto* gather_out =
pattern->NewNode(gather_out_repr())->assert_is_op_output("gather", "Out");
auto* fused_multi_transformer =
pattern->NewNode(fused_multi_transformer_repr())
->assert_is_op("fused_multi_transformer");
gather->LinksFrom({gather_in}).LinksTo({gather_out});
fused_multi_transformer->LinksFrom({gather_out});
}
} // namespace patterns
void FusedMultiTransformerCacheKVLayoutTransPass::FillConstantReshapePass(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::FusedMultiTransformerFillConstantPattern pattern(
gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FillConstantReshapePass";
GET_IR_NODE(fused_multi_transformer);
GET_IR_NODE(fill_constant);
GET_IR_NODE(fill_constant_out);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
fill_constant_out->Name()) == 0)
return;
auto fill_constant_input_names =
fill_constant->Op()->Input("ShapeTensorList");
auto fill_constant_trans_input_names =
std::vector<std::string>{fill_constant_input_names[0],
fill_constant_input_names[3],
fill_constant_input_names[1],
fill_constant_input_names[2],
fill_constant_input_names[4]};
fill_constant->Op()->SetInput("ShapeTensorList",
fill_constant_trans_input_names);
auto fill_constant_output_shape = fill_constant_out->Var()->GetShape();
fill_constant_out->Var()->SetShape({fill_constant_output_shape[0],
fill_constant_output_shape[3],
fill_constant_output_shape[1],
fill_constant_output_shape[2],
fill_constant_output_shape[4]});
fused_multi_transformer->Op()->SetAttr("friendly_device_type",
std::string("xpu"));
fill_constant->Op()->SetAttr("friendly_device_type", std::string("xpu"));
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::FusedMultiTransformerGatherPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GatherReshapePass";
GET_IR_NODE(gather);
GET_IR_NODE(fused_multi_transformer);
GET_IR_NODE(gather_in);
GET_IR_NODE(gather_out);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
gather_out->Name()) == 0)
return;
auto gather_in_shape = gather_in->Var()->GetShape();
auto gather_out_shape = gather_out->Var()->GetShape();
gather_in->Var()->SetShape({gather_in_shape[0],
gather_in_shape[3],
gather_in_shape[1],
gather_in_shape[2],
gather_in_shape[4]});
gather_out->Var()->SetShape({gather_out_shape[0],
gather_out_shape[3],
gather_out_shape[1],
gather_out_shape[2],
gather_out_shape[4]});
gather->Op()->SetAttr("axis", 2);
fused_multi_transformer->Op()->SetAttr("friendly_device_type",
std::string("xpu"));
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FillConstantReshapePass(graph);
GatherReshapePass(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(
fused_multi_transformer_cachekv_layout_trans_pass,
paddle::framework::ir::FusedMultiTransformerCacheKVLayoutTransPass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
(ShapeTensorList: [d0,d1,d2,d3,d4])
|
fill_constant
|
fused_multi_transformer
Fused subgraph:
(ShapeTensorList: [d0,d3,d1,d2,d4])
|
fill_constant
|
fused_multi_transformer
*/
void FillConstantReshapePass(ir::Graph* graph) const;
/*
Origin subgraph:
(gather_x: [d0,d1,d2,d3,d4])
|
gather(axis=1)
|
fused_multi_transformer
Fused subgraph:
(gather_x: [d0,d3,d1,d2,d4])
|
gather(axis=2)
|
fused_multi_transformer
*/
void GatherReshapePass(ir::Graph* graph) const;
const std::string name_scope_{
"fused_multi_transformer_cachekv_layout_trans_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
VarDesc* fill_constant(BlockDesc* block, std::vector<VarDesc*> shapes) {
VarDesc* out = Data(block, shapes[0]->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("fill_constant");
std::vector<std::string> shape_names;
for (auto shape : shapes) {
shape_names.push_back(shape->Name());
}
op->SetInput("ShapeTensorList", {shape_names});
op->SetOutput("Out", {out->Name()});
return out;
}
TEST(FillConstantReshapePass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* shape0 = Data(block, "shape0");
auto* shape1 = Data(block, "shape1");
auto* shape2 = Data(block, "shape2");
auto* shape3 = Data(block, "shape3");
auto* shape4 = Data(block, "shape4");
auto* shape5 = Data(block, "shape5");
auto* shape6 = Data(block, "shape6");
auto* shape7 = Data(block, "shape7");
auto* shape8 = Data(block, "shape8");
auto* shape9 = Data(block, "shape9");
auto* fill0 = fill_constant(block, {shape0, shape1, shape2, shape3, shape4});
fill0->SetShape({1, 2, 3, 4, 5});
auto* fill1 = fill_constant(block, {shape5, shape6, shape7, shape8, shape9});
fill1->SetShape({1, 2, 3, 4, 5});
OpDesc* fused_multi_transformer = block->AppendOp();
fused_multi_transformer->SetType("fused_multi_transformer");
fused_multi_transformer->SetInput("CacheKV", {fill0->Name(), fill1->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_cachekv_layout_trans_pass");
pass->Apply(graph.get());
auto fills = GetOpNodes(graph, "fill_constant");
auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList");
std::vector<std::string> expect_fill0_in_names{
"shape0", "shape3", "shape1", "shape2", "shape4"};
PADDLE_ENFORCE_EQ(fill0_in_names,
expect_fill0_in_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList");
std::vector<std::string> expect_fill1_in_names{
"shape5", "shape8", "shape6", "shape7", "shape9"};
PADDLE_ENFORCE_EQ(fill1_in_names,
expect_fill1_in_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
}
TEST(GatherReshapePass, basic) {
Layers layers;
auto* gather0_x = layers.data("gather0_x", {2, 1, 24, 512, 64});
auto* gather0_index = layers.data("gather0_index", {1});
auto* gather0_out = layers.gather(gather0_x, gather0_index, 1);
gather0_out->SetShape({2, 1, 24, 512, 64});
auto* gather1_x = layers.data("gather1_x", {2, 1, 24, 512, 64});
auto* gather1_index = layers.data("gather1_index", {1});
auto* gather1_out = layers.gather(gather1_x, gather1_index, 1);
gather1_out->SetShape({2, 1, 24, 512, 64});
auto* block = layers.Block();
OpDesc* fused_multi_transformer = block->AppendOp();
fused_multi_transformer->SetType("fused_multi_transformer");
fused_multi_transformer->SetInput("CacheKV",
{gather0_out->Name(), gather1_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_cachekv_layout_trans_pass");
pass->Apply(graph.get());
auto gathers = GetOpNodes(graph, "gather");
for (auto* gather : gathers) {
PADDLE_ENFORCE_EQ(
gather->Op()->GetAttrIfExists<int>("axis"),
2,
platform::errors::PreconditionNotMet(
"gather's axis attr should be updated to 2 by pass."));
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fused_multi_transformer_cachekv_layout_trans_pass);
......@@ -519,6 +519,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_cachekv_layout_trans_pass",
"one_beam_size_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
......
......@@ -291,31 +291,26 @@ void FusedMultiTransformerXpuInferMeta(
std::vector<MetaTensor*> cache_kv_out) {
auto x_dim = x.dims();
auto y_dim = qkvw[0]->dims();
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
phi::errors::InvalidArgument("The dimensions of x must be 3"
"(batch_size, seq_len, dim_embed),"
"but received dimensions of"
"Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(x_dim.size(),
3,
phi::errors::InvalidArgument(
"The dimensions of x must be 3(batch_size, seq_len, "
"dim_embed), but received dimensions of Input is [%d]",
x_dim.size()));
PADDLE_ENFORCE_EQ(
y_dim.size(),
4,
phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"but received dimensions of"
"Input is [%d]",
y_dim.size()));
phi::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4(3, num_head, dim_head, "
"dim_embed), but received dimensions of qkv_weight is [%d]",
y_dim.size()));
PADDLE_ENFORCE_EQ(
x_dim[2],
trans_qkvw ? y_dim[3] : y_dim[0],
phi::errors::InvalidArgument(
"ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is "
"true) or y_dim[0](trans_qkvw is false)"
"must be equal. But received: the shape "
"of input x = [%s], and the shape of "
"input qkv_weight = [%s]",
"The dimension of x_dim[2] and y_dim[3](trans_qkvw is true) or "
"y_dim[0](trans_qkvw is false) must be equal, but received: the "
"shape of input x = [%s], and the shape of input qkv_weight = [%s]",
x_dim,
y_dim));
if (cache_kv.size() > 0) {
......@@ -330,27 +325,27 @@ void FusedMultiTransformerXpuInferMeta(
phi::errors::InvalidArgument(
"The first dim of CacheKV must be 2, but got %d",
c_dim[0])); // 2
PADDLE_ENFORCE_EQ(c_dim[1],
x_dim[0],
phi::errors::InvalidArgument(
"The second dim of CacheKV must be equal with "
"batch size %d, but got %d",
x_dim[0],
c_dim[1])); // batch_size
PADDLE_ENFORCE_EQ(c_dim[2],
trans_qkvw ? y_dim[1] : y_dim[2],
phi::errors::InvalidArgument(
"The third dim of CacheKV must be equal with num "
"head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[2])); // num_head
PADDLE_ENFORCE_EQ(c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
phi::errors::InvalidArgument(
"The fifth dim of CacheKV must be equal with head "
"size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
PADDLE_ENFORCE_EQ(
c_dim[2],
x_dim[0],
phi::errors::InvalidArgument("The third dim of CacheKV must be equal "
"with batch size %d, but got %d",
x_dim[0],
c_dim[2])); // batch_size
PADDLE_ENFORCE_EQ(
c_dim[3],
trans_qkvw ? y_dim[1] : y_dim[2],
phi::errors::InvalidArgument("The fourth dim of CacheKV must be equal "
"with num head %d, but got %d",
trans_qkvw ? y_dim[1] : y_dim[2],
c_dim[3])); // num_head
PADDLE_ENFORCE_EQ(
c_dim[4],
trans_qkvw ? y_dim[2] : y_dim[3],
phi::errors::InvalidArgument("The fifth dim of CacheKV must be equal "
"with head size %d, but got %d",
trans_qkvw ? y_dim[2] : y_dim[3],
c_dim[4])); // head_size
}
out->set_dims(x_dim);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册