From aba0acc7976e250686d8ce7c526283e76ffffa73 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 31 Dec 2020 10:51:47 +0800 Subject: [PATCH] fix(sdk): add AssertEqual opr, fix dump_with_testcase_mge GitOrigin-RevId: 6f797570b674255418b04f2c3bd8d2e19c0e0d04 --- imperative/src/impl/ops/specializations.cpp | 16 ++++++++++++++++ sdk/load-and-run/dump_with_testcase_mge.py | 12 ++++++------ src/core/include/megbrain/ir/ops.td | 1 + 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index c04fe09d..bf798101 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -418,6 +418,22 @@ OP_TRAIT_REG(Identity, Identity) .fallback(); }} // identity +namespace { namespace assert_equal { +auto apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { + auto&& op = static_cast(def); + mgb_assert(inputs.size() == 2); + return opr::AssertEqual::make(inputs[0],inputs[1],op.param()); + + } + +OP_TRAIT_REG(AssertEqual, AssertEqual) + .apply_on_var_node(apply_on_var_node) + .fallback(); + +}} + namespace { namespace uniform_rng { auto apply_on_var_node( const OpDef& def, diff --git a/sdk/load-and-run/dump_with_testcase_mge.py b/sdk/load-and-run/dump_with_testcase_mge.py index c6881226..04c256a0 100755 --- a/sdk/load-and-run/dump_with_testcase_mge.py +++ b/sdk/load-and-run/dump_with_testcase_mge.py @@ -19,9 +19,9 @@ import megengine.core._imperative_rt as rt import megengine.core.tensor.megbrain_graph as G from megengine.utils import comp_graph_tools as cgtools from megengine.core.ops import builtin -from megengine.core.tensor.core import apply +from megengine.core._imperative_rt.core2 import apply from megengine.core.tensor.megbrain_graph import VarNode -from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine import tensor logger = mge.get_logger(__name__) @@ -195,7 +195,7 @@ def make_feeds(args): func = cg_rt.compile([node.outputs[0] for node in output_nodes]) def make_dev_tensor(value, dtype=None, device=None): - return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + return tensor(value, dtype=dtype, device=device)._dev_tensor() def calculate(*args, **kwargs): output_val = [] @@ -268,8 +268,8 @@ def make_feeds(args): def assert_equal(expect, real, **kwargs): op = builtin.AssertEqual(**kwargs) - (res,) = apply(op, expect, real) - return res + (res,) = G.apply_normal_varnode(op, expect, real) + return G.VarNode(res) verbose = not args.silent @@ -509,7 +509,7 @@ def main(): ) def make_dev_tensor(value, dtype=None, device=None): - return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor() + return tensor(value, dtype=dtype, device=device)._dev_tensor() for testcase in feeds["testcases"]: assert isinstance(testcase, dict) diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 9d8a54d0..1b614157 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -231,6 +231,7 @@ def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">; def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">; def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>; +def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>; def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> { let extraArguments = (ins MgbDTypeAttr:$dtype -- GitLab