Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
4101d5bc
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
4101d5bc
编写于
8月 31, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/imperative): add BackwardGraph.interpret
GitOrigin-RevId: bb3a59380ec937c7fd60daed161d3f41172da972
上级
afddefb6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
55 addition
and
29 deletion
+55
-29
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+8
-0
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+12
-1
imperative/src/impl/ops/backward_graph.cpp
imperative/src/impl/ops/backward_graph.cpp
+4
-28
imperative/src/include/megbrain/imperative/ops/backward_graph.h
...tive/src/include/megbrain/imperative/ops/backward_graph.h
+31
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
4101d5bc
...
...
@@ -12,6 +12,7 @@ import weakref
from
concurrent.futures
import
Future
,
ThreadPoolExecutor
from
..
import
_imperative_rt
from
.._imperative_rt.ops
import
BackwardGraph
from
.._wrap
import
device
as
as_device
from
..ops.builtin
import
OpDef
from
.core
import
OpBase
,
TensorBase
,
apply
...
...
@@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode):
return
_wrap
(
outputs
)
@
apply
.
register
()
def
_
(
op
:
BackwardGraph
,
*
args
:
VarNode
):
assert
args
graph
=
args
[
0
].
graph
return
op
.
interpret
(
lambda
op
,
args
:
apply
(
op
,
*
args
),
graph
.
make_const
,
args
)
def
input_callback
(
callback
,
*
args
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
outputs
=
_imperative_rt
.
input_callback
(
callback
,
as_device
(
device
).
to_c
(),
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
...
...
imperative/python/src/ops.cpp
浏览文件 @
4101d5bc
...
...
@@ -40,6 +40,18 @@ void init_ops(py::module m) {
attr
.
param
.
insert
(
attr
.
param
.
end
(),
s
.
begin
(),
s
.
end
());
});
py
::
class_
<
BackwardGraph
,
std
::
shared_ptr
<
BackwardGraph
>
,
OpDef
>
(
m
,
"BackwardGraph"
)
.
def
(
"interpret"
,
[](
BackwardGraph
&
self
,
py
::
object
pyf
,
py
::
object
pyc
,
const
mgb
::
SmallVector
<
py
::
object
>&
inputs
)
{
auto
f
=
[
pyf
](
OpDef
&
op
,
const
mgb
::
SmallVector
<
py
::
object
>&
inputs
)
{
return
py
::
cast
<
mgb
::
SmallVector
<
py
::
object
>>
(
pyf
(
op
.
copy
(),
inputs
));
};
auto
c
=
[
pyc
](
const
TensorPtr
&
tensor
)
{
return
pyc
(
tensor
->
dev_tensor
());
};
return
self
.
graph
().
interpret
<
py
::
object
>
(
f
,
c
,
inputs
);
});
py
::
class_
<
GetVarShape
,
std
::
shared_ptr
<
GetVarShape
>
,
OpDef
>
(
m
,
"GetVarShape"
)
.
def
(
py
::
init
());
...
...
@@ -98,7 +110,6 @@ void init_ops(py::module m) {
.
def
(
py
::
init
<>
())
.
def_readwrite
(
"offsets"
,
&
ParamPackConcat
::
offsets
);
py
::
class_
<
BackwardGraph
,
std
::
shared_ptr
<
BackwardGraph
>
,
OpDef
>
(
m
,
"BackwardGraph"
);
py
::
class_
<
CondTake
,
std
::
shared_ptr
<
CondTake
>
,
OpDef
>
(
m
,
"CondTake"
)
.
def
(
py
::
init
<>
());
...
...
imperative/src/impl/ops/backward_graph.cpp
浏览文件 @
4101d5bc
...
...
@@ -18,34 +18,10 @@ namespace imperative {
SmallVector
<
TensorPtr
>
BackwardGraph
::
InternalGraph
::
apply
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
{
ThinHashMap
<
size_t
,
TensorPtr
>
node2tensor
;
auto
&&
input_nodes
=
this
->
inputs
;
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
node2tensor
[
input_nodes
[
i
]]
=
inputs
[
i
];
}
for
(
auto
&&
i
:
constants
)
{
node2tensor
[
i
.
first
]
=
i
.
second
;
}
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
auto
&&
expr
=
exprs
[
i
];
SmallVector
<
TensorPtr
>
inputs
;
for
(
auto
&&
in
:
std
::
get
<
1
>
(
expr
))
{
inputs
.
push_back
(
node2tensor
.
at
(
in
));
}
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
std
::
get
<
0
>
(
expr
),
inputs
);
auto
output_nodes
=
std
::
get
<
2
>
(
expr
);
mgb_assert
(
outputs
.
size
()
==
output_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
node2tensor
[
output_nodes
[
i
]]
=
outputs
[
i
];
}
}
SmallVector
<
TensorPtr
>
ret
;
for
(
auto
&&
i
:
outputs
)
{
ret
.
push_back
(
node2tensor
.
at
(
i
));
}
return
ret
;
return
interpret
<
TensorPtr
>
(
&
OpDef
::
apply_on_physical_tensor
,
[](
const
TensorPtr
&
x
)
{
return
x
;},
inputs
);
}
SmallVector
<
LogicalTensorDesc
>
...
...
imperative/src/include/megbrain/imperative/ops/backward_graph.h
浏览文件 @
4101d5bc
...
...
@@ -40,6 +40,37 @@ public:
SmallVector
<
LogicalTensorDesc
>
infer_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
;
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
interpret
(
F
&&
f
,
C
&&
c
,
const
SmallVector
<
T
>&
inputs
)
const
{
ThinHashMap
<
size_t
,
T
>
node2tensor
;
auto
&&
input_nodes
=
this
->
inputs
;
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
node2tensor
[
input_nodes
[
i
]]
=
inputs
[
i
];
}
for
(
auto
&&
i
:
constants
)
{
node2tensor
[
i
.
first
]
=
c
(
i
.
second
);
}
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
auto
&&
expr
=
exprs
[
i
];
SmallVector
<
T
>
inputs
;
for
(
auto
&&
in
:
std
::
get
<
1
>
(
expr
))
{
inputs
.
push_back
(
node2tensor
.
at
(
in
));
}
auto
&&
outputs
=
f
(
*
std
::
get
<
0
>
(
expr
),
std
::
move
(
inputs
));
auto
&&
output_nodes
=
std
::
get
<
2
>
(
expr
);
mgb_assert
(
outputs
.
size
()
==
output_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
node2tensor
[
output_nodes
[
i
]]
=
std
::
move
(
outputs
[
i
]);
}
}
SmallVector
<
T
>
ret
;
for
(
auto
&&
i
:
outputs
)
{
ret
.
push_back
(
node2tensor
.
at
(
i
));
}
return
ret
;
}
};
const
InternalGraph
&
graph
()
const
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录