Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
76dbaa27
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
76dbaa27
编写于
9月 03, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/imperative): add name, make_h2d, dump_graph to graph runtime
GitOrigin-RevId: b8681a31a81502f12340dafd56c1b4d466b22020
上级
7336b306
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
52 addition
and
2 deletion
+52
-2
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+24
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+28
-2
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
76dbaa27
...
...
@@ -78,6 +78,14 @@ class Graph(_imperative_rt.ComputingGraph):
opnode
=
InputNode
(
*
args
,
device
=
device
,
dtype
=
dtype
,
shape
=
shape
,
graph
=
self
)
return
opnode
.
outputs
[
0
]
def
make_h2d
(
self
,
*
,
dtype
,
device
):
device
=
as_device
(
device
).
to_c
()
return
self
.
_wrap
(
_imperative_rt
.
make_h2d
(
self
,
device
,
dtype
))
def
dump
(
*
args
):
return
_imperative_rt
.
dump_graph
([
i
.
_node
for
i
in
args
])
class
VarNode
(
TensorBase
):
def
__init__
(
self
,
node
:
_imperative_rt
.
VarNode
):
...
...
@@ -92,6 +100,14 @@ class VarNode(TensorBase):
def
op
(
self
):
return
self
.
graph
.
_wrap
(
self
.
_node
.
owner
)
@
property
def
name
(
self
):
return
self
.
_node
.
name
@
name
.
setter
def
name
(
self
,
name
):
self
.
_node
.
name
=
name
@
property
def
dtype
(
self
):
return
self
.
_node
.
dtype
...
...
@@ -118,6 +134,14 @@ class OpNode:
def
graph
(
self
)
->
Graph
:
return
self
.
_node
.
graph
@
property
def
name
(
self
):
return
self
.
_node
.
name
@
name
.
setter
def
name
(
self
,
name
):
self
.
_node
.
name
=
name
@
property
def
inputs
(
self
):
return
tuple
(
map
(
self
.
graph
.
_wrap
,
self
.
_node
.
inputs
))
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
76dbaa27
...
...
@@ -11,6 +11,7 @@
#include "./graph_rt.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
...
...
@@ -47,7 +48,8 @@ void init_graph_rt(py::module m) {
py
::
class_
<
cg
::
VarNode
,
GraphNodePtr
<
cg
::
VarNode
>>
(
m
,
"VarNode"
)
.
def_property_readonly
(
"owner"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_opr
();})
.
def_property_readonly
(
"graph"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
owner_graph
();})
.
def_property_readonly
(
"name"
,
py
::
overload_cast
<>
(
&
VarNode
::
name
,
py
::
const_
))
.
def_property
(
"name"
,
py
::
overload_cast
<>
(
&
VarNode
::
name
,
py
::
const_
),
py
::
overload_cast
<
std
::
string
>
(
&
VarNode
::
name
))
.
def_property_readonly
(
"dtype"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
dtype
();})
.
def_property_readonly
(
"comp_node"
,
[](
cg
::
VarNode
*
v
)
{
return
v
->
comp_node
();})
.
def_property_readonly
(
"shape"
,
[](
cg
::
VarNode
*
v
)
->
const
TensorShape
*
{
...
...
@@ -75,7 +77,8 @@ void init_graph_rt(py::module m) {
py
::
class_
<
cg
::
OperatorNodeBase
,
GraphNodePtr
<
cg
::
OperatorNodeBase
>>
(
m
,
"OperatorNode"
)
.
def_property_readonly
(
"graph"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
opr
->
owner_graph
();})
.
def_property_readonly
(
"name"
,
py
::
overload_cast
<>
(
&
cg
::
OperatorNodeBase
::
name
,
py
::
const_
))
.
def_property
(
"name"
,
py
::
overload_cast
<>
(
&
cg
::
OperatorNodeBase
::
name
,
py
::
const_
),
py
::
overload_cast
<
std
::
string
>
(
&
cg
::
OperatorNodeBase
::
name
))
.
def_property_readonly
(
"inputs"
,
[](
cg
::
OperatorNodeBase
*
opr
)
{
return
to_tuple
(
opr
->
input
());
})
...
...
@@ -99,6 +102,15 @@ void init_graph_rt(py::module m) {
})
.
def_property_readonly
(
"options"
,
py
::
overload_cast
<>
(
&
cg
::
ComputingGraph
::
options
));
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
)
{
using
namespace
mgb
::
serialization
;
std
::
vector
<
uint8_t
>
buf
;
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_vector_proxy
(
&
buf
));
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
dumper
->
dump
(
symvars
);
return
py
::
bytes
(
reinterpret_cast
<
const
char
*>
(
&
buf
[
0
]),
buf
.
size
());
});
#define CURRENT_CLASS cg::ComputingGraph::Options
auto
PyComputingGraphOptions
=
py
::
class_
<
cg
::
ComputingGraph
::
Options
>
(
PyComputingGraph
,
"Options"
)
...
...
@@ -198,6 +210,20 @@ void init_graph_rt(py::module m) {
return
opr
::
ImmutableTensor
::
make
(
*
graph
,
hv
,
OperatorNodeConfig
(
cn
)).
node
();
});
m
.
def
(
"make_h2d"
,
[](
cg
::
ComputingGraph
&
graph
,
CompNode
cn
,
DType
dtype
,
std
::
optional
<
std
::
string
>
name
)
{
if
(
!
cn
.
valid
())
{
throw
py
::
type_error
(
"device must be valid"
);
}
if
(
!
dtype
.
valid
())
{
throw
py
::
type_error
(
"dtype must be valid"
);
}
OperatorNodeConfig
config
;
if
(
name
)
{
config
.
name
(
*
name
);
}
return
opr
::
Host2DeviceCopy
::
make
(
graph
,
std
::
make_shared
<
HostTensorND
>
(
cn
,
dtype
),
config
).
node
();
},
py
::
arg
(),
py
::
arg
(),
py
::
arg
(),
py
::
arg
()
=
py
::
none
());
m
.
def
(
"input_callback"
,
[
input_callback
](
std
::
function
<
DeviceTensorND
(
void
)
>
callback
,
const
CompNode
&
comp_node
,
const
DType
&
dtype
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录