Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
aba0acc7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
aba0acc7
编写于
12月 31, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(sdk): add AssertEqual opr, fix dump_with_testcase_mge
GitOrigin-RevId: 6f797570b674255418b04f2c3bd8d2e19c0e0d04
上级
dd9f54cd
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
23 addition
and
6 deletion
+23
-6
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+16
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+6
-6
src/core/include/megbrain/ir/ops.td
src/core/include/megbrain/ir/ops.td
+1
-0
未找到文件。
imperative/src/impl/ops/specializations.cpp
浏览文件 @
aba0acc7
...
...
@@ -418,6 +418,22 @@ OP_TRAIT_REG(Identity, Identity)
.
fallback
();
}}
// identity
namespace
{
namespace
assert_equal
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
static_cast
<
const
AssertEqual
&>
(
def
);
mgb_assert
(
inputs
.
size
()
==
2
);
return
opr
::
AssertEqual
::
make
(
inputs
[
0
],
inputs
[
1
],
op
.
param
());
}
OP_TRAIT_REG
(
AssertEqual
,
AssertEqual
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
namespace
{
namespace
uniform_rng
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
aba0acc7
...
...
@@ -19,9 +19,9 @@ import megengine.core._imperative_rt as rt
import
megengine.core.tensor.megbrain_graph
as
G
from
megengine.utils
import
comp_graph_tools
as
cgtools
from
megengine.core.ops
import
builtin
from
megengine.core.
tensor.core
import
apply
from
megengine.core.
_imperative_rt.core2
import
apply
from
megengine.core.tensor.megbrain_graph
import
VarNode
from
megengine
.core.tensor.raw_tensor
import
as_raw_
tensor
from
megengine
import
tensor
logger
=
mge
.
get_logger
(
__name__
)
...
...
@@ -195,7 +195,7 @@ def make_feeds(args):
func
=
cg_rt
.
compile
([
node
.
outputs
[
0
]
for
node
in
output_nodes
])
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
as_raw_
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
def
calculate
(
*
args
,
**
kwargs
):
output_val
=
[]
...
...
@@ -268,8 +268,8 @@ def make_feeds(args):
def
assert_equal
(
expect
,
real
,
**
kwargs
):
op
=
builtin
.
AssertEqual
(
**
kwargs
)
(
res
,)
=
apply
(
op
,
expect
,
real
)
return
res
(
res
,)
=
G
.
apply_normal_varnode
(
op
,
expect
,
real
)
return
G
.
VarNode
(
res
)
verbose
=
not
args
.
silent
...
...
@@ -509,7 +509,7 @@ def main():
)
def
make_dev_tensor
(
value
,
dtype
=
None
,
device
=
None
):
return
as_raw_
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
return
tensor
(
value
,
dtype
=
dtype
,
device
=
device
).
_dev_tensor
()
for
testcase
in
feeds
[
"testcases"
]:
assert
isinstance
(
testcase
,
dict
)
...
...
src/core/include/megbrain/ir/ops.td
浏览文件 @
aba0acc7
...
...
@@ -231,6 +231,7 @@ def BatchedIncrMeshIndexing: FancyIndexingBase<"BatchedIncrMeshIndexing">;
def BatchedSetMeshIndexing: FancyIndexingBase<"BatchedSetMeshIndexing">;
def FakeQuant: MgbHashableOp<"FakeQuant", [FakeQuantParam]>;
def AssertEqual: MgbHashableOp<"AssertEqual",[AssertEqualParam]>;
def ElemwiseMultiType: MgbHashableOp<"ElemwiseMultiType", [ElemwiseMultiTypeParam]> {
let extraArguments = (ins
MgbDTypeAttr:$dtype
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录