From d62dabe5f8eebeafdf58afe1fdb4fde710dd9ddf Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sat, 6 Feb 2021 01:09:26 +0800 Subject: [PATCH] fix(mgb): fix matmul model compat in flatbuffer GitOrigin-RevId: 2effad8e5fa8f371bac5b4e28b2110b3a089463e --- src/opr/impl/blas.oprdecl | 21 ++------------------- src/opr/impl/blas.sereg.h | 13 +++++-------- 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/src/opr/impl/blas.oprdecl b/src/opr/impl/blas.oprdecl index cfe629330..bc3a95f20 100644 --- a/src/opr/impl/blas.oprdecl +++ b/src/opr/impl/blas.oprdecl @@ -13,29 +13,12 @@ decl_opr('BatchedMatrixMul', 'False); then :math:`n` independent matrix multiplications would be ' 'performed and output shape is (n, a, c)') -decl_opr('MatrixMul', - pyname='matrix_mul_v2', - inputs=['opr0', 'opr1'], - params='MatrixMul', - desc='matrix multiplication', - version=2, has_out_dtype=True) - -decl_opr('BatchedMatrixMul', - pyname='batched_matrix_mul_v2', - inputs=['opr0', 'opr1'], - params='MatrixMul', - desc='batched matrix multiplication: input shapes should be ' - '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' - 'False); then :math:`n` independent matrix multiplications would be ' - 'performed and output shape is (n, a, c)', - version=2, has_out_dtype=True) - decl_opr('MatrixMul', inputs=['opr0', 'opr1'], params=[('param', 'MatrixMul'), ('execution_polity', 'ExecutionPolicy')], desc='matrix multiplication', - version=3, has_out_dtype=True) + version=2, has_out_dtype=True) decl_opr('BatchedMatrixMul', inputs=['opr0', 'opr1'], @@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul', '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' 'False); then :math:`n` independent matrix multiplications would be ' 'performed and output shape is (n, a, c)', - version=3, has_out_dtype=True) + version=2, has_out_dtype=True) decl_opr('Dot', inputs=['opr0', 'opr1'], diff --git a/src/opr/impl/blas.sereg.h b/src/opr/impl/blas.sereg.h index 04813fc9b..e959d8889 100644 --- a/src/opr/impl/blas.sereg.h +++ b/src/opr/impl/blas.sereg.h @@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); - ctx.write_param(opr.execution_policy()); } static VarNode* make(const cg::VarNodeArray& inputs, @@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl { const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); - auto execution_policy = - ctx.read_param(); - return make(inputs, param, execution_policy, config)->owner_opr(); + return make(inputs, param, {}, config)->owner_opr(); } }; @@ -90,10 +87,10 @@ struct OprLoadDumpImpl namespace opr { -using MatrixMulV3 = MatrixMul; -using BatchedMatrixMulV3 = BatchedMatrixMul; -MGB_SEREG_OPR(MatrixMulV3, 2); -MGB_SEREG_OPR(BatchedMatrixMulV3, 2); +using MatrixMulV2 = MatrixMul; +using BatchedMatrixMulV2 = BatchedMatrixMul; +MGB_SEREG_OPR(MatrixMulV2, 2); +MGB_SEREG_OPR(BatchedMatrixMulV2, 2); MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(SVD, 1); -- GitLab