Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
58df717e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
58df717e
编写于
4月 29, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/autodiff): fix attaching tensor already in gradient path
GitOrigin-RevId: da774509cabeb525ba717dbcb0ae88c3b0ad836b
上级
05186e7b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
78 addition
and
26 deletion
+78
-26
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+40
-0
imperative/src/impl/transformations/grad.cpp
imperative/src/impl/transformations/grad.cpp
+21
-25
imperative/src/include/megbrain/imperative/transformations/grad.h
...ve/src/include/megbrain/imperative/transformations/grad.h
+17
-1
未找到文件。
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
58df717e
...
...
@@ -136,6 +136,46 @@ def test_grad_with_tensor_wrapper():
np
.
testing
.
assert_almost_equal
(
x
.
grad
.
numpy
(),
4
*
x_np
**
3
,
decimal
=
6
)
def
test_wrt_intermediate_var
():
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
mge
.
Tensor
(
x_np
)
result
=
{}
with
Grad
()
as
grad
:
grad
.
wrt
(
x
,
callback
=
lambda
dx
:
result
.
update
(
dx
=
dx
))
y
=
mul
(
x
,
x
)
grad
.
wrt
(
y
,
callback
=
lambda
dy
:
result
.
update
(
dy
=
dy
))
z
=
mul
(
y
,
y
)
grad
(
z
,
mge
.
Tensor
(
np
.
ones_like
(
x_np
)))
np
.
testing
.
assert_almost_equal
(
result
[
"dx"
].
numpy
(),
4
*
x_np
**
3
,
decimal
=
6
)
np
.
testing
.
assert_almost_equal
(
result
[
"dy"
].
numpy
(),
2
*
(
x_np
**
2
),
decimal
=
6
)
@
pytest
.
mark
.
parametrize
(
"in_path"
,
[
False
,
True
])
def
test_wrt_visibility
(
in_path
):
x_np
=
np
.
random
.
rand
(
10
).
astype
(
"float32"
)
x
=
mge
.
Tensor
(
x_np
)
def
copy
(
x
):
xx
=
mge
.
Tensor
(
x
)
xx
.
_reset
(
x
)
return
xx
result
=
{}
with
Grad
()
as
grad
:
if
in_path
:
grad
.
wrt
(
x
,
callback
=
lambda
_
:
None
)
y
=
mul
(
x
,
x
)
grad
.
wrt
(
copy
(
y
),
callback
=
lambda
dy
:
result
.
update
(
dy
=
dy
))
z
=
mul
(
y
,
y
)
grad
(
z
,
mge
.
Tensor
(
np
.
ones_like
(
x_np
)))
assert
not
result
def
test_release
():
def
check
(
f
):
n
=
0
...
...
imperative/src/impl/transformations/grad.cpp
浏览文件 @
58df717e
...
...
@@ -265,20 +265,21 @@ void GradKey::backward() {
GradValue
::
ref_t
GradKey
::
attach
(
ValueRef
tensor
,
std
::
function
<
void
(
ValueRef
)
>
callback
)
{
auto
grad_value
=
tensor
.
as_ref
(
m_value_type
);
if
(
grad_value
)
{
mgb_assert
(
!
tensor
.
cast
(
m_value_type
).
slot
()
->
callback
,
"callback exists"
);
}
else
{
GradSlotPtr
grad_slot
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_slots
.
resize
(
1
);
grad_slot
.
m_index
=
0
;
grad_value
=
m_value_type
.
make
(
tensor
,
shared_from_this
(),
grad_slot
);
// always create a new grad value
GradSlotPtr
grad_slot
;
auto
&
grad_fn
=
grad_slot
.
m_fn
;
grad_fn
=
LocalPtr
<
GradFn
>::
make
();
grad_fn
->
m_key
=
shared_from_this
();
grad_fn
->
m_slots
.
resize
(
1
);
grad_fn
->
m_slots
[
0
].
callback
=
callback
;
grad_slot
.
m_index
=
0
;
if
(
auto
&&
grad_value
=
tensor
.
as_ref
(
m_value_type
))
{
grad_fn
->
m_backward
.
emplace
<
IdentityBackward
>
();
grad_fn
->
m_dests
.
push_back
(
grad_value
->
m_slot
);
tensor
=
grad_value
->
m_value
;
m_tape
.
emplace_back
(
grad_fn
,
nullptr
);
}
grad_value
->
slot
().
m_fn
->
m_slots
[
0
].
callback
=
callback
;
return
grad_value
;
return
m_value_type
.
make
(
tensor
,
shared_from_this
(),
grad_slot
);
}
void
GradKey
::
freeze
()
{
...
...
@@ -424,22 +425,17 @@ ValueRefList GradTransformation::apply_transformation(
return
outputs
;
}
else
if
(
op
.
is
<
CreateTensor
>
())
{
return
imperative
::
apply
(
op
,
inputs
);
}
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
auto
&
tensor
=
inputs
[
0
];
if
(
auto
&&
grad_value
=
tensor
.
as_ref
(
m_value_type
))
{
mgb_assert
(
!
has_key
(
attach_grad
->
key
()));
auto
output
=
fallback
()[
0
];
return
record_grad
(
m_value_type
.
make
(
output
,
m_key
,
grad_value
->
slot
()));
}
else
if
(
!
has_key
(
attach_grad
->
key
()))
{
}
else
if
(
auto
*
attach_grad
=
op
.
as
<
AttachGrad
>
())
{
if
(
!
has_key
(
attach_grad
->
key
()))
{
return
fallback
();
}
else
{
GenericFunction
callback
=
(
GenericFunction
&
)
inputs
[
1
].
cast
<
FunctionValue
>
();
auto
output
=
attach_grad
->
key
()
->
attach
(
tensor
,
[
callback
](
ValueRef
grad
)
{
auto
ret
=
callback
({
&
grad
,
1
});
assert
(
ret
.
empty
());
});
auto
output
=
attach_grad
->
key
()
->
attach
(
inputs
[
0
],
[
callback
](
ValueRef
grad
)
{
auto
ret
=
callback
({
&
grad
,
1
});
mgb_assert
(
ret
.
empty
());
});
return
{
record_grad
(
output
)};
}
}
else
if
(
auto
*
grad_backward
=
op
.
as
<
GradBackward
>
())
{
...
...
imperative/src/include/megbrain/imperative/transformations/grad.h
浏览文件 @
58df717e
...
...
@@ -83,6 +83,20 @@ public:
static
BackwardRule
lookup_grad_rule
(
Typeinfo
*
typeinfo
);
};
struct
IdentityBackward
{
bool
input_has_grad
(
size_t
i
)
{
mgb_assert
(
0
);
}
bool
output_requires_grad
(
size_t
i
)
{
mgb_assert
(
0
);
}
template
<
typename
F
>
void
operator
()(
Span
<
ValueRef
>
grads
,
F
&&
receiver
)
{
for
(
size_t
i
=
0
;
i
<
grads
.
size
();
++
i
)
{
if
(
grads
[
i
])
{
receiver
(
i
,
grads
[
i
]);
}
}
}
};
class
GradSlot
;
class
GradSlotPtr
;
class
GradSlotProducerPtr
;
...
...
@@ -165,7 +179,9 @@ private:
std
::
weak_ptr
<
GradKey
>
m_key
;
SmallVector
<
GradSlot
>
m_slots
;
SmallVector
<
GradSlotProducerPtr
>
m_dests
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
CustomBackward
>
m_backward
;
std
::
variant
<
std
::
monostate
,
BackwardGraphWithClosure
,
CustomBackward
,
IdentityBackward
>
m_backward
;
public:
void
clear
()
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录