From 5f6d8ce6814e67d66057a2b30bb909cec16cd2ab Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Mon, 17 Aug 2020 16:56:44 +0800 Subject: [PATCH] add reshape pass. test=develop (#4073) --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + lite/core/mir/fusion/CMakeLists.txt | 4 ++ lite/core/mir/fusion/reshape_fuse_pass.cc | 43 +++++++++++++++++ lite/core/mir/fusion/reshape_fuse_pass.h | 32 ++++++++++++ lite/core/mir/fusion/reshape_fuser.cc | 59 +++++++++++++++++++++++ lite/core/mir/fusion/reshape_fuser.h | 50 +++++++++++++++++++ lite/core/optimizer.h | 1 + 8 files changed, 191 insertions(+) create mode 100644 lite/core/mir/fusion/reshape_fuse_pass.cc create mode 100644 lite/core/mir/fusion/reshape_fuse_pass.h create mode 100644 lite/core/mir/fusion/reshape_fuser.cc create mode 100644 lite/core/mir/fusion/reshape_fuser.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 390584fe98..2a04db1519 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -48,6 +48,7 @@ USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(type_layout_cast_preprocess_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(lite_reshape_fuse_pass); USE_MIR_PASS(multi_stream_analysis_pass); USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 553963cce3..715fef702b 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(mir_passes fusion/quant_dequant_fuse_pass.cc fusion/sequence_pool_concat_fuse_pass.cc fusion/scale_activation_fuse_pass.cc + fusion/reshape_fuse_pass.cc fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__resnet_cbam_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 00307588f0..cb8d5e31ab 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -37,6 +37,9 @@ lite_cc_library(fuse_sequence_pool_concat lite_cc_library(fuse_scale_activation SRCS scale_activation_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_reshape + SRCS reshape_fuser.cc + DEPS pattern_matcher_high_api) lite_cc_library(fuse_match_matrix_activation SRCS match_matrix_activation_fuser.cc DEPS pattern_matcher_high_api) @@ -61,6 +64,7 @@ set(mir_fusers fuse_interpolate fuse_sequence_pool_concat fuse_scale_activation + fuse_reshape fuse_match_matrix_activation fuse_scales fuse_sequence_reverse_embedding diff --git a/lite/core/mir/fusion/reshape_fuse_pass.cc b/lite/core/mir/fusion/reshape_fuse_pass.cc new file mode 100644 index 0000000000..bbe19460a5 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuse_pass.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/reshape_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/reshape_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ReshapeFusePass::Apply(const std::unique_ptr& graph) { + std::vector reshape_type_cases{"reshape", "reshape2"}; + for (auto type_ : reshape_type_cases) { + fusion::ReshapeFuser reshape_fuser(type_); + reshape_fuser(graph.get()); + } + + for (auto type_ : reshape_type_cases) { + fusion::Reshape2OutFuser reshape2Out_fuser(type_); + reshape2Out_fuser(graph.get()); + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_reshape_fuse_pass, paddle::lite::mir::ReshapeFusePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/reshape_fuse_pass.h b/lite/core/mir/fusion/reshape_fuse_pass.h new file mode 100644 index 0000000000..0fc4ebf0b6 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ReshapeFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/reshape_fuser.cc b/lite/core/mir/fusion/reshape_fuser.cc new file mode 100644 index 0000000000..c823fbaac1 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuser.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/reshape_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ReshapeFuser::BuildPattern() { + auto* x = VarNode("x"); + auto* reshape = OpNode("reshape", type_); + auto* reshape_out = VarNode("Out"); + auto* out1 = OpNode("out1"); + + *x >> *reshape >> *reshape_out >> *out1; +} + +void ReshapeFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = const_cast(matched.at("reshape")->stmt()->op_info()); + op_desc->SetAttr("inplace", true); +} + +void Reshape2OutFuser::BuildPattern() { + auto* x = VarNode("x"); + auto* reshape = + OpNode("reshape", type_)->assert_op_attr("inplace", true); + auto* reshape_out = VarNode("Out"); + auto* out1 = OpNode("out1"); + auto* out2 = OpNode("out2"); + + *x >> *reshape >> *reshape_out >> *out1; + *reshape_out >> *out2; +} + +void Reshape2OutFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = const_cast(matched.at("reshape")->stmt()->op_info()); + op_desc->SetAttr("inplace", false); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/reshape_fuser.h b/lite/core/mir/fusion/reshape_fuser.h new file mode 100644 index 0000000000..f8faa92966 --- /dev/null +++ b/lite/core/mir/fusion/reshape_fuser.h @@ -0,0 +1,50 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ReshapeFuser : public FuseBase { + public: + explicit ReshapeFuser(const std::string& type) : type_(type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + std::string type_; +}; + +class Reshape2OutFuser : public FuseBase { + public: + explicit Reshape2OutFuser(const std::string& type) : type_(type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + std::string type_; +}; +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 03ed491f58..6b18e929c0 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -164,6 +164,7 @@ class Optimizer { "runtime_context_assign_pass", "argument_type_display_pass", + "lite_reshape_fuse_pass", "memory_optimize_pass"}}; -- GitLab