From 61469eec0bee98e1bd65ba54e99fe39998ded605 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 17 Feb 2023 16:54:52 +0800 Subject: [PATCH] [XPU] add multi_encoder_xpu_slice_fuse_pass, generate_sequence_xpu_fuse_pass, generate_sequence_xpu kernel (#50570) --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/xpu/fc_xpu_fuse_pass.cc | 2 - .../ir/xpu/generate_sequence_xpu_fuse_pass.cc | 182 ++++++++++++++++++ .../xpu/multi_encoder_xpu_slice_fuse_pass.cc | 154 +++++++++++++++ .../inference/api/paddle_pass_builder.cc | 4 +- paddle/phi/api/yaml/static_ops.yaml | 9 + paddle/phi/backends/xpu/xpu1_op_list.cc | 6 + paddle/phi/backends/xpu/xpu2_op_list.cc | 6 + paddle/phi/infermeta/fusion.cc | 8 + paddle/phi/infermeta/fusion.h | 4 + .../xpu/generate_sequence_xpu_kernel.cc | 56 ++++++ ...est_xpu_generate_sequence_xpu_fuse_pass.py | 78 ++++++++ .../test_xpu_multi_encoder_xpu_fuse_pass.py | 5 +- ...t_xpu_multi_encoder_xpu_slice_fuse_pass.py | 52 +++++ 14 files changed, 563 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/generate_sequence_xpu_kernel.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_xpu_generate_sequence_xpu_fuse_pass.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_slice_fuse_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 98e164a1ef5..0c2a65878d6 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -224,6 +224,8 @@ if(WITH_XPU) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu) + pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 268e9e30137..c7cc1dfc07f 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -77,14 +77,12 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, ->assert_is_op_input(mul_type_, "Y") ->assert_is_persistable_var() ->assert_more([](Node* node) { - return true; return node->Var()->GetShape().size() == 2; }); auto* mul = pattern->NewNode(mul_repr()) ->assert_is_op(mul_type_) ->assert_more([](Node* node) { - return true; auto op_type = node->Op()->Type(); if (op_type == "matmul") { return !PADDLE_GET_CONST(bool, diff --git a/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc new file mode 100644 index 00000000000..ed17144b6b6 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct GenerateSequenceXPUPattern : public PatternBase { + GenerateSequenceXPUPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(fill_any_like); + PATTERN_DECL_NODE(cumsum); + PATTERN_DECL_NODE(elementwise_sub); + // declare variable node's name + PATTERN_DECL_NODE(fill_any_like_x); + PATTERN_DECL_NODE(fill_any_like_out); + PATTERN_DECL_NODE(cumsum_out); + PATTERN_DECL_NODE(elementwise_sub_out); +}; + +GenerateSequenceXPUPattern::GenerateSequenceXPUPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* fill_any_like_x = pattern->NewNode(fill_any_like_x_repr()) + ->assert_is_op_input("fill_any_like", "X") + ->assert_var_not_persistable() + ->assert_more([](Node* node) { + return node->Var()->GetShape().size() == 2; + }); + auto* fill_any_like = + pattern->NewNode(fill_any_like_repr()) + ->assert_is_op("fill_any_like") + ->assert_more([](Node* node) { + float value = PADDLE_GET_CONST(float, node->Op()->GetAttr("value")); + return static_cast(value) == 1; + }); + auto* fill_any_like_out = pattern->NewNode(fill_any_like_out_repr()) + ->assert_is_op_output("fill_any_like", "Out") + ->assert_is_op_input("cumsum", "X") + ->assert_is_op_input("elementwise_sub", "Y") + ->assert_var_not_persistable() + ->assert_has_n_outputs(2); + auto* cumsum = + pattern->NewNode(cumsum_repr()) + ->assert_is_op("cumsum") + ->assert_more([](Node* node) { + return !PADDLE_GET_CONST(bool, node->Op()->GetAttr("exclusive")) && + !PADDLE_GET_CONST(bool, node->Op()->GetAttr("reverse")) && + !PADDLE_GET_CONST(bool, node->Op()->GetAttr("flatten")) && + ((PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == 1) || + (PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1)); + }); + auto* cumsum_out = pattern->NewNode(cumsum_out_repr()) + ->assert_is_op_output("cumsum", "Out") + ->assert_is_op_input("elementwise_sub", "X") + ->assert_var_not_persistable() + ->assert_has_n_outputs(1); + auto* elementwise_sub = + pattern->NewNode(elementwise_sub_repr()) + ->assert_is_op("elementwise_sub") + ->assert_more([](Node* node) { + return PADDLE_GET_CONST(int, node->Op()->GetAttr("axis")) == -1; + }); + auto* elementwise_sub_out = + pattern->NewNode(elementwise_sub_out_repr()) + ->assert_is_op_output("elementwise_sub", "Out") + ->assert_var_not_persistable(); + fill_any_like->LinksFrom({fill_any_like_x}).LinksTo({fill_any_like_out}); + cumsum->LinksFrom({fill_any_like_out}).LinksTo({cumsum_out}); + elementwise_sub->LinksFrom({cumsum_out, fill_any_like_out}) + .LinksTo({elementwise_sub_out}); +} + +} // namespace patterns + +/* +Origin subgraph: + fill_any_like + / \ + | | + | cumsum + | | + \ / + elemetwise_sub + +Fused subgraph: + generate_sequence_xpu +*/ +class GenerateSequenceXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"generate_sequence_xpu_fuse_pass"}; +}; + +void GenerateSequenceXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::GenerateSequenceXPUPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle GenerateSequenceXPUFusePass fuse"; + GET_IR_NODE(fill_any_like); + GET_IR_NODE(cumsum); + GET_IR_NODE(elementwise_sub); + GET_IR_NODE(fill_any_like_x); + GET_IR_NODE(fill_any_like_out); + GET_IR_NODE(cumsum_out); + GET_IR_NODE(elementwise_sub_out); + + auto* block = fill_any_like->Op()->Block(); + framework::OpDesc op_desc(block); + op_desc.SetType("generate_sequence_xpu"); + op_desc.SetInput("x", {fill_any_like_x->Name()}); + op_desc.SetOutput("out", {elementwise_sub_out->Name()}); + op_desc.SetAttr( + "dtype", PADDLE_GET_CONST(int, fill_any_like->Op()->GetAttr("dtype"))); + auto* generate_sequence_xpu = graph->CreateOpNode(&op_desc); + IR_NODE_LINK_TO(fill_any_like, generate_sequence_xpu); + IR_NODE_LINK_TO(generate_sequence_xpu, elementwise_sub_out); + + // delete useless node + std::unordered_set delete_nodes{ + fill_any_like, fill_any_like_out, cumsum, cumsum_out, elementwise_sub}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(generate_sequence_xpu_fuse_pass, + paddle::framework::ir::GenerateSequenceXPUFusePass); + +REGISTER_PASS_CAPABILITY(generate_sequence_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "generate_sequence_xpu", 0)); diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc new file mode 100644 index 00000000000..64693ebd082 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc @@ -0,0 +1,154 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct MultiEncoderXPUSlicePattern : public PatternBase { + MultiEncoderXPUSlicePattern(PDPattern* pattern, + const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(multi_encoder_xpu); + PATTERN_DECL_NODE(slice); + // declare variable node's name + PATTERN_DECL_NODE(multi_encoder_xpu_out); + PATTERN_DECL_NODE(slice_out); +}; + +MultiEncoderXPUSlicePattern::MultiEncoderXPUSlicePattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* multi_encoder_xpu = + pattern->NewNode(multi_encoder_xpu_repr()) + ->assert_is_op("multi_encoder_xpu") + ->assert_more([](Node* node) { + return (!PADDLE_GET_CONST(bool, + node->Op()->GetAttr("norm_before"))) && + (PADDLE_GET_CONST(int, node->Op()->GetAttr("slice_idx")) == + -1); + }); + auto* multi_encoder_xpu_out = + pattern->NewNode(multi_encoder_xpu_out_repr()) + ->assert_is_op_output("multi_encoder_xpu", "out") + ->assert_is_op_input("slice", "Input") + ->assert_var_not_persistable() + ->assert_has_n_outputs(1); + auto* slice = + pattern->NewNode(slice_repr()) + ->assert_is_op("slice") + ->assert_more([](Node* node) { + std::vector axes = + PADDLE_GET_CONST(std::vector, node->Op()->GetAttr("axes")); + std::vector decrease_axis = PADDLE_GET_CONST( + std::vector, node->Op()->GetAttr("decrease_axis")); + std::vector starts = PADDLE_GET_CONST( + std::vector, node->Op()->GetAttr("starts")); + std::vector ends = + PADDLE_GET_CONST(std::vector, node->Op()->GetAttr("ends")); + return axes.size() == 1 && axes[0] == 1 && + decrease_axis.size() == 1 && decrease_axis[0] == 1 && + starts.size() == 1 && starts[0] == 0 && // + ends.size() == 1 && ends[0] == 1; + }); + auto* slice_out = pattern->NewNode(slice_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_var_not_persistable(); + multi_encoder_xpu->LinksTo({multi_encoder_xpu_out}); + slice->LinksFrom({multi_encoder_xpu_out}).LinksTo({slice_out}); +} + +} // namespace patterns + +/* +Origin subgraph: + multi_encoder_xpu + | + slice + +Fused subgraph: + multi_encoder_xpu +*/ +class MultiEncoderXPUSliceFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"multi_encoder_xpu_slice_fuse_pass"}; +}; + +void MultiEncoderXPUSliceFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::MultiEncoderXPUSlicePattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle MultiEncoderXPUSliceFusePass fuse"; + GET_IR_NODE(multi_encoder_xpu); + GET_IR_NODE(slice); + GET_IR_NODE(multi_encoder_xpu_out); + GET_IR_NODE(slice_out); + + auto* op_desc = multi_encoder_xpu->Op(); + op_desc->SetOutput("out", {slice_out->Var()->Name()}); + op_desc->SetAttr("slice_idx", static_cast(0)); + IR_NODE_LINK_TO(multi_encoder_xpu, slice_out); + + // delete useless node + std::unordered_set delete_nodes{multi_encoder_xpu_out, slice}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(multi_encoder_xpu_slice_fuse_pass, + paddle::framework::ir::MultiEncoderXPUSliceFusePass); + +REGISTER_PASS_CAPABILITY(multi_encoder_xpu_slice_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "multi_encoder_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index ed422d9aded..2b0e8e430d0 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -517,11 +517,11 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ "delete_dropout_op_pass", + "generate_sequence_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", + "multi_encoder_xpu_slice_fuse_pass", // "embedding_with_eltwise_add_xpu_fuse_pass", "fc_xpu_fuse_pass", - // "multi_encoder_slice_link_xpu_fuse_pass", - // "generate_sequence_xpu_fuse_pass", // "link_previous_out_max_xpu_pass", }); use_xpu_ = true; diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 9a3626cabd6..9c9e5c0c8b2 100644 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -8,6 +8,15 @@ data_type : x optional : bias +- op : generate_sequence_xpu + args : (Tensor x, DataType dtype) + output : Tensor + infer_meta : + func : GenerateSequenceXPUInferMeta + kernel : + func : generate_sequence_xpu + data_type : dtype + - op : multi_encoder_xpu args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) diff --git a/paddle/phi/backends/xpu/xpu1_op_list.cc b/paddle/phi/backends/xpu/xpu1_op_list.cc index 6b8f9b47011..f40daae0c5d 100644 --- a/paddle/phi/backends/xpu/xpu1_op_list.cc +++ b/paddle/phi/backends/xpu/xpu1_op_list.cc @@ -106,6 +106,12 @@ XPUOpMap& get_kl1_ops() { {"gaussian_random", XPUKernelSet({phi::DataType::FLOAT32})}, {"gelu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"gelu", XPUKernelSet({phi::DataType::FLOAT32})}, + {"generate_sequence_xpu", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + })}, {"hard_switch_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_switch", XPUKernelSet({phi::DataType::FLOAT32})}, {"iou_similarity", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index bc883bc6e32..3450f5bc04a 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -320,6 +320,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"gelu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"generate_proposals_v2", XPUKernelSet({phi::DataType::FLOAT32})}, + {"generate_sequence_xpu", + XPUKernelSet({ + phi::DataType::FLOAT32, + phi::DataType::INT32, + phi::DataType::INT64, + })}, {"grad_add", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"greater_equal", diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 863a18ad848..699fcc830ce 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -42,6 +42,14 @@ void FcXPUInferMeta(const MetaTensor& x, out->set_layout(x.layout()); } +void GenerateSequenceXPUInferMeta(const MetaTensor& x, + DataType dtype, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_dtype(dtype); + out->set_layout(x.layout()); +} + void MultiEncoderXPUInferMeta( const MetaTensor& x, const std::vector& fc_weight, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 5806ecf640c..7848bb40f00 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -34,6 +34,10 @@ void FcXPUInferMeta(const MetaTensor& x, float act_alpha, MetaTensor* out); +void GenerateSequenceXPUInferMeta(const MetaTensor& x, + DataType dtype, + MetaTensor* out); + void MultiEncoderXPUInferMeta( const MetaTensor& x, const std::vector& fc_weight, diff --git a/paddle/phi/kernels/fusion/xpu/generate_sequence_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/generate_sequence_xpu_kernel.cc new file mode 100644 index 00000000000..21117938fbd --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/generate_sequence_xpu_kernel.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void GenerateSequenceXPU(const Context& ctx, + const DenseTensor& x, + DataType dtype, + DenseTensor* out) { + auto x_dims = x.dims(); + int batch = x_dims[0]; + int step = x_dims[1]; + + DenseTensor out_host; + out_host.Resize(x_dims); + out_host.set_type(dtype); + T* out_host_data = ctx.template HostAlloc(&out_host); + for (int i = 0; i < step; i++) { + out_host_data[i] = static_cast(i); + } + for (int i = 1; i < batch; i++) { + std::memcpy(out_host_data + i * step, out_host_data, step * sizeof(T)); + } + + ctx.template Alloc(out); + phi::Copy(ctx, out_host, out->place(), true, out); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(generate_sequence_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::GenerateSequenceXPU, + float, + int, + int64_t) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_generate_sequence_xpu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_generate_sequence_xpu_fuse_pass.py new file mode 100644 index 00000000000..6552883eaad --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_generate_sequence_xpu_fuse_pass.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestGenerateSequenceXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["generate_sequence_xpu"], (1e-5, 1e-5) + + def sample_program_config(self, draw): + fill_any_like_x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=32), min_size=2, max_size=2 + ) + ) + fill_any_like_dtype = draw(st.sampled_from([2, 3, 5])) + + fill_any_like_op = OpConfig( + "fill_any_like", + inputs={"X": ["fill_any_like_x"]}, + outputs={"Out": ["fill_any_like_out"]}, + dtype=fill_any_like_dtype, + value=1.0, + ) + cumsum_op = OpConfig( + "cumsum", + inputs={"X": ["fill_any_like_out"]}, + outputs={"Out": ["cumsum_out"]}, + axis=1, + exclusive=False, + flatten=False, + reverse=False, + ) + elementwise_sub_op = OpConfig( + "elementwise_sub", + inputs={"X": ["cumsum_out"], "Y": ["fill_any_like_out"]}, + outputs={"Out": ["elementwise_sub_out"]}, + axis=-1, + ) + + ops = [fill_any_like_op, cumsum_op, elementwise_sub_op] + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "fill_any_like_x": TensorConfig(shape=fill_any_like_x_shape), + }, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["generate_sequence_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py index 2bc46ff6845..a43fb2e3839 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_fuse_pass.py @@ -25,7 +25,7 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest): config = self.create_inference_config(use_xpu=True) yield config, ["multi_encoder_xpu"], (1e-1, 1e-1) - def sample_program_config(self, draw): + def multi_encoder_xpu_program_config(self, draw): # q: matmul+add+reshape+transpose q_matmul_op = OpConfig( "matmul_v2", @@ -325,6 +325,9 @@ class TestMultiEncoderXPUFusePass(PassAutoScanTest): ) return program_config + def sample_program_config(self, draw): + return self.multi_encoder_xpu_program_config(draw) + def test(self): self.run_and_statis( quant=False, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_slice_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_slice_fuse_pass.py new file mode 100644 index 00000000000..7f32ca416a1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_multi_encoder_xpu_slice_fuse_pass.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from program_config import OpConfig +from test_xpu_multi_encoder_xpu_fuse_pass import TestMultiEncoderXPUFusePass + + +class TestMultiEncoderXPUFusePass(TestMultiEncoderXPUFusePass): + def sample_program_config(self, draw): + slice_op = OpConfig( + "slice", + inputs={"Input": ["ln_2_out"]}, + outputs={"Out": ["slice_out"]}, + axes=[1], + decrease_axis=[1], + starts=[0], + ends=[1], + ) + program_config = self.multi_encoder_xpu_program_config(draw) + program_config.ops.append(slice_op) + program_config.outputs = ["slice_out"] + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=2, + min_success_num=2, + passes=[ + "multi_encoder_xpu_fuse_pass", + "multi_encoder_xpu_slice_fuse_pass", + ], + ) + + +if __name__ == "__main__": + np.random.seed(200) + unittest.main() -- GitLab