Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bf1a0fb7
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bf1a0fb7
编写于
2月 14, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(imperative): fix logsigmode bwd implementation
GitOrigin-RevId: 86de18760c1a298a7f5265e0959693c30366dd3f
上级
3a35827d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
125 addition
and
4 deletion
+125
-4
imperative/python/test/unit/functional/test_elemwise.py
imperative/python/test/unit/functional/test_elemwise.py
+112
-0
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+13
-4
未找到文件。
imperative/python/test/unit/functional/test_elemwise.py
浏览文件 @
bf1a0fb7
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
megengine
as
mge
import
megengine.autodiff
as
ad
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.functional.elemwise
as
elemwise
import
megengine.functional.elemwise
as
elemwise
from
megengine
import
tensor
from
megengine
import
tensor
from
megengine.core.tensor
import
dtype
from
megengine.core.tensor
import
dtype
from
megengine.core.tensor.utils
import
subgraph_fn
from
megengine.functional.elemwise
import
Elemwise
from
megengine.functional.elemwise
import
Elemwise
from
megengine.jit
import
trace
from
megengine.jit
import
trace
...
@@ -316,3 +318,113 @@ def test_maximum_grad_consistency(is_trace):
...
@@ -316,3 +318,113 @@ def test_maximum_grad_consistency(is_trace):
run
(
trace
(
symbolic
=
symbolic
)(
f
))
run
(
trace
(
symbolic
=
symbolic
)(
f
))
else
:
else
:
run
(
f
)
run
(
f
)
def
_get_logsigmoid_op
(
dtype
=
None
,
device
=
None
):
@
subgraph_fn
(
"LogSigmoid"
,
dtype
=
dtype
,
device
=
device
,
nr_inputs
=
1
,
jit_fusion
=
False
,
custom_grad
=
True
,
)
def
logsigmoid
(
inputs
,
f
,
c
):
(
inp
,)
=
inputs
[
0
:
1
]
neg_abs
=
f
(
"-"
,
f
(
"abs"
,
inp
))
exp
=
f
(
"exp"
,
neg_abs
)
oup0
=
f
(
"log1p"
,
exp
)
oup1
=
f
(
"relu"
,
f
(
"-"
,
inp
))
oup
=
f
(
"+"
,
oup0
,
oup1
)
oup
=
f
(
"-"
,
oup
)
(
oup_grad
,)
=
yield
(
oup
,)
oup_grad
=
f
(
"-"
,
oup_grad
)
inp_grad_0
=
f
(
"switch_gt0"
,
oup1
,
oup_grad
)
inp_grad_0
=
f
(
"-"
,
inp_grad_0
)
inp_grad_1
=
oup_grad
inp_grad_1
=
f
(
"/"
,
inp_grad_1
,
f
(
"+"
,
exp
,
c
(
1
)))
inp_grad_1
=
f
(
"*"
,
inp_grad_1
,
exp
)
inp_grad_1
=
f
(
"-"
,
inp_grad_1
)
inp_grad_1
=
f
(
"abs_grad"
,
inp
,
inp_grad_1
)
inp_grad
=
f
(
"+"
,
inp_grad_0
,
inp_grad_1
)
yield
(
inp_grad
,)
return
logsigmoid
def
origin_logsigmoid
(
inp
:
mge
.
tensor
)
->
mge
.
tensor
:
logsigmoid
=
_get_logsigmoid_op
(
inp
.
dtype
,
inp
.
device
)
(
oup
,)
=
logsigmoid
(
inp
)
return
oup
def
_get_softplus_op
(
dtype
=
None
,
device
=
None
):
@
subgraph_fn
(
"Softplus"
,
dtype
=
dtype
,
device
=
device
,
nr_inputs
=
1
,
jit_fusion
=
False
,
custom_grad
=
True
,
)
def
softplus
(
inputs
,
f
,
c
):
(
inp
,)
=
inputs
[
0
:
1
]
neg_abs
=
f
(
"-"
,
f
(
"abs"
,
inp
))
exp
=
f
(
"exp"
,
neg_abs
)
oup0
=
f
(
"log1p"
,
exp
)
oup1
=
f
(
"relu"
,
inp
)
oup
=
f
(
"+"
,
oup0
,
oup1
)
(
oup_grad
,)
=
yield
(
oup
,)
inp_grad_0
=
f
(
"switch_gt0"
,
oup1
,
oup_grad
)
inp_grad_1
=
oup_grad
inp_grad_1
=
f
(
"/"
,
oup_grad
,
f
(
"+"
,
exp
,
c
(
1
)))
inp_grad_1
=
f
(
"*"
,
inp_grad_1
,
exp
)
inp_grad_1
=
f
(
"-"
,
inp_grad_1
)
inp_grad_1
=
f
(
"abs_grad"
,
inp
,
inp_grad_1
)
inp_grad
=
f
(
"+"
,
inp_grad_0
,
inp_grad_1
)
yield
(
inp_grad
,)
return
softplus
def
origin_softplus
(
inp
:
mge
.
tensor
)
->
mge
.
tensor
:
softplus
=
_get_softplus_op
(
inp
.
dtype
,
inp
.
device
)
(
oup
,)
=
softplus
(
inp
)
return
oup
def
test_subgraph_elemwise_mode
():
def
_test_allclose
(
func
,
ori_func
):
targets
=
np
.
array
(
2
)
inp
=
np
.
random
.
randn
(
2
,
256
,
10
,
16
).
astype
(
"float32"
)
ori_inp
=
mge
.
tensor
(
inp
)
mge_inp
=
mge
.
tensor
(
inp
)
mge_gm
=
mge
.
autodiff
.
GradManager
().
attach
(
mge_inp
)
ori_gm
=
mge
.
autodiff
.
GradManager
().
attach
(
ori_inp
)
for
_
in
range
(
2
):
with
mge_gm
:
mge_output
=
func
(
mge_inp
)
loss
=
F
.
loss
.
square_loss
(
mge_output
.
sum
(),
mge
.
tensor
(
targets
,
dtype
=
np
.
float32
)
)
mge_gm
.
backward
(
loss
)
with
ori_gm
:
ori_output
=
ori_func
(
ori_inp
)
loss
=
F
.
loss
.
square_loss
(
ori_output
.
sum
(),
mge
.
tensor
(
targets
,
dtype
=
np
.
float32
)
)
ori_gm
.
backward
(
loss
)
np
.
testing
.
assert_allclose
(
mge_output
.
numpy
(),
ori_output
.
numpy
(),
rtol
=
1e-06
)
np
.
testing
.
assert_allclose
(
ori_inp
.
grad
.
numpy
(),
mge_inp
.
grad
.
numpy
(),
rtol
=
1e-06
)
_test_allclose
(
F
.
logsigmoid
,
origin_logsigmoid
)
_test_allclose
(
F
.
softplus
,
origin_softplus
)
src/opr/impl/basic_arith.cpp
浏览文件 @
bf1a0fb7
...
@@ -559,12 +559,21 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
...
@@ -559,12 +559,21 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
}
}
case
Mode
::
RELU6
:
case
Mode
::
RELU6
:
RET
(
EL2
(
RELU6_GRAD
,
i0
,
og
));
RET
(
EL2
(
RELU6_GRAD
,
i0
,
og
));
case
Mode
::
SOFTPLUS
:
case
Mode
::
SOFTPLUS
:
{
RET
(
EL2
(
SOFTPLUS_GRAD
,
i0
,
og
));
auto
abse
=
EL1
(
EXP
,
EL1
(
NEGATE
,
EL1
(
ABS
,
i0
)));
auto
logg
=
og
*
abse
/
(
1
+
abse
);
auto
absg
=
EL2
(
ABS_GRAD
,
i0
,
EL1
(
NEGATE
,
logg
));
RET
(
EL2
(
ADD
,
absg
,
EL2
(
SWITCH_GT0
,
EL1
(
RELU
,
i0
),
og
)));
}
case
Mode
::
HSIGMOID
:
case
Mode
::
HSIGMOID
:
RET
(
EL2
(
HSIGMOID_GRAD
,
i0
,
og
));
RET
(
EL2
(
HSIGMOID_GRAD
,
i0
,
og
));
case
Mode
::
LOGSIGMOID
:
case
Mode
::
LOGSIGMOID
:
{
RET
(
EL2
(
SOFTPLUS_GRAD
,
EL1
(
NEGATE
,
i0
),
og
));
og
=
EL1
(
NEGATE
,
og
);
auto
abse
=
EL1
(
EXP
,
EL1
(
NEGATE
,
EL1
(
ABS
,
i0
)));
auto
logg
=
og
*
abse
/
(
1
+
abse
);
auto
absg
=
EL2
(
ABS_GRAD
,
i0
,
EL1
(
NEGATE
,
logg
));
RET
(
EL2
(
SUB
,
absg
,
EL2
(
SWITCH_GT0
,
EL1
(
RELU
,
EL1
(
NEGATE
,
i0
)),
og
)));
}
case
Mode
::
SQRT
:
case
Mode
::
SQRT
:
RET
(
og
/
EL1
(
SQRT
,
i0
)
/
2
);
RET
(
og
/
EL1
(
SQRT
,
i0
)
/
2
);
case
Mode
::
SQUARE
:
case
Mode
::
SQUARE
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录