From 5eee25b26ae93067a013400841f871d4ad0bc7ea Mon Sep 17 00:00:00 2001 From: yongqiangma Date: Thu, 9 Apr 2020 09:36:17 +0000 Subject: [PATCH] add shape_eliminate_pass for rm shape. test=develop --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + .../mir/elimination/shape_eliminate_pass.cc | 85 +++++++++++++++++++ lite/core/optimizer.h | 1 + 4 files changed, 88 insertions(+) create mode 100644 lite/core/mir/elimination/shape_eliminate_pass.cc diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 41eca021a9..c4d7eceb29 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -47,3 +47,4 @@ USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); +USE_MIR_PASS(shape_eliminate_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 82b19b030c..e55c46b488 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(mir_passes fusion/sequence_pool_concat_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc + elimination/shape_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/shape_eliminate_pass.cc b/lite/core/mir/elimination/shape_eliminate_pass.cc new file mode 100644 index 0000000000..e3d6a74838 --- /dev/null +++ b/lite/core/mir/elimination/shape_eliminate_pass.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + // the previous op's output need updat + auto* pre_op = OpNode("preop")->assert_is_not_op_type("conditional_block"); + auto* input = VarNode("input") + ->assert_is_op_input("shape", "Input") + ->AsIntermediate(); + auto* shape_op = + OpNode("shape", "shape")->assert_is_op("shape")->AsIntermediate(); + auto* out = VarNode("out")->assert_is_op_output("shape", "Out"); + LOG(INFO) << "shapeshape"; + + *pre_op >> *input >> *shape_op >> *out; + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + auto* shape_node = matched.at("shape"); + auto* scope = shape_node->stmt()->op()->scope(); + auto* in = matched.at("input"); + auto shape_in_tensor = scope->FindVar(in->arg()->name)->Get(); + auto* out = matched.at("out"); + auto* shape_out_tensor = + scope->FindVar(out->arg()->name)->GetMutable(); + auto dim_data = shape_in_tensor.dims(); + std::vector shape_vec; + shape_vec.push_back(static_cast(dim_data.size())); + shape_out_tensor->Resize(shape_vec); + auto* out_data = shape_out_tensor->mutable_data(); + for (int i = 0; i < dim_data.size(); i++) { + out_data[i] = dim_data[i]; + } + op_info.UpdateAllOutputs(matched.at("input")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + + GraphSafeRemoveNodes(graph, {matched.at("shape")}); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } +}; + +} // namespace + +class ShapeEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(shape_eliminate_pass, paddle::lite::mir::ShapeEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index ca22c86907..50db757cba 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -71,6 +71,7 @@ class Optimizer { "identity_scale_eliminate_pass", // "elementwise_mul_constant_eliminate_pass", // "lite_sequence_pool_concat_fuse_pass", // + "shape_eliminate_pass", // #if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // -- GitLab