提交 58700062 编写于 作者: Z zhupengyang 提交者: GitHub

add transpose-softmax-transpose fuse pass (#1863)

* add transpose-softmax-transpose fuse pass

test=develop

* enable supported lite-npu ops

test=develop
上级 0fc8b4d4
......@@ -30,6 +30,7 @@ USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
......
......@@ -12,6 +12,7 @@ lite_cc_library(mir_passes
SRCS
fusion/fc_fuse_pass.cc
fusion/shuffle_channel_fuse_pass.cc
fusion/transpose_softmax_transpose_fuse_pass.cc
fusion/conv_elementwise_fuse_pass.cc
fusion/conv_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
......
......@@ -19,6 +19,9 @@ lite_cc_library(fuse_elementwise_add_activation
lite_cc_library(fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_transpose_softmax_transpose
SRCS transpose_softmax_transpose_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -28,6 +31,7 @@ set(mir_fusers
fuse_conv_bn
fuse_quant_dequant
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/transpose_softmax_transpose_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/transpose_softmax_transpose_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void TransposeSoftmaxTransposeFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::TransposeSoftmaxTransposeFuser fuser("transpose", "softmax");
fuser(graph.get());
fusion::TransposeSoftmaxTransposeFuser fuser2("transpose2", "softmax");
fuser2(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass,
paddle::lite::mir::TransposeSoftmaxTransposeFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class TransposeSoftmaxTransposeFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/transpose_softmax_transpose_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void TransposeSoftmaxTransposeFuser::BuildPattern() {
// create nodes.
auto* x1 = VarNode("x1")->assert_is_op_input(transpose_type_, "X");
auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out");
auto* y2 = VarNode("y2")->assert_is_op_output(softmax_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(transpose_type_, "Out");
auto* xshape1 =
VarNode("xshape1")->assert_is_op_output(transpose_type_, "XShape");
auto* xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
auto* transpose1 =
OpNode("transpose1", transpose_type_)->assert_is_op(transpose_type_);
auto* softmax = OpNode("softmax", softmax_type_)
->assert_op_attr_satisfied<int>(
"axis", [](int attr) { return attr == -1; });
auto* transpose2 =
OpNode("transpose2", transpose_type_)->assert_is_op(transpose_type_);
// create topology.
*x1 >> *transpose1 >> *y1 >> *softmax >> *y2 >> *transpose2 >> *out;
*transpose1 >> *xshape1;
*transpose2 >> *xshape2;
// nodes to remove
y1->AsIntermediate();
y2->AsIntermediate();
xshape1->AsIntermediate();
xshape2->AsIntermediate();
transpose1->AsIntermediate();
softmax->AsIntermediate();
transpose2->AsIntermediate();
}
void TransposeSoftmaxTransposeFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto softmax_op = LiteOpRegistry::Global().Create(softmax_type_);
auto softmax_old = matched.at("softmax")->stmt()->op();
auto* scope = softmax_old->scope();
auto& valid_places = softmax_old->valid_places();
softmax_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(softmax_op, valid_places);
IR_NODE_LINK_TO(matched.at("x1"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc TransposeSoftmaxTransposeFuser::GenOpDesc(
const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("softmax");
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("axis",
matched.at("transpose1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("axis")
.back());
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class TransposeSoftmaxTransposeFuser : public FuseBase {
public:
explicit TransposeSoftmaxTransposeFuser(const std::string& transpose_type,
const std::string& softmax_type)
: transpose_type_(transpose_type), softmax_type_(softmax_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string transpose_type_;
std::string softmax_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -65,6 +65,7 @@ class Optimizer {
"lite_conv_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
......
......@@ -30,3 +30,8 @@ USE_NPU_BRIDGE(split);
USE_NPU_BRIDGE(transpose);
USE_NPU_BRIDGE(transpose2);
USE_NPU_BRIDGE(shuffle_channel);
USE_NPU_BRIDGE(batch_norm);
USE_NPU_BRIDGE(bilinear_interp);
USE_NPU_BRIDGE(conv2d_transpose);
USE_NPU_BRIDGE(reshape);
USE_NPU_BRIDGE(reshape2);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册