From 40006e19f288cdbf8fd95b021508ea8b37d20561 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 8 Apr 2020 09:37:53 +0000 Subject: [PATCH] add assign_value_eliminate_pass to remove assign_value op. test=develop --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + .../assign_value_eliminate_pass.cc | 81 +++++++++++++++++++ lite/core/optimizer.h | 1 + 4 files changed, 84 insertions(+) create mode 100644 lite/core/mir/elimination/assign_value_eliminate_pass.cc diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 41eca021a9..25fd9b4eed 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -47,3 +47,4 @@ USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); +USE_MIR_PASS(assign_value_eliminate_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 82b19b030c..0e021fa444 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -23,6 +23,7 @@ lite_cc_library(mir_passes fusion/sequence_pool_concat_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc + elimination/assign_value_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/assign_value_eliminate_pass.cc b/lite/core/mir/elimination/assign_value_eliminate_pass.cc new file mode 100644 index 0000000000..0d2c220787 --- /dev/null +++ b/lite/core/mir/elimination/assign_value_eliminate_pass.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +template +void TensorFromVector(const std::vector& src, lite::Tensor* dst) { + auto* src_ptr = static_cast(src.data()); + auto* dst_ptr = static_cast(dst->mutable_data()); + auto size = src.size() * sizeof(T); + std::memcpy(dst_ptr, src_ptr, size); +} + +class Eliminator : public FuseBase { + public: + void BuildPattern() override { + auto* assign_value_op = OpNode("assign_value", "assign_value"); + auto* out = VarNode("out")->assert_is_op_output("assign_value", "Out"); + *assign_value_op >> *out; + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto* assign_node = matched.at("assign_value"); + auto* scope = assign_node->stmt()->op()->scope(); + auto* op_info = assign_node->stmt()->op()->op_info(); + auto shape = op_info->GetAttr>("shape"); + auto dtype = op_info->GetAttr("dtype"); + auto fp32_values = op_info->GetAttr>("fp32_values"); + auto int32_values = op_info->GetAttr>("int32_values"); + auto* out = matched.at("out"); + auto* out_tensor = scope->FindVar(out->arg()->name) + ->GetMutable(); + if (dtype == static_cast(lite::core::FluidType::INT32)) { + TensorFromVector(int32_values, out_tensor); + } else if (dtype == static_cast(lite::core::FluidType::FP32)) { + TensorFromVector(fp32_values, out_tensor); + } else { + LOG(FATAL) << "Unsupported dtype for assign_value_op:" << dtype; + } + GraphSafeRemoveNodes(graph, {matched.at("assign_value")}); + + } +}; + +} // namespace + +class AssignValueEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + Eliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(assign_value_eliminate_pass, + paddle::lite::mir::AssignValueEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index ca22c86907..e2463ec9b6 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -71,6 +71,7 @@ class Optimizer { "identity_scale_eliminate_pass", // "elementwise_mul_constant_eliminate_pass", // "lite_sequence_pool_concat_fuse_pass", // + "assign_value_eliminate_pass", #if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // -- GitLab