Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0075e6ac
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0075e6ac
编写于
10月 14, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge): refactor cross_entropy
GitOrigin-RevId: 1fac5b5b14e6de742f1373e6834384c12718ec25
上级
f4b16932
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
15 addition
and
26 deletion
+15
-26
imperative/python/megengine/functional/loss.py
imperative/python/megengine/functional/loss.py
+10
-21
imperative/python/test/unit/functional/test_loss.py
imperative/python/test/unit/functional/test_loss.py
+5
-5
未找到文件。
imperative/python/megengine/functional/loss.py
浏览文件 @
0075e6ac
...
...
@@ -176,30 +176,19 @@ def cross_entropy(
"target_ndim={}"
.
format
(
n0
,
n1
)
)
num_classes
=
pred
.
shape
[
axis
]
no_label_smooth
=
(
label_smooth
is
None
or
type
(
label_smooth
)
in
(
int
,
float
)
and
label_smooth
==
0
)
ls
=
label_smooth
if
with_logits
:
logZ
=
logsumexp
(
pred
,
axis
).
mean
()
primary_term
=
indexing_one_hot
(
pred
,
label
,
axis
).
mean
()
else
:
logZ
=
0
primary_term
=
log
(
indexing_one_hot
(
pred
,
label
,
axis
)).
mean
()
if
ls
is
None
or
type
(
ls
)
in
(
int
,
float
)
and
ls
==
0
:
return
logZ
-
primary_term
if
not
with_logits
:
if
no_label_smooth
:
return
-
log
(
indexing_one_hot
(
pred
,
label
,
axis
)).
mean
()
pred
=
log
(
pred
)
return
(
label_smooth
*
pred
.
mean
()
-
(
1
-
label_smooth
)
*
indexing_one_hot
(
pred
,
label
,
axis
).
mean
()
)
# Denominator of the softmax
down
=
logsumexp
(
pred
,
axis
=
axis
,
keepdims
=
True
)
up
=
indexing_one_hot
(
pred
,
label
,
axis
)
if
not
no_label_smooth
:
factor
=
label_smooth
/
num_classes
up
=
up
*
(
1
-
label_smooth
)
+
pred
.
sum
(
axis
=
axis
,
keepdims
=
True
)
*
factor
return
(
down
-
up
).
mean
()
return
logZ
-
ls
*
pred
.
mean
()
-
(
1
-
ls
)
*
primary_term
def
binary_cross_entropy
(
...
...
imperative/python/test/unit/functional/test_loss.py
浏览文件 @
0075e6ac
...
...
@@ -13,15 +13,15 @@ from megengine import tensor
def
test_cross_entropy_with_logits
():
data
=
tensor
([
1
,
100
]).
astype
(
np
.
float32
).
reshape
((
1
,
2
)
)
label
=
tensor
([
1
]).
astype
(
np
.
int32
)
data
=
tensor
([
[
0
,
50
],
[
0
,
-
150
]]).
astype
(
np
.
float32
)
label
=
tensor
([
1
,
0
]).
astype
(
np
.
int32
)
loss
=
F
.
nn
.
cross_entropy
(
data
,
label
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
0.0
)
label
=
tensor
([
0
]).
astype
(
np
.
int32
)
label
=
tensor
([
0
,
1
]).
astype
(
np
.
int32
)
loss
=
F
.
nn
.
cross_entropy
(
data
,
label
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
100
-
1
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
100
)
label
=
np
.
array
([
1
])
label
=
np
.
array
([
1
,
0
])
loss
=
F
.
nn
.
cross_entropy
(
data
,
label
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
(),
0.0
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录