提交 e9a7be46 编写于 作者: M Megvii Engine Team

feat(mgb): adapt to imperative runtime

GitOrigin-RevId: 3bccc17b6268e45f4b0f5a022be908420eb9cfc5
上级 5063a206
......@@ -1106,6 +1106,7 @@ def matmul(
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
else:
op = builtin.MatrixMul(
......@@ -1113,6 +1114,7 @@ def matmul(
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
strategy=get_conv_execution_strategy(),
)
(result,) = apply(op, inp1, inp2)
......
......@@ -243,7 +243,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param());
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
}
OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node)
......@@ -256,7 +257,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2);
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param());
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
}
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node)
......
......@@ -34,9 +34,9 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let results = (outs AnyType);
}
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>;
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册