Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2b254d61
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2b254d61
编写于
2月 26, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(dnn): fix softplus bwd kernel
GitOrigin-RevId: 1f01ab5592f29ead271d02f7de15cc1c8a65df44
上级
ee124dd3
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
24 addition
and
16 deletion
+24
-16
dnn/src/common/elemwise/kern_defs.cuh
dnn/src/common/elemwise/kern_defs.cuh
+10
-1
imperative/python/test/unit/functional/test_elemwise.py
imperative/python/test/unit/functional/test_elemwise.py
+1
-1
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+4
-13
src/opr/test/basic_arith/elemwise.cpp
src/opr/test/basic_arith/elemwise.cpp
+8
-0
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
+1
-1
未找到文件。
dnn/src/common/elemwise/kern_defs.cuh
浏览文件 @
2b254d61
...
...
@@ -85,6 +85,15 @@ __device__ __host__ inline float gelu_grad(float x, float dy) {
return
dy
*
(
normcdf_v
+
x
*
phi
);
}
//! grad of softplus
__device__
__host__
inline
float
softplus_grad
(
float
x
,
float
dy
)
{
float
logg
=
-
dy
*
expf
(
-
fabs
(
x
))
/
(
1.
f
+
expf
(
-
fabs
(
x
)));
float
grad0
=
x
>
0.
f
?
logg
:
-
logg
;
float
relux
=
x
<
0.
f
?
0.
f
:
x
;
float
grad1
=
relux
>
0.
f
?
dy
:
0.
f
;
return
grad0
+
grad1
;
}
__device__
__host__
inline
bool
feq
(
float
a
,
float
b
)
{
return
fabsf
(
a
-
b
)
<
1e-6
;
}
...
...
@@ -287,7 +296,7 @@ DEF_KERN_FLOAT(GELU_GRAD, gelu_grad(x, y));
DEF_KERN_FLOAT
(
ASINH_GRAD
,
y
/
sqrt
(
x
*
x
+
1.
f
));
DEF_KERN_FLOAT
(
ACOSH_GRAD
,
y
/
sqrt
(
x
*
x
-
1.
f
));
DEF_KERN_FLOAT
(
ATANH_GRAD
,
y
/
(
1.
f
-
x
*
x
));
DEF_KERN_FLOAT
(
SOFTPLUS_GRAD
,
y
*
expf
(
x
)
/
(
1.
f
+
expf
(
x
)
));
DEF_KERN_FLOAT
(
SOFTPLUS_GRAD
,
softplus_grad
(
x
,
y
));
DEF_KERN_FLOAT
(
RELU6_GRAD
,
x
<=
ctype
(
0
)
?
ctype
(
0
)
:
(
x
>=
ctype
(
6
)
?
ctype
(
0
)
:
y
));
DEF_KERN_FLOAT
(
HSIGMOID_GRAD
,
...
...
imperative/python/test/unit/functional/test_elemwise.py
浏览文件 @
2b254d61
...
...
@@ -397,7 +397,7 @@ def origin_softplus(inp: mge.tensor) -> mge.tensor:
def
test_subgraph_elemwise_mode
():
def
_test_allclose
(
func
,
ori_func
):
targets
=
np
.
array
(
2
)
inp
=
np
.
random
.
randn
(
2
,
256
,
10
,
16
).
astype
(
"float32"
)
inp
=
np
.
random
.
uniform
(
size
=
(
2
,
16
,
10
,
16
)).
astype
(
np
.
float32
)
ori_inp
=
mge
.
tensor
(
inp
)
mge_inp
=
mge
.
tensor
(
inp
)
...
...
src/opr/impl/basic_arith.cpp
浏览文件 @
2b254d61
...
...
@@ -559,21 +559,12 @@ MGB_IMPL_OPR_GRAD(Elemwise) {
}
case
Mode
::
RELU6
:
RET
(
EL2
(
RELU6_GRAD
,
i0
,
og
));
case
Mode
::
SOFTPLUS
:
{
auto
abse
=
EL1
(
EXP
,
EL1
(
NEGATE
,
EL1
(
ABS
,
i0
)));
auto
logg
=
og
*
abse
/
(
1
+
abse
);
auto
absg
=
EL2
(
ABS_GRAD
,
i0
,
EL1
(
NEGATE
,
logg
));
RET
(
EL2
(
ADD
,
absg
,
EL2
(
SWITCH_GT0
,
EL1
(
RELU
,
i0
),
og
)));
}
case
Mode
::
SOFTPLUS
:
RET
(
EL2
(
SOFTPLUS_GRAD
,
i0
,
og
));
case
Mode
::
HSIGMOID
:
RET
(
EL2
(
HSIGMOID_GRAD
,
i0
,
og
));
case
Mode
::
LOGSIGMOID
:
{
og
=
EL1
(
NEGATE
,
og
);
auto
abse
=
EL1
(
EXP
,
EL1
(
NEGATE
,
EL1
(
ABS
,
i0
)));
auto
logg
=
og
*
abse
/
(
1
+
abse
);
auto
absg
=
EL2
(
ABS_GRAD
,
i0
,
EL1
(
NEGATE
,
logg
));
RET
(
EL2
(
SUB
,
absg
,
EL2
(
SWITCH_GT0
,
EL1
(
RELU
,
EL1
(
NEGATE
,
i0
)),
og
)));
}
case
Mode
::
LOGSIGMOID
:
RET
(
EL2
(
SOFTPLUS_GRAD
,
-
i0
,
og
));
case
Mode
::
SQRT
:
RET
(
og
/
EL1
(
SQRT
,
i0
)
/
2
);
case
Mode
::
SQUARE
:
...
...
src/opr/test/basic_arith/elemwise.cpp
浏览文件 @
2b254d61
...
...
@@ -77,6 +77,14 @@ float do_fuse_add_h_swish(float x, float y) {
return
z
*
fmaxf
(
fminf
(
z
+
3.
f
,
6.
f
),
0.
f
)
/
6.
f
;
}
float
do_softplus_grad
(
float
x
,
float
y
)
{
float
logg
=
-
y
*
expf
(
-
fabs
(
x
))
/
(
1.
f
+
expf
(
-
fabs
(
x
)));
float
grad0
=
x
>
0.
f
?
logg
:
-
logg
;
float
relux
=
x
<
0.
f
?
0.
f
:
x
;
float
grad1
=
relux
>
0.
f
?
y
:
0.
f
;
return
grad0
+
grad1
;
}
template
<
typename
T
>
T
do_shl
(
T
,
T
);
// undefined
template
<
typename
T
>
...
...
src/opr/test/basic_arith/elemwise_binary_trait_def.inl
浏览文件 @
2b254d61
...
...
@@ -61,7 +61,7 @@ DEF_TRAIT(GELU_GRAD, do_gelu_grad(x, y))
DEF_TRAIT
(
ASINH_GRAD
,
y
/
std
::
sqrt
(
x
*
x
+
1
))
DEF_TRAIT
(
ACOSH_GRAD
,
y
/
std
::
sqrt
(
x
*
x
-
1
))
DEF_TRAIT
(
ATANH_GRAD
,
y
/
(
1
-
x
*
x
))
DEF_TRAIT
(
SOFTPLUS_GRAD
,
y
*
std
::
exp
(
x
)
/
(
1.
f
+
std
::
exp
(
x
)
))
DEF_TRAIT
(
SOFTPLUS_GRAD
,
do_softplus_grad
(
x
,
y
))
DEF_TRAIT
(
RELU6_GRAD
,
x
<=
0.
f
?
0.
f
:
(
x
>=
6.
f
?
0.
f
:
y
))
DEF_TRAIT
(
HSIGMOID_GRAD
,
x
<=
-
3.
f
?
0.
f
:
(
x
>=
3.
f
?
0.
f
:
(
y
/
6.
f
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录