Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
241b35a6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
241b35a6
编写于
5月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(ops): remove BackwardGraph op
GitOrigin-RevId: eda20e57606daad69790f6abbc7cd7fba2ba934c
上级
d2e33af5
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
221 addition
and
422 deletion
+221
-422
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+0
-16
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+8
-17
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+3
-3
imperative/python/src/imperative_rt.cpp
imperative/python/src/imperative_rt.cpp
+9
-21
imperative/python/src/ops.cpp
imperative/python/src/ops.cpp
+0
-37
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+0
-3
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+37
-15
imperative/python/src/trace.cpp
imperative/python/src/trace.cpp
+1
-1
imperative/src/impl/backward_graph_opt.cpp
imperative/src/impl/backward_graph_opt.cpp
+5
-7
imperative/src/impl/interpreter/interpreter_impl.cpp
imperative/src/impl/interpreter/interpreter_impl.cpp
+1
-2
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+64
-0
imperative/src/impl/ops/backward_graph.cpp
imperative/src/impl/ops/backward_graph.cpp
+0
-141
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+3
-34
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+1
-1
imperative/src/impl/proxy_graph/common.h
imperative/src/impl/proxy_graph/common.h
+1
-1
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+1
-24
imperative/src/include/megbrain/imperative/backward_graph_opt.h
...tive/src/include/megbrain/imperative/backward_graph_opt.h
+4
-4
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+53
-3
imperative/src/include/megbrain/imperative/ops/backward_graph.h
...tive/src/include/megbrain/imperative/ops/backward_graph.h
+0
-86
imperative/src/include/megbrain/imperative/ops/utility.h
imperative/src/include/megbrain/imperative/ops/utility.h
+1
-1
imperative/src/test/backward_graph.cpp
imperative/src/test/backward_graph.cpp
+29
-5
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
241b35a6
...
...
@@ -18,7 +18,6 @@ import numpy as np
from
..
import
_imperative_rt
from
.._imperative_rt
import
GraphOptimizeOptions
from
.._imperative_rt.core2
import
apply
,
set_cpp_apply_backward_varnode
from
.._imperative_rt.ops
import
BackwardGraph
from
.._wrap
import
device
as
as_device
from
..ops.builtin
import
OpDef
from
.core
import
TensorBase
...
...
@@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode):
return
_wrap
(
outputs
)
def
apply_backward_varnode
(
op
:
BackwardGraph
,
*
args
:
VarNode
):
assert
args
graph
=
args
[
0
].
graph
outputs
=
op
.
interpret
(
op
,
lambda
op
,
args
:
apply_normal_varnode
(
op
,
*
args
),
graph
.
_make_const_for_backward
,
args
,
)
return
outputs
set_cpp_apply_backward_varnode
(
apply_backward_varnode
)
def
input_callback
(
callback
,
*
args
,
device
=
None
,
dtype
=
None
,
shape
=
None
,
graph
=
None
):
outputs
=
_imperative_rt
.
input_callback
(
callback
,
as_device
(
device
).
to_c
(),
dtype
,
shape
,
_unwrap
(
args
),
graph
=
graph
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
241b35a6
...
...
@@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import (
)
from
..core._trace_option
import
set_symbolic_shape
from
..core._wrap
import
device
as
as_device
from
..core.ops.builtin
import
Ba
ckwardGraph
,
Ba
tchNorm
,
OpDef
from
..core.ops.builtin
import
BatchNorm
,
OpDef
from
..core.ops.special
import
Const
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor.utils
import
setscalar
...
...
@@ -587,10 +587,7 @@ class trace:
ivars
.
append
(
info
.
varnode
)
if
isinstance
(
op
,
BackwardGraph
):
ovars
=
G
.
apply_backward_varnode
(
op
,
*
ivars
)
else
:
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
if
require_links
and
len
(
ovars
)
>
0
:
io_links
=
(
ovars
[
0
],)
...
...
@@ -805,14 +802,11 @@ class trace:
name
=
info
.
name
,
)
ivars
.
append
(
h2v
[
h
])
if
isinstance
(
op
,
BackwardGraph
):
ovars
=
G
.
apply_backward_varnode
(
op
,
*
ivars
)
else
:
if
isinstance
(
op
,
BatchNorm
):
assert
(
op
.
fwd_mode
==
BatchNorm
.
FwdMode
.
INFERENCE
),
"can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
if
isinstance
(
op
,
BatchNorm
):
assert
(
op
.
fwd_mode
==
BatchNorm
.
FwdMode
.
INFERENCE
),
"can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
AutoNaming
.
record_opnode
(
ovars
[
0
].
op
)
...
...
@@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
ivars
[
0
]
=
opnode
.
outputs
[
0
]
active_trace
.
_lazy_eval_links
=
(
ivars
[
0
],)
if
isinstance
(
op
,
BackwardGraph
):
ovars
=
G
.
apply_backward_varnode
(
op
,
*
ivars
)
else
:
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
ovars
=
G
.
apply_normal_varnode
(
op
,
*
ivars
)
outputs
=
[
RawTensor
(
o
)
for
o
in
ovars
]
if
require_links
:
...
...
imperative/python/src/grad.cpp
浏览文件 @
241b35a6
...
...
@@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
input_requires_grad
[
i
]
=
python
::
input_requires_grad
(
ctx
,
i
);
}
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
ret
;
auto
bg
=
proxy_graph_detail
::
make_backward_graph
(
auto
bg
=
OpDef
::
make_backward_graph
(
*
ctx
.
op
,
inputs
,
input_requires_grad
,
output_has_grad
);
if
(
bg
.
backward
)
{
if
(
!
bg
.
backward
.
empty
()
)
{
ret
=
std
::
make_shared
<
OptimizedBackwardGraphResult
>
(
bg
);
}
backward_graph_cache
.
emplace
(
key
,
ret
);
...
...
@@ -112,7 +112,7 @@ struct BackwardGraphWithClosure {
size_t
count
=
std
::
count_if
(
save_for_backward
.
begin
(),
save_for_backward
.
end
(),
ranges
::
identity
{});
if
(
backward_graph
->
precomp
)
{
if
(
!
backward_graph
->
precomp
.
empty
()
)
{
auto
&&
irng
=
ranges
::
span
(
ctx
.
args
,
ctx
.
nargs
);
auto
&&
orng
=
views
::
transform
(
outputs
,
[](
auto
&&
i
){
return
i
.
get
();});
auto
precomp
=
apply
(
backward_graph
->
precomp
,
views
::
concat
(
irng
,
orng
));
...
...
imperative/python/src/imperative_rt.cpp
浏览文件 @
241b35a6
...
...
@@ -30,26 +30,14 @@ using namespace imperative;
using
namespace
interpreter
;
namespace
{
std
::
optional
<
std
::
tuple
<
std
::
shared_ptr
<
OpDef
>
,
std
::
vector
<
bool
>
,
std
::
vector
<
bool
>>>
make_backward_graph
(
const
OpDef
&
opdef
,
std
::
vector
<
LogicalTensorDesc
>
inputs
,
std
::
vector
<
bool
>
input_requires_grad
,
std
::
vector
<
bool
>
output_has_grad
)
{
auto
res
=
OpDef
::
make_backward_graph
(
opdef
,
SmallVector
<
LogicalTensorDesc
>
(
inputs
.
begin
(),
inputs
.
end
()),
SmallVector
<
bool
>
(
input_requires_grad
.
begin
(),
input_requires_grad
.
end
()),
SmallVector
<
bool
>
(
output_has_grad
.
begin
(),
output_has_grad
.
end
()));
if
(
res
.
backward
)
{
return
std
::
optional
<
std
::
tuple
<
std
::
shared_ptr
<
OpDef
>
,
std
::
vector
<
bool
>
,
std
::
vector
<
bool
>>>
{
std
::
in_place
,
res
.
backward
,
res
.
save_for_backward
,
res
.
input_has_grad
};
}
else
{
return
{};
}
}
}
// namespace
void
init_imperative_rt
(
py
::
module
m
)
{
m
.
def
(
"make_backward_graph"
,
&
make_backward_graph
);
auto
make_backward_graph
=
[](
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
){
auto
result
=
OpDef
::
make_backward_graph
(
def
,
inputs
,
input_requires_grad
,
output_has_grad
);
return
std
::
make_tuple
(
"backward_graph"
,
result
.
save_for_backward
,
result
.
input_has_grad
);
};
m
.
def
(
"make_backward_graph"
,
make_backward_graph
);
}
imperative/python/src/ops.cpp
浏览文件 @
241b35a6
...
...
@@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) {
}
/*********** begin of hand-write opdefs **************/
PyOpDefBegin
(
BackwardGraph
)
// {{
// };
PyOpDefEnd
(
BackwardGraph
)
void
_init_py_backward_graph
(
py
::
module
m
)
{
using
py_op
=
PyOp
(
BackwardGraph
);
auto
&
py_type
=
PyOpType
(
BackwardGraph
);
py_type
=
{
PyVarObject_HEAD_INIT
(
NULL
,
0
)};
py_type
.
tp_name
=
"megengine.core._imperative_rt.ops.BackwardGraph"
;
py_type
.
tp_basicsize
=
sizeof
(
PyOp
(
BackwardGraph
));
py_type
.
tp_flags
=
Py_TPFLAGS_DEFAULT
|
Py_TPFLAGS_BASETYPE
;
py_type
.
tp_doc
=
"BackwardGraph"
;
py_type
.
tp_base
=
&
PyOpType
(
OpDef
);
py_type
.
tp_dealloc
=
py_dealloc_generic
<
py_op
>
;
py_type
.
tp_new
=
py_new_generic
<
py_op
>
;
mgb_assert
(
PyType_Ready
(
&
py_type
)
>=
0
);
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
auto
interpret
=
py
::
cpp_function
(
[](
OpDef
&
self
,
py
::
object
pyf
,
py
::
object
pyc
,
const
mgb
::
SmallVector
<
py
::
object
>&
inputs
)
{
auto
f
=
[
pyf
](
OpDef
&
op
,
const
mgb
::
SmallVector
<
py
::
object
>&
inputs
)
{
return
py
::
cast
<
mgb
::
SmallVector
<
py
::
object
>>
(
pyf
(
op
.
shared_from_this
(),
inputs
));
};
auto
c
=
[
pyc
](
const
TensorPtr
&
tensor
)
{
return
pyc
(
tensor
->
dev_tensor
());
};
return
self
.
cast_final_safe
<
BackwardGraph
>
().
graph
().
interpret
<
py
::
object
>
(
f
,
c
,
inputs
);
});
mgb_assert
(
PyDict_SetItemString
(
py_type
.
tp_dict
,
"interpret"
,
interpret
.
release
().
ptr
())
>=
0
);
PyType_Modified
(
&
py_type
);
m
.
add_object
(
"BackwardGraph"
,
reinterpret_cast
<
PyObject
*>
(
&
py_type
));
mgb_assert
(
PyOp
(
OpDef
)
::
ctype2pytype
.
emplace
(
BackwardGraph
::
typeinfo
(),
&
py_type
).
second
);
}
struct
PyOpBase
:
PyOpDef
{
static
PyTypeObject
py_type
;
...
...
@@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)
void
init_ops
(
py
::
module
m
)
{
_init_py_op_def
(
m
);
_init_py_backward_graph
(
m
);
_init_py_op_base
(
m
);
INIT_ALL_OP
(
m
)
...
...
imperative/python/src/tensor.cpp
浏览文件 @
241b35a6
...
...
@@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ctx
.
args
=
&
tensors
[
0
];
ctx
.
nargs
=
nargs
;
ctx
.
pytype
=
pytype
;
if
(
ctx
.
op
->
same_type
<
BackwardGraph
>
())
{
ctx
.
backward
=
true
;
}
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
]))){
SmallVector
<
cg
::
VarNode
*>
vinputs
(
nargs
);
...
...
imperative/python/src/tensor.h
浏览文件 @
241b35a6
...
...
@@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
return
apply
(
ctx
);
}
template
<
typename
T
>
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
T
&&
tensors
)
->
std
::
enable_if_t
<
std
::
is_same_v
<
decltype
(
resolve_arrow
(
tensors
[
0
])),
Tensor
*>
,
apply_result_t
>
{
inline
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
nargs
=
tensors
.
size
();
Tensor
*
args
[
ctx
.
nargs
];
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
args
[
i
]
=
resolve_arrow
(
tensors
[
i
]);
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
ctx
.
flags
|=
args
[
i
]
->
m_flags
;
}
return
apply
(
ctx
);
}
inline
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
nargs
=
nargs
;
ctx
.
args
=
args
;
template
<
typename
T
>
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
T
&&
tensors
)
->
std
::
enable_if_t
<
std
::
is_same_v
<
decltype
(
resolve_arrow
(
tensors
[
0
])),
Tensor
*>
,
apply_result_t
>
{
size_t
nargs
=
tensors
.
size
();
Tensor
*
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
ctx
.
flags
|=
args
[
i
]
->
m_flags
;
args
[
i
]
=
resolve_arrow
(
tensors
[
i
])
;
}
return
apply
(
ctx
);
return
apply
(
op
,
args
,
nargs
);
}
inline
auto
apply
(
Subgraph
graph
,
Tensor
*
const
*
args
,
size_t
nargs
)
{
SmallVector
<
std
::
shared_ptr
<
Tensor
>>
inputs
;
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
inputs
.
push_back
(
args
[
i
]
->
shared_from_this
());
}
auto
apply_functor
=
[](
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
std
::
shared_ptr
<
Tensor
>>
inputs
)
{
return
apply
(
op
,
inputs
);
};
auto
const_functor
=
[](
imperative
::
TensorPtr
value
)
{
return
std
::
make_shared
<
Tensor
>
(
interpreter_for_py
->
put
(
value
->
dev_tensor
()));
};
return
graph
.
apply
(
inputs
,
apply_functor
,
const_functor
);
}
template
<
typename
T
>
auto
apply
(
Subgraph
graph
,
T
&&
tensors
)
->
std
::
enable_if_t
<
std
::
is_same_v
<
decltype
(
tensors
[
0
]),
Tensor
*>
,
apply_result_t
>
{
size_t
nargs
=
tensors
.
size
();
Tensor
*
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
args
[
i
]
=
resolve_arrow
(
tensors
[
i
]);
}
return
apply
(
graph
,
args
,
nargs
);
}
void
init_tensor
(
pybind11
::
module
);
...
...
imperative/python/src/trace.cpp
浏览文件 @
241b35a6
...
...
@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t
outputs
;
if
(
ctx
.
backward
)
{
//
call megbrain_graph.py apply(BackwardGraph, *args)
//
reach here when compiled=True
auto
args
=
py
::
tuple
(
ctx
.
nargs
+
1
);
args
[
0
]
=
py
::
cast
(
ctx
.
op
);
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
i
++
)
{
...
...
imperative/src/impl/backward_graph_opt.cpp
浏览文件 @
241b35a6
...
...
@@ -18,24 +18,22 @@ using namespace imperative;
OptimizedBackwardGraphResult
::
OptimizedBackwardGraphResult
(
const
BackwardGraphResult
&
src
)
:
input_has_grad
(
src
.
input_has_grad
)
{
if
(
!
src
.
backward
->
same_type
<
BackwardGraph
>
()
)
{
if
(
src
.
backward
.
exprs
.
size
()
<=
1
)
{
// backward graph only contains a single op
backward
=
src
.
backward
;
save_for_backward
=
src
.
save_for_backward
;
return
;
}
save_for_backward
.
resize
(
src
.
save_for_backward
.
size
(),
false
);
precomp
.
reset
(
new
BackwardGraph
);
backward
.
reset
(
new
BackwardGraph
);
auto
&&
graph
=
src
.
backward
->
cast_final_safe
<
BackwardGraph
>
().
graph
()
;
auto
&&
graph
=
src
.
backward
;
auto
&&
mask
=
src
.
save_for_backward
;
size_t
input_size
=
src
.
input_has_grad
.
size
();
size_t
output_size
=
(
mask
.
size
()
-
input_size
)
/
2
;
mgb_assert
(
input_size
+
output_size
*
2
==
mask
.
size
());
auto
&
fgraph
=
precomp
->
cast_final
<
BackwardGraph
>
().
graph
()
;
auto
&
bgraph
=
backward
->
cast_final
<
BackwardGraph
>
().
graph
()
;
auto
&
fgraph
=
precomp
;
auto
&
bgraph
=
backward
;
// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint
...
...
@@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
}
if
(
!
fgraph
.
outputs
.
size
())
{
precomp
.
reset
()
;
precomp
=
{}
;
}
}
imperative/src/impl/interpreter/interpreter_impl.cpp
浏览文件 @
241b35a6
...
...
@@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
op_type
==
RemoteSend
::
typeinfo
()
||
op_type
==
CollectiveComm
::
typeinfo
()
||
op_type
==
opr
::
InputCallback
::
typeinfo
()
||
op_type
==
opr
::
OutputCallback
::
typeinfo
()
||
op_type
==
BackwardGraph
::
typeinfo
())
{
op_type
==
opr
::
OutputCallback
::
typeinfo
())
{
return
m_commands
.
end
();
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
GetValue
>
)
{
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
241b35a6
...
...
@@ -10,6 +10,9 @@
*/
#include "megbrain/imperative/op_def.h"
#include <sstream>
#include "megbrain/imperative/ops/opr_attr.h"
#include "./op_trait.h"
...
...
@@ -117,6 +120,67 @@ const std::string OpDef::make_name() const {
return
m_scope
+
"."
+
trait
()
->
make_name
(
*
this
);
}
std
::
string
Subgraph
::
repr
()
const
{
std
::
ostringstream
buf
;
buf
<<
"("
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
inputs
[
i
];
}
buf
<<
") => {
\n
"
;
auto
fmt_const
=
[](
size_t
i
,
const
TensorPtr
&
t
)
{
if
(
t
->
shape
().
ndim
==
1
&&
t
->
shape
()[
0
]
==
1
)
{
auto
&&
v
=
t
->
get_value
();
if
(
v
.
dtype
()
==
dtype
::
Float32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
dt_float32
>
());
}
else
if
(
v
.
dtype
()
==
dtype
::
Int32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
int32_t
>
());
}
}
return
std
::
string
(
"%c"
)
+
std
::
to_string
(
i
);
};
std
::
unordered_map
<
size_t
,
std
::
string
>
const_reps
;
for
(
auto
&&
[
i
,
t
]
:
constants
)
{
const_reps
.
emplace
(
i
,
fmt_const
(
i
,
t
));
}
for
(
auto
&
[
op
,
ins
,
outs
]
:
exprs
)
{
buf
<<
" "
;
if
(
outs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outs
[
i
];
}
buf
<<
" = "
;
}
if
(
auto
*
p
=
op
->
try_cast_final
<
OprAttr
>
())
{
buf
<<
p
->
type
;
}
else
{
buf
<<
op
->
dyn_typeinfo
()
->
name
;
}
for
(
size_t
i
:
ins
)
{
buf
<<
" "
;
auto
&&
it
=
const_reps
.
find
(
i
);
if
(
it
!=
const_reps
.
end
())
{
buf
<<
it
->
second
;
}
else
{
buf
<<
"%"
<<
i
;
}
}
buf
<<
"
\n
"
;
}
buf
<<
" "
;
if
(
outputs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outputs
[
i
];
}
}
else
{
buf
<<
"()"
;
}
buf
<<
"
\n
}
\n
"
;
return
buf
.
str
();
}
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/impl/ops/backward_graph.cpp
浏览文件 @
241b35a6
...
...
@@ -19,147 +19,6 @@
namespace
mgb
{
namespace
imperative
{
SmallVector
<
TensorPtr
>
BackwardGraph
::
InternalGraph
::
apply
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
{
return
interpret
<
TensorPtr
>
(
&
OpDef
::
apply_on_physical_tensor
,
[](
const
TensorPtr
&
x
)
{
return
x
;},
inputs
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
BackwardGraph
::
InternalGraph
::
infer_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
{
using
TensorAttr
=
LogicalTensorDesc
;
ThinHashMap
<
size_t
,
TensorAttr
>
node2attr
;
auto
&&
input_nodes
=
this
->
inputs
;
auto
&&
output_nodes
=
this
->
outputs
;
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
node2attr
[
input_nodes
[
i
]]
=
inputs
[
i
];
}
for
(
auto
&&
i
:
constants
)
{
auto
*
value
=
i
.
second
->
try_get_value
();
mgb_assert
(
value
);
node2attr
[
i
.
first
]
=
TensorAttr
{
i
.
second
->
layout
(),
i
.
second
->
comp_node
(),
value
->
proxy_to_default_cpu
()};
}
bool
validated
=
true
;
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
auto
&&
[
expr_op
,
expr_inps
,
expr_oups
]
=
exprs
[
i
];
SmallVector
<
TensorAttr
>
expr_input_descs
;
for
(
auto
&&
inp
:
expr_inps
)
{
expr_input_descs
.
push_back
(
node2attr
.
at
(
inp
));
}
auto
[
expr_output_descs
,
expr_validated
]
=
OpDef
::
infer_output_attrs_fallible
(
*
expr_op
,
expr_input_descs
);
validated
=
validated
&&
expr_validated
;
mgb_assert
(
expr_output_descs
.
size
()
==
expr_oups
.
size
());
for
(
size_t
i
=
0
;
i
<
expr_output_descs
.
size
();
++
i
)
{
node2attr
[
expr_oups
[
i
]]
=
expr_output_descs
[
i
];
}
}
SmallVector
<
TensorAttr
>
ret
;
for
(
auto
&&
i
:
output_nodes
)
{
ret
.
push_back
(
node2attr
.
at
(
i
));
}
return
{
ret
,
validated
};
}
std
::
string
BackwardGraph
::
InternalGraph
::
repr
()
{
std
::
ostringstream
buf
;
buf
<<
"("
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
inputs
[
i
];
}
buf
<<
") => {
\n
"
;
auto
fmt_const
=
[](
size_t
i
,
TensorPtr
&
t
)
{
if
(
t
->
shape
().
ndim
==
1
&&
t
->
shape
()[
0
]
==
1
)
{
auto
&&
v
=
t
->
get_value
();
if
(
v
.
dtype
()
==
dtype
::
Float32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
dt_float32
>
());
}
else
if
(
v
.
dtype
()
==
dtype
::
Int32
{})
{
return
std
::
to_string
(
*
v
.
ptr
<
int32_t
>
());
}
}
return
std
::
string
(
"%c"
)
+
std
::
to_string
(
i
);
};
std
::
unordered_map
<
size_t
,
std
::
string
>
const_reps
;
for
(
auto
&&
[
i
,
t
]
:
constants
)
{
const_reps
.
emplace
(
i
,
fmt_const
(
i
,
t
));
}
for
(
auto
&
[
op
,
ins
,
outs
]
:
exprs
)
{
buf
<<
" "
;
if
(
outs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outs
[
i
];
}
buf
<<
" = "
;
}
if
(
auto
*
p
=
op
->
try_cast_final
<
OprAttr
>
())
{
buf
<<
p
->
type
;
}
else
{
buf
<<
op
->
dyn_typeinfo
()
->
name
;
}
for
(
size_t
i
:
ins
)
{
buf
<<
" "
;
auto
&&
it
=
const_reps
.
find
(
i
);
if
(
it
!=
const_reps
.
end
())
{
buf
<<
it
->
second
;
}
else
{
buf
<<
"%"
<<
i
;
}
}
buf
<<
"
\n
"
;
}
buf
<<
" "
;
if
(
outputs
.
size
())
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
i
>
0
)
buf
<<
", "
;
buf
<<
"%"
<<
outputs
[
i
];
}
}
else
{
buf
<<
"()"
;
}
buf
<<
"
\n
}
\n
"
;
return
buf
.
str
();
}
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
BackwardGraph
);
namespace
{
SmallVector
<
TensorPtr
>
backward_impl
(
const
OpDef
&
backward_graph
,
const
SmallVector
<
TensorPtr
>&
tensors
)
{
return
backward_graph
.
cast_final_safe
<
BackwardGraph
>
()
.
graph
().
apply
(
tensors
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_tensor_attrs
(
const
OpDef
&
backward_graph
,
const
SmallVector
<
LogicalTensorDesc
>
inputs
)
{
return
backward_graph
.
cast_final_safe
<
BackwardGraph
>
()
.
graph
().
infer_attrs
(
inputs
);
}
std
::
vector
<
std
::
pair
<
const
char
*
,
std
::
string
>>
props
(
const
OpDef
&
backward_graph
)
{
return
{};
}
OP_TRAIT_REG
(
BackwardGraph
,
BackwardGraph
)
.
apply_on_physical_tensor
(
backward_impl
)
.
infer_output_attrs_fallible
(
infer_tensor_attrs
)
.
props
(
props
)
.
fallback
();
}
// anonymous namespace
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/impl/proxy_graph.cpp
浏览文件 @
241b35a6
...
...
@@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph(
auto
*
gfunc
=
cg
::
lookup_grad_func
(
fwd
->
dyn_typeinfo
());
BackwardGraphResult
result
;
auto
&&
backward
=
BackwardGraph
::
make
();
auto
&&
igraph
=
backward
->
cast_final_safe
<
BackwardGraph
>
().
graph
();
auto
&&
igraph
=
result
.
backward
;
size_t
nr_backward_graph_inputs
=
0
;
auto
gen_expr
=
[
this
,
&
var2idx
,
&
igraph
,
&
push
,
&
fwd
,
...
...
@@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph(
++
nr_backward_graph_inputs
;
push
(
op
->
output
(
0
));
}
else
{
std
::
v
ector
<
size_t
>
inputs
,
outputs
;
SmallV
ector
<
size_t
>
inputs
,
outputs
;
for
(
auto
&&
i
:
op
->
input
())
{
if
(
i
->
owner_opr
()
==
fwd
)
{
if
(
var2idx
.
find
(
i
)
==
var2idx
.
end
())
{
...
...
@@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph(
for
(
auto
&&
i
:
op
->
usable_output
())
{
outputs
.
push_back
(
push
(
i
));
}
igraph
.
exprs
.
emplace_back
(
OpDef
::
make_from_op_node
(
op
),
inputs
,
outputs
);
igraph
.
exprs
.
push_back
({
OpDef
::
make_from_op_node
(
op
),
inputs
,
outputs
}
);
}
};
...
...
@@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph(
write_inputs
(
outputs
);
write_inputs
(
output_grads
);
mgb_assert
(
igraph
.
inputs
.
size
()
==
nr_backward_graph_inputs
);
auto
treat_as_single
=
[](
auto
&&
igraph
)
{
if
(
igraph
.
exprs
.
size
()
!=
1
)
return
false
;
auto
&&
expr
=
igraph
.
exprs
[
0
];
auto
&&
expr_inputs
=
std
::
get
<
1
>
(
expr
);
if
(
expr_inputs
.
size
()
!=
igraph
.
inputs
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
expr_inputs
.
size
();
++
i
)
{
if
(
igraph
.
inputs
[
i
]
!=
expr_inputs
[
i
])
{
return
false
;
}
}
auto
&&
expr_outputs
=
std
::
get
<
2
>
(
expr
);
if
(
expr_outputs
.
size
()
!=
igraph
.
outputs
.
size
())
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
expr_outputs
.
size
();
++
i
)
{
if
(
igraph
.
outputs
[
i
]
!=
expr_outputs
[
i
])
{
return
false
;
}
}
return
true
;
};
if
(
treat_as_single
(
igraph
))
{
result
.
backward
=
std
::
get
<
0
>
(
igraph
.
exprs
[
0
]);
}
else
{
result
.
backward
=
backward
;
}
return
result
;
}
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
241b35a6
...
...
@@ -65,7 +65,7 @@ private:
class
InputPlaceholder
;
struct
ProxyGraphInst
;
struct
GradGraph
;
struct
CurOprGuard
;
class
CurOprGuard
;
void
reset
();
...
...
imperative/src/impl/proxy_graph/common.h
浏览文件 @
241b35a6
...
...
@@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph {
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct
ProxyGraph
{
struct
InputPlaceholder
;
struct
MiniGraph
;
class
MiniGraph
;
};
}
// namespace mgb::imperative::proxy_graph
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
241b35a6
...
...
@@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def,
auto
output_descs
=
infer_output_attrs
(
def
,
inputs
);
SmallVector
<
TensorPtr
>
outputs
(
output_descs
.
size
(),
{});
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
&
output
=
outputs
[
i
];
auto
&
output_desc
=
output_descs
[
i
];
if
(
def
.
same_type
<
Elemwise
>
())
{
for
(
size_t
j
=
0
;
j
<
inputs
.
size
();
j
++
)
{
// TODO: reindex inputs to support inplace exprs like 'y = x op x'.
auto
&
input
=
inputs
[
j
];
// Because we pass inputs by value, if input and input->blob() are all unique,
// their ownerships are on the stack, thus we can reuse them safely.
// @see: interpreter::intl::ChannelImpl::process_one_task
if
(
input
.
unique
()
&&
input
->
blob
().
unique
()
&&
input
->
blob
()
->
storage
().
unique
()
&&
input
->
layout
().
dtype
==
output_desc
.
layout
.
dtype
&&
input
->
layout
().
eq_layout
(
output_desc
.
layout
)
&&
input
->
comp_node
()
==
output_desc
.
comp_node
)
{
static
std
::
atomic_llong
inplace_count
=
0
;
mgb_log_debug
(
"do inplace for elemwise, layout: %s, count: %lld"
,
output_desc
.
layout
.
to_string
().
c_str
(),
++
inplace_count
);
output
=
Tensor
::
make
(
input
->
blob
(),
input
->
layout
(),
input
->
offset
());
break
;
}
}
}
if
(
!
output
)
{
output
=
Tensor
::
make
(
output_desc
.
layout
,
output_desc
.
comp_node
);
}
outputs
[
i
]
=
Tensor
::
make
(
output_descs
[
i
].
layout
,
output_descs
[
i
].
comp_node
);
}
exec
(
def
,
inputs
,
outputs
);
auto
async_error
=
ProxyGraph
::
get_async_error
();
...
...
imperative/src/include/megbrain/imperative/backward_graph_opt.h
浏览文件 @
241b35a6
...
...
@@ -14,10 +14,10 @@
namespace
mgb
::
imperative
{
struct
OptimizedBackwardGraphResult
{
std
::
shared_ptr
<
OpDef
>
precomp
;
std
::
shared_ptr
<
OpDef
>
backward
;
std
::
v
ector
<
bool
>
save_for_backward
;
std
::
v
ector
<
bool
>
input_has_grad
;
Subgraph
precomp
;
Subgraph
backward
;
SmallV
ector
<
bool
>
save_for_backward
;
SmallV
ector
<
bool
>
input_has_grad
;
OptimizedBackwardGraphResult
(
const
BackwardGraphResult
&
bgraph
);
};
...
...
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
241b35a6
...
...
@@ -26,10 +26,60 @@ enum DispatchMode {
KERNEL
=
1
};
using
SharedOp
=
std
::
shared_ptr
<
OpDef
>
;
template
<
typename
T
>
struct
Expr
{
std
::
shared_ptr
<
OpDef
>
op
;
SmallVector
<
T
>
inputs
;
SmallVector
<
T
>
outputs
;
};
struct
Subgraph
{
SmallVector
<
size_t
>
inputs
;
SmallVector
<
std
::
pair
<
size_t
,
TensorPtr
>>
constants
;
SmallVector
<
size_t
>
outputs
;
SmallVector
<
Expr
<
size_t
>>
exprs
;
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
apply
(
SmallVector
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
std
::
unordered_map
<
size_t
,
T
>
idx2var
;
mgb_assert
(
inputs
.
size
()
==
input_vars
.
size
(),
"input size mismatch"
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
idx2var
[
inputs
[
i
]]
=
input_vars
[
i
];
}
for
(
auto
&&
[
idx
,
val
]
:
constants
)
{
idx2var
[
idx
]
=
c
(
val
);
}
for
(
auto
&
expr
:
exprs
)
{
SmallVector
<
T
>
expr_inputs
;
for
(
auto
idx
:
expr
.
inputs
)
{
expr_inputs
.
push_back
(
idx2var
[
idx
]);
}
SmallVector
<
T
>
expr_outputs
=
f
(
expr
.
op
,
std
::
move
(
expr_inputs
));
mgb_assert
(
expr_outputs
.
size
()
==
expr
.
outputs
.
size
(),
"output size mismatch"
);
for
(
size_t
i
=
0
;
i
<
expr_outputs
.
size
();
++
i
)
{
idx2var
[
expr
.
outputs
[
i
]]
=
expr_outputs
[
i
];
}
}
SmallVector
<
T
>
output_vars
;
for
(
auto
idx
:
outputs
)
{
output_vars
.
push_back
(
idx2var
[
idx
]);
}
return
output_vars
;
}
bool
empty
()
const
{
return
outputs
.
size
()
==
0
;
}
std
::
string
repr
()
const
;
};
struct
BackwardGraphResult
{
std
::
shared_ptr
<
OpDef
>
backward
;
std
::
v
ector
<
bool
>
save_for_backward
;
std
::
v
ector
<
bool
>
input_has_grad
;
Subgraph
backward
;
SmallV
ector
<
bool
>
save_for_backward
;
SmallV
ector
<
bool
>
input_has_grad
;
};
class
OpDef
:
public
Hashable
,
...
...
imperative/src/include/megbrain/imperative/ops/backward_graph.h
浏览文件 @
241b35a6
...
...
@@ -15,92 +15,6 @@
namespace
mgb
{
namespace
imperative
{
// a special OpDef used for taking gradient on physical tensor
struct
BackwardGraph
final
:
public
OpDefImplBase
<
BackwardGraph
>
{
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
public:
struct
InternalGraph
{
// op, inputs, outputs
using
Expr
=
std
::
tuple
<
std
::
shared_ptr
<
OpDef
>
,
std
::
vector
<
size_t
>
,
std
::
vector
<
size_t
>>
;
std
::
vector
<
Expr
>
exprs
;
// index array of input nodes
std
::
vector
<
size_t
>
inputs
;
// index array of output nodes
std
::
vector
<
size_t
>
outputs
;
// pair of (node index, correspending constant)
std
::
vector
<
std
::
pair
<
size_t
,
TensorPtr
>>
constants
;
SmallVector
<
TensorPtr
>
apply
(
const
SmallVector
<
TensorPtr
>&
inputs
)
const
;
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_attrs
(
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
const
;
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
interpret
(
F
&&
f
,
C
&&
c
,
const
SmallVector
<
T
>&
inputs
)
const
{
ThinHashMap
<
size_t
,
T
>
node2tensor
;
auto
&&
input_nodes
=
this
->
inputs
;
mgb_assert
(
inputs
.
size
()
==
input_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
node2tensor
[
input_nodes
[
i
]]
=
inputs
[
i
];
}
for
(
auto
&&
i
:
constants
)
{
node2tensor
[
i
.
first
]
=
c
(
i
.
second
);
}
for
(
size_t
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
auto
&&
expr
=
exprs
[
i
];
SmallVector
<
T
>
inputs
;
for
(
auto
&&
in
:
std
::
get
<
1
>
(
expr
))
{
inputs
.
push_back
(
node2tensor
.
at
(
in
));
}
auto
&&
outputs
=
f
(
*
std
::
get
<
0
>
(
expr
),
std
::
move
(
inputs
));
auto
&&
output_nodes
=
std
::
get
<
2
>
(
expr
);
mgb_assert
(
outputs
.
size
()
==
output_nodes
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
node2tensor
[
output_nodes
[
i
]]
=
std
::
move
(
outputs
[
i
]);
}
}
SmallVector
<
T
>
ret
;
for
(
auto
&&
i
:
outputs
)
{
ret
.
push_back
(
node2tensor
.
at
(
i
));
}
return
ret
;
}
std
::
string
repr
();
};
const
InternalGraph
&
graph
()
const
{
return
m_graph
;
}
InternalGraph
&
graph
()
{
return
m_graph
;
}
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
if
(
!
rhs
.
same_type
<
BackwardGraph
>
())
{
return
false
;
}
auto
&
other
=
rhs
.
cast_final_safe
<
BackwardGraph
>
();
if
(
this
==
&
other
)
{
return
true
;
}
// FIXME
return
false
;
}
std
::
string
repr
()
{
return
m_graph
.
repr
();}
private:
InternalGraph
m_graph
;
};
}
// namespace imperative
}
// namespace mgb
...
...
imperative/src/include/megbrain/imperative/ops/utility.h
浏览文件 @
241b35a6
...
...
@@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
}
bool
is_same_st
(
const
Hashable
&
rhs
)
const
override
{
return
obj
.
equal
(
static_cast
<
const
GenericPyOp
&>
(
rhs
).
obj
);
return
obj
.
equal
(
rhs
.
cast_final
<
GenericPyOp
>
(
).
obj
);
}
MGB_DYN_TYPE_OBJ_FINAL_DECL
;
...
...
imperative/src/test/backward_graph.cpp
浏览文件 @
241b35a6
...
...
@@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons
return
ret
;
}
SmallVector
<
TensorPtr
>
apply_shared_on_physical_tensor
(
std
::
shared_ptr
<
OpDef
>
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
return
OpDef
::
apply_on_physical_tensor
(
*
def
,
inputs
);
}
TEST
(
TestImperative
,
BackwardGraphBasic
)
{
HostTensorGenerator
<>
gen
;
SmallVector
<
HostTensorND
>
hvs
;
...
...
@@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs
.
clear
();
auto
input_grads
=
OpDef
::
apply_on_physical_tensor
(
*
(
result
.
backward
),
backward_graph_inputs
);
auto
input_grads
=
result
.
backward
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
){
return
x
;
}
);
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
for
(
size_t
i
=
0
;
i
<
input_has_grad
.
size
();
++
i
)
{
mgb_assert
(
input_has_grad
[
i
]
==
static_cast
<
bool
>
(
input_grads
[
i
]));
...
...
@@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs
.
clear
();
auto
input_grads
=
OpDef
::
apply_on_physical_tensor
(
*
(
result
.
backward
),
backward_graph_inputs
);
auto
input_grads
=
result
.
backward
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
){
return
x
;
}
);
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
for
(
size_t
i
=
0
;
i
<
input_has_grad
.
size
();
++
i
)
{
mgb_assert
(
input_has_grad
[
i
]
==
static_cast
<
bool
>
(
input_grads
[
i
]));
...
...
@@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto
c_tn
=
OpDef
::
apply_on_physical_tensor
(
*
op
,
{
a_tn
,
b_tn
})[
0
];
auto
backward_graph_inputs
=
prepare_backward_graph_inputs
<
SmallVector
<
TensorPtr
>>
(
bg
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads
=
expand_grads
(
bg
,
OpDef
::
apply_on_physical_tensor
(
*
bg
.
backward
,
backward_graph_inputs
));
auto
grads
=
expand_grads
(
bg
,
bg
.
backward
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
){
return
x
;
}
));
auto
precomp
=
OpDef
::
apply_on_physical_tensor
(
*
obg
.
precomp
,
{
a_tn
,
b_tn
,
c_tn
});
auto
precomp
=
obg
.
precomp
.
apply
(
SmallVector
<
TensorPtr
>
{
a_tn
,
b_tn
,
c_tn
},
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
){
return
x
;
}
);
ASSERT_EQ
(
precomp
.
size
(),
2
);
ASSERT_EQ
(
precomp
[
0
]
->
shape
().
ndim
,
1
);
ASSERT_LE
(
precomp
[
0
]
->
shape
()[
0
],
2
);
...
...
@@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
ASSERT_LE
(
precomp
[
1
]
->
shape
()[
0
],
2
);
auto
backward_inputs
=
prepare_optimized_backward_inputs
<
SmallVector
<
TensorPtr
>>
(
obg
,
precomp
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads2
=
expand_grads
(
obg
,
OpDef
::
apply_on_physical_tensor
(
*
obg
.
backward
,
backward_inputs
));
auto
grads2
=
expand_grads
(
obg
,
obg
.
backward
.
apply
(
backward_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
){
return
x
;
}
));
ASSERT_EQ
(
grads2
.
size
(),
2
);
MGB_ASSERT_TENSOR_EQ
(
grads
[
0
]
->
get_value
(),
grads2
[
0
]
->
get_value
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录