diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index c04fe09db5e6df3132540ed5856ccf59000c71eb..bf798101c65d72eb9569ebd877404e7632037383 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -418,6 +418,22 @@ OP_TRAIT_REG(Identity, Identity) .fallback(); }} // identity +namespace { namespace assert_equal { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); + + } + +OP_TRAIT_REG(AssertEqual, AssertEqual) + .apply_on_var_node(apply_on_var_node) + .fallback(); + +}} + namespace { namespace uniform_rng { auto apply_on_var_node( const OpDef& def, diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index c6881226c37e6d593138ece8d4e65cd5ad9694a8..04c256a09663bdf21021d2508bd7b46001c8c061 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -19,9 +19,9 @@ import megengine.core._imperative_rt as rt import megengine.core.tensor.megbrain_graph as G from megengine.utils import comp_graph_tools as cgtools from megengine.core.ops import builtin -from megengine.core.tensor.core import apply +from megengine.core._imperative_rt.core2 import apply from megengine.core.tensor.megbrain_graph import VarNode -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine import tensor logger = mge.get_logger(__name__) @@ -195,7 +195,7 @@ def make_feeds(args): func = cg_rt.compile([node.outputs[0] for node in output_nodes]) def make_dev_tensor(value, dtype=None, device=None): - return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + return tensor(value, dtype=dtype, device=device)._dev_tensor() def calculate(*args, **kwargs): output_val = [] @@ -268,8 +268,8 @@ def make_feeds(args): def assert_equal(expect, real, **kwargs): op = builtin.AssertEqual(**kwargs) - (res,) = apply(op, expect, real) - return res + (res,) = G.apply_normal_varnode(op, expect, real) + return G.VarNode(res) verbose = not args.silent @@ -509,7 +509,7 @@ def main(): ) def make_dev_tensor(value, dtype=None, device=None): - return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + return tensor(value, dtype=dtype, device=device)._dev_tensor() for testcase in feeds["testcases"]: assert isinstance(testcase, dict) diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 9d8a54d0c133c47e0ffa115ae19bc96b9d8f775a..1b614157676900ffc4363f0c800567b0299a58d3 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -231,6 +231,7 @@ def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">; def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; +def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { let extraArguments = (ins MgbDTypeAttr:$dtype