From 5e8b15f5a8110b52c2fdd9e5b09794e831c3600f Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 28 Aug 2019 13:48:53 +0800 Subject: [PATCH] add transpose-softmax-transpose fuse pass (#1863) * add transpose-softmax-transpose fuse pass test=develop * enable supported lite-npu ops test=develop --- lite/api/paddle_use_passes.h | 1 + lite/core/mir/CMakeLists.txt | 1 + lite/core/mir/fusion/CMakeLists.txt | 4 + .../transpose_softmax_transpose_fuse_pass.cc | 39 ++++++++ .../transpose_softmax_transpose_fuse_pass.h | 32 +++++++ .../transpose_softmax_transpose_fuser.cc | 95 +++++++++++++++++++ .../transpose_softmax_transpose_fuser.h | 44 +++++++++ lite/core/optimizer.h | 11 ++- lite/npu/bridge/paddle_use_npu_bridges.h | 5 + 9 files changed, 227 insertions(+), 5 deletions(-) create mode 100644 lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc create mode 100644 lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h create mode 100644 lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc create mode 100644 lite/core/mir/fusion/transpose_softmax_transpose_fuser.h diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index e7ec702df5..bc2e59f387 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -30,6 +30,7 @@ USE_MIR_PASS(graph_visualze); USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_shuffle_channel_fuse_pass); +USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass); USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index a757018e3d..d96a67f52e 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -12,6 +12,7 @@ lite_cc_library(mir_passes SRCS fusion/fc_fuse_pass.cc fusion/shuffle_channel_fuse_pass.cc + fusion/transpose_softmax_transpose_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc fusion/conv_activation_fuse_pass.cc fusion/conv_bn_fuse_pass.cc diff --git a/lite/core/mir/fusion/CMakeLists.txt b/lite/core/mir/fusion/CMakeLists.txt index 1c4a5eac94..92421a2cf8 100644 --- a/lite/core/mir/fusion/CMakeLists.txt +++ b/lite/core/mir/fusion/CMakeLists.txt @@ -19,6 +19,9 @@ lite_cc_library(fuse_elementwise_add_activation lite_cc_library(fuse_quant_dequant SRCS quant_dequant_op_fuser.cc DEPS pattern_matcher_high_api) +lite_cc_library(fuse_transpose_softmax_transpose + SRCS transpose_softmax_transpose_fuser.cc + DEPS pattern_matcher_high_api) set(mir_fusers fuse_fc @@ -28,6 +31,7 @@ set(mir_fusers fuse_conv_bn fuse_quant_dequant fuse_elementwise_add_activation + fuse_transpose_softmax_transpose CACHE INTERNAL "fusers") if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc new file mode 100644 index 0000000000..93bfef0ae5 --- /dev/null +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h" +#include +#include +#include "lite/core/mir/fusion/transpose_softmax_transpose_fuser.h" +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void TransposeSoftmaxTransposeFusePass::Apply( + const std::unique_ptr& graph) { + fusion::TransposeSoftmaxTransposeFuser fuser("transpose", "softmax"); + fuser(graph.get()); + + fusion::TransposeSoftmaxTransposeFuser fuser2("transpose2", "softmax"); + fuser2(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass, + paddle::lite::mir::TransposeSoftmaxTransposeFusePass); diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h new file mode 100644 index 0000000000..4ae6ce83c4 --- /dev/null +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class TransposeSoftmaxTransposeFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc new file mode 100644 index 0000000000..5e55999442 --- /dev/null +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/fusion/transpose_softmax_transpose_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void TransposeSoftmaxTransposeFuser::BuildPattern() { + // create nodes. + auto* x1 = VarNode("x1")->assert_is_op_input(transpose_type_, "X"); + auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out"); + auto* y2 = VarNode("y2")->assert_is_op_output(softmax_type_, "Out"); + auto* out = VarNode("out")->assert_is_op_output(transpose_type_, "Out"); + + auto* xshape1 = + VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape"); + auto* xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + + auto* transpose1 = + OpNode("transpose1", transpose_type_)->assert_is_op(transpose_type_); + + auto* softmax = OpNode("softmax", softmax_type_) + ->assert_op_attr_satisfied( + "axis", [](int attr) { return attr == -1; }); + + auto* transpose2 = + OpNode("transpose2", transpose_type_)->assert_is_op(transpose_type_); + + // create topology. + *x1 >> *transpose1 >> *y1 >> *softmax >> *y2 >> *transpose2 >> *out; + *transpose1 >> *xshape1; + *transpose2 >> *xshape2; + + // nodes to remove + y1->AsIntermediate(); + y2->AsIntermediate(); + xshape1->AsIntermediate(); + xshape2->AsIntermediate(); + transpose1->AsIntermediate(); + softmax->AsIntermediate(); + transpose2->AsIntermediate(); +} + +void TransposeSoftmaxTransposeFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto softmax_op = LiteOpRegistry::Global().Create(softmax_type_); + auto softmax_old = matched.at("softmax")->stmt()->op(); + auto* scope = softmax_old->scope(); + auto& valid_places = softmax_old->valid_places(); + softmax_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(softmax_op, valid_places); + + IR_NODE_LINK_TO(matched.at("x1"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("out")); +} + +cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc( + const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("softmax"); + op_desc.SetInput("X", {matched.at("x1")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); + op_desc.SetAttr("axis", + matched.at("transpose1") + ->stmt() + ->op_info() + ->GetAttr>("axis") + .back()); + + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/mir/fusion/transpose_softmax_transpose_fuser.h b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.h new file mode 100644 index 0000000000..fbccfd2c6a --- /dev/null +++ b/lite/core/mir/fusion/transpose_softmax_transpose_fuser.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class TransposeSoftmaxTransposeFuser : public FuseBase { + public: + explicit TransposeSoftmaxTransposeFuser(const std::string& transpose_type, + const std::string& softmax_type) + : transpose_type_(transpose_type), softmax_type_(softmax_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string transpose_type_; + std::string softmax_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 0ee0562706..fcc3470525 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -61,11 +61,12 @@ class Optimizer { // kernels, and the OpenCL devices will be discarded. // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. - "lite_conv_elementwise_fuse_pass", // - "lite_conv_activation_fuse_pass", // - "lite_fc_fuse_pass", // - "lite_shuffle_channel_fuse_pass", // - "identity_scale_eliminate_pass", // + "lite_conv_elementwise_fuse_pass", // + "lite_conv_activation_fuse_pass", // + "lite_fc_fuse_pass", // + "lite_shuffle_channel_fuse_pass", // + "lite_transpose_softmax_transpose_fuse_pass", // + "identity_scale_eliminate_pass", // #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK "lite_elementwise_add_activation_fuse_pass", // #endif diff --git a/lite/npu/bridge/paddle_use_npu_bridges.h b/lite/npu/bridge/paddle_use_npu_bridges.h index ba55f52212..9b7f717a41 100644 --- a/lite/npu/bridge/paddle_use_npu_bridges.h +++ b/lite/npu/bridge/paddle_use_npu_bridges.h @@ -30,3 +30,8 @@ USE_NPU_BRIDGE(split); USE_NPU_BRIDGE(transpose); USE_NPU_BRIDGE(transpose2); USE_NPU_BRIDGE(shuffle_channel); +USE_NPU_BRIDGE(batch_norm); +USE_NPU_BRIDGE(bilinear_interp); +USE_NPU_BRIDGE(conv2d_transpose); +USE_NPU_BRIDGE(reshape); +USE_NPU_BRIDGE(reshape2); -- GitLab