未验证 提交 f23b3028 编写于 作者: myq406450149's avatar myq406450149 提交者: GitHub

add interpolate type2 pass (#3396)

* add interpolate type2 pass. test=develp
上级 01f25d78
...@@ -23,11 +23,15 @@ namespace lite { ...@@ -23,11 +23,15 @@ namespace lite {
namespace mir { namespace mir {
void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void InterpolateFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::InterpolateFuser bilinear_interp_fuser("bilinear_interp"); std::vector<std::string> Interpolate_type_cases{"bilinear_interp",
bilinear_interp_fuser(graph.get()); "nearest_interp"};
for (auto type_ : Interpolate_type_cases) {
fusion::InterpolateFuser interp_fuser(type_);
interp_fuser(graph.get());
fusion::InterpolateFuser nearest_interp_fuser("nearest_interp"); fusion::InterpolateFuser2 interp_fuser2(type_);
nearest_interp_fuser(graph.get()); interp_fuser2(graph.get());
}
} }
} // namespace mir } // namespace mir
......
...@@ -22,6 +22,9 @@ namespace mir { ...@@ -22,6 +22,9 @@ namespace mir {
namespace fusion { namespace fusion {
void InterpolateFuser::BuildPattern() { void InterpolateFuser::BuildPattern() {
// type1 fill_constant -->
// x --> shape --> slice --> cast --> elementwise_mul --> interpolate
// `-------------------------------------------------->
auto* x = VarNode("x"); auto* x = VarNode("x");
auto* shape = OpNode("shape", "shape")->AsIntermediate(); auto* shape = OpNode("shape", "shape")->AsIntermediate();
auto* shape_out = VarNode("shape_out")->AsIntermediate(); auto* shape_out = VarNode("shape_out")->AsIntermediate();
...@@ -89,6 +92,64 @@ cpp::OpDesc InterpolateFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -89,6 +92,64 @@ cpp::OpDesc InterpolateFuser::GenOpDesc(const key2nodes_t& matched) {
return op_desc; return op_desc;
} }
void InterpolateFuser2::BuildPattern() {
// type2 x --> shape --> slice --> cast --> scale --> interpolate
// `---------------------------------------->
auto* x = VarNode("x");
auto* shape = OpNode("shape", "shape")->AsIntermediate();
auto* shape_out = VarNode("shape_out")->AsIntermediate();
auto* slice = OpNode("slice", "slice")
->assert_op_attr_satisfied<std::vector<int>>(
"axes",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 0;
})
->assert_op_attr_satisfied<std::vector<int>>(
"starts",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 2;
})
->assert_op_attr_satisfied<std::vector<int>>(
"ends",
[](const std::vector<int>& attr) {
return attr.size() == 1 && attr[0] == 4;
})
->AsIntermediate();
auto* slice_out = VarNode("slice_out")->AsIntermediate();
auto* cast = OpNode("cast", "cast")->AsIntermediate();
auto* cast_out = VarNode("cast_out")->AsIntermediate();
auto* scale = OpNode("scale", "scale")->AsIntermediate();
auto* scale_out = VarNode("scale_out")->AsIntermediate();
auto* interpolate = OpNode("interpolate", interp_type_)->AsIntermediate();
auto* interpolate_out = VarNode("interpolate_out");
// create topology.
*x >> *shape >> *shape_out >> *slice >> *slice_out >> *cast >> *cast_out >>
*scale >> *scale_out >> *interpolate >> *interpolate_out;
*x >> *interpolate;
}
void InterpolateFuser2::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto interp_op = LiteOpRegistry::Global().Create(interp_type_);
auto interp_old = matched.at("interpolate")->stmt()->op();
auto* scope = interp_old->scope();
auto& valid_places = interp_old->valid_places();
interp_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(interp_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("interpolate_out"));
}
cpp::OpDesc InterpolateFuser2::GenOpDesc(const key2nodes_t& matched) {
auto op_desc = *matched.at("interpolate")->stmt()->op_info();
op_desc.SetInput("OutSize", {});
return op_desc;
}
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
...@@ -36,6 +36,19 @@ class InterpolateFuser : public FuseBase { ...@@ -36,6 +36,19 @@ class InterpolateFuser : public FuseBase {
std::string interp_type_; std::string interp_type_;
}; };
class InterpolateFuser2 : public FuseBase {
public:
explicit InterpolateFuser2(const std::string& interp_type)
: interp_type_(interp_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string interp_type_;
};
} // namespace fusion } // namespace fusion
} // namespace mir } // namespace mir
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册