提交 d62dabe5 编写于 作者: M Megvii Engine Team

fix(mgb): fix matmul model compat in flatbuffer

GitOrigin-RevId: 2effad8e5fa8f371bac5b4e28b2110b3a089463e
上级 05692d05
...@@ -13,29 +13,12 @@ decl_opr('BatchedMatrixMul', ...@@ -13,29 +13,12 @@ decl_opr('BatchedMatrixMul',
'False); then :math:`n` independent matrix multiplications would be ' 'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)') 'performed and output shape is (n, a, c)')
decl_opr('MatrixMul',
pyname='matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='matrix multiplication',
version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
pyname='batched_matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='batched matrix multiplication: input shapes should be '
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=2, has_out_dtype=True)
decl_opr('MatrixMul', decl_opr('MatrixMul',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'), params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')], ('execution_polity', 'ExecutionPolicy')],
desc='matrix multiplication', desc='matrix multiplication',
version=3, has_out_dtype=True) version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul', decl_opr('BatchedMatrixMul',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
...@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul', ...@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul',
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are ' '(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be ' 'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)', 'performed and output shape is (n, a, c)',
version=3, has_out_dtype=True) version=2, has_out_dtype=True)
decl_opr('Dot', decl_opr('Dot',
inputs=['opr0', 'opr1'], inputs=['opr0', 'opr1'],
......
...@@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl { ...@@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>(); auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<megdnn::param::MatrixMul>(opr.param()); ctx.write_param<megdnn::param::MatrixMul>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
} }
static VarNode* make(const cg::VarNodeArray& inputs, static VarNode* make(const cg::VarNodeArray& inputs,
...@@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl { ...@@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl {
const cg::VarNodeArray& inputs, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::MatrixMul>(); auto param = ctx.read_param<megdnn::param::MatrixMul>();
auto execution_policy = return make(inputs, param, {}, config)->owner_opr();
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
} }
}; };
...@@ -90,10 +87,10 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2> ...@@ -90,10 +87,10 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
namespace opr { namespace opr {
using MatrixMulV3 = MatrixMul; using MatrixMulV2 = MatrixMul;
using BatchedMatrixMulV3 = BatchedMatrixMul; using BatchedMatrixMulV2 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV3, 2); MGB_SEREG_OPR(MatrixMulV2, 2);
MGB_SEREG_OPR(BatchedMatrixMulV3, 2); MGB_SEREG_OPR(BatchedMatrixMulV2, 2);
MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1); MGB_SEREG_OPR(SVD, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册