Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3dbac4f4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3dbac4f4
编写于
8月 24, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): add atlas_subgraph module
GitOrigin-RevId: 11530383c0a31f4648ed89d3070b2dab178ea5b2
上级
00ef6772
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
69 addition
and
4 deletion
+69
-4
python_module/megengine/functional/external.py
python_module/megengine/functional/external.py
+11
-0
python_module/megengine/module/external.py
python_module/megengine/module/external.py
+28
-1
python_module/test/unit/module/AtlasRuntimeOprTest.basic.om
python_module/test/unit/module/AtlasRuntimeOprTest.basic.om
+0
-0
python_module/test/unit/module/test_external.py
python_module/test/unit/module/test_external.py
+30
-3
未找到文件。
python_module/megengine/functional/external.py
浏览文件 @
3dbac4f4
...
...
@@ -34,6 +34,17 @@ def cambricon_subgraph(
)
@
wrap_io_tensor
def
atlas_subgraph
(
inputs
:
List
[
Tensor
],
data
:
bytes
)
->
List
[
Tensor
]:
"""Load a serialized Atlas subgraph (i.e. om model) and
execute the operations defined in the subgraph.
:param inputs: List of input tensors of the subgraph.
:param data: The serialized subgraph.
"""
return
mgb
.
opr
.
atlas_runtime
(
tuple
(
map
(
lambda
x
:
x
.
_symvar
,
inputs
)),
data
)
@
wrap_io_tensor
def
extern_opr_subgraph
(
inputs
,
output_shapes
:
List
[
tuple
],
dump_name
:
str
,
dump_data
:
bytes
,
...
...
python_module/megengine/module/external.py
浏览文件 @
3dbac4f4
...
...
@@ -8,7 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
from
..functional.external
import
cambricon_subgraph
,
extern_opr_subgraph
from
..functional.external
import
(
atlas_subgraph
,
cambricon_subgraph
,
extern_opr_subgraph
,
)
from
.module
import
Module
...
...
@@ -41,6 +45,29 @@ class CambriconSubgraph(Module):
return
outputs
class
AtlasSubgraph
(
Module
):
r
"""Load a serialized Atlas subgraph.
See :func:`~.atlas_subgraph` for more details.
"""
def
__init__
(
self
,
data
):
super
(
AtlasSubgraph
,
self
).
__init__
()
self
.
_data
=
data
@
property
def
data
(
self
):
return
self
.
_data
.
tobytes
()
@
data
.
setter
def
data
(
self
,
val
):
self
.
_data
=
np
.
frombuffer
(
val
,
dtype
=
np
.
uint8
)
def
forward
(
self
,
inputs
):
outputs
=
atlas_subgraph
(
inputs
,
self
.
_data
)
return
outputs
class
ExternOprSubgraph
(
Module
):
r
"""Load a serialized extern opr subgraph.
"""
...
...
python_module/test/unit/module/AtlasRuntimeOprTest.basic.om
0 → 100644
浏览文件 @
3dbac4f4
文件已添加
python_module/test/unit/module/test_external.py
浏览文件 @
3dbac4f4
...
...
@@ -13,10 +13,10 @@ import numpy as np
import
megengine
as
mge
from
megengine
import
tensor
from
megengine.module
import
Module
from
megengine.module.external
import
CambriconSubgraph
from
megengine.module.external
import
AtlasSubgraph
,
CambriconSubgraph
class
My
Module
(
Module
):
class
Cambricon
Module
(
Module
):
def
__init__
(
self
,
data
):
super
().
__init__
()
self
.
cambricon
=
CambriconSubgraph
(
data
,
"subnet0"
,
True
)
...
...
@@ -31,7 +31,7 @@ def test_cambricon_module():
model
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model
)
with
open
(
model
,
"rb"
)
as
f
:
data
=
f
.
read
()
m
=
My
Module
(
data
)
m
=
Cambricon
Module
(
data
)
inputs
=
[]
inputs
.
append
(
tensor
(
dtype
=
np
.
float16
,
device
=
"cambricon0"
))
inputs
[
0
].
set_value
(
np
.
random
.
normal
(
size
=
(
1
,
64
,
32
,
32
)).
astype
(
np
.
float16
))
...
...
@@ -41,3 +41,30 @@ def test_cambricon_module():
return
pred
pred
=
inference
(
inputs
)
class
AtlasModule
(
Module
):
def
__init__
(
self
,
data
):
super
().
__init__
()
self
.
atlas
=
AtlasSubgraph
(
data
)
def
forward
(
self
,
inputs
):
out
=
self
.
atlas
(
inputs
)
return
out
def
test_atlas_module
():
model
=
"AtlasRuntimeOprTest.basic.om"
model
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
model
)
with
open
(
model
,
"rb"
)
as
f
:
data
=
f
.
read
()
m
=
AtlasModule
(
data
)
inputs
=
[]
inputs
.
append
(
tensor
(
dtype
=
np
.
float32
,
device
=
"atlas0"
))
inputs
[
0
].
set_value
(
np
.
random
.
normal
(
size
=
(
4
,
3
,
16
,
16
)).
astype
(
np
.
float32
))
def
inference
(
inps
):
pred
=
m
(
inps
)
return
pred
pred
=
inference
(
inputs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录