Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b1bf193e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b1bf193e
编写于
1月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(functional/loss): add reduction choices to loss functions
GitOrigin-RevId: a29e6bb4cfeda8a5d56a50985f9dbc2d1f1be515
上级
36b1ba05
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
84 addition
and
14 deletion
+84
-14
imperative/python/megengine/functional/loss.py
imperative/python/megengine/functional/loss.py
+48
-14
imperative/python/test/unit/functional/test_loss.py
imperative/python/test/unit/functional/test_loss.py
+36
-0
未找到文件。
imperative/python/megengine/functional/loss.py
浏览文件 @
b1bf193e
...
...
@@ -6,8 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
functools
import
numpy
as
np
from
..core.tensor.array_method
import
_reduce
from
..tensor
import
Tensor
from
.elemwise
import
abs
,
log
from
.nn
import
indexing_one_hot
,
logsigmoid
,
logsumexp
,
relu
...
...
@@ -22,7 +25,26 @@ __all__ = [
]
def
l1_loss
(
pred
:
Tensor
,
label
:
Tensor
)
->
Tensor
:
def
_reduce_output
(
loss_fn
):
r
"""
Wrapper to apply canonical reductions to loss outputs.
"""
@
functools
.
wraps
(
loss_fn
)
def
reduced_loss_fn
(
*
args
,
reduction
=
"mean"
,
**
kwargs
):
loss
=
loss_fn
(
*
args
,
**
kwargs
)
if
reduction
==
"none"
:
return
loss
elif
reduction
in
(
"mean"
,
"sum"
):
return
_reduce
(
reduction
)(
loss
)
else
:
raise
ValueError
(
"{} is not a valid value for reduction"
.
format
(
reduction
))
return
reduced_loss_fn
@
_reduce_output
def
l1_loss
(
pred
:
Tensor
,
label
:
Tensor
,
reduction
:
str
=
"mean"
)
->
Tensor
:
r
"""
Calculates the mean absolute error (MAE) between
each element in the pred :math:`x` and label :math:`y`.
...
...
@@ -43,6 +65,7 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
:param pred: predicted result from model.
:param label: ground truth to compare.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
...
...
@@ -66,10 +89,11 @@ def l1_loss(pred: Tensor, label: Tensor) -> Tensor:
"""
diff
=
pred
-
label
return
abs
(
diff
)
.
mean
()
return
abs
(
diff
)
def
square_loss
(
pred
:
Tensor
,
label
:
Tensor
)
->
Tensor
:
@
_reduce_output
def
square_loss
(
pred
:
Tensor
,
label
:
Tensor
,
reduction
:
str
=
"mean"
)
->
Tensor
:
r
"""
Calculates the mean squared error (squared L2 norm) between
each element in the pred :math:`x` and label :math:`y`.
...
...
@@ -90,6 +114,7 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
:param pred: predicted result from model.
:param label: ground truth to compare.
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Shape:
...
...
@@ -118,15 +143,17 @@ def square_loss(pred: Tensor, label: Tensor) -> Tensor:
"""
diff
=
pred
-
label
return
(
diff
**
2
).
mean
()
return
diff
**
2
@
_reduce_output
def
cross_entropy
(
pred
:
Tensor
,
label
:
Tensor
,
axis
:
int
=
1
,
with_logits
:
bool
=
True
,
label_smooth
:
float
=
0
,
reduction
:
str
=
"mean"
,
)
->
Tensor
:
r
"""
Computes the multi-class cross entropy loss (using logits by default).
...
...
@@ -148,6 +175,7 @@ def cross_entropy(
:param axis: an axis along which softmax will be applied. Default: 1
:param with_logits: whether to apply softmax first. Default: True
:param label_smooth: a label smoothing of parameter that can re-distribute target distribution. Default: 0
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
...
...
@@ -182,20 +210,21 @@ def cross_entropy(
ls
=
label_smooth
if
with_logits
:
logZ
=
logsumexp
(
pred
,
axis
)
.
mean
()
primary_term
=
indexing_one_hot
(
pred
,
label
,
axis
)
.
mean
()
logZ
=
logsumexp
(
pred
,
axis
)
primary_term
=
indexing_one_hot
(
pred
,
label
,
axis
)
else
:
logZ
=
0
primary_term
=
log
(
indexing_one_hot
(
pred
,
label
,
axis
))
.
mean
()
primary_term
=
log
(
indexing_one_hot
(
pred
,
label
,
axis
))
if
ls
is
None
or
type
(
ls
)
in
(
int
,
float
)
and
ls
==
0
:
return
logZ
-
primary_term
if
not
with_logits
:
pred
=
log
(
pred
)
return
logZ
-
ls
*
pred
.
mean
()
-
(
1
-
ls
)
*
primary_term
return
logZ
-
ls
*
pred
.
mean
(
axis
)
-
(
1
-
ls
)
*
primary_term
@
_reduce_output
def
binary_cross_entropy
(
pred
:
Tensor
,
label
:
Tensor
,
with_logits
:
bool
=
True
pred
:
Tensor
,
label
:
Tensor
,
with_logits
:
bool
=
True
,
reduction
:
str
=
"mean"
,
)
->
Tensor
:
r
"""
Computes the binary cross entropy loss (using logits by default).
...
...
@@ -206,6 +235,7 @@ def binary_cross_entropy(
:param pred: `(N, *)`, where `*` means any number of additional dimensions.
:param label: `(N, *)`, same shape as the input.
:param with_logits: bool, whether to apply sigmoid first. Default: True
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
...
...
@@ -229,13 +259,16 @@ def binary_cross_entropy(
"""
if
not
with_logits
:
return
-
(
label
*
log
(
pred
)
+
(
1
-
label
)
*
log
(
1
-
pred
))
.
mean
()
return
-
(
label
*
log
(
pred
)
+
(
1
-
label
)
*
log
(
1
-
pred
))
# logsigmoid(pred) and logsigmoid(-pred) has common sub-expression
# hopefully the backend would optimize this
return
-
(
label
*
logsigmoid
(
pred
)
+
(
1
-
label
)
*
logsigmoid
(
-
pred
))
.
mean
()
return
-
(
label
*
logsigmoid
(
pred
)
+
(
1
-
label
)
*
logsigmoid
(
-
pred
))
def
hinge_loss
(
pred
:
Tensor
,
label
:
Tensor
,
norm
:
str
=
"L1"
)
->
Tensor
:
@
_reduce_output
def
hinge_loss
(
pred
:
Tensor
,
label
:
Tensor
,
norm
:
str
=
"L1"
,
reduction
:
str
=
"mean"
)
->
Tensor
:
r
"""
Caculates the hinge loss which is often used in SVM.
...
...
@@ -246,6 +279,7 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
:param pred: input tensor representing the predicted probability, shape is `(N, C)`.
:param label: input tensor representing the binary classification label, shape is `(N, C)`.
:param norm: specify the norm to caculate the loss, should be "L1" or "L2".
:param reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
:return: loss value.
Examples:
...
...
@@ -272,6 +306,6 @@ def hinge_loss(pred: Tensor, label: Tensor, norm: str = "L1") -> Tensor:
# Converts binary labels to -1/1 labels.
loss
=
relu
(
1.0
-
pred
*
label
)
if
norm
==
"L1"
:
return
loss
.
sum
(
axis
=
1
)
.
mean
()
return
loss
.
sum
(
axis
=
1
)
else
:
return
(
loss
**
2
).
sum
(
axis
=
1
)
.
mean
()
return
(
loss
**
2
).
sum
(
axis
=
1
)
imperative/python/test/unit/functional/test_loss.py
浏览文件 @
b1bf193e
...
...
@@ -7,6 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
pytest
import
megengine.functional
as
F
from
megengine
import
tensor
...
...
@@ -43,3 +44,38 @@ def test_cross_entropy():
l_ref
=
ref
(
x
,
y
)
l
=
F
.
nn
.
cross_entropy
(
tensor
(
x
,
"float32"
),
tensor
(
y
,
"int32"
),
with_logits
=
False
)
np
.
testing
.
assert_allclose
(
l
.
numpy
(),
l_ref
)
def
test_cross_entropy_reduction
():
logits
=
np
.
random
.
randn
(
16
,
10
)
label
=
np
.
random
.
randint
(
10
,
size
=
[
16
])
logits
=
tensor
(
logits
,
dtype
=
"float32"
)
label
=
tensor
(
label
,
dtype
=
"int32"
)
perm
=
np
.
random
.
permutation
(
16
)
logits_perm
=
tensor
(
logits
[
perm
],
dtype
=
"float32"
)
label_perm
=
tensor
(
label
[
perm
],
dtype
=
"int32"
)
loss
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"none"
)
loss_perm
=
F
.
nn
.
cross_entropy
(
logits_perm
,
label_perm
,
reduction
=
"none"
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
()[
perm
],
loss_perm
.
numpy
())
loss_sum
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"sum"
)
np
.
testing
.
assert_allclose
(
loss
.
numpy
().
sum
(),
loss_sum
.
numpy
(),
rtol
=
2e-7
)
loss_mean
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"mean"
)
np
.
testing
.
assert_allclose
(
loss_mean
.
numpy
(),
loss_sum
.
numpy
()
/
16
)
loss_ls
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"mean"
,
label_smooth
=
0.1
)
loss_ls_none_reduce
=
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"none"
,
label_smooth
=
0.1
)
np
.
testing
.
assert_allclose
(
loss_ls
.
numpy
(),
loss_ls_none_reduce
.
numpy
().
mean
(),
rtol
=
2e-7
)
with
pytest
.
raises
(
ValueError
):
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"MEAN"
)
with
pytest
.
raises
(
ValueError
):
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"max"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录