diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index 405f4bbe21cbab70e7869d8a6751d639ec97900d..ba3e0566277f74fe6dd09b52d77b07aba14a458d 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -39,5 +39,6 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(memory_optimize_pass); +USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index e0898693b9a0954113bd8fb32af1bfd8ee4f5821..8d5b3e6876ace41e7e63835e9c847bd381c16635 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -20,6 +20,7 @@ lite_cc_library(mir_passes fusion/elementwise_add_activation_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_cast_pass.cc diff --git a/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..863c01ef0646794b5cbe54d7a81a8f26dbf164ae --- /dev/null +++ b/lite/core/mir/elimination/elementwise_mul_constant_eliminate_pass.cc @@ -0,0 +1,88 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class ElementwiseMulConstantEliminator : public FuseBase { + public: + void BuildPattern() override { + auto* pre_op = OpNode("preop"); // the previous op's output need update + auto* post_op = OpNode("postop"); // the post op's output need update + // TODO(Superjomn) check has only one output + auto* x = + VarNode("x")->assert_is_op_input("elementwise_mul", "X")->AsOutput(); + auto* y = VarNode("Y")->assert_is_op_input("elementwise_mul", "Y"); + + // create op nodes + auto* mul = OpNode("mul", "elementwise_mul") + ->assert_is_op("elementwise_mul") + ->AsIntermediate(); + + auto* fill_constant = OpNode("fill_constant", "fill_constant") + ->assert_is_op("fill_constant") + ->assert_op_attr("value", 1.) + ->AsIntermediate(); + // create output node + auto* mul_out = + VarNode("output")->assert_is_op_output("elementwise_mul", "Out"); + // create topology. + std::vector add_inputs{x, y}; + *pre_op >> *x; + *fill_constant >> *y; + add_inputs >> *mul >> *mul_out; + *mul_out >> *post_op; + + // The pre_op will be eliminated, and a new output-updated op will insert. + mul_out->AsIntermediate(); // mul_out is pre_op's output, need to update + y->AsIntermediate(); // need to update + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& post_op = matched.at("postop")->AsStmt(); + auto op_info = *post_op.op_info(); + + op_info.UpdateAllInputs(matched.at("output")->AsArg().name, + matched.at("x")->AsArg().name); + post_op.ResetOp(op_info, graph->valid_places()); + + IR_NODE_LINK_TO(matched.at("x"), matched.at("postop")); + } +}; + +} // namespace + +class ElementwiseMulConstantEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + ElementwiseMulConstantEliminator eliminator; + eliminator(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(elementwise_mul_constant_eliminate_pass, + paddle::lite::mir::ElementwiseMulConstantEliminatePass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/fusion/conv_activation_fuse_pass.cc b/lite/core/mir/fusion/conv_activation_fuse_pass.cc index 0d11b47db6a7f767f8cd032877d8647b0872b8d4..c5ce74e30e34b5878a534010b6cf8b86f91a1118 100644 --- a/lite/core/mir/fusion/conv_activation_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_activation_fuse_pass.cc @@ -30,7 +30,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr& graph) { break; } } - for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { + for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) { for (auto act_type : act_types) { for (auto has_bias : {true, false}) { fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); diff --git a/lite/core/mir/fusion/conv_bn_fuse_pass.cc b/lite/core/mir/fusion/conv_bn_fuse_pass.cc index 5ab5f8c0a4797e51cce656de43883a68d4931e9b..4725ca74855d72674b922478acd1f6f3a3b59798 100644 --- a/lite/core/mir/fusion/conv_bn_fuse_pass.cc +++ b/lite/core/mir/fusion/conv_bn_fuse_pass.cc @@ -26,7 +26,8 @@ namespace mir { void ConvBNFusePass::Apply(const std::unique_ptr& graph) { // initialze fuser params std::vector conv_has_bias_cases{true, false}; - std::vector conv_type_cases{"conv2d", "depthwise_conv2d"}; + std::vector conv_type_cases{ + "conv2d", "depthwise_conv2d", "conv2d_transpose"}; // start fuse using params for (auto conv_has_bias : conv_has_bias_cases) { diff --git a/lite/core/mir/variable_place_inference_pass.h b/lite/core/mir/variable_place_inference_pass.h index b3609230b993b73edd386cc3aea55081d25538a2..875bf23082a24cb6fcae878b46cc9dcdbb2b76f7 100644 --- a/lite/core/mir/variable_place_inference_pass.h +++ b/lite/core/mir/variable_place_inference_pass.h @@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass { void CheckAllArgumentTypeDetermined(SSAGraph* graph) { for (auto& node : graph->mutable_nodes()) { if (node.IsArg()) { + if (node.inlinks.size() == 0 && node.outlinks.size() == 0) { + // empty node + continue; + } CHECK(node.AsArg().type) << "node " << node.AsArg().name << " type not determined, " << &node; } diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 216e2635812ae1f9a00e48bb1f80a2489722c2cc..6b2d7f5b18df8230af856d68a8301ee3cb929900 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -67,7 +67,9 @@ class Optimizer { "lite_transpose_softmax_transpose_fuse_pass", // "lite_interpolate_fuse_pass", // "identity_scale_eliminate_pass", // -#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) + "elementwise_mul_constant_eliminate_pass", // +#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \ + (defined LITE_WITH_ARM) "lite_elementwise_add_activation_fuse_pass", // #endif "static_kernel_pick_pass", // pick original kernel from graph