Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0b8dc2c9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0b8dc2c9
编写于
8月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(subgraph): add generic encoded_graph
GitOrigin-RevId: 56d90be0e702ed15cafc9e00586a313bf817c9dd
上级
88b3c842
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
247 addition
and
51 deletion
+247
-51
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+1
-1
imperative/python/src/imperative_rt.cpp
imperative/python/src/imperative_rt.cpp
+1
-1
imperative/src/impl/backward_graph_opt.cpp
imperative/src/impl/backward_graph_opt.cpp
+9
-9
imperative/src/impl/op_def.cpp
imperative/src/impl/op_def.cpp
+1
-1
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+11
-10
imperative/src/impl/proxy_graph.h
imperative/src/impl/proxy_graph.h
+1
-1
imperative/src/impl/proxy_graph_detail.cpp
imperative/src/impl/proxy_graph_detail.cpp
+2
-2
imperative/src/impl/subgraph.cpp
imperative/src/impl/subgraph.cpp
+21
-0
imperative/src/include/megbrain/imperative/backward_graph_opt.h
...tive/src/include/megbrain/imperative/backward_graph_opt.h
+1
-1
imperative/src/include/megbrain/imperative/op_def.h
imperative/src/include/megbrain/imperative/op_def.h
+1
-7
imperative/src/include/megbrain/imperative/proxy_graph_detail.h
...tive/src/include/megbrain/imperative/proxy_graph_detail.h
+1
-1
imperative/src/include/megbrain/imperative/subgraph.h
imperative/src/include/megbrain/imperative/subgraph.h
+180
-0
imperative/src/test/backward_graph.cpp
imperative/src/test/backward_graph.cpp
+17
-17
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -77,7 +77,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
std
::
shared_ptr
<
OptimizedBackwardGraphResult
>
ret
;
auto
bg
=
OpDef
::
make_backward_graph
(
*
ctx
.
op
,
inputs
,
input_requires_grad
,
output_has_grad
);
if
(
!
bg
.
backward
.
empty
())
{
if
(
!
bg
.
graph
.
empty
())
{
ret
=
std
::
make_shared
<
OptimizedBackwardGraphResult
>
(
bg
);
}
backward_graph_cache
.
emplace
(
key
,
ret
);
...
...
imperative/python/src/imperative_rt.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -37,7 +37,7 @@ void init_imperative_rt(py::module m) {
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
){
auto
result
=
OpDef
::
make_backward_graph
(
def
,
inputs
,
input_requires_grad
,
output_has_grad
);
return
std
::
make_tuple
(
"backward_graph"
,
result
.
save_for_backward
,
result
.
input_has_grad
);
return
std
::
make_tuple
(
"backward_graph"
,
result
.
input_mask
,
result
.
output_mask
);
};
m
.
def
(
"make_backward_graph"
,
make_backward_graph
);
}
imperative/src/impl/backward_graph_opt.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -16,19 +16,19 @@
using
namespace
mgb
;
using
namespace
imperative
;
OptimizedBackwardGraphResult
::
OptimizedBackwardGraphResult
(
const
BackwardGraphResult
&
src
)
:
input_has_grad
(
src
.
input_has_grad
)
{
if
(
src
.
backward
.
exprs
.
size
()
<=
1
)
{
OptimizedBackwardGraphResult
::
OptimizedBackwardGraphResult
(
const
EncodedSubraph
&
src
)
:
input_has_grad
(
src
.
output_mask
)
{
if
(
src
.
graph
.
exprs
.
size
()
<=
1
)
{
// backward graph only contains a single op
backward
=
src
.
backward
;
save_for_backward
=
src
.
save_for_backward
;
backward
=
src
.
graph
;
save_for_backward
=
src
.
input_mask
;
return
;
}
save_for_backward
.
resize
(
src
.
save_for_backward
.
size
(),
false
);
save_for_backward
.
resize
(
src
.
input_mask
.
size
(),
false
);
auto
&&
graph
=
src
.
backward
;
auto
&&
mask
=
src
.
save_for_backward
;
size_t
input_size
=
src
.
input_has_grad
.
size
();
auto
&&
graph
=
src
.
graph
;
auto
&&
mask
=
src
.
input_mask
;
size_t
input_size
=
src
.
output_mask
.
size
();
size_t
output_size
=
(
mask
.
size
()
-
input_size
)
/
2
;
mgb_assert
(
input_size
+
output_size
*
2
==
mask
.
size
());
...
...
imperative/src/impl/op_def.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -80,7 +80,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli
return
def
.
trait
()
->
infer_output_attrs_fallible
(
def
,
inputs
);
}
BackwardGraphResult
OpDef
::
make_backward_graph
(
EncodedSubraph
OpDef
::
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
...
...
imperative/src/impl/proxy_graph.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -668,14 +668,14 @@ struct ProxyGraph::GradGraph {
cg
::
VarNode
*
grad
;
};
BackwardGraphResult
EncodedSubraph
ProxyGraph
::
make_backward_graph
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
)
{
ThinHashMap
<
VarNode
*
,
size_t
>
var2idx
;
auto
push
=
[
&
var2idx
,
cnt
=
0
](
VarNode
*
var
)
mutable
{
auto
push
=
[
&
var2idx
,
cnt
=
1
](
VarNode
*
var
)
mutable
{
//cnt is always greater non zero
auto
&&
ret
=
var2idx
.
emplace
(
var
,
cnt
++
);
mgb_assert
(
ret
.
second
,
"var %s has been already inserted"
,
var
->
cname
());
return
ret
.
first
->
second
;
...
...
@@ -702,8 +702,8 @@ ProxyGraph::make_backward_graph(
}
auto
*
gfunc
=
cg
::
lookup_grad_func
(
fwd
->
dyn_typeinfo
());
BackwardGraphResult
result
;
auto
&&
igraph
=
result
.
backward
;
EncodedSubraph
result
;
auto
&&
igraph
=
result
.
graph
;
size_t
nr_backward_graph_inputs
=
0
;
auto
gen_expr
=
[
this
,
&
var2idx
,
&
igraph
,
&
push
,
&
fwd
,
...
...
@@ -735,7 +735,7 @@ ProxyGraph::make_backward_graph(
// set backward graph outputs
cg
::
DepOprIter
iter
{
gen_expr
};
iter
.
set_visited
(
fwd
);
result
.
input_has_grad
.
resize
(
inputs
.
size
());
result
.
output_mask
.
resize
(
inputs
.
size
());
VarNodeArray
output_grads_with_unused_var
;
{
...
...
@@ -760,6 +760,7 @@ ProxyGraph::make_backward_graph(
if
(
grad_results
.
valid
())
{
grad
=
grad_results
.
val
()[
i
];
}
else
{
mgb_assert
(
gfunc
,
"could not find grad function"
);
auto
res
=
(
*
gfunc
)(
fwd
,
i
,
output_grads_with_unused_var
);
if
(
res
.
from_single
())
{
grad
=
res
.
single
();
...
...
@@ -776,9 +777,9 @@ ProxyGraph::make_backward_graph(
fwd
->
dyn_typeinfo
()
->
name
,
i
);
iter
.
add
(
grad
);
igraph
.
outputs
.
push_back
(
var2idx
.
at
(
grad
));
result
.
input_has_grad
[
i
]
=
true
;
result
.
output_mask
[
i
]
=
true
;
}
else
{
result
.
input_has_grad
[
i
]
=
false
;
result
.
output_mask
[
i
]
=
false
;
}
}
if
(
igraph
.
outputs
.
empty
())
{
...
...
@@ -787,15 +788,15 @@ ProxyGraph::make_backward_graph(
// set backward graph inputs
igraph
.
inputs
.
reserve
(
nr_backward_graph_inputs
);
result
.
save_for_backward
.
reserve
(
nr_backward_graph_inputs
);
result
.
input_mask
.
reserve
(
nr_backward_graph_inputs
);
auto
write_inputs
=
[
&
igraph
,
&
var2idx
,
&
result
](
const
VarNodeArray
&
vars
)
{
for
(
auto
&&
i
:
vars
)
{
auto
&&
iter
=
var2idx
.
find
(
i
);
if
(
iter
!=
var2idx
.
end
())
{
igraph
.
inputs
.
push_back
(
iter
->
second
);
result
.
save_for_backward
.
push_back
(
true
);
result
.
input_mask
.
push_back
(
true
);
}
else
{
result
.
save_for_backward
.
push_back
(
false
);
result
.
input_mask
.
push_back
(
false
);
}
}
};
...
...
imperative/src/impl/proxy_graph.h
浏览文件 @
0b8dc2c9
...
...
@@ -40,7 +40,7 @@ public:
const
SmallVector
<
Tensor
*>&
outputs
,
const
SmallVector
<
Tensor
*>&
workspace
);
BackwardGraphResult
make_backward_graph
(
EncodedSubraph
make_backward_graph
(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
...
...
imperative/src/impl/proxy_graph_detail.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -133,7 +133,7 @@ size_t get_backward_graph_hash_key(const OpDef& def,
return
state
.
digest
();
}
struct
BackwardGraphCache
:
std
::
unordered_map
<
size_t
,
BackwardGraphResult
>
,
CompNodeDepedentObject
{
struct
BackwardGraphCache
:
std
::
unordered_map
<
size_t
,
EncodedSubraph
>
,
CompNodeDepedentObject
{
std
::
shared_ptr
<
void
>
on_comp_node_finalize
()
override
{
clear
();
return
{};
...
...
@@ -142,7 +142,7 @@ struct BackwardGraphCache : std::unordered_map<size_t, BackwardGraphResult>, Com
}
// anonymous namespace
BackwardGraphResult
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
...
...
imperative/src/impl/subgraph.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -101,5 +101,26 @@ void Subgraph::replace_vars(
}
}
std
::
string
EncodedSubraph
::
repr
()
const
{
std
::
string
buffer
;
buffer
.
push_back
(
'|'
);
for
(
size_t
i
=
0
;
i
<
input_mask
.
size
();
++
i
)
{
buffer
.
push_back
(
input_mask
[
i
]
?
'#'
:
' '
);
}
buffer
.
push_back
(
'|'
);
buffer
.
push_back
(
'\n'
);
buffer
.
append
(
graph
.
repr
());
buffer
.
push_back
(
'|'
);
for
(
size_t
i
=
0
;
i
<
output_mask
.
size
();
++
i
)
{
buffer
.
push_back
(
output_mask
[
i
]
?
'#'
:
' '
);
}
buffer
.
push_back
(
'|'
);
return
buffer
;
}
size_t
EncodedSubraph
::
hash
()
const
{
return
std
::
hash
<
std
::
string
>
{}(
repr
());
}
}
// namespace imperative
}
// namespace mgb
imperative/src/include/megbrain/imperative/backward_graph_opt.h
浏览文件 @
0b8dc2c9
...
...
@@ -19,7 +19,7 @@ struct OptimizedBackwardGraphResult {
SmallVector
<
bool
>
save_for_backward
;
SmallVector
<
bool
>
input_has_grad
;
OptimizedBackwardGraphResult
(
const
BackwardGraphResult
&
bgraph
);
OptimizedBackwardGraphResult
(
const
EncodedSubraph
&
bgraph
);
};
}
// namespace mgb::imperative
imperative/src/include/megbrain/imperative/op_def.h
浏览文件 @
0b8dc2c9
...
...
@@ -29,12 +29,6 @@ enum DispatchMode {
using
SharedOp
=
std
::
shared_ptr
<
OpDef
>
;
struct
BackwardGraphResult
{
Subgraph
backward
;
SmallVector
<
bool
>
save_for_backward
;
SmallVector
<
bool
>
input_has_grad
;
};
class
OpDef
:
public
Hashable
,
public
NonCopyableObj
,
public
std
::
enable_shared_from_this
<
OpDef
>
{
...
...
@@ -91,7 +85,7 @@ public:
const
SmallVector
<
TensorPtr
>&
inputs_tensors
,
const
SmallVector
<
MemoryDesc
>&
inputs_mems
);
static
BackwardGraphResult
make_backward_graph
(
static
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
...
...
imperative/src/include/megbrain/imperative/proxy_graph_detail.h
浏览文件 @
0b8dc2c9
...
...
@@ -38,7 +38,7 @@ void exec(const OpDef& def,
const
SmallVector
<
TensorPtr
>&
inputs
,
const
SmallVector
<
TensorPtr
>&
outputs
);
BackwardGraphResult
EncodedSubraph
make_backward_graph
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
...
...
imperative/src/include/megbrain/imperative/subgraph.h
浏览文件 @
0b8dc2c9
...
...
@@ -96,5 +96,185 @@ struct Subgraph {
bool
operator
==
(
const
Subgraph
&
rhs
)
const
;
};
struct
EncodedSubraph
{
Subgraph
graph
;
SmallVector
<
bool
>
input_mask
;
SmallVector
<
bool
>
output_mask
;
template
<
typename
TContainer
>
TContainer
encode_inputs
(
TContainer
inputs
)
const
{
TContainer
encoded_inputs
;
size_t
index
=
0
;
for
(
auto
&&
input
:
inputs
)
{
mgb_assert
(
index
<
input_mask
.
size
(),
"index out of range"
);
if
(
input_mask
[
index
++
])
{
encoded_inputs
.
push_back
(
input
);
}
}
mgb_assert
(
index
==
input_mask
.
size
(),
"mask size mismatch"
);
return
encoded_inputs
;
}
template
<
typename
TContainer
>
TContainer
encode_outputs
(
TContainer
outputs
)
const
{
TContainer
encoded_outputs
;
size_t
index
=
0
;
for
(
auto
&&
output
:
outputs
)
{
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
if
(
output_mask
[
index
++
])
{
encoded_outputs
.
push_back
(
output
);
}
}
mgb_assert
(
index
==
output_mask
.
size
(),
"mask size mismatch"
);
return
encoded_outputs
;
}
template
<
typename
TContainer
>
TContainer
decode_outputs
(
TContainer
outputs
)
const
{
TContainer
decoded_outputs
;
size_t
index
=
0
;
for
(
size_t
i
=
0
;
i
<
output_mask
.
size
();
i
++
)
{
mgb_assert
(
index
<
output_mask
.
size
(),
"index out of range"
);
if
(
output_mask
[
i
])
{
decoded_outputs
.
push_back
(
outputs
[
index
++
]);
}
else
{
decoded_outputs
.
emplace_back
();
}
}
mgb_assert
(
decoded_outputs
.
size
()
==
output_mask
.
size
(),
"mask size mismatch"
);
return
decoded_outputs
;
}
static
EncodedSubraph
make
(
Subgraph
graph
)
{
EncodedSubraph
result
;
result
.
input_mask
=
graph
.
gen_input_mask
();
result
.
output_mask
=
graph
.
gen_output_mask
();
graph
.
inputs
=
result
.
encode_inputs
(
graph
.
inputs
);
graph
.
outputs
=
result
.
encode_outputs
(
graph
.
outputs
);
result
.
graph
=
graph
;
return
result
;
}
static
EncodedSubraph
make_single
(
std
::
shared_ptr
<
OpDef
>
op
,
SmallVector
<
bool
>
input_mask
,
SmallVector
<
bool
>
output_mask
)
{
EncodedSubraph
result
;
result
.
input_mask
=
input_mask
;
result
.
output_mask
=
output_mask
;
Subgraph
::
var_t
last_var
=
0
;
for
(
auto
&&
mask
:
input_mask
)
{
if
(
mask
)
{
result
.
graph
.
inputs
.
push_back
(
++
last_var
);
}
}
for
(
auto
&&
mask
:
output_mask
)
{
if
(
mask
)
{
result
.
graph
.
outputs
.
push_back
(
++
last_var
);
}
}
result
.
graph
.
exprs
=
{
Subgraph
::
expr_t
{
op
,
result
.
graph
.
inputs
,
result
.
graph
.
outputs
}};
return
result
;
}
template
<
typename
T
,
typename
F
,
typename
C
>
SmallVector
<
T
>
apply
(
SmallVector
<
T
>
input_vars
,
F
&&
f
,
C
&&
c
)
const
{
auto
encoded_inputs
=
encode_inputs
(
input_vars
);
auto
encoded_outputs
=
graph
.
apply
(
encoded_inputs
,
std
::
forward
<
F
>
(
f
),
std
::
forward
<
C
>
(
c
));
return
decode_outputs
(
encoded_outputs
);
}
std
::
string
repr
()
const
;
size_t
hash
()
const
;
};
template
<
typename
T
>
class
GradContext
{
public:
using
var_t
=
T
;
using
vars_t
=
SmallVector
<
var_t
>
;
using
expr_t
=
Expr
<
T
>
;
private:
std
::
unordered_map
<
var_t
,
var_t
>
m_grads
;
std
::
unordered_set
<
var_t
>
m_vars_require_grad
;
std
::
function
<
var_t
(
var_t
,
var_t
)
>
m_accumulator
;
std
::
vector
<
expr_t
>
m_exprs
;
public:
GradContext
(
std
::
function
<
var_t
(
var_t
,
var_t
)
>
accumulator
)
:
m_accumulator
{
std
::
move
(
accumulator
)}{}
SmallVector
<
bool
>
get_require_grads
(
vars_t
dests
)
{
SmallVector
<
bool
>
mask
;
for
(
auto
&&
dest
:
dests
)
{
mask
.
push_back
(
bool
(
m_vars_require_grad
.
count
(
dest
)));
}
return
mask
;
}
SmallVector
<
bool
>
get_has_grads
(
vars_t
dests
)
{
SmallVector
<
bool
>
mask
;
for
(
auto
&&
dest
:
dests
)
{
mask
.
push_back
(
bool
(
m_grads
.
count
(
dest
)));
}
return
mask
;
}
void
mark_require_grads
(
vars_t
dests
)
{
for
(
auto
&&
dest
:
dests
)
{
m_vars_require_grad
.
insert
(
dest
);
}
}
var_t
accumulate_grad
(
var_t
dest
,
var_t
grad
)
{
if
(
!
m_grads
.
count
(
dest
))
{
return
m_grads
[
dest
]
=
grad
;
}
else
{
return
m_grads
[
dest
]
=
m_accumulator
(
m_grads
[
dest
],
grad
);
}
}
void
record_expr
(
std
::
shared_ptr
<
OpDef
>
op
,
vars_t
inputs
,
vars_t
outputs
)
{
bool
require_grad
=
false
;
for
(
auto
&&
input
:
inputs
)
{
if
(
m_vars_require_grad
.
count
(
input
))
{
require_grad
=
true
;
break
;
}
}
if
(
require_grad
)
{
m_exprs
.
push_back
({
op
,
inputs
,
outputs
});
mark_require_grads
(
outputs
);
}
}
template
<
typename
TFunctor
>
void
backward
(
vars_t
outputs
,
vars_t
output_grads
,
TFunctor
functor
)
{
size_t
nr_outputs
=
outputs
.
size
();
for
(
size_t
i
=
0
;
i
<
nr_outputs
;
++
i
)
{
m_grads
[
outputs
[
i
]]
=
output_grads
[
i
];
}
auto
exprs
=
m_exprs
;
std
::
reverse
(
exprs
.
begin
(),
exprs
.
end
());
for
(
const
expr_t
&
expr
:
exprs
)
{
size_t
nr_inputs
=
expr
.
inputs
.
size
();
vars_t
input_grads
=
functor
(
expr
,
get_grads
(
expr
.
outputs
));
mgb_assert
(
input_grads
.
size
()
==
nr_inputs
,
"input size mismatch"
);
for
(
size_t
i
=
0
;
i
<
nr_inputs
;
++
i
)
{
if
(
input_grads
[
i
]
&&
m_vars_require_grad
.
count
(
expr
.
inputs
[
i
]))
{
accumulate_grad
(
expr
.
inputs
[
i
],
input_grads
[
i
]);
}
}
}
}
var_t
get_grad
(
var_t
dest
)
{
if
(
m_grads
.
count
(
dest
))
{
return
m_grads
.
at
(
dest
);
}
return
0
;
}
vars_t
get_grads
(
vars_t
dests
)
{
vars_t
grads
;
for
(
auto
&&
dest
:
dests
)
{
grads
.
push_back
(
get_grad
(
dest
));
}
return
grads
;
}
};
}
// namespace imperative
}
// namespace mgb
\ No newline at end of file
imperative/src/test/backward_graph.cpp
浏览文件 @
0b8dc2c9
...
...
@@ -22,22 +22,22 @@ using namespace cg;
using
namespace
imperative
;
template
<
typename
T
>
T
prepare_backward_graph_inputs
(
const
BackwardGraphResult
&
bg
,
const
T
&
inputs
,
T
prepare_backward_graph_inputs
(
const
EncodedSubraph
&
bg
,
const
T
&
inputs
,
const
T
&
outputs
,
const
T
&
grads
)
{
T
ret
;
size_t
i
=
0
;
for
(
auto
&&
t
:
inputs
)
{
if
(
bg
.
save_for_backward
[
i
++
])
{
if
(
bg
.
input_mask
[
i
++
])
{
ret
.
push_back
(
t
);
}
}
for
(
auto
&&
t
:
outputs
)
{
if
(
bg
.
save_for_backward
[
i
++
])
{
if
(
bg
.
input_mask
[
i
++
])
{
ret
.
push_back
(
t
);
}
}
for
(
auto
&&
t
:
grads
)
{
if
(
bg
.
save_for_backward
[
i
++
])
{
if
(
bg
.
input_mask
[
i
++
])
{
ret
.
push_back
(
t
);
}
}
...
...
@@ -45,10 +45,10 @@ T prepare_backward_graph_inputs(const BackwardGraphResult& bg, const T& inputs,
}
template
<
typename
T
,
typename
U
>
T
expand_grads
(
const
U
&
bg
,
const
T
&
outputs
)
{
T
ret
(
bg
.
input_has_grad
.
size
());
for
(
size_t
i
=
0
,
j
=
0
;
i
<
bg
.
input_has_grad
.
size
();
++
i
)
{
if
(
bg
.
input_has_grad
[
i
])
{
T
expand_grads
(
const
U
&
mask
,
const
T
&
outputs
)
{
T
ret
(
mask
.
size
());
for
(
size_t
i
=
0
,
j
=
0
;
i
<
mask
.
size
();
++
i
)
{
if
(
mask
[
i
])
{
ret
[
i
]
=
outputs
[
j
++
];
}
}
...
...
@@ -80,7 +80,7 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg,
}
SmallVector
<
TensorPtr
>
apply_shared_on_physical_tensor
(
std
::
shared_ptr
<
OpDef
>
def
,
SmallVector
<
TensorPtr
>
inputs
)
{
std
::
shared_ptr
<
OpDef
>
def
,
SmallVector
<
TensorPtr
>
inputs
,
size_t
nr_outputs
)
{
return
OpDef
::
apply_on_physical_tensor
(
*
def
,
inputs
);
}
...
...
@@ -104,8 +104,8 @@ TEST(TestImperative, BackwardGraphBasic) {
}
auto
result
=
OpDef
::
make_backward_graph
(
*
attr
,
input_descs
,
{
true
,
true
},
{
true
});
auto
&&
save_for_backward
=
result
.
save_for_backward
;
auto
&&
input_has_grad
=
result
.
input_has_grad
;
auto
&&
save_for_backward
=
result
.
input_mask
;
auto
&&
input_has_grad
=
result
.
output_mask
;
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
attr
,
inputs
);
inputs
.
push_back
(
outputs
[
0
]);
...
...
@@ -124,7 +124,7 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs
.
clear
();
auto
input_grads
=
result
.
backward
.
apply
(
backward_graph_inputs
,
auto
input_grads
=
result
.
graph
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
});
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
...
...
@@ -159,8 +159,8 @@ TEST(TestImperative, BackwardGraphIdentity) {
input_descs
.
push_back
({
a
->
layout
(),
a
->
comp_node
()});
auto
result
=
OpDef
::
make_backward_graph
(
*
attr
,
input_descs
,
{
true
},
{
true
});
auto
&&
save_for_backward
=
result
.
save_for_backward
;
auto
&&
input_has_grad
=
result
.
input_has_grad
;
auto
&&
save_for_backward
=
result
.
input_mask
;
auto
&&
input_has_grad
=
result
.
output_mask
;
auto
outputs
=
OpDef
::
apply_on_physical_tensor
(
*
attr
,
inputs
);
inputs
.
push_back
(
outputs
[
0
]);
...
...
@@ -178,7 +178,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs
.
clear
();
auto
input_grads
=
result
.
backward
.
apply
(
backward_graph_inputs
,
auto
input_grads
=
result
.
graph
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
});
mgb_assert
(
input_grads
.
size
()
==
input_has_grad
.
size
());
...
...
@@ -245,7 +245,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
prepare_backward_graph_inputs
<
SmallVector
<
TensorPtr
>>
(
bg
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads
=
expand_grads
(
bg
,
bg
.
backward
.
apply
(
backward_graph_inputs
,
expand_grads
(
bg
.
output_mask
,
bg
.
graph
.
apply
(
backward_graph_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
}));
...
...
@@ -262,7 +262,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
prepare_optimized_backward_inputs
<
SmallVector
<
TensorPtr
>>
(
obg
,
precomp
,
{
a_tn
,
b_tn
},
{
c_tn
},
{
dc_tn
});
auto
grads2
=
expand_grads
(
obg
,
obg
.
input_has_grad
,
obg
.
backward
.
apply
(
backward_inputs
,
apply_shared_on_physical_tensor
,
[
&
](
auto
&&
x
)
{
return
x
;
}));
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录