From ec8353e82b83e47ba3ef761285c078f16f3692ba Mon Sep 17 00:00:00 2001 From: HappyAngel Date: Tue, 17 Dec 2019 17:54:21 +0800 Subject: [PATCH] [lite]add some fusion (#2604) * add cv image process * fix arm liunx build error * add LITE_WITH_CV defien to make cv, test=develop * fix cv format, annd add describe in utils/cv * delete some Meaningless comments, test=develop * set LITE_WITH_CV=OFF in build.sh, test=develop * delete cv_enum.h in utils/cv, push the contents in cv_ennum.h to paddle_image_preprocess.h, test=develop * according to reviews to redefine paddle_image_preprocess.h, test=develop * add detailed note of flipParam, test=develop * fix format in paddle_image_preprocess.h, test=develop * fix error when build x86. test=develop lite_with_X86 does not contain lite_with_cv * fix cmake error in llite/CMakeLists.txt, missing mkdir cxx, test=develop * according to review change, test=develop * chang grb to rgb, test=develop * add elemetnwise mul constant elimination and deconv+relu, deconv+batchnorm fusion, test=develop * fix format, test=develop --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + ...elementwise_mul_constant_eliminate_pass.cc | 88 +++++++++++++++++++ .../mir/fusion/conv_activation_fuse_pass.cc | 2 +- lite/core/mir/fusion/conv_bn_fuse_pass.cc | 3 +- lite/core/mir/variable_place_inference_pass.h | 4 + lite/core/optimizer.h | 4 +- 7 files changed, 100 insertions(+), 3 deletions(-) create mode 100644 lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 405f4bbe21..ba3e056627 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -39,5 +39,6 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index e0898693b9..8d5b3e6876 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -20,6 +20,7 @@ lite_cc_library(mir_passes fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc new file mode 100644 index 0000000000..863c01ef06 --- /dev/null +++ b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class ElementwiseMulConstantEliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + auto* post_op = OpNode("postop"); // the post op's output need update + // TODO(Superjomn) check has only one output + auto* x = + VarNode("x")->assert_is_op_input("elementwise_mul", "X")->AsOutput(); + auto* y = VarNode("Y")->assert_is_op_input("elementwise_mul", "Y"); + + // create op nodes + auto* mul = OpNode("mul", "elementwise_mul") + ->assert_is_op("elementwise_mul") + ->AsIntermediate(); + + auto* fill_constant = OpNode("fill_constant", "fill_constant") + ->assert_is_op("fill_constant") + ->assert_op_attr("value", 1.) + ->AsIntermediate(); + // create output node + auto* mul_out = + VarNode("output")->assert_is_op_output("elementwise_mul", "Out"); + // create topology. + std::vector add_inputs{x, y}; + *pre_op >> *x; + *fill_constant >> *y; + add_inputs >> *mul >> *mul_out; + *mul_out >> *post_op; + + // The pre_op will be eliminated, and a new output-updated op will insert. + mul_out->AsIntermediate(); // mul_out is pre_op's output, need to update + y->AsIntermediate(); // need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& post_op = matched.at("postop")->AsStmt(); + auto op_info = *post_op.op_info(); + + op_info.UpdateAllInputs(matched.at("output")->AsArg().name, + matched.at("x")->AsArg().name); + post_op.ResetOp(op_info, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("x"), matched.at("postop")); + } +}; + +} // namespace + +class ElementwiseMulConstantEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + ElementwiseMulConstantEliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(elementwise_mul_constant_eliminate_pass, + paddle::lite::mir::ElementwiseMulConstantEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index 0d11b47db6..c5ce74e30e 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -30,7 +30,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { break; } } - for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto act_type : act_types) { for (auto has_bias : {true, false}) { fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 5ab5f8c0a4..4725ca7485 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -26,7 +26,8 @@ namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index b3609230b9..875bf23082 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass { void CheckAllArgumentTypeDetermined(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (node.IsArg()) { + if (node.inlinks.size() == 0 && node.outlinks.size() == 0) { + // empty node + continue; + } CHECK(node.AsArg().type) << "node " << node.AsArg().name << " type not determined, " << &node; } diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 216e263581..6b2d7f5b18 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -67,7 +67,9 @@ class Optimizer { "lite_transpose_softmax_transpose_fuse_pass", // "lite_interpolate_fuse_pass", // "identity_scale_eliminate_pass", // -#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) + "elementwise_mul_constant_eliminate_pass", // +#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ + (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // #endif "static_kernel_pick_pass", // pick original kernel from graph -- GitLab