diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ea35664abb5526f8014ac9091d9c47f199e94d9c..4faa9cd2183a46d450d85392bdb34e999bb50d45 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -129,6 +129,7 @@ pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(delete_cast_op_pass inference) pass_library(delete_elementwise_mul_op_pass inference) pass_library(delete_repeated_ops_pass inference) +pass_library(fused_continuous_same_ops_pass inference) pass_library(sigmoid_elementmul_fuse_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc index 13393ec1b6895245348ccf4238220749e6cb2380..3300bbd08dffbbf3d8c7528816c9ad39c5ec2b83 100644 --- a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc +++ b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc @@ -101,18 +101,18 @@ class DeleteRepeatedOpsPass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; private: - void DeleteRepeatedOps( - ir::Graph* graph, - const std::string& op_type, - std::function gen_op_key_fn) const; + void DeleteRepeatedOps(ir::Graph* graph, + const std::string& op_type, + std::function gen_op_key_fn) const; const std::string name_scope_{"delete_repeated_ops_pass"}; + mutable int delete_op_count{0}; }; void DeleteRepeatedOpsPass::DeleteRepeatedOps( ir::Graph* graph, const std::string& op_type, - std::function gen_op_key_fn) const { + std::function gen_op_key_fn) const { GraphPatternDetector gpd; patterns::VarWithRepeatedOpsPattern pattern( gpd.mutable_pattern(), name_scope_, op_type); @@ -140,7 +140,7 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps( } } if (out_op_is_invalid) continue; - auto attr_key = gen_op_key_fn(op->Op()); + auto attr_key = gen_op_key_fn(op); ops_map[attr_key].push_back(op); } for (auto iter = ops_map.begin(); iter != ops_map.end();) { @@ -173,16 +173,18 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps( }; gpd(graph, handler); + delete_op_count += delete_counts; if (delete_counts > 0) { LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type << " ops"; } } -std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; } +std::string GenShapeAttrKey(Node* shape_op_node) { return ""; } -std::string GenSliceAttrKey(OpDesc* slice_op_desc) { +std::string GenSliceAttrKey(Node* slice_op_node) { std::string attr_key; + auto slice_op_desc = slice_op_node->Op(); auto starts = slice_op_desc->GetAttrIfExists>("starts"); auto ends = slice_op_desc->GetAttrIfExists>("ends"); auto axes = slice_op_desc->GetAttrIfExists>("axes"); @@ -207,21 +209,24 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) { return attr_key; } -std::string GenCastAttrKey(OpDesc* cast_op_desc) { +std::string GenCastAttrKey(Node* cast_op_node) { + auto cast_op_desc = cast_op_node->Op(); auto in_dtype = cast_op_desc->GetAttrIfExists("in_dtype"); auto out_dtype = cast_op_desc->GetAttrIfExists("out_dtype"); return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" + std::to_string(out_dtype); } -std::string GenAddAttrKey(OpDesc* add_op_desc) { +std::string GenAddAttrKey(Node* add_op_node) { + auto add_op_desc = add_op_node->Op(); std::string x_name = add_op_desc->Input("X")[0]; std::string y_name = add_op_desc->Input("Y")[0]; auto axis = add_op_desc->GetAttrIfExists("axis"); return x_name + "_" + y_name + "_axis_" + std::to_string(axis); } -std::string GenScaleAttrKey(OpDesc* scale_op_desc) { +std::string GenScaleAttrKey(Node* scale_op_node) { + auto scale_op_desc = scale_op_node->Op(); auto scale = scale_op_desc->GetAttrIfExists("scale"); auto bias = scale_op_desc->GetAttrIfExists("bias"); auto bias_after_scale = @@ -230,17 +235,53 @@ std::string GenScaleAttrKey(OpDesc* scale_op_desc) { "_bias_after_scale_" + std::to_string(bias_after_scale); } +std::string GenGatherAttrKey(Node* gather_op_node) { + std::string input_names{""}; + for (auto input_var : gather_op_node->inputs) { + input_names += input_var->Var()->Name(); + } + auto gather_op_desc = gather_op_node->Op(); + auto axis = gather_op_desc->GetAttrIfExists("axis"); + return "axis_" + std::to_string(axis) + "_input_names_" + input_names; +} + +std::string GenSqueeze2AttrKey(Node* squeeze2_op_node) { + auto squeeze2_op_desc = squeeze2_op_node->Op(); + auto axes = squeeze2_op_desc->GetAttrIfExists>("axes"); + std::string attr_key{""}; + attr_key += "axes_"; + for (auto axis : axes) { + attr_key += std::to_string(axis) + "_"; + } + return attr_key; +} + void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - - DeleteRepeatedOps(graph, "shape", GenShapeAttrKey); - DeleteRepeatedOps(graph, "slice", GenSliceAttrKey); - DeleteRepeatedOps(graph, "cast", GenCastAttrKey); - DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey); - DeleteRepeatedOps(graph, "scale", GenScaleAttrKey); - DeleteRepeatedOps(graph, "cast", GenCastAttrKey); + int repeat_time = 0; + int total_delete_op_count = 0; + // This pass needs to loop run until there are no nodes in the graph that need + // to be deleted. + while (true) { + delete_op_count = 0; + DeleteRepeatedOps(graph, "shape", GenShapeAttrKey); + DeleteRepeatedOps(graph, "slice", GenSliceAttrKey); + DeleteRepeatedOps(graph, "cast", GenCastAttrKey); + DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey); + DeleteRepeatedOps(graph, "scale", GenScaleAttrKey); + DeleteRepeatedOps(graph, "gather", GenGatherAttrKey); + DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey); + DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey); + LOG(INFO) << "Round " << repeat_time++ + << ": delete op counts: " << delete_op_count; + total_delete_op_count += delete_op_count; + if (delete_op_count == 0) { + break; // No node need to delete. + } + } + LOG(INFO) << "Total delete op counts: " << total_delete_op_count; } } // namespace ir diff --git a/paddle/fluid/framework/ir/fused_continuous_same_ops_pass.cc b/paddle/fluid/framework/ir/fused_continuous_same_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..62b043bf3294b0d97c8f0cbded8e4c100886f228 --- /dev/null +++ b/paddle/fluid/framework/ir/fused_continuous_same_ops_pass.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +namespace patterns { + +struct ContinuousSameOpsPattern : public PatternBase { + ContinuousSameOpsPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type); + PATTERN_DECL_NODE(first_in_var_node); + PATTERN_DECL_NODE(first_out_var_node); + PATTERN_DECL_NODE(second_out_var_node); + // declare op node's name + PATTERN_DECL_NODE(first_op_node); + PATTERN_DECL_NODE(second_op_node); + std::string op_type_; +}; + +ContinuousSameOpsPattern::ContinuousSameOpsPattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type) + : PatternBase(pattern, name_scope, name_scope), op_type_(op_type) { + auto* first_in_var_node = + pattern->NewNode(first_in_var_node_repr()) + ->assert_var_not_persistable() + ->assert_is_op_input(op_type_, "X") + ->AsInput() + ->assert_more([&](Node* node) { + // assert pre op type is not same. + auto input_nodes = node->inputs; + if (input_nodes.size() != 1) return false; + if (!input_nodes.empty() && input_nodes[0]->IsOp() && + input_nodes[0]->Op()->Type() == op_type_) { + return false; + } + return true; + }); + auto* first_op_node = + pattern->NewNode(first_op_node_repr())->assert_is_op(op_type_); + auto* first_out_var_node = pattern->NewNode(first_out_var_node_repr()) + ->assert_var_not_persistable() + ->assert_is_op_output(op_type_, "Out") + ->assert_has_n_outputs(1); + first_op_node->LinksFrom({first_in_var_node}).LinksTo({first_out_var_node}); + auto* second_op_node = + pattern->NewNode(second_op_node_repr())->assert_is_op(op_type_); + auto* second_out_var_node = pattern->NewNode(second_out_var_node_repr()) + ->assert_var_not_persistable() + ->assert_is_op_output(op_type_, "Out") + ->AsOutput(); + second_op_node->LinksFrom({first_out_var_node}) + .LinksTo({second_out_var_node}); +} + +} // namespace patterns + +/* +Fused continuous same ops into one. +Origin graph: + input + | + | + unsqueeze2 + | + | + unsqueeze2 + | + | + unsqueeze2 + | + | + out + +After: + input + | + | + unsqueeze2 + | + | + out +*/ + +class FusedContinuousSameOpsPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FusedReshapeOps(ir::Graph* graph) const; + void FusedUnsqueezeOps(ir::Graph* graph) const; + + const std::string name_scope_{"fused_continuous_same_ops_pass"}; + mutable int delete_op_count{0}; +}; + +void FusedContinuousSameOpsPass::FusedReshapeOps(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ContinuousSameOpsPattern pattern( + gpd.mutable_pattern(), name_scope_, "reshape2"); + int delete_counts = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle fused continuous reshape ops."; + GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH( + second_out_var_node, second_out_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern); + auto first_node_attr_shape = + first_op_node->Op()->GetAttrIfExists>("shape"); + if (first_node_attr_shape.empty()) return; + auto second_node_attr_shape = + second_op_node->Op()->GetAttrIfExists>("shape"); + if (second_node_attr_shape.empty()) return; + second_op_node->Op()->RenameInput(first_out_var_node->Name(), + first_in_var_node->Name()); + IR_NODE_LINK_TO(first_in_var_node, second_op_node); + GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node}); + delete_counts++; + }; + gpd(graph, handler); + delete_op_count += delete_counts; + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " repeated " + << "reshape2" + << " ops"; + } +} +void FusedContinuousSameOpsPass::FusedUnsqueezeOps(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::ContinuousSameOpsPattern pattern( + gpd.mutable_pattern(), name_scope_, "unsqueeze2"); + int delete_counts = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle fused continuous unsqueeze ops."; + GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH( + second_out_var_node, second_out_var_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern); + GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern); + auto first_node_attr_axes = + first_op_node->Op()->GetAttrIfExists>("axes"); + if (first_node_attr_axes.empty()) return; + auto second_node_attr_axes = + second_op_node->Op()->GetAttrIfExists>("axes"); + if (second_node_attr_axes.empty()) return; + second_op_node->Op()->RenameInput(first_out_var_node->Name(), + first_in_var_node->Name()); + second_node_attr_axes.insert(second_node_attr_axes.begin(), + first_node_attr_axes.begin(), + first_node_attr_axes.end()); + second_op_node->Op()->SetAttr("axes", second_node_attr_axes); + IR_NODE_LINK_TO(first_in_var_node, second_op_node); + GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node}); + delete_counts++; + }; + gpd(graph, handler); + delete_op_count += delete_counts; + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " repeated " + << "unsqueeze2" + << " ops"; + } +} +void FusedContinuousSameOpsPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + int repeat_time = 0; + int total_delete_op_count = 0; + // This pass needs to loop run until there are no nodes in the graph that need + // to be deleted. + while (true) { + delete_op_count = 0; + FusedReshapeOps(graph); + FusedUnsqueezeOps(graph); + LOG(INFO) << "Round " << repeat_time++ + << ": delete op counts: " << delete_op_count; + total_delete_op_count += delete_op_count; + if (delete_op_count == 0) { + LOG(INFO) << "--- no nodes need to delete --- break"; + break; // No node need to delete. + } + } + LOG(INFO) << "Total delete op counts: " << total_delete_op_count; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fused_continuous_same_ops_pass, + paddle::framework::ir::FusedContinuousSameOpsPass); + +REGISTER_PASS_CAPABILITY(fused_continuous_same_ops_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "reshape2", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "unsqueeze2", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index c872db45117dc513bb2c0be0281e6e36df10da13..958ee89af0c31d6378e30ee8bcf905f249686326 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -507,8 +507,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_assign_op_pass", "delete_dropout_op_pass", "delete_concat_op_pass", - "identity_op_clean_pass", "delete_repeated_ops_pass", + "identity_op_clean_pass", + "fused_continuous_same_ops_pass", "reshape_unstack_concat_fuse_pass", "delete_op_device_pass", "constant_folding_pass", diff --git a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py index 5f7799aaee83d45fb9dffdafc8914889d94a106b..b6f45c5841c0e8e63f7ea4cbb5fe0d7b2f852b62 100644 --- a/test/ir/inference/test_xpu_delete_repeated_ops_pass.py +++ b/test/ir/inference/test_xpu_delete_repeated_ops_pass.py @@ -13,8 +13,10 @@ # limitations under the License. import unittest +from functools import partial import hypothesis.strategies as st +import numpy as np from auto_scan_test import PassAutoScanTest from program_config import OpConfig, ProgramConfig, TensorConfig @@ -380,5 +382,350 @@ class TestDeleteRepeatedScalePass(PassAutoScanTest): ) +class TestDeleteRepeatedSqueezePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'squeeze2', 'relu', 'relu', 'relu'], ( + 1e-5, + 1e-5, + ) + + def sample_program_config(self, draw): + scale_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + scale_x[0] = 1 + axis = 0 + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + squeeze_op0 = OpConfig( + "squeeze2", + inputs={ + "X": ["scale0_out"], + }, + axes=[axis], + outputs={"Out": ["squeeze0_out"]}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["squeeze0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) + scale_op1 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale1_out"]}, + ) + squeeze_op1 = OpConfig( + "squeeze2", + inputs={ + "X": ["scale1_out"], + }, + axes=[axis], + outputs={"Out": ["squeeze1_out"]}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["squeeze1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) + scale_op2 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale2_out"]}, + ) + squeeze_op2 = OpConfig( + "squeeze2", + inputs={ + "X": ["scale2_out"], + }, + axes=[axis], + outputs={"Out": ["squeeze2_out"]}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["squeeze2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [ + scale_op0, + squeeze_op0, + relu_op0, + scale_op1, + squeeze_op1, + relu_op1, + scale_op2, + squeeze_op2, + relu_op2, + ] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + +class TestDeleteRepeatedUnSqueezePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'unsqueeze2', 'relu', 'relu', 'relu'], ( + 1e-5, + 1e-5, + ) + + def sample_program_config(self, draw): + scale_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + axis = 0 + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + unsqueeze_op0 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["scale0_out"], + }, + axes=[axis], + outputs={"Out": ["unsqueeze0_out"]}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["unsqueeze0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) + scale_op1 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale1_out"]}, + ) + unsqueeze_op1 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["scale1_out"], + }, + axes=[axis], + outputs={"Out": ["unsqueeze1_out"]}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["unsqueeze1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) + scale_op2 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale2_out"]}, + ) + unsqueeze_op2 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["scale2_out"], + }, + axes=[axis], + outputs={"Out": ["unsqueeze2_out"]}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["unsqueeze2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + ops = [ + scale_op0, + unsqueeze_op0, + relu_op0, + scale_op1, + unsqueeze_op1, + relu_op1, + scale_op2, + unsqueeze_op2, + relu_op2, + ] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + +class TestDeleteRepeatedGatherPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'gather', 'relu', 'relu', 'relu'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + scale_x = draw( + st.lists( + st.integers(min_value=3, max_value=20), min_size=2, max_size=4 + ) + ) + axis = 0 + + def generate_index(*args, **kwargs): + return np.array([0]).astype(np.int64) + + gather_index = np.array([0]).astype(np.int64) + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + gather_op0 = OpConfig( + "gather", + inputs={"X": ["scale0_out"], "Index": ["gather_index"]}, + axis=axis, + outputs={"Out": ["gather0_out"]}, + ) + relu_op0 = OpConfig( + "relu", + inputs={ + "X": ["gather0_out"], + }, + outputs={"Out": ["relu0_out"]}, + ) + scale_op1 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale1_out"]}, + ) + gather_op1 = OpConfig( + "gather", + inputs={"X": ["scale1_out"], "Index": ["gather_index"]}, + axis=axis, + outputs={"Out": ["gather1_out"]}, + ) + relu_op1 = OpConfig( + "relu", + inputs={ + "X": ["gather1_out"], + }, + outputs={"Out": ["relu1_out"]}, + ) + scale_op2 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale2_out"]}, + ) + gather_op2 = OpConfig( + "gather", + inputs={"X": ["scale2_out"], "Index": ["gather_index"]}, + axis=axis, + outputs={"Out": ["gather2_out"]}, + ) + relu_op2 = OpConfig( + "relu", + inputs={ + "X": ["gather2_out"], + }, + outputs={"Out": ["relu2_out"]}, + ) + + ops = [ + scale_op0, + gather_op0, + relu_op0, + scale_op1, + gather_op1, + relu_op1, + scale_op2, + gather_op2, + relu_op2, + ] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + "gather_index": TensorConfig(data_gen=partial(generate_index)), + }, + outputs=["relu0_out", "relu1_out", "relu2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + if __name__ == "__main__": unittest.main() diff --git a/test/ir/inference/test_xpu_fused_continuous_same_ops_pass.py b/test/ir/inference/test_xpu_fused_continuous_same_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..70fdb2f34fb2f95223780e641d20bb53073f0466 --- /dev/null +++ b/test/ir/inference/test_xpu_fused_continuous_same_ops_pass.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestFusedSameUnSqueezePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'unsqueeze2'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + scale_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=1, max_size=3 + ) + ) + first_unsqueeze_axis = 0 + second_unsqueeze_axis = 1 + third_unsqueeze_axis = 2 + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + unsqueeze_op0 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["scale0_out"], + }, + axes=[first_unsqueeze_axis], + outputs={"Out": ["unsqueeze0_out"]}, + ) + unsqueeze_op1 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["unsqueeze0_out"], + }, + axes=[second_unsqueeze_axis], + outputs={"Out": ["unsqueeze1_out"]}, + ) + unsqueeze_op2 = OpConfig( + "unsqueeze2", + inputs={ + "X": ["unsqueeze1_out"], + }, + axes=[third_unsqueeze_axis], + outputs={"Out": ["unsqueeze2_out"]}, + ) + ops = [scale_op0, unsqueeze_op0, unsqueeze_op1, unsqueeze_op2] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + }, + outputs=["unsqueeze2_out"], + ) + return program_config + + +class TestFusedSameReshapePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['scale', 'reshape2'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + scale_x = draw( + st.sampled_from([[8, 16], [16, 32], [64, 16], [16, 8], [16, 16]]) + ) + first_reshape_shape = [-1, 16, 4] + second_reshape_shape = [-1, 8] + scale_op0 = OpConfig( + "scale", + inputs={ + "X": ["scale_x"], + }, + scale=2.0, + bias=1.0, + bias_after_scale=True, + outputs={"Out": ["scale0_out"]}, + ) + reshape_op0 = OpConfig( + "reshape2", + inputs={ + "X": ["scale0_out"], + }, + shape=first_reshape_shape, + outputs={"Out": ["reshape0_out"]}, + ) + reshape_op1 = OpConfig( + "reshape2", + inputs={ + "X": ["reshape0_out"], + }, + shape=second_reshape_shape, + outputs={"Out": ["reshape1_out"]}, + ) + ops = [scale_op0, reshape_op0, reshape_op1] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "scale_x": TensorConfig(shape=scale_x), + }, + outputs=["reshape1_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + min_success_num=5, + passes=["fused_continuous_same_ops_pass"], + ) + + +if __name__ == "__main__": + unittest.main()