未验证 提交 5f6d8ce6 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

add reshape pass. test=develop (#4073)

上级 d9c72521
......@@ -48,6 +48,7 @@ USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(type_layout_cast_preprocess_pass);
USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(lite_reshape_fuse_pass);
USE_MIR_PASS(multi_stream_analysis_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass);
......
......@@ -23,6 +23,7 @@ lite_cc_library(mir_passes
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
fusion/scale_activation_fuse_pass.cc
fusion/reshape_fuse_pass.cc
fusion/__xpu__resnet_fuse_pass.cc
fusion/__xpu__resnet_cbam_fuse_pass.cc
fusion/__xpu__multi_encoder_fuse_pass.cc
......
......@@ -37,6 +37,9 @@ lite_cc_library(fuse_sequence_pool_concat
lite_cc_library(fuse_scale_activation
SRCS scale_activation_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_reshape
SRCS reshape_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_match_matrix_activation
SRCS match_matrix_activation_fuser.cc
DEPS pattern_matcher_high_api)
......@@ -61,6 +64,7 @@ set(mir_fusers
fuse_interpolate
fuse_sequence_pool_concat
fuse_scale_activation
fuse_reshape
fuse_match_matrix_activation
fuse_scales
fuse_sequence_reverse_embedding
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/reshape_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/reshape_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ReshapeFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> reshape_type_cases{"reshape", "reshape2"};
for (auto type_ : reshape_type_cases) {
fusion::ReshapeFuser reshape_fuser(type_);
reshape_fuser(graph.get());
}
for (auto type_ : reshape_type_cases) {
fusion::Reshape2OutFuser reshape2Out_fuser(type_);
reshape2Out_fuser(graph.get());
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_reshape_fuse_pass, paddle::lite::mir::ReshapeFusePass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ReshapeFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/reshape_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ReshapeFuser::BuildPattern() {
auto* x = VarNode("x");
auto* reshape = OpNode("reshape", type_);
auto* reshape_out = VarNode("Out");
auto* out1 = OpNode("out1");
*x >> *reshape >> *reshape_out >> *out1;
}
void ReshapeFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = const_cast<OpInfo*>(matched.at("reshape")->stmt()->op_info());
op_desc->SetAttr<bool>("inplace", true);
}
void Reshape2OutFuser::BuildPattern() {
auto* x = VarNode("x");
auto* reshape =
OpNode("reshape", type_)->assert_op_attr<bool>("inplace", true);
auto* reshape_out = VarNode("Out");
auto* out1 = OpNode("out1");
auto* out2 = OpNode("out2");
*x >> *reshape >> *reshape_out >> *out1;
*reshape_out >> *out2;
}
void Reshape2OutFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = const_cast<OpInfo*>(matched.at("reshape")->stmt()->op_info());
op_desc->SetAttr<bool>("inplace", false);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ReshapeFuser : public FuseBase {
public:
explicit ReshapeFuser(const std::string& type) : type_(type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
std::string type_;
};
class Reshape2OutFuser : public FuseBase {
public:
explicit Reshape2OutFuser(const std::string& type) : type_(type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
std::string type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -164,6 +164,7 @@ class Optimizer {
"runtime_context_assign_pass",
"argument_type_display_pass",
"lite_reshape_fuse_pass",
"memory_optimize_pass"}};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册