Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
288c2e08
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
288c2e08
编写于
1月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/autodiff): fix expand_dims and grad rule fallback
GitOrigin-RevId: 4aae771222aa1e1a0d8bebd589f4e32c59044f4c
上级
a5609f3b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
47 addition
and
7 deletion
+47
-7
imperative/python/src/grad.cpp
imperative/python/src/grad.cpp
+3
-1
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+22
-6
imperative/python/test/unit/core/test_autodiff.py
imperative/python/test/unit/core/test_autodiff.py
+22
-0
未找到文件。
imperative/python/src/grad.cpp
浏览文件 @
288c2e08
...
...
@@ -309,6 +309,8 @@ public:
auto
&
emplace
(
Args
&&
...
args
)
{
return
get
()
->
backward
.
emplace
<
T
>
(
std
::
forward
<
Args
>
(
args
)...);
}
void
reset
()
{
grad_fn
=
nullptr
;
}
};
apply_result_t
backward_graph_grad_rule
(
ApplyContext
&
ctx
,
GradFnHelper
&
ret_grad_fn
)
{
...
...
@@ -398,7 +400,7 @@ apply_result_t apply_grad(ApplyContext& ctx) {
maker
.
finalize
();
return
ret
;
}
catch
(
GradRuleFallback
&
)
{
grad_fn_holder
.
emplace
<
std
::
monostate
>
();
grad_fn_holder
.
reset
();
}
}
return
backward_graph_grad_rule
(
ctx
,
grad_fn_holder
);
...
...
imperative/python/src/grad_override.cpp
浏览文件 @
288c2e08
...
...
@@ -177,11 +177,27 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
throw
GradRuleFallback
();
}
template
<
typename
T
,
typename
U
>
apply_result_t
axisAddRemove_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&&
op
=
ctx
.
op
->
cast_final_safe
<
T
>
();
apply_result_t
addAxis_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&&
op
=
ctx
.
op
->
cast_final_safe
<
AddAxis
>
();
mgb_assert
(
ctx
.
nargs
==
1
);
auto
&&
grad_op
=
U
::
make
(
op
.
axis
);
auto
&&
grad_op
=
RemoveAxis
::
make
(
op
.
axis
);
std
::
sort
(
grad_op
->
axis
.
begin
(),
grad_op
->
axis
.
end
(),
std
::
greater
<
int32_t
>
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
ret
[
0
]
=
python
::
apply
(
grad_op_
,
grad
)[
0
];
return
ret
;
});
return
apply
(
ctx
);
}
apply_result_t
removeAxis_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&&
op
=
ctx
.
op
->
cast_final_safe
<
RemoveAxis
>
();
mgb_assert
(
ctx
.
nargs
==
1
);
auto
&&
grad_op
=
AddAxis
::
make
(
op
.
axis
);
std
::
sort
(
grad_op
->
axis
.
begin
(),
grad_op
->
axis
.
end
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
...
...
@@ -201,8 +217,8 @@ struct Init {
reg
.
emplace
(
Subtensor
::
typeinfo
(),
subtensor_grad_rule
);
reg
.
emplace
(
IndexingMultiAxisVec
::
typeinfo
(),
indexingMultiAxisVec_grad_rule
);
reg
.
emplace
(
Reduce
::
typeinfo
(),
reduce_grad_rule
);
reg
.
emplace
(
AddAxis
::
typeinfo
(),
a
xisAddRemove_grad_rule
<
AddAxis
,
RemoveAxis
>
);
reg
.
emplace
(
RemoveAxis
::
typeinfo
(),
axisAddRemove_grad_rule
<
RemoveAxis
,
AddAxis
>
);
reg
.
emplace
(
AddAxis
::
typeinfo
(),
a
ddAxis_grad_rule
);
reg
.
emplace
(
RemoveAxis
::
typeinfo
(),
removeAxis_grad_rule
);
}
}
_
;
...
...
imperative/python/test/unit/core/test_autodiff.py
浏览文件 @
288c2e08
...
...
@@ -335,3 +335,25 @@ def test_Reduce_mean():
grad
(
y
,
F
.
ones_like
(
y
))
np
.
testing
.
assert_equal
(
np
.
ones
((
3
,
3
),
dtype
=
np
.
float32
)
/
3
,
x
.
grad
.
numpy
())
def
test_addAxis
():
x_np
=
np
.
random
.
rand
(
3
,
3
).
astype
(
"float32"
)
x
=
mge
.
Tensor
(
x_np
)
grad
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
y
=
F
.
expand_dims
(
x
,
[
2
,
3
])
grad
(
y
,
F
.
ones_like
(
y
))
np
.
testing
.
assert_equal
(
np
.
ones
((
3
,
3
),
dtype
=
np
.
float32
),
x
.
grad
.
numpy
())
def
test_removeAxis
():
x_np
=
np
.
random
.
rand
(
3
,
3
,
1
,
1
).
astype
(
"float32"
)
x
=
mge
.
Tensor
(
x_np
)
grad
=
Grad
().
wrt
(
x
,
callback
=
save_to
(
x
))
y
=
F
.
squeeze
(
x
,
[
2
,
3
])
grad
(
y
,
F
.
ones_like
(
y
))
np
.
testing
.
assert_equal
(
np
.
ones
((
3
,
3
,
1
,
1
),
dtype
=
np
.
float32
),
x
.
grad
.
numpy
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录