diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 7caefe90e76c55f9bb7c9f8ffd360b6a161aab26..98fac59e52a0a911a2fcd4cf9e8e61d336d5a2c0 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1106,6 +1106,7 @@ def matmul( transposeB=transpose_b, compute_mode=compute_mode, format=format, + strategy=get_conv_execution_strategy(), ) else: op = builtin.MatrixMul( @@ -1113,6 +1114,7 @@ def matmul( transposeB=transpose_b, compute_mode=compute_mode, format=format, + strategy=get_conv_execution_strategy(), ) (result,) = apply(op, inp1, inp2) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index f609e1228f12f4dbe49f365d83332da03699f663..15392a092979f977b5c35ea01ab40de48d4ac860 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -243,7 +243,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param()); + return opr::MatrixMul::make(inputs[0], inputs[1], matmul.param(), + matmul.policy()); } OP_TRAIT_REG(MatrixMul, MatrixMul) .apply_on_var_node(apply_on_var_node) @@ -256,7 +257,8 @@ auto apply_on_var_node( const VarNodeArray& inputs) { auto&& matmul = static_cast(def); mgb_assert(inputs.size() == 2); - return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param()); + return opr::BatchedMatrixMul::make(inputs[0], inputs[1], matmul.param(), + matmul.policy()); } OP_TRAIT_REG(BatchedMatrixMul, BatchedMatrixMul) .apply_on_var_node(apply_on_var_node) @@ -428,7 +430,7 @@ auto apply_on_var_node( return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); } - + OP_TRAIT_REG(AssertEqual, AssertEqual) .apply_on_var_node(apply_on_var_node) .fallback(); diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 7c3834de1bff9d08a624919a8be8cd57facecbb3..7f06a49ac65413da74b95d6db08180ffb2f5e2d6 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -34,9 +34,9 @@ def TypeCvt: MgbHashableOp<"TypeCvt", [], [NoSideEffect]> { let results = (outs AnyType); } -def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam]>; +def MatrixMul: MgbHashableOp<"MatrixMul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; -def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam]>; +def BatchedMatrixMul: MgbHashableOp<"BatchedMatmul", [MatrixMulParam, ExecutionPolicyParamBase<"policy">]>; def Dot: MgbHashableOp<"Dot", [EmptyParam]>;