提交 aba0acc7 编写于 作者: M Megvii Engine Team

fix(sdk): add AssertEqual opr, fix dump_with_testcase_mge

GitOrigin-RevId: 6f797570b674255418b04f2c3bd8d2e19c0e0d04
上级 dd9f54cd
......@@ -418,6 +418,22 @@ OP_TRAIT_REG(Identity, Identity)
.fallback();
}} // identity
namespace { namespace assert_equal {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const AssertEqual&>(def);
mgb_assert(inputs.size() == 2);
return opr::AssertEqual::make(inputs[0],inputs[1],op.param());
}
OP_TRAIT_REG(AssertEqual, AssertEqual)
.apply_on_var_node(apply_on_var_node)
.fallback();
}}
namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
......
......@@ -19,9 +19,9 @@ import megengine.core._imperative_rt as rt
import megengine.core.tensor.megbrain_graph as G
from megengine.utils import comp_graph_tools as cgtools
from megengine.core.ops import builtin
from megengine.core.tensor.core import apply
from megengine.core._imperative_rt.core2 import apply
from megengine.core.tensor.megbrain_graph import VarNode
from megengine.core.tensor.raw_tensor import as_raw_tensor
from megengine import tensor
logger = mge.get_logger(__name__)
......@@ -195,7 +195,7 @@ def make_feeds(args):
func = cg_rt.compile([node.outputs[0] for node in output_nodes])
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
return tensor(value, dtype=dtype, device=device)._dev_tensor()
def calculate(*args, **kwargs):
output_val = []
......@@ -268,8 +268,8 @@ def make_feeds(args):
def assert_equal(expect, real, **kwargs):
op = builtin.AssertEqual(**kwargs)
(res,) = apply(op, expect, real)
return res
(res,) = G.apply_normal_varnode(op, expect, real)
return G.VarNode(res)
verbose = not args.silent
......@@ -509,7 +509,7 @@ def main():
)
def make_dev_tensor(value, dtype=None, device=None):
return as_raw_tensor(value, dtype=dtype, device=device)._dev_tensor()
return tensor(value, dtype=dtype, device=device)._dev_tensor()
for testcase in feeds["testcases"]:
assert isinstance(testcase, dict)
......
......@@ -231,6 +231,7 @@ def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册