diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 1c24fa14b5fdc400943f71d39d66dca4e10cc0e4..dd78d0e8b3b8fa581059fb7e85c1cdc1c0232d2c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -252,6 +252,8 @@ if(WITH_XPU) xpu DEPS ${XPU_PASS_DEPS}) pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( @@ -536,4 +538,8 @@ if(WITH_XPU) test_multi_encoder_xpu_adaptive_seqlen_fuse_pass SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass) + cc_test( + test_fold_interp_outsize_fuse_pass + SRCS xpu/fold_interp_outsize_fuse_pass_test.cc + DEPS fold_interp_outsize_fuse_pass) endif() diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 9811ac01b0b5cf5e5de0c19bba83ef523f45eeba..ed52eb3190c50609c38547270213a1139800e22c 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -361,20 +361,31 @@ struct Layers { return outs; } - std::vector split(VarDesc* x, int num_or_section, int axis = 0) { - std::vector outs(num_or_section); - for (int i = 0; i < num_or_section; i++) { + std::vector split(VarDesc* x, + int num_or_section = 0, + int axis = 0, + std::vector sections = {-1}) { + int out_num = num_or_section; + if (num_or_section == 0) { + out_num = sections.size(); + } + std::vector outs(out_num); + for (int i = 0; i < out_num; i++) { outs[i] = lod_tensor(unique_name()); } - std::vector out_names(num_or_section); - for (int i = 0; i < num_or_section; i++) { + std::vector out_names(out_num); + for (int i = 0; i < out_num; i++) { out_names[i] = outs[i]->Name(); } OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("split"); op->SetInput("X", {x->Name()}); op->SetOutput("Out", out_names); - op->SetAttr("num_or_section", num_or_section); + if (num_or_section == 0) { + op->SetAttr("sections", sections); + } else { + op->SetAttr("num_or_section", num_or_section); + } op->SetAttr("axis", axis); op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a3db6c73be1e2e3894e0bec5431fd4ca19155d5 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.cc @@ -0,0 +1,229 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h" +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { +struct DetectorFusePattern : public PatternBase { + DetectorFusePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(shape); + PATTERN_DECL_NODE(cast1); + PATTERN_DECL_NODE(slice); + PATTERN_DECL_NODE(concat); + PATTERN_DECL_NODE(split); + PATTERN_DECL_NODE(cast2); + PATTERN_DECL_NODE(bilinear_interp); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(shape_out); + PATTERN_DECL_NODE(cast1_out); + PATTERN_DECL_NODE(slice_out); + PATTERN_DECL_NODE(concat_y); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(split_out_0); + PATTERN_DECL_NODE(split_out_1); + PATTERN_DECL_NODE(cast2_out); +}; + +DetectorFusePattern::DetectorFusePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("shape", "Input") + ->assert_is_op_input("bilinear_interp_v2", "X"); + auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape"); + auto* shape_out = pattern->NewNode(shape_out_repr()) + ->assert_is_op_output("shape", "Out") + ->assert_is_op_input("cast", "X"); + shape->LinksFrom({x}).LinksTo({shape_out}); + auto* cast1 = pattern->NewNode(cast1_repr()) + ->assert_is_op("cast") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("in_dtype") == 2 && + op_desc->GetAttrIfExists("out_dtype") == 3; + }); + auto* cast1_out = pattern->NewNode(cast1_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("slice", "Input"); + cast1->LinksFrom({shape_out}).LinksTo({cast1_out}); + auto* slice = + pattern->NewNode(slice_repr()) + ->assert_is_op("slice") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists>("axes") == + std::vector{0} && + op_desc->GetAttrIfExists>("starts") == + std::vector{0} && + op_desc->GetAttrIfExists>("ends") == + std::vector{2}; + }); + auto* slice_out = pattern->NewNode(slice_out_repr()) + ->assert_is_op_output("slice", "Out") + ->assert_is_op_nth_input("concat", "X", 0); + slice->LinksFrom({cast1_out}).LinksTo({slice_out}); + auto* concat = pattern->NewNode(concat_repr()) + ->assert_is_op("concat") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == 0; + }); + auto* concat_y = pattern->NewNode(concat_y_repr()) + ->assert_is_op_nth_input("concat", "X", 1) + ->assert_is_persistable_var(); + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("split", "X"); + concat->LinksFrom({slice_out, concat_y}).LinksTo({concat_out}); + auto* split = pattern->NewNode(split_repr()) + ->assert_is_op("split") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("axis") == 0 && + (op_desc->GetAttrIfExists>( + "sections") == std::vector{2, 2} || + op_desc->GetAttrIfExists("num") == 2); + }); + auto* split_out_0 = pattern->NewNode(split_out_0_repr()) + ->assert_is_op_nth_output("split", "Out", 0); + auto* split_out_1 = pattern->NewNode(split_out_1_repr()) + ->assert_is_op_nth_output("split", "Out", 1) + ->assert_is_op_input("cast", "X"); + split->LinksFrom({concat_out}).LinksTo({split_out_0, split_out_1}); + auto* cast2 = pattern->NewNode(cast2_repr()) + ->assert_is_op("cast") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + return op_desc->GetAttrIfExists("in_dtype") == 3 && + op_desc->GetAttrIfExists("out_dtype") == 2; + }); + auto* cast2_out = pattern->NewNode(cast2_out_repr()) + ->assert_is_op_output("cast", "Out") + ->assert_is_op_input("bilinear_interp_v2", "OutSize"); + cast2->LinksFrom({split_out_1}).LinksTo({cast2_out}); + auto* bilinear_interp = pattern->NewNode(bilinear_interp_repr()) + ->assert_is_op("bilinear_interp_v2"); + bilinear_interp->LinksFrom({x, cast2_out}); +} + +} // namespace patterns + +void FoldInterpOutsizeFusePass::DetectorFuse(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::DetectorFusePattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DetectorFuse"; + /* declare operator node's name */ + GET_IR_NODE(shape); + GET_IR_NODE(cast1); + GET_IR_NODE(slice); + GET_IR_NODE(concat); + GET_IR_NODE(split); + GET_IR_NODE(cast2); + GET_IR_NODE(bilinear_interp); + /* declare variable node's name*/ + GET_IR_NODE(x); + GET_IR_NODE(shape_out); + GET_IR_NODE(cast1_out); + GET_IR_NODE(slice_out); + GET_IR_NODE(concat_y); + GET_IR_NODE(concat_out); + GET_IR_NODE(split_out_0); + GET_IR_NODE(split_out_1); + GET_IR_NODE(cast2_out); + + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + + auto* concat_y_t = + scope->GetVar(concat_y->Name())->GetMutable(); + // concat_y int64 --> int32 + auto tensor_type = concat_y_t->dtype(); + if (tensor_type == phi::DataType::INT64) { + CastToInt32(concat_y_t, nullptr); + } + bilinear_interp->Op()->RenameInput(cast2_out->Name(), concat_y->Name()); + IR_NODE_UNLINK(x, shape); + IR_NODE_UNLINK(cast2_out, bilinear_interp); + IR_NODE_LINK_TO(concat_y, bilinear_interp); + // delete useless node + std::unordered_set delete_nodes = {shape, + cast1, + slice, + concat, + split, + cast2, + shape_out, + cast1_out, + slice_out, + concat_out, + split_out_0, + split_out_1, + cast2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +void FoldInterpOutsizeFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + DetectorFuse(graph); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fold_interp_outsize_fuse_pass, + paddle::framework::ir::FoldInterpOutsizeFusePass); + +REGISTER_PASS_CAPABILITY(fold_interp_outsize_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "shape", 0)); diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..08dc0fe7b73976b2e209ccf03e71bdd3d76ab8ef --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h @@ -0,0 +1,74 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class FoldInterpOutsizeFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + /* + Origin subgraph: + x + / \ + | shape + | | + | cast + | | + | slice + | | + | concat + | | + | split + | | \ + | | \ + | outvar_1 outvar_0 + | | + | cast + | / + \ / + bilinear_interp_v2 + + Fused subgraph: + x + | concat_y + | / + bilinear_interp_v2 + */ + void DetectorFuse(ir::Graph* graph) const; + + const std::string name_scope_{"fold_interp_outsize_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7836a27b4561c353274f7c34f62be601ed48373 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass_test.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(DetectorFuse, basic) { + Layers layers; + auto* block = layers.Block(); + + auto* shape_x = layers.data("shape_x", {1, 18, 288, 288}); + auto* concat_y = + layers.data("concat_y", {576, 576}, true, proto::VarType::INT64); + auto* shape_out = layers.shape(shape_x); + auto* cast1_out = layers.cast(shape_out, 2, 3); + auto* slice_out = layers.slice(cast1_out, {0}, {0}, {2}); + auto* concat_out = layers.concat({slice_out, concat_y}, 0); + auto split_outs = layers.split(concat_out, 0, 0, {2, 2}); + auto* split_out_1 = split_outs[1]; + auto* cast2_out = layers.cast(split_out_1, 3, 2); + + OpDesc* bilinear_interp_v2_op = block->AppendOp(); + bilinear_interp_v2_op->SetType("bilinear_interp_v2"); + bilinear_interp_v2_op->SetInput("X", {shape_x->Name()}); + bilinear_interp_v2_op->SetInput("OutSize", {cast2_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("fold_interp_outsize_fuse_pass"); + pass->Apply(graph.get()); + auto ops_num = GetNumOpNodes(graph); + PADDLE_ENFORCE_EQ( + ops_num, + 1, + platform::errors::PreconditionNotMet( + "graph should only have 2 op nodes, but received %d.", ops_num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(fold_interp_outsize_fuse_pass); diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index d075d42d29506893c2ae789275fa9b4ce0978fef..643e0e33744df7cfeae1071f795ace31035315ff 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -70,6 +70,39 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) { } } +void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) { + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + + phi::DenseTensor int32_tensor; + phi::DenseTensor* out_ptr = out == nullptr ? &int32_tensor : out; + out_ptr->Resize(in->dims()); + out_ptr->set_type(phi::DataType::INT32); + out_ptr->set_layout(in->layout()); + + switch (in->dtype()) { + case phi::DataType::INT64: + phi::CastKernel(*cpu_ctx, *in, phi::DataType::INT32, out_ptr); + break; + case phi::DataType::INT32: + if (out == nullptr) { + return; + } else { + phi::AssignKernel(*cpu_ctx, *in, out_ptr); + } + break; + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Only support int64 and int32, but received dtype is %s.", + phi::DataTypeToString(in->dtype()))); + break; + } + + if (out == nullptr) { + Assign(*out_ptr, in); + } +} + void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) { auto* cpu_ctx = static_cast( platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.h b/paddle/fluid/framework/ir/xpu/quant_utils.h index 85e9ddb11825d721c457dde0f939e2eaabea4ee5..b417fa03323db8ac0bd51b5c42b0f4e29870d2ed 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.h +++ b/paddle/fluid/framework/ir/xpu/quant_utils.h @@ -25,6 +25,8 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); +void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); + // 1. Quant weight from fp32 to int16/int31 // 2. Weight data is in-place update. // 3. Generate weight max tensor diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 41a07585984e5f8e538f2445272068c99fa922d4..bac89b17fae3cc54f2b955756551dcd4e9deb236 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -522,6 +522,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "multi_encoder_xpu_slice_fuse_pass", "fused_multi_transformer_cachekv_layout_trans_pass", "one_beam_size_fuse_pass", + "fold_interp_outsize_fuse_pass", "delete_cast_op_pass", "stack_fuse_pass", "fused_multi_transformer_xpu_pass",