未验证 提交 c49958a2 编写于 作者: Z zhupengyang 提交者: GitHub

add interpolate fuse pass (#1980)

test=develop
上级 febfd7d6
......@@ -31,6 +31,7 @@ USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
......
......@@ -13,6 +13,7 @@ lite_cc_library(mir_passes
fusion/fc_fuse_pass.cc
fusion/shuffle_channel_fuse_pass.cc
fusion/transpose_softmax_transpose_fuse_pass.cc
fusion/interpolate_fuse_pass.cc
fusion/conv_elementwise_fuse_pass.cc
fusion/conv_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
......
......@@ -22,6 +22,9 @@ lite_cc_library(fuse_quant_dequant
lite_cc_library(fuse_transpose_softmax_transpose
SRCS transpose_softmax_transpose_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_interpolate
SRCS interpolate_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -32,6 +35,7 @@ set(mir_fusers
fuse_quant_dequant
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
fuse_interpolate
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/interpolate_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/interpolate_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::InterpolateFuser bilinear_interp_fuser("bilinear_interp");
bilinear_interp_fuser(graph.get());
fusion::InterpolateFuser nearest_interp_fuser("nearest_interp");
nearest_interp_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_interpolate_fuse_pass,
paddle::lite::mir::InterpolateFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class InterpolateFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/interpolate_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void InterpolateFuser::BuildPattern() {
auto* x = VarNode("x");
auto* shape = OpNode("shape", "shape")->AsIntermediate();
auto* shape_out = VarNode("shape_out")->AsIntermediate();
auto* slice = OpNode("slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 0;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 2;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 4;
})
->AsIntermediate();
auto* slice_out = VarNode("slice_out")->AsIntermediate();
auto* cast = OpNode("cast", "cast")->AsIntermediate();
auto* cast_out = VarNode("cast_out")->AsIntermediate();
auto* fill_constant =
OpNode("fill_constant", "fill_constant")->AsIntermediate();
auto* fill_constant_out = VarNode("fill_constant_out")->AsIntermediate();
auto* elementwise_mul =
OpNode("elementwise_mul", "elementwise_mul")
->assert_op_attr_satisfied<int>(
"axis", [](int attr) { return attr == -1 || attr == 0; })
->AsIntermediate();
auto* elementwise_mul_out = VarNode("elementwise_mul_out")->AsIntermediate();
auto* interpolate = OpNode("interpolate", interp_type_)->AsIntermediate();
auto* interpolate_out = VarNode("interpolate_out");
// create topology.
*x >> *shape >> *shape_out >> *slice >> *slice_out >> *cast >> *cast_out >>
*elementwise_mul >> *elementwise_mul_out >> *interpolate >>
*interpolate_out;
*fill_constant >> *fill_constant_out >> *elementwise_mul;
*x >> *interpolate;
}
void InterpolateFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto interp_op = LiteOpRegistry::Global().Create(interp_type_);
auto interp_old = matched.at("interpolate")->stmt()->op();
auto* scope = interp_old->scope();
auto& valid_places = interp_old->valid_places();
interp_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(interp_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("interpolate_out"));
}
cpp::OpDesc InterpolateFuser::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("interpolate")->stmt()->op_info();
op_desc.SetInput("OutSize", {});
op_desc.SetAttr(
"scale",
matched.at("fill_constant")->stmt()->op_info()->GetAttr<float>("value"));
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class InterpolateFuser : public FuseBase {
public:
explicit InterpolateFuser(const std::string& interp_type)
: interp_type_(interp_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string interp_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -67,6 +67,7 @@ class Optimizer {
"lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"lite_transpose_softmax_transpose_fuse_pass", //
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册