未验证 提交 ec8353e8 编写于 作者: H HappyAngel 提交者: GitHub

[lite]add some fusion (#2604)

* add cv image process

* fix arm liunx build error

* add LITE_WITH_CV defien to make cv, test=develop

* fix cv format, annd add describe in utils/cv

* delete some Meaningless comments, test=develop

* set LITE_WITH_CV=OFF in build.sh, test=develop

* delete cv_enum.h in utils/cv, push the contents in cv_ennum.h to paddle_image_preprocess.h, test=develop

* according to reviews to redefine paddle_image_preprocess.h, test=develop

* add detailed note of flipParam, test=develop

* fix format in paddle_image_preprocess.h, test=develop

* fix error when build x86. test=develop

lite_with_X86 does not contain lite_with_cv

* fix cmake error in llite/CMakeLists.txt, missing mkdir cxx, test=develop

* according to review change, test=develop

* chang grb to rgb, test=develop

* add elemetnwise mul constant elimination and deconv+relu, deconv+batchnorm fusion, test=develop

* fix format, test=develop
上级 1e9823a0
...@@ -39,5 +39,6 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass); ...@@ -39,5 +39,6 @@ USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass); USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(memory_optimize_pass); USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass);
...@@ -20,6 +20,7 @@ lite_cc_library(mir_passes ...@@ -20,6 +20,7 @@ lite_cc_library(mir_passes
fusion/elementwise_add_activation_fuse_pass.cc fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_cast_pass.cc type_target_cast_pass.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class ElementwiseMulConstantEliminator : public FuseBase {
public:
void BuildPattern() override {
auto* pre_op = OpNode("preop"); // the previous op's output need update
auto* post_op = OpNode("postop"); // the post op's output need update
// TODO(Superjomn) check has only one output
auto* x =
VarNode("x")->assert_is_op_input("elementwise_mul", "X")->AsOutput();
auto* y = VarNode("Y")->assert_is_op_input("elementwise_mul", "Y");
// create op nodes
auto* mul = OpNode("mul", "elementwise_mul")
->assert_is_op("elementwise_mul")
->AsIntermediate();
auto* fill_constant = OpNode("fill_constant", "fill_constant")
->assert_is_op("fill_constant")
->assert_op_attr<float>("value", 1.)
->AsIntermediate();
// create output node
auto* mul_out =
VarNode("output")->assert_is_op_output("elementwise_mul", "Out");
// create topology.
std::vector<PMNode*> add_inputs{x, y};
*pre_op >> *x;
*fill_constant >> *y;
add_inputs >> *mul >> *mul_out;
*mul_out >> *post_op;
// The pre_op will be eliminated, and a new output-updated op will insert.
mul_out->AsIntermediate(); // mul_out is pre_op's output, need to update
y->AsIntermediate(); // need to update
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& post_op = matched.at("postop")->AsStmt();
auto op_info = *post_op.op_info();
op_info.UpdateAllInputs(matched.at("output")->AsArg().name,
matched.at("x")->AsArg().name);
post_op.ResetOp(op_info, graph->valid_places());
IR_NODE_LINK_TO(matched.at("x"), matched.at("postop"));
}
};
} // namespace
class ElementwiseMulConstantEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
ElementwiseMulConstantEliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(elementwise_mul_constant_eliminate_pass,
paddle::lite::mir::ElementwiseMulConstantEliminatePass)
.BindTargets({TARGET(kAny)});
...@@ -30,7 +30,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -30,7 +30,7 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
break; break;
} }
} }
for (auto conv_type : {"conv2d", "depthwise_conv2d"}) { for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
for (auto act_type : act_types) { for (auto act_type : act_types) {
for (auto has_bias : {true, false}) { for (auto has_bias : {true, false}) {
fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias); fusion::ConvActivationFuser fuser(conv_type, act_type, has_bias);
......
...@@ -26,7 +26,8 @@ namespace mir { ...@@ -26,7 +26,8 @@ namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params // initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false}; std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"}; std::vector<std::string> conv_type_cases{
"conv2d", "depthwise_conv2d", "conv2d_transpose"};
// start fuse using params // start fuse using params
for (auto conv_has_bias : conv_has_bias_cases) { for (auto conv_has_bias : conv_has_bias_cases) {
......
...@@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -48,6 +48,10 @@ class VariablePlaceInferencePass : public DebugPass {
void CheckAllArgumentTypeDetermined(SSAGraph* graph) { void CheckAllArgumentTypeDetermined(SSAGraph* graph) {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (node.IsArg()) { if (node.IsArg()) {
if (node.inlinks.size() == 0 && node.outlinks.size() == 0) {
// empty node
continue;
}
CHECK(node.AsArg().type) << "node " << node.AsArg().name CHECK(node.AsArg().type) << "node " << node.AsArg().name
<< " type not determined, " << &node; << " type not determined, " << &node;
} }
......
...@@ -67,7 +67,9 @@ class Optimizer { ...@@ -67,7 +67,9 @@ class Optimizer {
"lite_transpose_softmax_transpose_fuse_pass", // "lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", // "lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", // "identity_scale_eliminate_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) "elementwise_mul_constant_eliminate_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_add_activation_fuse_pass", //
#endif #endif
"static_kernel_pick_pass", // pick original kernel from graph "static_kernel_pick_pass", // pick original kernel from graph
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册