提交 e9a7be46 编写于 作者: M Megvii Engine Team

feat(mgb): adapt to imperative runtime

GitOrigin-RevId: 3bccc17b6268e45f4b0f5a022be908420eb9cfc5
上级 5063a206
...@@ -1106,6 +1106,7 @@ def matmul( ...@@ -1106,6 +1106,7 @@ def matmul(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=get_conv_execution_strategy(),
) )
else: else:
op = builtin.MatrixMul( op = builtin.MatrixMul(
...@@ -1113,6 +1114,7 @@ def matmul( ...@@ -1113,6 +1114,7 @@ def matmul(
transposeB=transpose_b, transposeB=transpose_b,
compute_mode=compute_mode, compute_mode=compute_mode,
format=format, format=format,
strategy=get_conv_execution_strategy(),
) )
(result,) = apply(op, inp1, inp2) (result,) = apply(op, inp1, inp2)
......
...@@ -243,7 +243,8 @@ auto apply_on_var_node( ...@@ -243,7 +243,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& matmul = static_cast<const MatrixMul&>(def); auto&& matmul = static_cast<const MatrixMul&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
} }
OP_TRAIT_REG(MatrixMul, MatrixMul) OP_TRAIT_REG(MatrixMul, MatrixMul)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
...@@ -256,7 +257,8 @@ auto apply_on_var_node( ...@@ -256,7 +257,8 @@ auto apply_on_var_node(
const VarNodeArray& inputs) { const VarNodeArray& inputs) {
auto&& matmul = static_cast<const BatchedMatrixMul&>(def); auto&& matmul = static_cast<const BatchedMatrixMul&>(def);
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(),
matmul.policy());
} }
OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
...@@ -428,7 +430,7 @@ auto apply_on_var_node( ...@@ -428,7 +430,7 @@ auto apply_on_var_node(
return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); return opr::AssertEqual::make(inputs[0],inputs[1],op.param());
} }
OP_TRAIT_REG(AssertEqual, AssertEqual) OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node) .apply_on_var_node(apply_on_var_node)
.fallback(); .fallback();
......
...@@ -34,9 +34,9 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { ...@@ -34,9 +34,9 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> {
let results = (outs AnyType); let results = (outs AnyType);
} }
def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>; def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>; def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>;
def Dot: MgbHashableOp<"Dot", [EmptyParam]>; def Dot: MgbHashableOp<"Dot", [EmptyParam]>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册