diff --git a/lite/core/mir/fusion/interpolate_fuse_pass.cc b/lite/core/mir/fusion/interpolate_fuse_pass.cc index 51c9868cf3ed76ee6f02ac954f74c330e9f1a8e1..ab152c94561410f8febc5f5db7a1709bb114fb94 100644 --- a/lite/core/mir/fusion/interpolate_fuse_pass.cc +++ b/lite/core/mir/fusion/interpolate_fuse_pass.cc @@ -23,11 +23,15 @@ namespace lite { namespace mir { void InterpolateFusePass::Apply(const std::unique_ptr& graph) { - fusion::InterpolateFuser bilinear_interp_fuser("bilinear_interp"); - bilinear_interp_fuser(graph.get()); + std::vector Interpolate_type_cases{"bilinear_interp", + "nearest_interp"}; + for (auto type_ : Interpolate_type_cases) { + fusion::InterpolateFuser interp_fuser(type_); + interp_fuser(graph.get()); - fusion::InterpolateFuser nearest_interp_fuser("nearest_interp"); - nearest_interp_fuser(graph.get()); + fusion::InterpolateFuser2 interp_fuser2(type_); + interp_fuser2(graph.get()); + } } } // namespace mir diff --git a/lite/core/mir/fusion/interpolate_fuser.cc b/lite/core/mir/fusion/interpolate_fuser.cc index 458ef76cb4432dd54678824b1a179e554bcbbf78..ebbd63f8613fb6d62b580004cf7522683db08e38 100644 --- a/lite/core/mir/fusion/interpolate_fuser.cc +++ b/lite/core/mir/fusion/interpolate_fuser.cc @@ -22,6 +22,9 @@ namespace mir { namespace fusion { void InterpolateFuser::BuildPattern() { + // type1 fill_constant --> + // x --> shape --> slice --> cast --> elementwise_mul --> interpolate + // `--------------------------------------------------> auto* x = VarNode("x"); auto* shape = OpNode("shape", "shape")->AsIntermediate(); auto* shape_out = VarNode("shape_out")->AsIntermediate(); @@ -89,6 +92,64 @@ cpp::OpDesc InterpolateFuser::GenOpDesc(const key2nodes_t& matched) { return op_desc; } +void InterpolateFuser2::BuildPattern() { + // type2 x --> shape --> slice --> cast --> scale --> interpolate + // `----------------------------------------> + auto* x = VarNode("x"); + auto* shape = OpNode("shape", "shape")->AsIntermediate(); + auto* shape_out = VarNode("shape_out")->AsIntermediate(); + auto* slice = OpNode("slice", "slice") + ->assert_op_attr_satisfied>( + "axes", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 0; + }) + ->assert_op_attr_satisfied>( + "starts", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 2; + }) + ->assert_op_attr_satisfied>( + "ends", + [](const std::vector& attr) { + return attr.size() == 1 && attr[0] == 4; + }) + ->AsIntermediate(); + auto* slice_out = VarNode("slice_out")->AsIntermediate(); + auto* cast = OpNode("cast", "cast")->AsIntermediate(); + auto* cast_out = VarNode("cast_out")->AsIntermediate(); + auto* scale = OpNode("scale", "scale")->AsIntermediate(); + auto* scale_out = VarNode("scale_out")->AsIntermediate(); + auto* interpolate = OpNode("interpolate", interp_type_)->AsIntermediate(); + auto* interpolate_out = VarNode("interpolate_out"); + + // create topology. + *x >> *shape >> *shape_out >> *slice >> *slice_out >> *cast >> *cast_out >> + *scale >> *scale_out >> *interpolate >> *interpolate_out; + *x >> *interpolate; +} + +void InterpolateFuser2::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto interp_op = LiteOpRegistry::Global().Create(interp_type_); + auto interp_old = matched.at("interpolate")->stmt()->op(); + auto* scope = interp_old->scope(); + auto& valid_places = interp_old->valid_places(); + interp_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(interp_op, valid_places); + + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("interpolate_out")); +} + +cpp::OpDesc InterpolateFuser2::GenOpDesc(const key2nodes_t& matched) { + auto op_desc = *matched.at("interpolate")->stmt()->op_info(); + op_desc.SetInput("OutSize", {}); + return op_desc; +} + } // namespace fusion } // namespace mir } // namespace lite diff --git a/lite/core/mir/fusion/interpolate_fuser.h b/lite/core/mir/fusion/interpolate_fuser.h index 51f5655e76749ea4de6e1789f499862f2ac46437..96fa6b260190114d41fe6308217fef05de21bd44 100644 --- a/lite/core/mir/fusion/interpolate_fuser.h +++ b/lite/core/mir/fusion/interpolate_fuser.h @@ -36,6 +36,19 @@ class InterpolateFuser : public FuseBase { std::string interp_type_; }; +class InterpolateFuser2 : public FuseBase { + public: + explicit InterpolateFuser2(const std::string& interp_type) + : interp_type_(interp_type) {} + + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string interp_type_; +}; + } // namespace fusion } // namespace mir } // namespace lite