提交 d62dabe5 编写于 作者: M Megvii Engine Team

fix(mgb): fix matmul model compat in flatbuffer

GitOrigin-RevId: 2effad8e5fa8f371bac5b4e28b2110b3a089463e
上级 05692d05
......@@ -13,29 +13,12 @@ decl_opr('BatchedMatrixMul',
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)')
decl_opr('MatrixMul',
pyname='matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='matrix multiplication',
version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
pyname='batched_matrix_mul_v2',
inputs=['opr0', 'opr1'],
params='MatrixMul',
desc='batched matrix multiplication: input shapes should be '
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=2, has_out_dtype=True)
decl_opr('MatrixMul',
inputs=['opr0', 'opr1'],
params=[('param', 'MatrixMul'),
('execution_polity', 'ExecutionPolicy')],
desc='matrix multiplication',
version=3, has_out_dtype=True)
version=2, has_out_dtype=True)
decl_opr('BatchedMatrixMul',
inputs=['opr0', 'opr1'],
......@@ -45,7 +28,7 @@ decl_opr('BatchedMatrixMul',
'(n, a, b) and (n, b, c) (assuming transposeA and transeposeB are '
'False); then :math:`n` independent matrix multiplications would be '
'performed and output shape is (n, a, c)',
version=3, has_out_dtype=True)
version=2, has_out_dtype=True)
decl_opr('Dot',
inputs=['opr0', 'opr1'],
......
......@@ -51,7 +51,6 @@ struct MatrixMulLoadDumpImpl {
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
auto&& opr = opr_.cast_final_safe<Opr>();
ctx.write_param<megdnn::param::MatrixMul>(opr.param());
ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
}
static VarNode* make(const cg::VarNodeArray& inputs,
......@@ -68,9 +67,7 @@ struct MatrixMulLoadDumpImpl {
const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
auto param = ctx.read_param<megdnn::param::MatrixMul>();
auto execution_policy =
ctx.read_param<megdnn::param::ExecutionPolicy>();
return make(inputs, param, execution_policy, config)->owner_opr();
return make(inputs, param, {}, config)->owner_opr();
}
};
......@@ -90,10 +87,10 @@ struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
namespace opr {
using MatrixMulV3 = MatrixMul;
using BatchedMatrixMulV3 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV3, 2);
MGB_SEREG_OPR(BatchedMatrixMulV3, 2);
using MatrixMulV2 = MatrixMul;
using BatchedMatrixMulV2 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV2, 2);
MGB_SEREG_OPR(BatchedMatrixMulV2, 2);
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册