diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 23665557576fff9efcadbf5e92cb65315bccb38d..c81616c0cede16adda44bfe6add091286709dc75 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -280,6 +280,7 @@ if(WITH_XPU) pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(gather_squeeze_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() diff --git a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc index 3300bbd08dffbbf3d8c7528816c9ad39c5ec2b83..508279b101692e7a750db660592855d661df7f1d 100644 --- a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc +++ b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc @@ -225,6 +225,17 @@ std::string GenAddAttrKey(Node* add_op_node) { return x_name + "_" + y_name + "_axis_" + std::to_string(axis); } +std::string GenTranspose2AttrKey(Node* transpose_op_node) { + auto transpose_op_desc = transpose_op_node->Op(); + auto axis = transpose_op_desc->GetAttrIfExists>("axis"); + std::string attr_key; + attr_key += "axis_"; + for (auto x : axis) { + attr_key += std::to_string(x) + "_"; + } + return attr_key; +} + std::string GenScaleAttrKey(Node* scale_op_node) { auto scale_op_desc = scale_op_node->Op(); auto scale = scale_op_desc->GetAttrIfExists("scale"); @@ -274,6 +285,7 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { DeleteRepeatedOps(graph, "gather", GenGatherAttrKey); DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey); DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey); + DeleteRepeatedOps(graph, "transpose2", GenTranspose2AttrKey); LOG(INFO) << "Round " << repeat_time++ << ": delete op counts: " << delete_op_count; total_delete_op_count += delete_op_count; diff --git a/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.cc b/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9bc9d2fe9229a994dabb02d06e54182e95a4f700 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.cc @@ -0,0 +1,173 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/gather_squeeze_pass.h" +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct GatherSqueeze : public PatternBase { + GatherSqueeze(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(gather); + PATTERN_DECL_NODE(squeeze2); + // declare variable node's name + PATTERN_DECL_NODE(gather_in); + PATTERN_DECL_NODE(gather_index); + PATTERN_DECL_NODE(gather_out); +}; // struct GatherSqueeze + +GatherSqueeze::GatherSqueeze(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* gather_in = pattern->NewNode(gather_in_repr()) + ->assert_is_op_input("gather", "X") + ->assert_more([](Node* node) { + for (auto* op : node->outputs) { + if (op->Op()->Type() != "gather") { + return false; + } + } + return node->outputs.size() >= 2 && + node->Var()->GetShape().size() >= 2 && + node->Var()->GetShape().size() < 5; + }); + auto* gather_index = pattern->NewNode(gather_index_repr()) + ->assert_is_op_input("gather", "Index") + ->assert_more([](Node* node) { + auto shape = node->Var()->GetShape(); + return shape.size() == 1 && shape[0] == 1; + }); + auto* gather = pattern->NewNode(gather_repr())->assert_is_op("gather"); + auto* gather_out = pattern->NewNode(gather_out_repr()) + ->assert_is_op_output("gather", "Out") + ->assert_is_op_input("squeeze2", "X"); + auto* squeeze2 = pattern->NewNode(squeeze2_repr())->assert_is_op("squeeze2"); + + gather->LinksFrom({gather_in, gather_index}).LinksTo({gather_out}); + gather_out->LinksTo({squeeze2}); +} + +} // namespace patterns + +void GatherSqueezePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + AddTranspose(graph); +} + +void GatherSqueezePass::AddTranspose(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::GatherSqueeze pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle GatherSqueezePass"; + GET_IR_NODE(gather); + GET_IR_NODE(gather_in); + GET_IR_NODE(gather_out); + GET_IR_NODE(squeeze2); + + bool flag = true; + auto var_dims = static_cast(gather_in->Var()->GetShape().size()); + auto gather_axis = gather->Op()->GetAttrIfExists("axis"); + auto squeeze_axis = + squeeze2->Op()->GetAttrIfExists>("axes"); + + flag = flag && gather_axis == var_dims - 1; + flag = flag && (squeeze_axis == std::vector{-1} || + squeeze_axis == std::vector{var_dims - 1}); + + if (flag) { + gather->Op()->SetAttr("axis", 0); + squeeze2->Op()->SetAttr("axes", std::vector{0}); + + std::string transpose2_out_name = + patterns::PDNodeName(name_scope_, "transpose2out"); + VarDesc transpose2_out_vardesc(transpose2_out_name); + OpDesc transpose2_op_desc(gather->Op()->Block()); + auto gather_in_shape = gather_in->Var()->GetShape(); + auto gather_out_shape = gather_out->Var()->GetShape(); + transpose2_out_vardesc.SetDataType(gather_in->Var()->GetDataType()); + + if (var_dims == 2) { + gather_out->Var()->SetShape({gather_out_shape[1], gather_out_shape[0]}); + transpose2_out_vardesc.SetShape( + {gather_in_shape[1], gather_in_shape[0]}); + transpose2_op_desc.SetAttr("axis", std::vector{1, 0}); + } else if (var_dims == 3) { + gather_out->Var()->SetShape( + {gather_out_shape[2], gather_out_shape[0], gather_out_shape[1]}); + transpose2_out_vardesc.SetShape( + {gather_in_shape[2], gather_in_shape[0], gather_in_shape[1]}); + transpose2_op_desc.SetAttr("axis", std::vector{2, 0, 1}); + } else { + gather_out->Var()->SetShape({gather_out_shape[3], + gather_out_shape[0], + gather_out_shape[1], + gather_out_shape[2]}); + transpose2_out_vardesc.SetShape({gather_in_shape[3], + gather_in_shape[0], + gather_in_shape[1], + gather_in_shape[2]}); + transpose2_op_desc.SetAttr("axis", std::vector{3, 0, 1, 2}); + } + + auto* transpose2_out = graph->CreateVarNode(&transpose2_out_vardesc); + + transpose2_op_desc.SetType("transpose2"); + transpose2_op_desc.SetInput("X", {gather_in->Name()}); + transpose2_op_desc.SetOutput("Out", {transpose2_out->Name()}); + auto* transpose2 = graph->CreateOpNode(&transpose2_op_desc); + + gather->Op()->SetInput("X", {transpose2_out->Name()}); + + IR_NODE_UNLINK(gather_in, gather); + IR_NODE_LINK_TO(gather_in, transpose2); + IR_NODE_LINK_TO(transpose2, transpose2_out); + IR_NODE_LINK_TO(transpose2_out, gather); + + found_subgraph_count++; + } + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(gather_squeeze_pass, paddle::framework::ir::GatherSqueezePass); + +REGISTER_PASS_CAPABILITY(gather_squeeze_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("gather", 1) + .EQ("squeeze2", 0)); diff --git a/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.h b/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..810663493fa0641f945a3182c01e25d787653205 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/gather_squeeze_pass.h @@ -0,0 +1,69 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class GatherSqueezePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + add transpose2 before gather + squeeze2 + For example: + graph: + x + | + gather (axis = -1) + | + squeeze2 (axis = -1) + | + output + ------------------------------------------------------ + After the pass is applied: + x + | + transpose2 (2, 0, 1) + | + gather (axis = 0) + | + squeeze2 (axis = 0) + | + output + */ + void AddTranspose(ir::Graph* graph) const; + + const std::string name_scope_{"gather_squeeze_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 15b04450b33bff5a9458c9b8e81bac329cfa2d33..b7e380ccaa45b9bb72c678de5c6bc41cea68f2fc 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -509,6 +509,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", + "gather_squeeze_pass", "delete_repeated_ops_pass", "identity_op_clean_pass", "fused_continuous_same_ops_pass", diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 10a61e474a2bec6a67342e0b8ac5d3704490065f..ad83ce9a39619e6def2dd7874b18c5c2c310eaab 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -26,7 +26,7 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"add_layernorm_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, - {"abs", XPUKernelSet({phi::DataType::FLOAT32})}, + {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, diff --git a/paddle/phi/kernels/xpu/abs_grad_kernel.cc b/paddle/phi/kernels/xpu/abs_grad_kernel.cc index b9fab28254d29c87f2e10c2100c4b418a97b8ae2..9345de9b22e8b60fb57cb3e9ade1d375fd0ae565 100644 --- a/paddle/phi/kernels/xpu/abs_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/abs_grad_kernel.cc @@ -26,16 +26,18 @@ void AbsGradKernel(const Context& ctx, const DenseTensor& dout, DenseTensor* dx) { ctx.template Alloc(dx); + using XPUType = typename XPUTypeTrait::Type; int r = xpu::abs_grad(ctx.x_context(), - x.data(), - dout.data(), - dout.data(), - dx->data(), + reinterpret_cast(x.data()), + reinterpret_cast(dout.data()), + reinterpret_cast(dout.data()), + reinterpret_cast(dx->data()), x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs_grad"); } } // namespace phi -PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) { +PD_REGISTER_KERNEL( + abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float, phi::dtype::float16) { kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } diff --git a/paddle/phi/kernels/xpu/abs_kernel.cc b/paddle/phi/kernels/xpu/abs_kernel.cc index 4213c92a1ebb43c92bdeb1d240141d6843ffe261..7abdd1f0715b6069935eb88301e6ac438b60be57 100644 --- a/paddle/phi/kernels/xpu/abs_kernel.cc +++ b/paddle/phi/kernels/xpu/abs_kernel.cc @@ -22,9 +22,14 @@ namespace phi { template void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { ctx.template Alloc(out); - int r = xpu::abs(ctx.x_context(), x.data(), out->data(), x.numel()); + using XPUType = typename XPUTypeTrait::Type; + int r = xpu::abs(ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs"); } } // namespace phi -PD_REGISTER_KERNEL(abs, XPU, ALL_LAYOUT, phi::AbsKernel, float) {} +PD_REGISTER_KERNEL( + abs, XPU, ALL_LAYOUT, phi::AbsKernel, float, phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py index b6f45c5841c0e8e63f7ea4cbb5fe0d7b2f852b62..90615678342c3d29fb73edad18477a34893ffcf1 100644 --- a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py +++ b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py @@ -727,5 +727,90 @@ class TestDeleteRepeatedGatherPass(PassAutoScanTest): ) +class TestDeleteRepeatedTransposePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['transpose2', 'relu', 'relu', 'relu'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=4)) + H = draw(st.integers(min_value=1, max_value=64)) + W = draw(st.integers(min_value=1, max_value=64)) + in_shape = [batch_size, H, W] + axis = [0, 2, 1] + + transpose_op0 = OpConfig( + type='transpose2', + inputs={ + "X": ["transpose_x"], + }, + outputs={"Out": ["transpose_output0"]}, + attrs={"axis": axis}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["transpose_output0"], + }, + outputs={"Out": ["relu0_out"]}, + ) + transpose_op1 = OpConfig( + type='transpose2', + inputs={ + "X": ["transpose_x"], + }, + outputs={"Out": ["transpose_output1"]}, + attrs={"axis": axis}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["transpose_output1"], + }, + outputs={"Out": ["relu1_out"]}, + ) + transpose_op2 = OpConfig( + type='transpose2', + inputs={ + "X": ["transpose_x"], + }, + outputs={"Out": ["transpose_output2"]}, + attrs={"axis": axis}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["transpose_output2"], + }, + outputs={"Out": ["relu2_out"]}, + ) + + ops = [ + transpose_op0, + relu_op0, + transpose_op1, + relu_op1, + transpose_op2, + relu_op2, + ] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "transpose_x": TensorConfig(shape=in_shape), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/ir/inference/test_xpu_gather_squeeze_pass.py b/test/ir/inference/test_xpu_gather_squeeze_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f90d3f6f5fda90de088de229a7945a6c1d3d57 --- /dev/null +++ b/test/ir/inference/test_xpu_gather_squeeze_pass.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestGatherAddTransposePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, [ + "transpose2", + "gather", + "transpose2", + "gather", + "squeeze2", + "squeeze2", + ], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=4), min_size=3, max_size=3 + ) + ) + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + def generate_index(*args, **kwargs): + return np.array([0]).astype(np.int64) + + axis = 2 + axes = [2] + gather_op0 = OpConfig( + "gather", + inputs={"X": ["gather_in"], "Index": ["gather_index0"]}, + outputs={"Out": ["gather_out0"]}, + axis=axis, + ) + + gather_op1 = OpConfig( + "gather", + inputs={"X": ["gather_in"], "Index": ["gather_index1"]}, + outputs={"Out": ["gather_out1"]}, + axis=axis, + ) + + squeeze_op0 = OpConfig( + "squeeze2", + inputs={ + "X": ["gather_out0"], + }, + outputs={"Out": ["squeeze_out0"]}, + axes=axes, + ) + + squeeze_op1 = OpConfig( + "squeeze2", + inputs={ + "X": ["gather_out1"], + }, + outputs={"Out": ["squeeze_out1"]}, + axes=axes, + ) + + ops = [gather_op0, gather_op1, squeeze_op0, squeeze_op1] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "gather_in": TensorConfig( + data_gen=partial(generate_data, x_shape) + ), + "gather_index0": TensorConfig(data_gen=partial(generate_index)), + "gather_index1": TensorConfig(data_gen=partial(generate_index)), + }, + weights={}, + outputs=["squeeze_out0", "squeeze_out1"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, max_examples=25, passes=["gather_squeeze_pass"] + ) + + +if __name__ == "__main__": + unittest.main()