From 972a8d54890060535fd6f93628bddf80d26f0fc4 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 20 May 2022 20:04:43 +0800 Subject: [PATCH] feat(ci): add model compatibility check in ci GitOrigin-RevId: d17c21ab4fd43005ecfc5a27c7488c6ef866445e --- src/opr/impl/dnn/dnn.sereg.h | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 54130d256..cef8825b1 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -284,6 +284,34 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Param = opr::RNNCellForward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::RNNCellForward::make( + i[0], i[1], i[2], i[3], i[4], i[5], param, config) + .node() + ->owner_opr(); + } +}; + +template <> +struct OprMaker { + using Param = opr::LSTMCellForward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + return opr::LSTMCellForward::make( + i[0], i[1], i[2], i[3], i[4], i[5], i[6], param, config) + .node() + ->owner_opr(); + } +}; + template <> struct OprMaker { using Param = opr::RNNBackward::Param; @@ -718,6 +746,8 @@ MGB_SEREG_OPR(LSQ, 4); MGB_SEREG_OPR(LSQBackward, 5); MGB_SEREG_OPR(LayerNorm, 0); MGB_SEREG_OPR(LayerNormBackward, 0); +MGB_SEREG_OPR(RNNCellForward, 6); +MGB_SEREG_OPR(LSTMCellForward, 7); MGB_SEREG_OPR(RNNForward, 3); MGB_SEREG_OPR(RNNBackward, 7); MGB_SEREG_OPR(LSTMForward, 4); -- GitLab