Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cf3f58cb
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cf3f58cb
编写于
1月 04, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/autodiff): fix segfault when grad is nullptr
GitOrigin-RevId: 6139212bfdc75ac7af5275436f5557ba487673e7
上级
288c2e08
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
21 addition
and
9 deletion
+21
-9
imperative/python/src/grad_override.cpp
imperative/python/src/grad_override.cpp
+21
-9
未找到文件。
imperative/python/src/grad_override.cpp
浏览文件 @
cf3f58cb
...
...
@@ -59,6 +59,9 @@ apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& make
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
2
);
if
(
!
grad
)
{
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
if
(
shapes
[
i
])
{
ret
[
i
]
=
reduce_to
(
grad
,
shapes
[
i
].
get
());
...
...
@@ -84,6 +87,9 @@ apply_result_t reshape_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
2
);
if
(
!
grad
)
{
return
ret
;
}
for
(
size_t
i
=
0
;
i
<
2
;
++
i
)
{
if
(
shapes
[
i
])
{
ret
[
i
]
=
reshape_to
(
grad
,
shapes
[
i
].
get
());
...
...
@@ -107,10 +113,10 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
inputs
=
std
::
move
(
inputs
),
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
if
(
inputs
[
0
])
{
if
(
grad
&&
inputs
[
0
])
{
SmallVector
<
Tensor
*>
args_
(
inputs
.
size
()
+
1
);
Tensor
*
grad
=
grads
[
0
];
auto
&&
zeros
=
make_tensor
(
grad
->
comp_node
(),
inputs
[
0
].
get
());
args_
[
0
]
=
zeros
.
get
();
args_
[
1
]
=
grad
;
...
...
@@ -137,10 +143,10 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
inputs
=
std
::
move
(
inputs
),
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
if
(
inputs
[
0
])
{
if
(
grad
&&
inputs
[
0
])
{
SmallVector
<
Tensor
*>
args_
(
inputs
.
size
()
+
1
);
Tensor
*
grad
=
grads
[
0
];
auto
&&
zeros
=
make_tensor
(
grad
->
comp_node
(),
inputs
[
0
].
get
());
args_
[
0
]
=
zeros
.
get
();
args_
[
1
]
=
grad
;
...
...
@@ -167,7 +173,7 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
if
(
shapes
[
0
])
{
if
(
grad
&&
shapes
[
0
])
{
ret
[
0
]
=
broadcast_to
(
grad
,
shapes
[
0
].
get
());
}
return
ret
;
...
...
@@ -180,14 +186,17 @@ apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker)
apply_result_t
addAxis_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&&
op
=
ctx
.
op
->
cast_final_safe
<
AddAxis
>
();
mgb_assert
(
ctx
.
nargs
==
1
);
bool
flag
=
input_requires_grad
(
ctx
,
0
);
auto
&&
grad_op
=
RemoveAxis
::
make
(
op
.
axis
);
std
::
sort
(
grad_op
->
axis
.
begin
(),
grad_op
->
axis
.
end
(),
std
::
greater
<
int32_t
>
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)
,
flag_
=
flag
](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
ret
[
0
]
=
python
::
apply
(
grad_op_
,
grad
)[
0
];
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
python
::
apply
(
grad_op_
,
grad
)[
0
];
}
return
ret
;
});
return
apply
(
ctx
);
...
...
@@ -196,14 +205,17 @@ apply_result_t addAxis_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker
apply_result_t
removeAxis_grad_rule
(
ApplyContext
&
ctx
,
CustomBackward
::
Maker
&
maker
)
{
auto
&&
op
=
ctx
.
op
->
cast_final_safe
<
RemoveAxis
>
();
mgb_assert
(
ctx
.
nargs
==
1
);
bool
flag
=
input_requires_grad
(
ctx
,
0
);
auto
&&
grad_op
=
AddAxis
::
make
(
op
.
axis
);
std
::
sort
(
grad_op
->
axis
.
begin
(),
grad_op
->
axis
.
end
());
maker
.
output_size
(
1
).
output_captured
(
0
,
false
);
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
maker
.
backward
([
grad_op_
=
std
::
move
(
grad_op
)
,
flag_
=
flag
](
BackwardContext
&
,
Tensor
*
const
*
grads
,
size_t
ngrads
)
{
mgb_assert
(
ngrads
==
1
);
Tensor
*
grad
=
grads
[
0
];
apply_result_t
ret
(
1
);
ret
[
0
]
=
python
::
apply
(
grad_op_
,
grad
)[
0
];
if
(
grad
&&
flag_
)
{
ret
[
0
]
=
python
::
apply
(
grad_op_
,
grad
)[
0
];
}
return
ret
;
});
return
apply
(
ctx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录