未验证 提交 61469eec 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add multi_encoder_xpu_slice_fuse_pass, generate_sequence_xpu_fuse_pass,...

[XPU] add multi_encoder_xpu_slice_fuse_pass, generate_sequence_xpu_fuse_pass, generate_sequence_xpu kernel (#50570)
上级 60318f0d
......@@ -224,6 +224,8 @@ if(WITH_XPU)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu)
endif()
cc_library(
......
......@@ -77,14 +77,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
->assert_is_op_input(mul_type_, "Y")
->assert_is_persistable_var()
->assert_more([](Node* node) {
return true;
return node->Var()->GetShape().size() == 2;
});
auto* mul =
pattern->NewNode(mul_repr())
->assert_is_op(mul_type_)
->assert_more([](Node* node) {
return true;
auto op_type = node->Op()->Type();
if (op_type == "matmul") {
return !PADDLE_GET_CONST(bool,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct GenerateSequenceXPUPattern : public PatternBase {
GenerateSequenceXPUPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(fill_any_like);
PATTERN_DECL_NODE(cumsum);
PATTERN_DECL_NODE(elementwise_sub);
// declare variable node's name
PATTERN_DECL_NODE(fill_any_like_x);
PATTERN_DECL_NODE(fill_any_like_out);
PATTERN_DECL_NODE(cumsum_out);
PATTERN_DECL_NODE(elementwise_sub_out);
};
GenerateSequenceXPUPattern::GenerateSequenceXPUPattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fill_any_like_x = pattern->NewNode(fill_any_like_x_repr())
->assert_is_op_input("fill_any_like", "X")
->assert_var_not_persistable()
->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2;
});
auto* fill_any_like =
pattern->NewNode(fill_any_like_repr())
->assert_is_op("fill_any_like")
->assert_more([](Node* node) {
float value = PADDLE_GET_CONST(float, node->Op()->GetAttr("value"));
return static_cast<int>(value) == 1;
});
auto* fill_any_like_out = pattern->NewNode(fill_any_like_out_repr())
->assert_is_op_output("fill_any_like", "Out")
->assert_is_op_input("cumsum", "X")
->assert_is_op_input("elementwise_sub", "Y")
->assert_var_not_persistable()
->assert_has_n_outputs(2);
auto* cumsum =
pattern->NewNode(cumsum_repr())
->assert_is_op("cumsum")
->assert_more([](Node* node) {
return !PADDLE_GET_CONST(bool, node->Op()->GetAttr("exclusive")) &&
!PADDLE_GET_CONST(bool, node->Op()->GetAttr("reverse")) &&
!PADDLE_GET_CONST(bool, node->Op()->GetAttr("flatten")) &&
((PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == 1) ||
(PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1));
});
auto* cumsum_out = pattern->NewNode(cumsum_out_repr())
->assert_is_op_output("cumsum", "Out")
->assert_is_op_input("elementwise_sub", "X")
->assert_var_not_persistable()
->assert_has_n_outputs(1);
auto* elementwise_sub =
pattern->NewNode(elementwise_sub_repr())
->assert_is_op("elementwise_sub")
->assert_more([](Node* node) {
return PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1;
});
auto* elementwise_sub_out =
pattern->NewNode(elementwise_sub_out_repr())
->assert_is_op_output("elementwise_sub", "Out")
->assert_var_not_persistable();
fill_any_like->LinksFrom({fill_any_like_x}).LinksTo({fill_any_like_out});
cumsum->LinksFrom({fill_any_like_out}).LinksTo({cumsum_out});
elementwise_sub->LinksFrom({cumsum_out, fill_any_like_out})
.LinksTo({elementwise_sub_out});
}
} // namespace patterns
/*
Origin subgraph:
fill_any_like
/ \
| |
| cumsum
| |
\ /
elemetwise_sub
Fused subgraph:
generate_sequence_xpu
*/
class GenerateSequenceXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"generate_sequence_xpu_fuse_pass"};
};
void GenerateSequenceXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::GenerateSequenceXPUPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GenerateSequenceXPUFusePass fuse";
GET_IR_NODE(fill_any_like);
GET_IR_NODE(cumsum);
GET_IR_NODE(elementwise_sub);
GET_IR_NODE(fill_any_like_x);
GET_IR_NODE(fill_any_like_out);
GET_IR_NODE(cumsum_out);
GET_IR_NODE(elementwise_sub_out);
auto* block = fill_any_like->Op()->Block();
framework::OpDesc op_desc(block);
op_desc.SetType("generate_sequence_xpu");
op_desc.SetInput("x", {fill_any_like_x->Name()});
op_desc.SetOutput("out", {elementwise_sub_out->Name()});
op_desc.SetAttr(
"dtype", PADDLE_GET_CONST(int, fill_any_like->Op()->GetAttr("dtype")));
auto* generate_sequence_xpu = graph->CreateOpNode(&op_desc);
IR_NODE_LINK_TO(fill_any_like, generate_sequence_xpu);
IR_NODE_LINK_TO(generate_sequence_xpu, elementwise_sub_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes{
fill_any_like, fill_any_like_out, cumsum, cumsum_out, elementwise_sub};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(generate_sequence_xpu_fuse_pass,
paddle::framework::ir::GenerateSequenceXPUFusePass);
REGISTER_PASS_CAPABILITY(generate_sequence_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"generate_sequence_xpu", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct MultiEncoderXPUSlicePattern : public PatternBase {
MultiEncoderXPUSlicePattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(multi_encoder_xpu);
PATTERN_DECL_NODE(slice);
// declare variable node's name
PATTERN_DECL_NODE(multi_encoder_xpu_out);
PATTERN_DECL_NODE(slice_out);
};
MultiEncoderXPUSlicePattern::MultiEncoderXPUSlicePattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* multi_encoder_xpu =
pattern->NewNode(multi_encoder_xpu_repr())
->assert_is_op("multi_encoder_xpu")
->assert_more([](Node* node) {
return (!PADDLE_GET_CONST(bool,
node->Op()->GetAttr("norm_before"))) &&
(PADDLE_GET_CONST(int, node->Op()->GetAttr("slice_idx")) ==
-1);
});
auto* multi_encoder_xpu_out =
pattern->NewNode(multi_encoder_xpu_out_repr())
->assert_is_op_output("multi_encoder_xpu", "out")
->assert_is_op_input("slice", "Input")
->assert_var_not_persistable()
->assert_has_n_outputs(1);
auto* slice =
pattern->NewNode(slice_repr())
->assert_is_op("slice")
->assert_more([](Node* node) {
std::vector<int> axes =
PADDLE_GET_CONST(std::vector<int>, node->Op()->GetAttr("axes"));
std::vector<int> decrease_axis = PADDLE_GET_CONST(
std::vector<int>, node->Op()->GetAttr("decrease_axis"));
std::vector<int> starts = PADDLE_GET_CONST(
std::vector<int>, node->Op()->GetAttr("starts"));
std::vector<int> ends =
PADDLE_GET_CONST(std::vector<int>, node->Op()->GetAttr("ends"));
return axes.size() == 1 && axes[0] == 1 &&
decrease_axis.size() == 1 && decrease_axis[0] == 1 &&
starts.size() == 1 && starts[0] == 0 && //
ends.size() == 1 && ends[0] == 1;
});
auto* slice_out = pattern->NewNode(slice_out_repr())
->assert_is_op_output("slice", "Out")
->assert_var_not_persistable();
multi_encoder_xpu->LinksTo({multi_encoder_xpu_out});
slice->LinksFrom({multi_encoder_xpu_out}).LinksTo({slice_out});
}
} // namespace patterns
/*
Origin subgraph:
multi_encoder_xpu
|
slice
Fused subgraph:
multi_encoder_xpu
*/
class MultiEncoderXPUSliceFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"};
};
void MultiEncoderXPUSliceFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::MultiEncoderXPUSlicePattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle MultiEncoderXPUSliceFusePass fuse";
GET_IR_NODE(multi_encoder_xpu);
GET_IR_NODE(slice);
GET_IR_NODE(multi_encoder_xpu_out);
GET_IR_NODE(slice_out);
auto* op_desc = multi_encoder_xpu->Op();
op_desc->SetOutput("out", {slice_out->Var()->Name()});
op_desc->SetAttr("slice_idx", static_cast<int>(0));
IR_NODE_LINK_TO(multi_encoder_xpu, slice_out);
// delete useless node
std::unordered_set<const Node*> delete_nodes{multi_encoder_xpu_out, slice};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_encoder_xpu_slice_fuse_pass,
paddle::framework::ir::MultiEncoderXPUSliceFusePass);
REGISTER_PASS_CAPABILITY(multi_encoder_xpu_slice_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"multi_encoder_xpu", 0));
......@@ -517,11 +517,11 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"generate_sequence_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
// "embedding_with_eltwise_add_xpu_fuse_pass",
"fc_xpu_fuse_pass",
// "multi_encoder_slice_link_xpu_fuse_pass",
// "generate_sequence_xpu_fuse_pass",
// "link_previous_out_max_xpu_pass",
});
use_xpu_ = true;
......
......@@ -8,6 +8,15 @@
data_type : x
optional : bias
- op : generate_sequence_xpu
args : (Tensor x, DataType dtype)
output : Tensor
infer_meta :
func : GenerateSequenceXPUInferMeta
kernel :
func : generate_sequence_xpu
data_type : dtype
- op : multi_encoder_xpu
args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx)
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
......
......@@ -106,6 +106,12 @@ XPUOpMap& get_kl1_ops() {
{"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})},
{"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32})},
{"generate_sequence_xpu",
XPUKernelSet({
phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
})},
{"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})},
{"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})},
......
......@@ -320,6 +320,12 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"generate_proposals_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"generate_sequence_xpu",
XPUKernelSet({
phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::INT64,
})},
{"grad_add",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"greater_equal",
......
......@@ -42,6 +42,14 @@ void FcXPUInferMeta(const MetaTensor& x,
out->set_layout(x.layout());
}
void GenerateSequenceXPUInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(dtype);
out->set_layout(x.layout());
}
void MultiEncoderXPUInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& fc_weight,
......
......@@ -34,6 +34,10 @@ void FcXPUInferMeta(const MetaTensor& x,
float act_alpha,
MetaTensor* out);
void GenerateSequenceXPUInferMeta(const MetaTensor& x,
DataType dtype,
MetaTensor* out);
void MultiEncoderXPUInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& fc_weight,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void GenerateSequenceXPU(const Context& ctx,
const DenseTensor& x,
DataType dtype,
DenseTensor* out) {
auto x_dims = x.dims();
int batch = x_dims[0];
int step = x_dims[1];
DenseTensor out_host;
out_host.Resize(x_dims);
out_host.set_type(dtype);
T* out_host_data = ctx.template HostAlloc<T>(&out_host);
for (int i = 0; i < step; i++) {
out_host_data[i] = static_cast<T>(i);
}
for (int i = 1; i < batch; i++) {
std::memcpy(out_host_data + i * step, out_host_data, step * sizeof(T));
}
ctx.template Alloc<T>(out);
phi::Copy(ctx, out_host, out->place(), true, out);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(generate_sequence_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::GenerateSequenceXPU,
float,
int,
int64_t) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestGenerateSequenceXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["generate_sequence_xpu"], (1e-5, 1e-5)
def sample_program_config(self, draw):
fill_any_like_x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=32), min_size=2, max_size=2
)
)
fill_any_like_dtype = draw(st.sampled_from([2, 3, 5]))
fill_any_like_op = OpConfig(
"fill_any_like",
inputs={"X": ["fill_any_like_x"]},
outputs={"Out": ["fill_any_like_out"]},
dtype=fill_any_like_dtype,
value=1.0,
)
cumsum_op = OpConfig(
"cumsum",
inputs={"X": ["fill_any_like_out"]},
outputs={"Out": ["cumsum_out"]},
axis=1,
exclusive=False,
flatten=False,
reverse=False,
)
elementwise_sub_op = OpConfig(
"elementwise_sub",
inputs={"X": ["cumsum_out"], "Y": ["fill_any_like_out"]},
outputs={"Out": ["elementwise_sub_out"]},
axis=-1,
)
ops = [fill_any_like_op, cumsum_op, elementwise_sub_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"fill_any_like_x": TensorConfig(shape=fill_any_like_x_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["generate_sequence_xpu_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
......@@ -25,7 +25,7 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest):
config = self.create_inference_config(use_xpu=True)
yield config, ["multi_encoder_xpu"], (1e-1, 1e-1)
def sample_program_config(self, draw):
def multi_encoder_xpu_program_config(self, draw):
# q: matmul+add+reshape+transpose
q_matmul_op = OpConfig(
"matmul_v2",
......@@ -325,6 +325,9 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest):
)
return program_config
def sample_program_config(self, draw):
return self.multi_encoder_xpu_program_config(draw)
def test(self):
self.run_and_statis(
quant=False,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from program_config import OpConfig
from test_xpu_multi_encoder_xpu_fuse_pass import TestMultiEncoderXPUFusePass
class TestMultiEncoderXPUFusePass(TestMultiEncoderXPUFusePass):
def sample_program_config(self, draw):
slice_op = OpConfig(
"slice",
inputs={"Input": ["ln_2_out"]},
outputs={"Out": ["slice_out"]},
axes=[1],
decrease_axis=[1],
starts=[0],
ends=[1],
)
program_config = self.multi_encoder_xpu_program_config(draw)
program_config.ops.append(slice_op)
program_config.outputs = ["slice_out"]
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=2,
min_success_num=2,
passes=[
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
],
)
if __name__ == "__main__":
np.random.seed(200)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册