diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 390584fe988b8f53a6164ed1312aa8c3136c0a11..2a04db1519431cb2608c8f39997581dc3bc63973 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -48,6 +48,7 @@ USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(type_layout_cast_preprocess_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(lite_reshape_fuse_pass); USE_MIR_PASS(multi_stream_analysis_pass); USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 553963cce3c8c785364863e06c5a49aa303242ff..715fef702bfc41e3a8c8dc9c698eb38895b668ca 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(mir_passes fusion/quant_dequant_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc fusion/scale_activation_fuse_pass.cc + fusion/reshape_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__resnet_cbam_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 00307588f0d6ebcc483ea255eb0a112e0c807ee6..cb8d5e31abb9f492e54049f58c224f13f60ee7ed 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -37,6 +37,9 @@ lite_cc_library(fuse_sequence_pool_concat lite_cc_library(fuse_scale_activation SRCS scale_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_reshape + SRCS reshape_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_match_matrix_activation SRCS match_matrix_activation_fuser.cc DEPS pattern_matcher_high_api) @@ -61,6 +64,7 @@ set(mir_fusers fuse_interpolate fuse_sequence_pool_concat fuse_scale_activation + fuse_reshape fuse_match_matrix_activation fuse_scales fuse_sequence_reverse_embedding diff --git a/lite/core/mir/fusion/reshape_fuse_pass.cc b/lite/core/mir/fusion/reshape_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..bbe19460a57572f14d3981e66611963a2cb90ab0 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuse_pass.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/reshape_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/reshape_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ReshapeFusePass::Apply(const std::unique_ptr& graph) { + std::vector reshape_type_cases{"reshape", "reshape2"}; + for (auto type_ : reshape_type_cases) { + fusion::ReshapeFuser reshape_fuser(type_); + reshape_fuser(graph.get()); + } + + for (auto type_ : reshape_type_cases) { + fusion::Reshape2OutFuser reshape2Out_fuser(type_); + reshape2Out_fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_reshape_fuse_pass, paddle::lite::mir::ReshapeFusePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/reshape_fuse_pass.h b/lite/core/mir/fusion/reshape_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..0fc4ebf0b6fa4fa49ef8c7a2d839d0e6214bf71a --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ReshapeFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/reshape_fuser.cc b/lite/core/mir/fusion/reshape_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..c823fbaac1864324c3754287cf7ec9eec686287e --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuser.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/reshape_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ReshapeFuser::BuildPattern() { + auto* x = VarNode("x"); + auto* reshape = OpNode("reshape", type_); + auto* reshape_out = VarNode("Out"); + auto* out1 = OpNode("out1"); + + *x >> *reshape >> *reshape_out >> *out1; +} + +void ReshapeFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = const_cast(matched.at("reshape")->stmt()->op_info()); + op_desc->SetAttr("inplace", true); +} + +void Reshape2OutFuser::BuildPattern() { + auto* x = VarNode("x"); + auto* reshape = + OpNode("reshape", type_)->assert_op_attr("inplace", true); + auto* reshape_out = VarNode("Out"); + auto* out1 = OpNode("out1"); + auto* out2 = OpNode("out2"); + + *x >> *reshape >> *reshape_out >> *out1; + *reshape_out >> *out2; +} + +void Reshape2OutFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = const_cast(matched.at("reshape")->stmt()->op_info()); + op_desc->SetAttr("inplace", false); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/reshape_fuser.h b/lite/core/mir/fusion/reshape_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..f8faa9296621e467650f89d28eff25dd759f8449 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuser.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ReshapeFuser : public FuseBase { + public: + explicit ReshapeFuser(const std::string& type) : type_(type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + std::string type_; +}; + +class Reshape2OutFuser : public FuseBase { + public: + explicit Reshape2OutFuser(const std::string& type) : type_(type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + std::string type_; +}; +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 03ed491f58f097e8eaf3e95e9b476547e0126ebc..6b18e929c077699a723b9dd9db313370d061cbb8 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -164,6 +164,7 @@ class Optimizer { "runtime_context_assign_pass", "argument_type_display_pass", + "lite_reshape_fuse_pass", "memory_optimize_pass"}};