Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
c2293815
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
c2293815
编写于
3月 17, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(autodiff): proxy_graph_detail::make_backward_graph support multiple opnodes
GitOrigin-RevId: 2c0c8f330da645438f2a5ef17c9acef588f89fb3
上级
335d51b4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
158 addition
and
130 deletion
+158
-130
imperative/src/impl/proxy_graph.cpp
imperative/src/impl/proxy_graph.cpp
+62
-115
imperative/src/impl/subgraph_detail.cpp
imperative/src/impl/subgraph_detail.cpp
+82
-10
imperative/src/include/megbrain/imperative/subgraph.h
imperative/src/include/megbrain/imperative/subgraph.h
+12
-5
imperative/src/include/megbrain/imperative/subgraph_detail.h
imperative/src/include/megbrain/imperative/subgraph_detail.h
+2
-0
未找到文件。
imperative/src/impl/proxy_graph.cpp
浏览文件 @
c2293815
...
@@ -11,10 +11,13 @@
...
@@ -11,10 +11,13 @@
#include "./proxy_graph.h"
#include "./proxy_graph.h"
#include "./blob_manager_impl.h"
#include "./blob_manager_impl.h"
#include "megbrain/graph.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/graph/operator_node.h"
#include "megbrain/graph/static_infer.h"
#include "megbrain/graph/static_infer.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/subgraph_detail.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/tensor_manip.h"
...
@@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
...
@@ -486,139 +489,83 @@ EncodedSubgraph ProxyGraph::make_backward_graph(
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
const
OpDef
&
opdef
,
const
SmallVector
<
LogicalTensorDesc
>&
input_descs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
)
{
const
SmallVector
<
bool
>&
output_has_grad
)
{
ThinHashMap
<
VarNode
*
,
size_t
>
var2idx
;
using
op_t
=
OperatorNodeBase
*
;
auto
push
=
[
&
var2idx
,
using
var_t
=
VarNode
*
;
cnt
=
1
](
VarNode
*
var
)
mutable
{
// cnt is always greater non zero
using
vars_t
=
VarNodeArray
;
auto
&&
ret
=
var2idx
.
emplace
(
var
,
cnt
++
);
mgb_assert
(
ret
.
second
,
"var %s has been already inserted"
,
var
->
cname
());
return
ret
.
first
->
second
;
};
auto
inputs
=
make_input_place_holders
(
input_descs
);
auto
inputs
=
make_input_place_holders
(
input_descs
);
auto
fwd
=
OpDef
::
apply_on_var_node
(
opdef
,
inputs
)[
0
]
->
owner_opr
();
auto
outputs
=
OpDef
::
apply_on_var_node
(
opdef
,
inputs
);
auto
&&
outputs
=
fwd
->
usable_output
();
SmallVector
<
LogicalTensorDesc
>
output_descs
;
SmallVector
<
LogicalTensorDesc
>
output_descs
;
for
(
auto
&&
i
:
outputs
)
{
for
(
auto
&&
i
:
outputs
)
{
output_descs
.
push_back
({
TensorLayout
{
i
->
dtype
()},
i
->
comp_node
()});
output_descs
.
push_back
({
TensorLayout
{
i
->
dtype
()},
i
->
comp_node
()});
}
}
GradContext
<
op_t
,
var_t
>
grad_context
{[
&
](
VarNode
*
lhs
,
VarNode
*
rhs
)
->
VarNode
*
{
auto
add
=
opr
::
Elemwise
::
Mode
::
ADD
;
return
opr
::
Elemwise
::
make
(
VarNodeArray
{
lhs
,
rhs
},
add
).
node
();
}};
cg
::
DepOprIter
iter
{[
&
](
OperatorNodeBase
*
op
)
{
grad_context
.
record_expr
(
op
,
op
->
input
(),
op
->
output
());
}};
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
auto
&
input
=
inputs
[
i
];
iter
.
set_visited
(
input
->
owner_opr
());
if
(
input_requires_grad
[
i
])
{
grad_context
.
mark_require_grad
(
input
);
}
}
for
(
auto
&&
output
:
outputs
)
{
iter
.
add
(
output
);
}
auto
output_grads
=
make_input_place_holders
(
output_descs
);
auto
output_grads
=
make_input_place_holders
(
output_descs
);
mgb_assert
(
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
output_grads
.
size
()
==
output_has_grad
.
size
(),
"%d vs %d"
,
output_grads
.
size
(),
output_has_grad
.
size
());
bool
any_input_has_grad
=
false
;
for
(
size_t
i
=
0
;
i
<
output_grads
.
size
();
++
i
)
{
if
(
!
output_has_grad
[
i
])
{
if
(
!
output_has_grad
[
i
])
{
output_grads
[
i
]
=
nullptr
;
output_grads
[
i
]
=
nullptr
;
}
else
{
any_input_has_grad
=
true
;
}
}
}
}
if
(
!
any_input_has_grad
)
{
auto
compute_input_grads
=
[
&
](
op_t
op
,
vars_t
inputs
,
vars_t
outputs
,
return
{};
vars_t
output_grads
)
{
}
auto
*
gfunc
=
cg
::
lookup_grad_func
(
op
->
dyn_typeinfo
());
auto
*
gfunc
=
cg
::
lookup_grad_func
(
fwd
->
dyn_typeinfo
());
vars_t
input_grads
(
inputs
.
size
(),
nullptr
);
bool
any_grad
=
false
;
EncodedSubgraph
result
;
for
(
auto
&&
output_grad
:
output_grads
)
{
auto
&&
igraph
=
result
.
graph
;
if
(
output_grad
)
{
any_grad
=
true
;
size_t
nr_backward_graph_inputs
=
0
;
auto
gen_expr
=
[
this
,
&
var2idx
,
&
igraph
,
&
push
,
&
fwd
,
&
nr_backward_graph_inputs
](
cg
::
OperatorNodeBase
*
op
)
{
if
(
auto
t
=
as_tensor
(
op
))
{
mgb_assert
(
op
->
output
().
size
()
==
1
);
igraph
.
constants
.
emplace_back
(
push
(
op
->
output
(
0
)),
std
::
move
(
t
));
}
else
if
(
op
->
same_type
<
InputPlaceholder
>
())
{
++
nr_backward_graph_inputs
;
push
(
op
->
output
(
0
));
}
else
{
SmallVector
<
size_t
>
inputs
,
outputs
;
for
(
auto
&&
i
:
op
->
input
())
{
if
(
i
->
owner_opr
()
==
fwd
)
{
if
(
var2idx
.
find
(
i
)
==
var2idx
.
end
())
{
++
nr_backward_graph_inputs
;
push
(
i
);
}
}
inputs
.
push_back
(
var2idx
.
at
(
i
));
}
for
(
auto
&&
i
:
op
->
usable_output
())
{
outputs
.
push_back
(
push
(
i
));
}
}
igraph
.
exprs
.
push_back
({
OpDef
::
make_from_op_node
(
op
),
inputs
,
outputs
});
}
}
};
if
(
!
gfunc
||
!
any_grad
)
{
return
input_grads
;
// set backward graph outputs
cg
::
DepOprIter
iter
{
gen_expr
};
iter
.
set_visited
(
fwd
);
result
.
output_mask
.
resize
(
inputs
.
size
());
VarNodeArray
output_grads_with_unused_var
;
{
auto
iter
=
output_grads
.
begin
();
for
(
auto
&&
i
:
fwd
->
output
())
{
if
(
i
->
contain_flag
(
VarNode
::
Flag
::
VOLATILE_CONTENT
))
{
// the var node with VOLATILE_CONTENT(e.g. workspace
// or an empty var) would not be considered as a normal
// output, so its grad is always NULL
output_grads_with_unused_var
.
push_back
(
nullptr
);
}
else
{
output_grads_with_unused_var
.
push_back
(
*
iter
);
++
iter
;
}
}
}
mgb_assert
(
iter
==
output_grads
.
end
());
Maybe
<
VarNodeArray
>
grad_results
;
}
auto
&&
input_requires_grad
=
grad_context
.
get_require_grads
(
inputs
);
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
Maybe
<
VarNodeArray
>
grad_results
;
VarNode
*
grad
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
if
(
grad_results
.
valid
())
{
VarNode
*
grad
;
if
(
grad_results
.
valid
())
{
grad
=
grad_results
.
val
()[
i
];
}
else
{
mgb_assert
(
gfunc
,
"could not find grad function"
);
auto
res
=
(
*
gfunc
)(
fwd
,
i
,
output_grads_with_unused_var
);
if
(
res
.
from_single
())
{
grad
=
res
.
single
();
}
else
{
grad_results
.
emplace
(
res
.
all
(
fwd
));
grad
=
grad_results
.
val
()[
i
];
grad
=
grad_results
.
val
()[
i
];
}
}
if
(
grad
&&
!
grad
->
owner_opr
()
->
same_type
<
opr
::
InvalidGrad
>
()
&&
input_requires_grad
[
i
])
{
mgb_assert
(
!
grad
->
owner_opr
()
->
same_type
<
opr
::
InvalidGrad
>
(),
"gradient of operator %s w.r.t. input #%lu is "
"either not well defined or not implemented"
,
fwd
->
dyn_typeinfo
()
->
name
,
i
);
iter
.
add
(
grad
);
igraph
.
outputs
.
push_back
(
var2idx
.
at
(
grad
));
result
.
output_mask
[
i
]
=
true
;
}
else
{
result
.
output_mask
[
i
]
=
false
;
}
}
if
(
igraph
.
outputs
.
empty
())
{
return
{};
}
// set backward graph inputs
auto
write_inputs
=
[
&
igraph
,
&
var2idx
,
&
result
](
const
VarNodeArray
&
vars
)
{
for
(
auto
&&
i
:
vars
)
{
auto
&&
iter
=
var2idx
.
find
(
i
);
if
(
iter
!=
var2idx
.
end
())
{
igraph
.
inputs
.
push_back
(
iter
->
second
);
result
.
input_mask
.
push_back
(
true
);
}
else
{
}
else
{
result
.
input_mask
.
push_back
(
false
);
mgb_assert
(
gfunc
,
"could not find grad function"
);
auto
res
=
(
*
gfunc
)(
op
,
i
,
output_grads
);
if
(
res
.
from_single
())
{
grad
=
res
.
single
();
}
else
{
grad_results
.
emplace
(
res
.
all
(
op
));
grad
=
grad_results
.
val
()[
i
];
}
}
if
(
grad
&&
!
grad
->
owner_opr
()
->
same_type
<
opr
::
InvalidGrad
>
())
{
if
(
input_requires_grad
[
i
])
{
input_grads
[
i
]
=
grad
;
}
}
}
}
}
return
input_grads
;
};
};
write_inputs
(
inputs
);
grad_context
.
backward
(
outputs
,
output_grads
,
compute_input_grads
);
write_inputs
(
outputs
);
auto
input_grads
=
grad_context
.
get_grads
(
inputs
);
write_inputs
(
output_grads
);
VarNodeArray
bgraph_inputs
;
mgb_assert
(
igraph
.
inputs
.
size
()
==
nr_backward_graph_inputs
);
bgraph_inputs
.
insert
(
bgraph_inputs
.
end
(),
inputs
.
begin
(),
inputs
.
end
());
return
result
;
bgraph_inputs
.
insert
(
bgraph_inputs
.
end
(),
outputs
.
begin
(),
outputs
.
end
());
bgraph_inputs
.
insert
(
bgraph_inputs
.
end
(),
output_grads
.
begin
(),
output_grads
.
end
());
auto
graph
=
subgraph_detail
::
make_from_computing_graph
(
bgraph_inputs
,
input_grads
);
return
graph
;
}
}
VarNodeArray
ProxyGraph
::
make_input_place_holders
(
VarNodeArray
ProxyGraph
::
make_input_place_holders
(
...
...
imperative/src/impl/subgraph_detail.cpp
浏览文件 @
c2293815
...
@@ -107,13 +107,16 @@ EncodedSubgraph make_backward_graph_from_forward(
...
@@ -107,13 +107,16 @@ EncodedSubgraph make_backward_graph_from_forward(
Subgraph
::
Builder
<
LogicalTensorDesc
>
builder
(
Subgraph
::
Builder
<
LogicalTensorDesc
>
builder
(
[](
auto
&&
op
,
auto
&&
input_descs
,
size_t
nr_outputs
)
{
[](
auto
&&
op
,
auto
&&
input_descs
,
size_t
nr_outputs
)
{
auto
[
descs
,
_
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
auto
[
descs
,
_
]
=
OpDef
::
infer_output_attrs_fallible
(
*
op
,
input_descs
);
mgb_assert
(
descs
.
size
()
==
nr_outputs
,
"nr_outputs mismatch for %s"
,
op
->
to_string
().
c_str
());
return
descs
;
return
descs
;
});
});
auto
accum_grad
=
[
&
](
var_t
lhs
,
var_t
rhs
)
{
auto
accum_grad
=
[
&
](
var_t
lhs
,
var_t
rhs
)
{
return
builder
.
write_expr
(
return
builder
.
write_expr
(
Elemwise
::
make
(
Elemwise
::
Mode
::
ADD
),
{
lhs
,
rhs
},
1
)[
0
];
Elemwise
::
make
(
Elemwise
::
Mode
::
ADD
),
{
lhs
,
rhs
},
1
)[
0
];
};
};
GradContext
<
var_t
>
grad_context
{
accum_grad
};
GradContext
<
std
::
shared_ptr
<
OpDef
>
,
var_t
>
grad_context
{
accum_grad
};
auto
input_vars
=
builder
.
write_inputs
(
inputs
);
auto
input_vars
=
builder
.
write_inputs
(
inputs
);
auto
outputs
=
forward_graph
.
apply
<
var_t
>
(
auto
outputs
=
forward_graph
.
apply
<
var_t
>
(
input_vars
,
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
),
input_vars
,
std
::
bind
(
&
decltype
(
builder
)
::
write_expr
,
&
builder
,
_1
,
_2
,
_3
),
...
@@ -143,19 +146,17 @@ EncodedSubgraph make_backward_graph_from_forward(
...
@@ -143,19 +146,17 @@ EncodedSubgraph make_backward_graph_from_forward(
grad_context
.
backward
(
grad_context
.
backward
(
apply_mask
(
outputs
,
output_has_grad
),
apply_mask
(
outputs
,
output_has_grad
),
apply_mask
(
output_grads
,
output_has_grad
),
apply_mask
(
output_grads
,
output_has_grad
),
[
&
](
Subgraph
::
expr_t
expr
,
vars_t
output_grads
)
{
[
&
](
Subgraph
::
op_t
op
,
vars_t
inputs
,
vars_t
outputs
,
vars_t
output_grads
)
{
auto
bg
=
OpDef
::
make_backward_graph
(
auto
bg
=
OpDef
::
make_backward_graph
(
*
expr
.
op
,
builder
.
get_descs
(
expr
.
inputs
),
*
op
,
builder
.
get_descs
(
inputs
),
grad_context
.
get_require_grads
(
expr
.
inputs
),
grad_context
.
get_require_grads
(
inputs
),
grad_context
.
get_has_grads
(
expr
.
outputs
));
grad_context
.
get_has_grads
(
outputs
));
if
(
bg
.
graph
.
empty
())
{
if
(
bg
.
graph
.
empty
())
{
return
vars_t
(
expr
.
inputs
.
size
(),
0
);
return
vars_t
(
inputs
.
size
(),
0
);
}
}
vars_t
grad_inputs
;
vars_t
grad_inputs
;
grad_inputs
.
insert
(
grad_inputs
.
insert
(
grad_inputs
.
end
(),
inputs
.
begin
(),
inputs
.
end
());
grad_inputs
.
end
(),
expr
.
inputs
.
begin
(),
expr
.
inputs
.
end
());
grad_inputs
.
insert
(
grad_inputs
.
end
(),
outputs
.
begin
(),
outputs
.
end
());
grad_inputs
.
insert
(
grad_inputs
.
end
(),
expr
.
outputs
.
begin
(),
expr
.
outputs
.
end
());
grad_inputs
.
insert
(
grad_inputs
.
insert
(
grad_inputs
.
end
(),
output_grads
.
begin
(),
output_grads
.
end
());
grad_inputs
.
end
(),
output_grads
.
begin
(),
output_grads
.
end
());
auto
apply_functor
=
auto
apply_functor
=
...
@@ -183,6 +184,77 @@ EncodedSubgraph make_backward_graph(
...
@@ -183,6 +184,77 @@ EncodedSubgraph make_backward_graph(
forward_graph
,
inputs
,
input_requires_grad
,
output_has_grad
);
forward_graph
,
inputs
,
input_requires_grad
,
output_has_grad
);
}
}
EncodedSubgraph
make_from_computing_graph
(
const
VarNodeArray
&
inputs
,
const
VarNodeArray
&
outputs
)
{
Subgraph
subgraph
;
std
::
unordered_map
<
VarNode
*
,
size_t
>
var2idx
;
size_t
next_idx
=
0
;
var2idx
[
nullptr
]
=
next_idx
++
;
for
(
auto
&&
input
:
inputs
)
{
if
(
input
)
{
var2idx
[
input
]
=
next_idx
++
;
}
}
auto
is_tensor_holder
=
[](
cg
::
OperatorNodeBase
*
op
)
{
return
op
->
input
().
empty
();
};
auto
as_tensor
=
[](
VarNode
*
var
)
->
TensorPtr
{
auto
*
opr
=
var
->
owner_opr
();
if
(
auto
*
imm_tensor
=
opr
->
try_cast_final
<
opr
::
ImmutableTensor
>
())
{
auto
&&
dv
=
imm_tensor
->
value
();
HostTensorND
hv
(
dv
.
comp_node
(),
dv
.
shape
(),
dv
.
dtype
());
// get host value
auto
&&
cpu_value
=
imm_tensor
->
host_value
();
mgb_assert
(
cpu_value
.
comp_node
()
==
CompNode
::
default_cpu
());
// default_cpu is synchronous with respect to caller
hv
.
proxy_to_default_cpu
().
copy_from_fixlayout
(
cpu_value
);
return
Tensor
::
make
(
dv
,
hv
);
}
else
if
(
auto
*
shared_tensor
=
opr
->
try_cast_final
<
opr
::
SharedDeviceTensor
>
())
{
return
Tensor
::
make
(
shared_tensor
->
get_dev_tensor
());
}
else
{
mgb_assert
(
false
,
"unsupported tensor holder opr %s"
,
opr
->
dyn_typeinfo
()
->
name
);
}
};
cg
::
DepOprIter
iter
{[
&
](
cg
::
OperatorNodeBase
*
op
)
{
// TODO: implement make_backward_graph for mm ops
// mgb_assert(!op->node_prop().contain(cg::OperatorNodeProp::Flag::IMPURE_FUNC));
if
(
is_tensor_holder
(
op
))
{
for
(
auto
&&
output
:
op
->
usable_output
())
{
subgraph
.
constants
.
push_back
(
{
var2idx
[
output
]
=
next_idx
++
,
as_tensor
(
output
)});
}
}
else
{
Subgraph
::
vars_t
inputs
;
Subgraph
::
vars_t
outputs
;
for
(
auto
&&
input
:
op
->
input
())
{
inputs
.
push_back
(
var2idx
.
at
(
input
));
}
// NOTE: use usable_output
for
(
auto
&&
output
:
op
->
usable_output
())
{
outputs
.
push_back
(
var2idx
[
output
]
=
next_idx
++
);
}
auto
opdef
=
OpDef
::
make_from_op_node
(
op
);
subgraph
.
exprs
.
push_back
({
opdef
,
inputs
,
outputs
});
}
}};
for
(
auto
&&
input
:
inputs
)
{
if
(
input
)
{
iter
.
set_visited
(
input
->
owner_opr
());
}
subgraph
.
inputs
.
push_back
(
var2idx
.
at
(
input
));
}
for
(
auto
&&
output
:
outputs
)
{
if
(
output
)
{
iter
.
add
(
output
);
}
subgraph
.
outputs
.
push_back
(
var2idx
.
at
(
output
));
}
return
EncodedSubgraph
::
make
(
subgraph
);
}
}
// namespace subgraph_detail
}
// namespace subgraph_detail
}
// namespace imperative
}
// namespace imperative
}
// namespace mgb
}
// namespace mgb
imperative/src/include/megbrain/imperative/subgraph.h
浏览文件 @
c2293815
...
@@ -189,12 +189,17 @@ struct EncodedSubgraph {
...
@@ -189,12 +189,17 @@ struct EncodedSubgraph {
size_t
hash
()
const
;
size_t
hash
()
const
;
};
};
template
<
typename
T
>
template
<
typename
T
Op
,
typename
TVar
>
class
GradContext
{
class
GradContext
{
public:
public:
using
var_t
=
T
;
using
op_t
=
TOp
;
using
var_t
=
TVar
;
using
vars_t
=
SmallVector
<
var_t
>
;
using
vars_t
=
SmallVector
<
var_t
>
;
using
expr_t
=
Expr
<
T
>
;
struct
expr_t
{
op_t
op
;
vars_t
inputs
;
vars_t
outputs
;
};
private:
private:
std
::
unordered_map
<
var_t
,
var_t
>
m_grads
;
std
::
unordered_map
<
var_t
,
var_t
>
m_grads
;
...
@@ -219,6 +224,7 @@ public:
...
@@ -219,6 +224,7 @@ public:
}
}
return
mask
;
return
mask
;
}
}
void
mark_require_grad
(
var_t
dest
)
{
m_vars_require_grad
.
insert
(
dest
);
}
void
mark_require_grads
(
vars_t
dests
)
{
void
mark_require_grads
(
vars_t
dests
)
{
for
(
auto
&&
dest
:
dests
)
{
for
(
auto
&&
dest
:
dests
)
{
m_vars_require_grad
.
insert
(
dest
);
m_vars_require_grad
.
insert
(
dest
);
...
@@ -231,7 +237,7 @@ public:
...
@@ -231,7 +237,7 @@ public:
return
m_grads
[
dest
]
=
m_accumulator
(
m_grads
[
dest
],
grad
);
return
m_grads
[
dest
]
=
m_accumulator
(
m_grads
[
dest
],
grad
);
}
}
}
}
void
record_expr
(
std
::
shared_ptr
<
OpDef
>
op
,
vars_t
inputs
,
vars_t
outputs
)
{
void
record_expr
(
op_t
op
,
vars_t
inputs
,
vars_t
outputs
)
{
bool
require_grad
=
false
;
bool
require_grad
=
false
;
for
(
auto
&&
input
:
inputs
)
{
for
(
auto
&&
input
:
inputs
)
{
if
(
m_vars_require_grad
.
count
(
input
))
{
if
(
m_vars_require_grad
.
count
(
input
))
{
...
@@ -254,7 +260,8 @@ public:
...
@@ -254,7 +260,8 @@ public:
std
::
reverse
(
exprs
.
begin
(),
exprs
.
end
());
std
::
reverse
(
exprs
.
begin
(),
exprs
.
end
());
for
(
const
expr_t
&
expr
:
exprs
)
{
for
(
const
expr_t
&
expr
:
exprs
)
{
size_t
nr_inputs
=
expr
.
inputs
.
size
();
size_t
nr_inputs
=
expr
.
inputs
.
size
();
vars_t
input_grads
=
functor
(
expr
,
get_grads
(
expr
.
outputs
));
vars_t
input_grads
=
functor
(
expr
.
op
,
expr
.
inputs
,
expr
.
outputs
,
get_grads
(
expr
.
outputs
));
mgb_assert
(
input_grads
.
size
()
==
nr_inputs
,
"input size mismatch"
);
mgb_assert
(
input_grads
.
size
()
==
nr_inputs
,
"input size mismatch"
);
for
(
size_t
i
=
0
;
i
<
nr_inputs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nr_inputs
;
++
i
)
{
if
(
input_grads
[
i
]
&&
m_vars_require_grad
.
count
(
expr
.
inputs
[
i
]))
{
if
(
input_grads
[
i
]
&&
m_vars_require_grad
.
count
(
expr
.
inputs
[
i
]))
{
...
...
imperative/src/include/megbrain/imperative/subgraph_detail.h
浏览文件 @
c2293815
...
@@ -43,6 +43,8 @@ EncodedSubgraph make_backward_graph_from_forward(
...
@@ -43,6 +43,8 @@ EncodedSubgraph make_backward_graph_from_forward(
const
EncodedSubgraph
&
forward
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
EncodedSubgraph
&
forward
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
input_requires_grad
,
const
SmallVector
<
bool
>&
output_has_grad
);
const
SmallVector
<
bool
>&
output_has_grad
);
EncodedSubgraph
make_from_computing_graph
(
const
VarNodeArray
&
inputs
,
const
VarNodeArray
&
outputs
);
}
// namespace subgraph_detail
}
// namespace subgraph_detail
}
// namespace imperative
}
// namespace imperative
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录