Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
522e556b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
522e556b
编写于
5月 24, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(autodiff): support higher order grad
GitOrigin-RevId: 86390d217940d2240d6908a29a6956b90f3b7b2e
上级
5198b783
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
252 addition
and
91 deletion
+252
-91
imperative/python/megengine/autodiff/grad_manager.py
imperative/python/megengine/autodiff/grad_manager.py
+49
-0
imperative/python/megengine/core/autodiff/grad.py
imperative/python/megengine/core/autodiff/grad.py
+8
-0
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+120
-71
imperative/python/src/grad.h
imperative/python/src/grad.h
+9
-1
imperative/python/src/grad_info.h
imperative/python/src/grad_info.h
+5
-0
imperative/python/src/tensor.cpp
imperative/python/src/tensor.cpp
+1
-1
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+33
-2
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+5
-2
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-14
imperative/src/impl/ops/utility.cpp
imperative/src/impl/ops/utility.cpp
+22
-0
未找到文件。
imperative/python/megengine/autodiff/grad_manager.py
浏览文件 @
522e556b
...
@@ -20,6 +20,9 @@ class AttachSpec:
...
@@ -20,6 +20,9 @@ class AttachSpec:
__slots__
=
"tensor"
,
"callbacks"
__slots__
=
"tensor"
,
"callbacks"
_global_priority
=
0
class
GradManager
:
class
GradManager
:
r
"""
r
"""
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
GradManager computes gradients or more generally, vector-Jacobian product, by reverse mode
...
@@ -118,6 +121,7 @@ class GradManager:
...
@@ -118,6 +121,7 @@ class GradManager:
self
.
_grad
=
None
self
.
_grad
=
None
self
.
_after_backward_callback
=
[]
self
.
_after_backward_callback
=
[]
self
.
_gradients
=
{}
self
.
_gradients
=
{}
self
.
_priority
=
None
def
attach
(
self
,
tensors
:
Iterable
[
Tensor
],
callbacks
=
None
):
def
attach
(
self
,
tensors
:
Iterable
[
Tensor
],
callbacks
=
None
):
r
"""
r
"""
...
@@ -293,6 +297,7 @@ class GradManager:
...
@@ -293,6 +297,7 @@ class GradManager:
After this call, you will be able to call :meth:`backward`.
After this call, you will be able to call :meth:`backward`.
"""
"""
global
_global_priority
if
self
.
_recording
:
if
self
.
_recording
:
raise
RuntimeError
(
"already recording"
)
raise
RuntimeError
(
"already recording"
)
grad
=
Grad
()
grad
=
Grad
()
...
@@ -300,6 +305,9 @@ class GradManager:
...
@@ -300,6 +305,9 @@ class GradManager:
self
.
_grad
=
grad
self
.
_grad
=
grad
for
spec
in
self
.
_attach_specs
.
values
():
for
spec
in
self
.
_attach_specs
.
values
():
self
.
_do_record
(
spec
)
self
.
_do_record
(
spec
)
if
self
.
_priority
is
None
:
grad
.
_priority
=
_global_priority
_global_priority
-=
1
grad
.
__enter__
()
grad
.
__enter__
()
def
_do_record
(
self
,
spec
):
def
_do_record
(
self
,
spec
):
...
@@ -321,11 +329,14 @@ class GradManager:
...
@@ -321,11 +329,14 @@ class GradManager:
After this call, you will not be able to call :meth:`backward`.
After this call, you will not be able to call :meth:`backward`.
"""
"""
global
_global_priority
if
self
.
_grad
is
not
None
:
if
self
.
_grad
is
not
None
:
self
.
_grad
.
__exit__
(
None
,
None
,
None
)
self
.
_grad
.
__exit__
(
None
,
None
,
None
)
self
.
_grad
=
None
self
.
_grad
=
None
self
.
_recording
=
False
self
.
_recording
=
False
self
.
_gradients
=
dict
()
self
.
_gradients
=
dict
()
if
self
.
_priority
is
None
:
_global_priority
+=
1
def
__enter__
(
self
):
def
__enter__
(
self
):
self
.
record
()
self
.
record
()
...
@@ -333,3 +344,41 @@ class GradManager:
...
@@ -333,3 +344,41 @@ class GradManager:
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
self
.
release
()
self
.
release
()
def
__and__
(
self
,
other
):
if
isinstance
(
other
,
GradManager
):
return
GradManagerGroup
([
self
,
other
])
return
NotImplemented
__rand__
=
__and__
class
GradManagerGroup
:
def
__init__
(
self
,
gms
)
->
None
:
self
.
_gms
=
list
(
gms
)
def
merge_with
(
self
,
other
):
if
isinstance
(
other
,
GradManager
):
other
=
GradManagerGroup
([
other
])
elif
not
isinstance
(
other
,
GradManagerGroup
):
return
NotImplemented
return
GradManagerGroup
([
*
self
.
_gms
,
*
other
.
_gms
])
__and__
=
merge_with
__rand__
=
merge_with
__or__
=
merge_with
__ror__
=
merge_with
def
__enter__
(
self
):
global
_global_priority
_global_priority
+=
1
for
gm
in
self
.
_gms
:
gm
.
_priority
=
_global_priority
gm
.
record
()
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
global
_global_priority
_global_priority
-=
1
for
gm
in
self
.
_gms
:
gm
.
release
()
gm
.
_priority
=
None
imperative/python/megengine/core/autodiff/grad.py
浏览文件 @
522e556b
...
@@ -47,6 +47,14 @@ class Grad:
...
@@ -47,6 +47,14 @@ class Grad:
self
.
_impl
=
GradKey
(
name
)
self
.
_impl
=
GradKey
(
name
)
_grad_manager_dict
[
self
.
_name
]
=
self
_grad_manager_dict
[
self
.
_name
]
=
self
@
property
def
_priority
(
self
):
return
self
.
_impl
.
priority
@
_priority
.
setter
def
_priority
(
self
,
priority
):
self
.
_impl
.
priority
=
priority
@
property
@
property
def
_name
(
self
):
def
_name
(
self
):
return
self
.
_impl
.
name
return
self
.
_impl
.
name
...
...
imperative/python/src/grad.cpp
浏览文件 @
522e556b
...
@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
...
@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
*
(
size_t_ptr
++
)
=
mgb
::
hash
(
ctx
.
args
[
i
]
->
dtype
().
handle
());
*
(
size_t_ptr
++
)
=
mgb
::
hash
(
ctx
.
args
[
i
]
->
dtype
().
handle
());
*
(
size_t_ptr
++
)
=
mgb
::
hash
(
ctx
.
args
[
i
]
->
comp_node
());
*
(
size_t_ptr
++
)
=
mgb
::
hash
(
ctx
.
args
[
i
]
->
comp_node
());
*
(
bool_ptr
++
)
=
bool
(
ctx
.
args
[
i
]
->
m_grad_info
.
grad_fn
);
*
(
bool_ptr
++
)
=
!
ctx
.
args
[
i
]
->
m_grad_info_dict
.
empty
(
);
}
}
mgb_assert
(
bool_ptr0
==
reinterpret_cast
<
bool
*>
(
size_t_ptr
)
&&
mgb_assert
(
bool_ptr0
==
reinterpret_cast
<
bool
*>
(
size_t_ptr
)
&&
bool_ptr
==
reinterpret_cast
<
bool
*>
(
buf
+
buf_size
));
bool_ptr
==
reinterpret_cast
<
bool
*>
(
buf
+
buf_size
));
...
@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
...
@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
inputs_copy
.
push_back
(
python
::
apply
(
FastpathCopy
::
make
(),
ctx
.
args
[
i
]
->
shared_from_this
())[
0
]);
inputs_copy
.
push_back
(
python
::
apply
(
FastpathCopy
::
make
(),
ctx
.
args
[
i
]
->
shared_from_this
())[
0
]);
inputs_copy_weak
.
push_back
(
inputs_copy
.
back
().
get
());
inputs_copy_weak
.
push_back
(
inputs_copy
.
back
().
get
());
inputs_copy
.
back
()
->
m_grad_info
=
ctx
.
args
[
i
]
->
m_grad_info
;
inputs_copy
.
back
()
->
m_grad_info
_dict
=
ctx
.
args
[
i
]
->
m_grad_info_dict
;
}
}
ApplyContext
ctx_dup
=
ctx
;
ApplyContext
ctx_dup
=
ctx
;
ctx_dup
.
args
=
inputs_copy_weak
.
data
();
ctx_dup
.
args
=
inputs_copy_weak
.
data
();
...
@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
...
@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
}
// namespace
}
// namespace
apply_result_t
apply_grad
(
ApplyContext
&
ctx
)
{
apply_result_t
apply_grad
(
ApplyContext
&
ctx
)
{
std
::
shared_ptr
<
GradKey
>
grad_key
;
std
::
unordered_set
<
std
::
shared_ptr
<
GradKey
>>
grad_keys
;
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
auto
*
tensor
=
ctx
.
args
[
i
];
auto
*
tensor
=
ctx
.
args
[
i
];
if
(
tensor
->
m_grad_info
.
grad_fn
)
{
if
(
!
tensor
->
m_grad_info_dict
.
empty
())
{
auto
&&
input_grad_key
=
tensor
->
m_grad_info
.
grad_fn
->
key
.
lock
();
size_t
grad_cnt
=
0
;
// tensor is attached to a live GradKey
for
(
auto
&&
grad_info
:
tensor
->
m_grad_info_dict
)
{
if
(
input_grad_key
&&
input_grad_key
->
active
)
{
auto
input_grad_key
=
grad_info
.
grad_fn
->
key
.
lock
();
if
(
grad_key
)
{
if
(
input_grad_key
&&
input_grad_key
->
active
&&
!
input_grad_key
->
is_blocked
())
{
if
(
grad_key
!=
input_grad_key
)
{
grad_keys
.
insert
(
input_grad_key
);
PyErr_SetString
(
PyExc_NotImplementedError
,
"second order grad"
);
grad_cnt
++
;
throw
pyext17
::
py_err_set
();
}
}
else
{
grad_key
=
std
::
move
(
input_grad_key
);
}
}
}
else
{
}
// cleanup stale grad info
if
(
!
grad_cnt
)
{
// under what condition?
tensor
->
m_grad_info
=
{};
tensor
->
m_flags
&=
~
Flags
::
GRAD
;
tensor
->
m_flags
&=
~
Flags
::
GRAD
;
}
}
}
else
{
}
else
{
...
@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
...
@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
ctx
.
flags
&=
~
Flags
::
GRAD
;
ctx
.
flags
&=
~
Flags
::
GRAD
;
if
(
!
grad_key
)
{
if
(
grad_keys
.
empty
()
)
{
return
apply
(
ctx
);
return
apply
(
ctx
);
}
}
...
@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) {
...
@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) {
return
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
return
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
}();
}();
auto
&
grad_fn
=
grad_fn_holder
.
grad_fn
;
if
(
!
grad_fn_holder
.
grad_fn
)
{
if
(
!
grad_fn
)
{
return
outputs
;
return
outputs
;
}
}
grad_fn
->
key
=
grad_key
;
for
(
auto
&&
grad_key
:
grad_keys
)
{
grad_fn
->
slots
.
resize
(
outputs
.
size
());
auto
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
->
dsts
.
reserve
(
ctx
.
nargs
);
grad_fn
->
backward
=
grad_fn_holder
.
grad_fn
->
backward
;
grad_fn
->
key
=
grad_key
;
grad_fn
->
slots
.
resize
(
outputs
.
size
());
grad_fn
->
dsts
.
reserve
(
ctx
.
nargs
);
std
::
visit
([
&
](
auto
&
backward
)
{
std
::
visit
([
&
](
auto
&
backward
)
{
using
T
=
std
::
decay_t
<
decltype
(
backward
)
>
;
using
T
=
std
::
decay_t
<
decltype
(
backward
)
>
;
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
if
constexpr
(
std
::
is_same_v
<
T
,
std
::
monostate
>
)
{
mgb_assert
(
0
);
mgb_assert
(
0
);
}
else
{
}
else
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
if
(
backward
.
input_has_grad
(
i
)
&&
input_requires_grad
(
ctx
,
i
))
{
if
(
backward
.
input_has_grad
(
i
)
&&
input_requires_grad
(
ctx
,
i
)
&&
ctx
.
args
[
i
]
->
m_grad_info_dict
.
count
(
grad_key
.
get
()))
{
auto
&
input_grad_info
=
ctx
.
args
[
i
]
->
m_grad_info
;
auto
&
input_grad_info
=
ctx
.
args
[
i
]
->
m_grad_info_dict
.
at
(
grad_key
.
get
());
grad_fn
->
dsts
.
emplace_back
(
input_grad_info
);
grad_fn
->
dsts
.
emplace_back
(
input_grad_info
);
// register as grad producer
// register as grad producer
grad_fn
->
dsts
.
back
().
producer_record
.
insert_after
(
input_grad_info
->
producer_head
);
grad_fn
->
dsts
.
back
().
producer_record
.
insert_after
(
input_grad_info
->
producer_head
);
}
else
{
}
else
{
grad_fn
->
dsts
.
emplace_back
();
grad_fn
->
dsts
.
emplace_back
();
}
}
}
}
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
if
(
backward
.
output_requires_grad
(
i
))
{
if
(
backward
.
output_requires_grad
(
i
))
{
if
(
backward
.
output_captured
(
i
))
{
if
(
backward
.
output_captured
(
i
))
{
// avoid reference cycle [Tensor <-> GradFn]
// avoid reference cycle [Tensor <-> GradFn]
static
std
::
shared_ptr
<
OpDef
>
op
=
std
::
make_shared
<
FastpathCopy
>
();
static
std
::
shared_ptr
<
OpDef
>
op
=
std
::
shared_ptr
<
OpDef
>
(
new
FastpathCopy
());
outputs
[
i
]
=
python
::
apply
(
op
,
outputs
[
i
])[
0
];
outputs
[
i
]
=
python
::
apply
(
op
,
outputs
[
i
])[
0
];
}
// populate grad info of output tensor
auto
&
grad_info
=
outputs
[
i
]
->
m_grad_info_dict
[
grad_key
.
get
()];
grad_info
.
grad_fn
=
grad_fn
;
grad_info
.
idx
=
i
;
grad_info
.
insert_after
(
grad_key
->
free_vars_head
);
outputs
[
i
]
->
m_flags
|=
Flags
::
GRAD
;
}
}
// populate grad info of output tensor
auto
&
grad_info
=
outputs
[
i
]
->
m_grad_info
;
grad_info
.
grad_fn
=
grad_fn
;
grad_info
.
idx
=
i
;
grad_info
.
insert_after
(
grad_key
->
free_vars_head
);
outputs
[
i
]
->
m_flags
|=
Flags
::
GRAD
;
}
}
}
}
}
},
grad_fn
->
backward
);
},
grad_fn
->
backward
);
// record forward history
// record forward history
grad_key
->
tape
.
emplace_back
(
grad_fn
);
grad_key
->
tape
.
emplace_back
(
grad_fn
);
}
return
outputs
;
return
outputs
;
}
}
PyObject
*
GradKeyWrapper
::
get_priority
()
{
return
py
::
cast
(
m_key
->
priority
).
release
().
ptr
();
}
void
GradKeyWrapper
::
set_priority
(
pybind11
::
handle
priority
)
{
m_key
->
name
=
py
::
cast
<
int
>
(
priority
);
}
void
GradKeyWrapper
::
attach
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
void
GradKeyWrapper
::
attach
(
PyObject
*
const
*
args
,
size_t
nargs
)
{
if
(
nargs
!=
2
)
{
if
(
nargs
!=
2
)
{
throw
py
::
type_error
(
"expect 2 arguments"
);
throw
py
::
type_error
(
"expect 2 arguments"
);
...
@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
...
@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) {
throw
py
::
value_error
(
"grad key finalized"
);
throw
py
::
value_error
(
"grad key finalized"
);
}
}
if
(
tensor
->
m_grad_info
.
grad_fn
)
{
if
(
tensor
->
m_grad_info_dict
.
count
(
this
))
{
if
(
tensor
->
m_grad_info
.
grad_fn
->
key
.
lock
().
get
()
!=
this
)
{
if
(
tensor
->
m_grad_info_dict
.
at
(
this
)
->
callback
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"second order grad"
);
throw
pyext17
::
py_err_set
();
}
if
(
tensor
->
m_grad_info
->
callback
)
{
throw
py
::
value_error
(
"callback already set on this tensor"
);
throw
py
::
value_error
(
"callback already set on this tensor"
);
}
}
}
else
{
}
else
{
tensor
->
m_grad_info
.
idx
=
0
;
auto
&
grad_info
=
tensor
->
m_grad_info_dict
[
this
];
auto
&
grad_fn
=
tensor
->
m_grad_info
.
grad_fn
;
grad_info
.
idx
=
0
;
auto
&
grad_fn
=
grad_info
.
grad_fn
;
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
=
std
::
make_shared
<
GradFn
>
();
grad_fn
->
key
=
shared_from_this
();
grad_fn
->
key
=
shared_from_this
();
grad_fn
->
slots
.
resize
(
1
);
grad_fn
->
slots
.
resize
(
1
);
tensor
->
m_
grad_info
.
insert_after
(
free_vars_head
);
grad_info
.
insert_after
(
free_vars_head
);
tensor
->
m_flags
|=
Flags
::
GRAD
;
tensor
->
m_flags
|=
Flags
::
GRAD
;
}
}
tensor
->
m_grad_info
.
grad_fn
->
slots
[
0
].
callback
=
std
::
move
(
callback
);
tensor
->
m_grad_info
_dict
.
at
(
this
)
.
grad_fn
->
slots
[
0
].
callback
=
std
::
move
(
callback
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
...
@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
active
=
false
;
active
=
false
;
struct
CleanupGuard
{
struct
CleanupGuard
{
GradKey
*
owner
;
GradKey
*
owner
;
CleanupGuard
(
GradKey
*
this_
)
:
owner
(
this_
)
{}
size_t
priority_backup
;
~
CleanupGuard
()
{
owner
->
cleanup
();}
CleanupGuard
(
GradKey
*
this_
)
:
owner
(
this_
)
{
priority_backup
=
sm_min_priority
;
sm_min_priority
=
owner
->
priority
;
}
~
CleanupGuard
()
{
owner
->
cleanup
();
sm_min_priority
=
priority_backup
;
}
}
_cleanup_guard
(
this
);
}
_cleanup_guard
(
this
);
if
(
tape
.
empty
())
return
;
if
(
tape
.
empty
())
return
;
...
@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
...
@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
}
}
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensors
.
size
();
++
i
)
{
auto
&
grad_info
=
tensors
[
i
]
->
m_tensor
->
m_grad_info
;
if
(
tensors
[
i
]
->
m_tensor
->
m_grad_info_dict
.
count
(
this
)
==
0
)
{
if
(
grad_info
.
grad_fn
&&
grad_info
.
grad_fn
->
key
.
lock
().
get
()
==
this
)
{
continue
;
grad_info
->
grad
=
grads
[
i
]
->
m_tensor
;
}
}
auto
&
grad_info
=
tensors
[
i
]
->
m_tensor
->
m_grad_info_dict
.
at
(
this
);
grad_info
->
grad
=
grads
[
i
]
->
m_tensor
;
}
}
std
::
vector
<
std
::
shared_ptr
<
GradFn
>>
ref_keeper
;
std
::
vector
<
std
::
shared_ptr
<
GradFn
>>
ref_keeper
;
ref_keeper
.
reserve
(
tape
.
size
());
ref_keeper
.
reserve
(
tape
.
size
());
// back-propagation in reverse order
// back-propagation in reverse order
for
(
std
::
ptrdiff_t
k
=
tape
.
size
()
-
1
;
k
>=
0
;
--
k
)
{
for
(
std
::
ptrdiff_t
k
=
tape
.
size
()
-
1
;
k
>=
0
;
--
k
)
{
auto
&&
grad_fn
=
tape
[
k
].
lock
();
auto
&&
grad_fn
=
tape
[
k
].
lock
();
...
@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
...
@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) {
PyErr_SetString
(
PyExc_TypeError
,
"expect Tensor"
);
PyErr_SetString
(
PyExc_TypeError
,
"expect Tensor"
);
return
nullptr
;
return
nullptr
;
}
}
auto
&&
grad_fn
=
tw
->
m_tensor
->
m_grad_info
.
grad_fn
;
if
(
tw
->
m_tensor
->
m_grad_info_dict
.
count
(
m_key
.
get
()))
{
if
(
grad_fn
&&
grad_fn
->
key
.
lock
()
==
m_key
)
{
Py_RETURN_TRUE
;
Py_RETURN_TRUE
;
}
}
Py_RETURN_FALSE
;
Py_RETURN_FALSE
;
}
}
int
GradKey
::
sm_min_priority
=
0
;
GradKey
::~
GradKey
()
{
GradKey
::~
GradKey
()
{
cleanup
();
cleanup
();
}
}
...
@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
...
@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() {
return
registry
;
return
registry
;
}
}
void
GradInfoCollection
::
_shrink
()
{
auto
pred
=
[](
GradInfo
&
info
){
return
!
(
info
.
grad_fn
)
||
info
.
grad_fn
->
key
.
expired
();
};
auto
iter
=
std
::
remove_if
(
m_storage
.
begin
(),
m_storage
.
end
(),
pred
);
m_storage
.
erase
(
iter
,
m_storage
.
end
());
}
bool
GradInfoCollection
::
contains
(
GradKey
*
key
)
{
_shrink
();
for
(
auto
&&
grad_info
:
m_storage
)
{
if
(
grad_info
.
grad_fn
->
key
.
lock
().
get
()
==
key
)
{
return
true
;
}
}
return
false
;
}
GradInfo
&
GradInfoCollection
::
operator
[](
GradKey
*
key
)
{
_shrink
();
for
(
auto
&&
grad_info
:
m_storage
)
{
if
(
grad_info
.
grad_fn
->
key
.
lock
().
get
()
==
key
)
{
return
grad_info
;
}
}
m_storage
.
emplace_back
();
return
m_storage
.
back
();
}
GradInfo
&
GradInfoCollection
::
at
(
GradKey
*
key
)
{
_shrink
();
for
(
auto
&&
grad_info
:
m_storage
)
{
if
(
grad_info
.
grad_fn
->
key
.
lock
().
get
()
==
key
)
{
return
grad_info
;
}
}
mgb_assert
(
false
);
}
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
imperative/python/src/grad.h
浏览文件 @
522e556b
...
@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
...
@@ -26,12 +26,18 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
bool
active
=
true
;
bool
active
=
true
;
GradInfo
::
head_t
free_vars_head
;
GradInfo
::
head_t
free_vars_head
;
std
::
vector
<
std
::
weak_ptr
<
GradFn
>>
tape
;
std
::
vector
<
std
::
weak_ptr
<
GradFn
>>
tape
;
int
priority
=
0
;
~
GradKey
();
~
GradKey
();
void
attach
(
Tensor
*
tensor
,
pybind11
::
object
callback
);
void
attach
(
Tensor
*
tensor
,
pybind11
::
object
callback
);
void
backward
(
std
::
vector
<
TensorWrapper
*>
,
std
::
vector
<
TensorWrapper
*>
);
void
backward
(
std
::
vector
<
TensorWrapper
*>
,
std
::
vector
<
TensorWrapper
*>
);
void
cleanup
();
void
cleanup
();
bool
is_blocked
()
const
{
return
priority
<
sm_min_priority
;
}
private:
static
int
sm_min_priority
;
};
};
struct
GradKeyWrapper
{
struct
GradKeyWrapper
{
...
@@ -44,6 +50,8 @@ struct GradKeyWrapper {
...
@@ -44,6 +50,8 @@ struct GradKeyWrapper {
PyObject
*
get_name
();
PyObject
*
get_name
();
void
set_name
(
pybind11
::
handle
name
);
void
set_name
(
pybind11
::
handle
name
);
PyObject
*
get_priority
();
void
set_priority
(
pybind11
::
handle
priority
);
void
attach
(
PyObject
*
const
*
args
,
size_t
nargs
);
void
attach
(
PyObject
*
const
*
args
,
size_t
nargs
);
void
backward
(
std
::
vector
<
TensorWrapper
*>
,
std
::
vector
<
TensorWrapper
*>
);
void
backward
(
std
::
vector
<
TensorWrapper
*>
,
std
::
vector
<
TensorWrapper
*>
);
PyObject
*
is_attached_to
(
PyObject
*
const
*
args
,
size_t
nargs
);
PyObject
*
is_attached_to
(
PyObject
*
const
*
args
,
size_t
nargs
);
...
@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M
...
@@ -150,7 +158,7 @@ using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::M
std
::
unordered_map
<
Typeinfo
*
,
GradRuleFn
>&
grad_rule_registry
();
std
::
unordered_map
<
Typeinfo
*
,
GradRuleFn
>&
grad_rule_registry
();
inline
bool
input_requires_grad
(
const
ApplyContext
&
ctx
,
size_t
i
)
{
inline
bool
input_requires_grad
(
const
ApplyContext
&
ctx
,
size_t
i
)
{
return
bool
(
ctx
.
args
[
i
]
->
m_grad_info
.
grad_fn
);
return
!
ctx
.
args
[
i
]
->
m_grad_info_dict
.
empty
(
);
}
}
struct
GradRuleFallback
:
std
::
exception
{};
struct
GradRuleFallback
:
std
::
exception
{};
...
...
imperative/python/src/grad_info.h
浏览文件 @
522e556b
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
struct
GradKey
;
struct
GradFn
;
struct
GradFn
;
struct
GradSlot
;
struct
GradSlot
;
...
@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be
...
@@ -32,6 +33,10 @@ struct GradInfo : GradSlotPtr, intrusive_list::Node<GradInfo, intrusive_list::be
GradInfo
(
GradInfo
&&
)
=
default
;
GradInfo
(
GradInfo
&&
)
=
default
;
GradInfo
&
operator
=
(
GradInfo
&
)
=
default
;
GradInfo
&
operator
=
(
GradInfo
&
)
=
default
;
GradInfo
&
operator
=
(
GradInfo
&&
)
=
default
;
GradInfo
&
operator
=
(
GradInfo
&&
)
=
default
;
GradInfo
(
const
GradInfo
&
rhs
)
:
GradInfo
(
const_cast
<
GradInfo
&>
(
rhs
)){}
GradInfo
&
operator
=
(
const
GradInfo
&
rhs
)
{
return
*
this
=
const_cast
<
GradInfo
&>
(
rhs
);
}
};
};
}
// namespace mgb::imperative::python
}
// namespace mgb::imperative::python
imperative/python/src/tensor.cpp
浏览文件 @
522e556b
...
@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
...
@@ -182,7 +182,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
]))){
if
(
py
::
isinstance
<
PySymbolVar
>
(
py
::
handle
(
args
[
0
]))){
SmallVector
<
cg
::
VarNode
*>
vinputs
(
nargs
);
SmallVector
<
cg
::
VarNode
*>
vinputs
(
nargs
);
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
vinputs
[
i
]
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
()
->
m_node
;
vinputs
[
i
]
=
py
::
handle
(
args
[
i
]).
cast
<
PySymbolVar
*>
()
->
m_node
;
}
}
auto
op
=
ctx
.
op
.
get
();
auto
op
=
ctx
.
op
.
get
();
auto
rst
=
OpDef
::
apply_on_var_node
(
*
op
,
vinputs
);
auto
rst
=
OpDef
::
apply_on_var_node
(
*
op
,
vinputs
);
...
...
imperative/python/src/tensor.h
浏览文件 @
522e556b
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "megbrain/imperative/interpreter.h"
#include "megbrain/imperative/interpreter.h"
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include <string>
#include <string>
#include <unordered_map>
#include "./pyext17.h"
#include "./pyext17.h"
...
@@ -36,6 +37,8 @@ struct ObjectPtr : B {
...
@@ -36,6 +37,8 @@ struct ObjectPtr : B {
namespace
mgb
::
imperative
::
python
{
namespace
mgb
::
imperative
::
python
{
struct
GradKey
;
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
extern
interpreter
::
Interpreter
::
Channel
*
interpreter_for_py
;
class
SharedHandle
{
class
SharedHandle
{
...
@@ -58,6 +61,34 @@ public:
...
@@ -58,6 +61,34 @@ public:
};
};
// impl in grad.cpp
class
GradInfoCollection
{
private:
SmallVector
<
GradInfo
>
m_storage
;
protected:
void
_shrink
();
public:
bool
contains
(
GradKey
*
key
);
GradInfo
&
operator
[](
GradKey
*
key
);
GradInfo
&
at
(
GradKey
*
key
);
bool
empty
()
{
_shrink
();
return
m_storage
.
empty
();
}
auto
begin
()
{
_shrink
();
return
m_storage
.
begin
();
}
auto
end
()
{
_shrink
();
return
m_storage
.
end
();
}
size_t
count
(
GradKey
*
key
)
{
return
contains
(
key
)
?
1
:
0
;
}
};
struct
Tensor
:
std
::
enable_shared_from_this
<
Tensor
>
,
NonCopyableObj
{
struct
Tensor
:
std
::
enable_shared_from_this
<
Tensor
>
,
NonCopyableObj
{
using
flags_t
=
uint64_t
;
using
flags_t
=
uint64_t
;
...
@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
...
@@ -69,7 +100,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
flags_t
m_flags
=
0
;
flags_t
m_flags
=
0
;
GradInfo
m_grad_info
;
GradInfo
Collection
m_grad_info_dict
;
TraceInfo
m_trace_info
;
TraceInfo
m_trace_info
;
SharedHandle
m_handle
;
SharedHandle
m_handle
;
std
::
string
user_custom_name
;
std
::
string
user_custom_name
;
...
@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
...
@@ -88,7 +119,7 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
inline
std
::
shared_ptr
<
Tensor
>
copy
()
{
inline
std
::
shared_ptr
<
Tensor
>
copy
()
{
auto
ret
=
std
::
make_shared
<
Tensor
>
(
m_handle
);
auto
ret
=
std
::
make_shared
<
Tensor
>
(
m_handle
);
ret
->
m_flags
=
m_flags
;
ret
->
m_flags
=
m_flags
;
ret
->
m_grad_info
=
m_grad_info
;
ret
->
m_grad_info
_dict
=
m_grad_info_dict
;
ret
->
m_trace_info
=
m_trace_info
;
ret
->
m_trace_info
=
m_trace_info
;
ret
->
m_var
=
m_var
;
ret
->
m_var
=
m_var
;
return
ret
;
return
ret
;
...
...
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
522e556b
...
@@ -108,21 +108,24 @@ def test_grad_2():
...
@@ -108,21 +108,24 @@ def test_grad_2():
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
4
*
x_np
**
3
,
decimal
=
6
)
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
4
*
x_np
**
3
,
decimal
=
6
)
@
pytest
.
mark
.
skip
(
reason
=
"high order gradient was not implemented yet"
)
def
test_2nd_grad
():
def
test_2nd_grad
():
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
as_tensor
(
x_np
)
x
=
as_tensor
(
x_np
)
ones
=
as_tensor
(
np
.
ones_like
(
x_np
))
ones
=
as_tensor
(
np
.
ones_like
(
x_np
))
grad
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
grad
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
grad
.
_priority
=
-
1
grad2
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
grad2
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
grad2
.
_priority
=
0
y
=
cos
(
x
)
y
=
cos
(
x
)
grad
(
y
,
ones
)
grad
(
y
,
ones
)
z
=
x
.
grad
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
-
np
.
sin
(
x_np
),
decimal
=
5
)
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
-
np
.
sin
(
x_np
),
decimal
=
5
)
grad2
(
x
.
grad
,
ones
)
x
.
grad
=
None
grad2
(
z
,
ones
)
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
-
np
.
cos
(
x_np
))
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
-
np
.
cos
(
x_np
))
...
...
imperative/src/impl/ops/specializations.cpp
浏览文件 @
522e556b
...
@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy)
...
@@ -398,20 +398,6 @@ OP_TRAIT_REG(Copy, Copy)
.
fallback
();
.
fallback
();
}}
// copy
}}
// copy
namespace
{
namespace
identity
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Identity
>
();
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Identity
::
make
(
inputs
[
0
],
config
);
}
OP_TRAIT_REG
(
Identity
,
Identity
)
.
apply_on_var_node
(
apply_on_var_node
)
.
fallback
();
}}
// identity
namespace
{
namespace
assert_equal
{
namespace
{
namespace
assert_equal
{
auto
apply_on_var_node
(
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
OpDef
&
def
,
...
...
imperative/src/impl/ops/utility.cpp
浏览文件 @
522e556b
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/utility.h"
...
@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy)
...
@@ -32,4 +33,25 @@ OP_TRAIT_REG(FastpathCopy,FastpathCopy)
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
FastpathCopy
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
FastpathCopy
);
namespace
{
namespace
identity
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
op
=
def
.
cast_final_safe
<
Identity
>
();
mgb_assert
(
inputs
.
size
()
==
1
);
OperatorNodeConfig
config
{
op
.
make_name
()};
return
opr
::
Identity
::
make
(
inputs
[
0
],
config
);
}
auto
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
)
{
return
SmallVector
<
TensorPtr
>
{
inputs
[
0
]};
}
OP_TRAIT_REG
(
Identity
,
Identity
)
.
apply_on_var_node
(
apply_on_var_node
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}}
// identity
}
// namespace mgb::imperative
}
// namespace mgb::imperative
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录