Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3c49d1d3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3c49d1d3
编写于
5月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/functional): add hinge loss
GitOrigin-RevId: 64c89c1f8c4e4ecbaf6892f9570c5d9db0027a1d
上级
dd8f3ffc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
71 addition
and
1 deletion
+71
-1
python_module/megengine/functional/__init__.py
python_module/megengine/functional/__init__.py
+1
-0
python_module/megengine/functional/loss.py
python_module/megengine/functional/loss.py
+44
-1
python_module/test/unit/functional/test_functional.py
python_module/test/unit/functional/test_functional.py
+26
-0
未找到文件。
python_module/megengine/functional/__init__.py
浏览文件 @
3c49d1d3
...
@@ -43,6 +43,7 @@ from .loss import (
...
@@ -43,6 +43,7 @@ from .loss import (
binary_cross_entropy
,
binary_cross_entropy
,
cross_entropy
,
cross_entropy
,
cross_entropy_with_softmax
,
cross_entropy_with_softmax
,
hinge_loss
,
l1_loss
,
l1_loss
,
nll_loss
,
nll_loss
,
square_loss
,
square_loss
,
...
...
python_module/megengine/functional/loss.py
浏览文件 @
3c49d1d3
...
@@ -9,8 +9,9 @@
...
@@ -9,8 +9,9 @@
import
megengine._internal
as
mgb
import
megengine._internal
as
mgb
from
..core.tensor
import
Tensor
from
..core.tensor
import
Tensor
from
.elemwise
import
abs
,
equal
,
log
,
maximum
,
power
from
.elemwise
import
abs
,
equal
,
log
,
maximum
,
power
,
relu
from
.nn
import
assert_equal
,
indexing_one_hot
from
.nn
import
assert_equal
,
indexing_one_hot
from
.tensor
import
where
from
.utils
import
zero_grad
from
.utils
import
zero_grad
...
@@ -297,3 +298,45 @@ def nll_loss(
...
@@ -297,3 +298,45 @@ def nll_loss(
loss
=
indexing_one_hot
(
pred
,
label
,
axis
)
*
mask
loss
=
indexing_one_hot
(
pred
,
label
,
axis
)
*
mask
return
-
1.0
*
loss
.
sum
()
/
maximum
(
mask
.
sum
(),
1.0
)
return
-
1.0
*
loss
.
sum
()
/
maximum
(
mask
.
sum
(),
1.0
)
def
hinge_loss
(
pred
:
Tensor
,
label
:
Tensor
,
norm
:
str
=
"L1"
)
->
Tensor
:
r
"""
Caculate the hinge loss which is often used in SVMs.
The hinge loss can be described as:
.. math:: loss(x, y) = \frac{1}{N}\sum_i\sum_j(max(0, 1 - x_i_j*y_i_j))
:param pred: The input tensor representing the predicted probability, shape is (N, C).
:param label: The input tensor representing the binary classification label, shape is (N, C).
:param norm: Specify the norm to caculate the loss, should be "L1" or "L2".
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[0.5, -0.5, 0.1], [-0.6, 0.7, 0.8]])
label = tensor([[1, -1, -1], [-1, 1, 1]])
loss = F.hinge_loss(pred, label)
print(loss.numpy())
Outputs:
.. testoutput::
[1.5]
"""
assert
norm
in
[
"L1"
,
"L2"
],
"norm must be L1 or L2"
# Converts binary labels to -1/1 labels.
loss
=
relu
(
1.0
-
pred
*
label
)
if
norm
==
"L1"
:
return
loss
.
sum
(
axis
=
1
).
mean
()
else
:
return
(
loss
**
2
).
sum
(
axis
=
1
).
mean
()
python_module/test/unit/functional/test_functional.py
浏览文件 @
3c49d1d3
...
@@ -336,6 +336,32 @@ def test_binary_cross_entropy():
...
@@ -336,6 +336,32 @@ def test_binary_cross_entropy():
opr_test
(
cases
,
F
.
binary_cross_entropy
,
compare_fn
=
compare_fn
)
opr_test
(
cases
,
F
.
binary_cross_entropy
,
compare_fn
=
compare_fn
)
def
test_hinge_loss
():
np
.
random
.
seed
(
123
)
# case with L1 norm
cases
=
[]
for
shape
in
[(
2
,
2
),
(
2
,
3
)]:
data
=
np
.
random
.
uniform
(
size
=
shape
).
astype
(
np
.
float32
)
label
=
2
*
np
.
random
.
randint
(
0
,
1
,
size
=
shape
).
astype
(
np
.
int32
)
-
1
expect
=
np
.
clip
(
0
,
np
.
inf
,
1
-
data
*
label
).
sum
(
axis
=
1
).
mean
()
cases
.
append
({
"input"
:
[
data
,
label
],
"output"
:
tensor
(
expect
)})
opr_test
(
cases
,
F
.
hinge_loss
)
# cases with L2 norm
cases
=
[]
for
shape
in
[(
2
,
2
),
(
2
,
3
)]:
data
=
np
.
random
.
uniform
(
size
=
shape
).
astype
(
np
.
float32
)
label
=
2
*
np
.
random
.
randint
(
0
,
1
,
size
=
shape
).
astype
(
np
.
int32
)
-
1
expect
=
((
np
.
clip
(
0
,
np
.
inf
,
1
-
data
*
label
)
**
2
).
sum
(
axis
=
1
)).
mean
()
cases
.
append
({
"input"
:
[
data
,
label
],
"output"
:
tensor
(
expect
)})
def
hinge_loss_with_l2_norm
(
pred
,
label
):
return
F
.
hinge_loss
(
pred
,
label
,
"L2"
)
opr_test
(
cases
,
hinge_loss_with_l2_norm
)
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
skip
def
test_conv_bias
():
def
test_conv_bias
():
inp_scale
=
0.01
inp_scale
=
0.01
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录