Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9fb5581f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9fb5581f
编写于
12月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge): add specialized grad rule support
GitOrigin-RevId: 141ff0a24f0f843ff5457c06e303332eb4276ef6
上级
645fc6f0
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
253 addition
and
37 deletion
+253
-37
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+30
-35
imperative/python/src/grad.h
imperative/python/src/grad.h
+111
-0
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+63
-0
imperative/python/src/tensor.h
imperative/python/src/tensor.h
+49
-2
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
9fb5581f
...
...
@@ -70,7 +70,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
inputs
[
i
].
comp_node
=
ctx
.
args
[
i
]
->
comp_node
();
inputs
[
i
].
layout
.
dtype
=
ctx
.
args
[
i
]
->
dtype
();
input_requires_grad
[
i
]
=
bool
(
ctx
.
args
[
i
]
->
m_grad_info
.
grad_fn
);
input_requires_grad
[
i
]
=
python
::
input_requires_grad
(
ctx
,
i
);
}
auto
result
=
std
::
make_shared
<
BackwardGraphResult
>
(
proxy_graph_detail
::
make_backward_graph
(
...
...
@@ -82,21 +82,6 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph(
return
result
;
}
struct
BackwardContext
{
PyTypeObject
*
pytype
=
nullptr
;
auto
wrap_tensor
(
std
::
shared_ptr
<
Tensor
>
t
)
{
if
(
pytype
)
{
return
TensorWrapper
::
make
(
pytype
,
std
::
move
(
t
));
}
return
TensorWrapper
::
make
(
std
::
move
(
t
));
}
auto
wrap_tensor
(
Tensor
*
t
)
{
return
wrap_tensor
(
t
->
shared_from_this
());
}
};
struct
BackwardGraphWithClosure
{
std
::
shared_ptr
<
BackwardGraphResult
>
backward_graph
;
SmallVector
<
std
::
shared_ptr
<
Tensor
>>
closure
;
...
...
@@ -270,7 +255,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
// same length as inputs (of forward op)
SmallVector
<
GradSlotProducerPtr
>
dsts
;
// encapsules actual function to compute gradient
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
PythonBackward
>
backward
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
PythonBackward
,
CustomBackward
>
backward
;
// a flag used during backward
bool
in_ref_keeper
=
false
;
...
...
@@ -335,8 +320,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
pyin
[
i
]
=
TensorWrapper
::
make
(
ctx
.
pytype
,
ctx
.
args
[
i
]
->
shared_from_this
());
}
auto
grad_rule
=
py
::
getattr
(
op
->
obj
,
"_grad_rule"
);
auto
pyret
=
(
scoped_disable
(
Flags
::
GRAD
),
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
grad_rule
.
ptr
(),
pyin
.
ptr
(),
nullptr
)));
// comma expression
auto
pyret
=
py
::
reinterpret_steal
<
py
::
object
>
(
PyObject_Call
(
grad_rule
.
ptr
(),
pyin
.
ptr
(),
nullptr
));
auto
[
outputs
,
backward
]
=
py
::
cast
<
std
::
tuple
<
py
::
object
,
py
::
function
>>
(
pyret
);
ret_grad_fn
.
emplace
<
PythonBackward
>
(
std
::
move
(
backward
),
ctx
.
nargs
);
if
(
auto
*
tw
=
TensorWrapper
::
try_cast
(
outputs
.
ptr
()))
{
...
...
@@ -388,9 +372,25 @@ apply_result_t apply_grad(ApplyContext& ctx) {
}
GradFnHelper
grad_fn_holder
;
auto
outputs
=
ctx
.
op
->
same_type
<
GenericPyOp
>
()
?
python_grad_rule
(
ctx
,
grad_fn_holder
)
:
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
auto
outputs
=
[
&
]()
{
auto
_
=
scoped_disable
(
Flags
::
GRAD
);
if
(
ctx
.
op
->
same_type
<
GenericPyOp
>
())
{
return
python_grad_rule
(
ctx
,
grad_fn_holder
);
}
auto
&&
registry
=
grad_rule_registry
();
auto
&&
it
=
registry
.
find
(
ctx
.
op
->
dyn_typeinfo
());
if
(
it
!=
registry
.
end
())
{
auto
&&
maker
=
grad_fn_holder
.
emplace
<
CustomBackward
>
().
maker
(
ctx
);
try
{
auto
ret
=
it
->
second
(
ctx
,
maker
);
maker
.
finalize
();
return
ret
;
}
catch
(
GradRuleFallback
&
)
{
grad_fn_holder
.
emplace
<
std
::
monostate
>
();
}
}
return
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
}();
auto
&
grad_fn
=
grad_fn_holder
.
grad_fn
;
if
(
!
grad_fn
)
{
...
...
@@ -407,7 +407,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
mgb_assert
(
0
);
}
else
{
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
if
(
backward
.
input_has_grad
(
i
))
{
if
(
backward
.
input_has_grad
(
i
)
&&
input_requires_grad
(
ctx
,
i
)
)
{
auto
&
input_grad_info
=
ctx
.
args
[
i
]
->
m_grad_info
;
grad_fn
->
dsts
.
emplace_back
(
input_grad_info
);
// register as grad producer
...
...
@@ -487,18 +487,8 @@ void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) {
grad
=
std
::
forward
<
T
>
(
delta
);
return
;
}
static
ApplyContext
ctx
;
if
(
!
ctx
.
op
)
{
ctx
.
op
=
std
::
shared_ptr
<
OpDef
>
(
new
Elemwise
(
Elemwise
::
Mode
::
ADD
));
ctx
.
nargs
=
2
;
}
Tensor
*
args
[
2
]
=
{
grad
.
get
(),
delta
.
get
()};
ctx
.
args
=
args
;
ctx
.
flags
=
grad
->
m_flags
|
delta
->
m_flags
;
if
(
is_tracing
)
{
ctx
.
flags
|=
Flags
::
TRACE
;
}
grad
=
apply
(
ctx
)[
0
];
static
std
::
shared_ptr
<
OpDef
>
op
=
std
::
shared_ptr
<
OpDef
>
(
new
Elemwise
(
Elemwise
::
Mode
::
ADD
));
grad
=
apply
(
op
,
grad
,
std
::
forward
<
T
>
(
delta
))[
0
];
}
void
GradKey
::
backward
(
std
::
vector
<
TensorWrapper
*>
tensors
,
std
::
vector
<
TensorWrapper
*>
grads
)
{
...
...
@@ -582,4 +572,9 @@ GradKey::~GradKey() {
cleanup
();
}
std
::
unordered_map
<
Typeinfo
*
,
GradRuleFn
>&
grad_rule_registry
()
{
static
std
::
unordered_map
<
Typeinfo
*
,
GradRuleFn
>
registry
;
return
registry
;
}
}
// namespace mgb::imperative::python
imperative/python/src/grad.h
浏览文件 @
9fb5581f
...
...
@@ -45,6 +45,117 @@ struct GradKeyWrapper {
void
backward
(
std
::
vector
<
TensorWrapper
*>
,
std
::
vector
<
TensorWrapper
*>
);
};
struct
BackwardContext
{
PyTypeObject
*
pytype
=
nullptr
;
auto
wrap_tensor
(
std
::
shared_ptr
<
Tensor
>
t
)
{
if
(
pytype
)
{
return
TensorWrapper
::
make
(
pytype
,
std
::
move
(
t
));
}
return
TensorWrapper
::
make
(
std
::
move
(
t
));
}
auto
wrap_tensor
(
Tensor
*
t
)
{
return
wrap_tensor
(
t
->
shared_from_this
());
}
};
struct
CustomBackward
{
using
BackwardFn
=
std
::
function
<
apply_result_t
(
BackwardContext
&
,
Tensor
*
const
*
,
size_t
)
>
;
BackwardFn
m_backward
;
SmallVector
<
bool
,
8
>
m_input_has_grad
;
struct
OutputAttr
{
bool
requires_grad
=
true
,
captured
=
true
;};
SmallVector
<
OutputAttr
>
m_output_attrs
;
public:
template
<
typename
T
,
typename
R
>
void
operator
()(
BackwardContext
&
ctx
,
T
&&
grads
,
R
&&
receiver
)
{
size_t
nargs
=
grads
.
size
();
Tensor
*
args
[
nargs
];
for
(
size_t
i
=
0
;
i
<
nargs
;
++
i
)
{
args
[
i
]
=
grads
[
i
];
}
auto
ret
=
m_backward
(
ctx
,
args
,
nargs
);
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
if
(
auto
&&
t
=
ret
[
i
])
{
receiver
(
i
,
std
::
move
(
t
));
}
}
}
bool
input_has_grad
(
size_t
i
)
{
return
m_input_has_grad
[
i
];}
bool
output_requires_grad
(
size_t
i
)
{
return
m_output_attrs
[
i
].
requires_grad
;}
bool
output_captured
(
size_t
i
)
{
return
m_output_attrs
[
i
].
captured
;}
class
Maker
{
bool
output_size_set
=
false
,
input_has_grad_initialized
=
false
;
CustomBackward
&
target
;
ApplyContext
&
ctx
;
void
init_input_has_grad
()
{
if
(
!
input_has_grad_initialized
)
{
input_has_grad_initialized
=
true
;
target
.
m_input_has_grad
.
resize
(
ctx
.
nargs
,
true
);
}
}
public:
Maker
(
CustomBackward
&
target_
,
ApplyContext
&
ctx_
)
:
target
(
target_
),
ctx
(
ctx_
)
{}
template
<
typename
F
>
Maker
&
backward
(
F
&&
f
)
{
mgb_assert
(
!
target
.
m_backward
);
target
.
m_backward
=
std
::
forward
<
F
>
(
f
);
return
*
this
;
}
// mandatory
Maker
&
output_size
(
size_t
sz
)
{
mgb_assert
(
!
output_size_set
);
output_size_set
=
true
;
target
.
m_output_attrs
.
resize
(
sz
);
return
*
this
;
}
// optional, defaults to all true
Maker
&
input_has_grad
(
size_t
i
,
bool
v
)
{
init_input_has_grad
();
target
.
m_input_has_grad
.
at
(
i
)
=
v
;
return
*
this
;
}
// optional, defaults to all true
Maker
&
output_requires_grad
(
size_t
i
,
bool
v
)
{
target
.
m_output_attrs
.
at
(
i
).
requires_grad
=
v
;
return
*
this
;
}
// optional, defaults to all true
Maker
&
output_captured
(
size_t
i
,
bool
v
)
{
target
.
m_output_attrs
.
at
(
i
).
captured
=
v
;
return
*
this
;
}
void
finalize
()
{
mgb_assert
(
output_size_set
);
init_input_has_grad
();
}
};
Maker
maker
(
ApplyContext
&
ctx
)
{
return
{
*
this
,
ctx
};}
};
using
GradRuleFn
=
std
::
function
<
apply_result_t
(
ApplyContext
&
,
CustomBackward
::
Maker
&
)
>
;
std
::
unordered_map
<
Typeinfo
*
,
GradRuleFn
>&
grad_rule_registry
();
inline
bool
input_requires_grad
(
const
ApplyContext
&
ctx
,
size_t
i
)
{
return
bool
(
ctx
.
args
[
i
]
->
m_grad_info
.
grad_fn
);
}
struct
GradRuleFallback
:
std
::
exception
{};
template
<
typename
T
>
bool
register_grad_rule
(
Typeinfo
*
typeinfo
,
T
&&
rule
)
{
return
grad_rule_registry
().
emplace
(
typeinfo
,
std
::
forward
<
T
>
(
rule
)).
second
;
}
}
// namespace mgb::imperative::python
namespace
pybind11
::
detail
{
...
...
imperative/python/src/grad_override.cpp
0 → 100644
浏览文件 @
9fb5581f
/**
* \file imperative/python/src/grad_override.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./grad.h"
#include "megbrain/imperative/ops/autogen.h"
namespace
mgb
::
imperative
::
python
{
namespace
{
std
::
shared_ptr
<
Tensor
>
get_shape
(
Tensor
*
x
)
{
static
auto
op
=
GetVarShape
::
make
();
return
python
::
apply
(
op
,
x
)[
0
];
}
std
::
shared_ptr
<
Tensor
>
reduce_to
(
Tensor
*
x
,
Tensor
*
s
)
{
static
auto
op
=
Reduce
::
make
();
return
python
::
apply
(
op
,
x
,
s
)[
0
];
}
apply_result_t
elemwise_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&
op
=
ctx
.
op
->
cast_final_safe
<
Elemwise
>
();
if
(
op
.
mode
==
Elemwise
::
Mode
::
ADD
)
{
mgb_assert
(
ctx
.
nargs
==
2
);
std
::
array
<
std
::
shared_ptr
<
Tensor
>
,
2
>
input_shapes
;
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
if
(
input_requires_grad
(
ctx
,
i
))
{
input_shapes
[
i
]
=
get_shape
(
ctx
.
args
[
i
]);
}
}
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
shapes
=
std
::
move
(
input_shapes
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
2
);
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
if
(
shapes
[
i
])
{
ret
[
i
]
=
reduce_to
(
grad
,
shapes
[
i
].
get
());
}
}
return
ret
;
});
return
apply
(
ctx
);
}
throw
GradRuleFallback
();
}
struct
Init
{
Init
()
{
auto
&
reg
=
grad_rule_registry
();
reg
.
emplace
(
Elemwise
::
typeinfo
(),
elemwise_grad_rule
);
}
}
_
;
}
// namespace
}
// namespace mgb::imperative::python
imperative/python/src/tensor.h
浏览文件 @
9fb5581f
...
...
@@ -199,12 +199,59 @@ using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>;
apply_result_t
apply
(
ApplyContext
&
ctx
);
void
init_tensor
(
pybind11
::
module
);
template
<
typename
T
>
decltype
(
auto
)
resolve_arrow
(
T
&&
p
)
{
if
constexpr
(
std
::
is_pointer_v
<
std
::
remove_reference_t
<
T
>>
)
{
auto
*
ret
=
p
;
return
ret
;
}
else
{
auto
probe
=
[](
auto
&&
p
)
->
decltype
(
p
.
operator
->
())
{};
if
constexpr
(
std
::
is_invocable_v
<
decltype
(
probe
),
decltype
(
p
)
>
)
{
return
resolve_arrow
(
p
.
operator
->
());
}
else
{
return
p
;
}
}
}
template
<
typename
...
Args
>
constexpr
bool
is_all_tensor_ptr
=
(...
&&
std
::
is_same_v
<
decltype
(
resolve_arrow
(
std
::
declval
<
Args
>
())),
Tensor
*>
);
extern
bool
is_tracing
;
extern
bool
is_tracing
;
// FIXME: should use ApplyContext::global_enable
extern
bool
is_symbolic
;
extern
bool
is_compiled
;
template
<
typename
...
Args
,
std
::
enable_if_t
<
is_all_tensor_ptr
<
Args
...>,
int
>
=
0
>
apply_result_t
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
Args
&&
...
args
)
{
ApplyContext
ctx
;
Tensor
*
arg_arr
[]
=
{
resolve_arrow
(
args
)...};
ctx
.
flags
=
(
0
|
...
|
args
->
m_flags
);
ctx
.
flags
|=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
args
=
arg_arr
;
ctx
.
nargs
=
sizeof
...(
args
);
ctx
.
op
=
std
::
move
(
op
);
return
apply
(
ctx
);
}
template
<
typename
T
>
auto
apply
(
std
::
shared_ptr
<
OpDef
>
op
,
T
&&
tensors
)
->
std
::
enable_if_t
<
std
::
is_same_v
<
decltype
(
resolve_arrow
(
tensors
[
0
])),
Tensor
*>
,
apply_result_t
>
{
ApplyContext
ctx
;
ctx
.
op
=
std
::
move
(
op
);
ctx
.
flags
=
is_tracing
?
Tensor
::
Flags
::
TRACE
:
0
;
ctx
.
nargs
=
tensors
.
size
();
Tensor
*
args
[
ctx
.
nargs
];
ctx
.
args
=
args
;
for
(
size_t
i
=
0
;
i
<
ctx
.
nargs
;
++
i
)
{
args
[
i
]
=
resolve_arrow
(
tensors
[
i
]);
ctx
.
flags
|=
args
[
i
]
->
m_flags
;
}
return
apply
(
ctx
);
}
void
init_tensor
(
pybind11
::
module
);
extern
pybind11
::
object
cpp_apply_with_tracing
,
cpp_apply_compiled_mode
;
extern
pybind11
::
object
cpp_apply_backward_varnode
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录