Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
cebda6ff
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
cebda6ff
编写于
9月 24, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/imperative): add ctc loss
GitOrigin-RevId: e29854a98e9d372c2802b073a03c0fc6f29f25ac
上级
f5cb21ed
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
307 addition
and
2 deletion
+307
-2
imperative/python/megengine/core/tensor/array_method.py
imperative/python/megengine/core/tensor/array_method.py
+1
-0
imperative/python/megengine/functional/elemwise.py
imperative/python/megengine/functional/elemwise.py
+7
-0
imperative/python/megengine/functional/loss.py
imperative/python/megengine/functional/loss.py
+164
-2
imperative/python/test/unit/functional/test_elemwise.py
imperative/python/test/unit/functional/test_elemwise.py
+10
-0
imperative/python/test/unit/functional/test_loss.py
imperative/python/test/unit/functional/test_loss.py
+125
-0
未找到文件。
imperative/python/megengine/core/tensor/array_method.py
浏览文件 @
cebda6ff
...
...
@@ -61,6 +61,7 @@ def _elwise(*args, mode):
_ElwMod
.
H_SWISH
,
_ElwMod
.
SIGMOID
,
_ElwMod
.
SIN
,
_ElwMod
.
LOG_SUM_EXP
,
)
and
(
amp
.
_enabled
or
np
.
all
([
np
.
issubdtype
(
arg
.
dtype
,
np
.
integer
)
for
arg
in
args
])
):
...
...
imperative/python/megengine/functional/elemwise.py
浏览文件 @
cebda6ff
...
...
@@ -48,6 +48,7 @@ __all__ = [
"logical_not"
,
"logical_or"
,
"logical_xor"
,
"logaddexp"
,
"maximum"
,
"minimum"
,
"mod"
,
...
...
@@ -406,6 +407,12 @@ def logical_xor(x, y):
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
XOR
)
def
logaddexp
(
x
:
Tensor
,
y
:
Tensor
)
->
Tensor
:
r
"""Element-wise `numerically stable log(exp(x) + exp(y)`
"""
return
_elwise
(
x
,
y
,
mode
=
Elemwise
.
Mode
.
LOG_SUM_EXP
)
# comparison functions
...
...
imperative/python/megengine/functional/loss.py
浏览文件 @
cebda6ff
...
...
@@ -12,9 +12,9 @@ import numpy as np
from
..core.tensor.array_method
import
_reduce
from
..tensor
import
Tensor
from
.elemwise
import
abs
,
log
from
.elemwise
import
abs
,
equal
,
log
,
logaddexp
,
maximum
from
.nn
import
indexing_one_hot
,
logsigmoid
,
logsumexp
,
relu
from
.tensor
import
where
from
.tensor
import
broadcast_to
,
cumsum
,
linspace
,
ones
,
where
,
zeros
__all__
=
[
"l1_loss"
,
...
...
@@ -22,6 +22,7 @@ __all__ = [
"cross_entropy"
,
"binary_cross_entropy"
,
"hinge_loss"
,
"ctc_loss"
,
]
...
...
@@ -316,3 +317,164 @@ def hinge_loss(
return
loss
.
sum
(
axis
=
1
)
else
:
return
(
loss
**
2
).
sum
(
axis
=
1
)
def
_gen_repeat_idx
(
inp
:
Tensor
):
idx
=
cumsum
(
inp
,
axis
=
0
)
ret
=
zeros
(
inp
.
sum
(),
dtype
=
"int32"
)
ret
[
idx
[:
-
1
]]
=
1
return
cumsum
(
ret
,
axis
=
0
)
def
_gen_tile_idx
(
inp
:
Tensor
):
idx
=
cumsum
(
inp
,
axis
=
0
)
ret
=
ones
(
inp
.
sum
(),
dtype
=
"int32"
)
ret
[
idx
[:
-
1
]]
=
-
(
inp
-
1
)[:
-
1
]
return
cumsum
(
ret
,
axis
=
0
)
-
1
def
_expand_label
(
label
:
Tensor
,
label_lengths
:
Tensor
,
blank
:
int
)
->
Tensor
:
N
=
label_lengths
.
shape
[
0
]
if
len
(
label
.
shape
)
==
1
:
L
=
label_lengths
.
max
()
unpack_label
=
zeros
((
N
,
L
),
dtype
=
"int32"
)
+
blank
idx_0
=
_gen_repeat_idx
(
label_lengths
)
idx_1
=
_gen_tile_idx
(
label_lengths
)
unpack_label
[
idx_0
,
idx_1
]
=
label
label
=
unpack_label
L
=
label
.
shape
[
1
]
ex_label
=
zeros
((
N
,
L
*
2
+
1
),
dtype
=
"int32"
)
+
blank
ex_label
[:,
1
::
2
]
=
label
return
ex_label
def
_safelog
(
x
:
Tensor
)
->
Tensor
:
eps
=
np
.
finfo
(
x
.
dtype
).
tiny
return
log
(
maximum
(
x
,
eps
))
def
ctc_loss
(
pred
:
Tensor
,
pred_lengths
:
Tensor
,
label
:
Tensor
,
label_lengths
:
Tensor
,
blank
:
int
=
0
,
reduction
:
str
=
"mean"
,
)
->
Tensor
:
r
"""The Connectionist Temporal Classification loss.
Args:
pred: The probabilities of the output, shape is (T, N, C) ,
where T=input length, N=batch size, and C=number of classes (including blank).
pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, )
label: groundtruth labels, containing the indices of groundtruth
symbols for each sequence at each output time step, and the blank
symbol should not be included. shape is (N, S) or (sum(label_lengths)).
label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, )
blank: the blank symbol number, default 0
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Returns:
loss value.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
pred_length = tensor([2,2])
label = tensor([1,1])
label_lengths = tensor([1,1])
loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths)
print(loss.numpy())
Outputs:
.. testoutput::
0.1504417
"""
T
,
N
,
C
=
pred
.
shape
assert
(
pred_lengths
.
size
==
N
),
"pred_lengths must be equal to batch_size {}, but got {}"
.
format
(
N
,
pred_lengths
.
size
)
assert
(
label_lengths
.
size
==
N
),
"label_lengths must be euqal to batch_size {}, but got {}"
.
format
(
N
,
label_lengths
.
size
)
assert
(
blank
>=
0
and
blank
<
C
),
"blank must be in label range [0, {}), but got {}"
.
format
(
C
,
blank
)
assert
(
pred_lengths
.
min
()
>
0
and
pred_lengths
.
max
()
<=
T
),
"pred_lengths must be in range ({}, {}], bug got min {}, max {}"
.
format
(
0
,
T
,
pred_lengths
.
min
(),
pred_lengths
.
max
()
)
if
label
.
ndim
==
1
:
# concatenated label
assert
label_lengths
.
min
()
>
0
,
"label lengths muse be positive"
assert
(
label
.
size
==
label_lengths
.
sum
()
),
"label size must be equal to sum(label_lengths)"
else
:
N
,
S
=
label
.
shape
assert
(
label_lengths
.
min
()
>
0
and
label_lengths
.
max
()
<=
S
),
"label_lengths must be in range ({}, {}], bug got min {}, max {}"
.
format
(
0
,
S
,
label_lengths
.
min
(),
label_lengths
.
max
()
)
label
=
_expand_label
(
label
,
label_lengths
,
blank
)
label_mask
=
label
[:,
2
:]
!=
label
[:,
:
-
2
]
L
=
label
.
shape
[
1
]
pred
=
pred
.
transpose
(
1
,
0
,
2
)
# (T, N, C) -> (N, T, C)
batch_idx
=
linspace
(
0
,
N
-
1
,
N
).
astype
(
"int32"
).
reshape
(
-
1
)
batch_idx_NL
=
broadcast_to
(
batch_idx
.
reshape
(
N
,
1
),
(
N
,
L
)).
reshape
(
-
1
)
match_pred
=
pred
[
batch_idx_NL
,
:,
label
.
reshape
(
-
1
)].
reshape
(
N
,
L
,
-
1
)
# (N, T, C) -> (N, L, T)
log_alpha
=
zeros
((
N
,
L
),
dtype
=
"float32"
)
log_alpha
[:,
:
2
]
=
match_pred
[:,
:
2
,
0
]
log_alpha
=
_safelog
(
log_alpha
)
ret
=
-
logaddexp
(
log_alpha
[
batch_idx
,
label_lengths
*
2
],
log_alpha
[
batch_idx
,
label_lengths
*
2
-
1
],
)
*
equal
(
pred_lengths
-
1
,
0
)
for
t
in
range
(
1
,
T
):
la2
=
log_alpha
[:,
:
-
2
]
log_alpha
[:,
1
:]
=
logaddexp
(
log_alpha
[:,
1
:],
log_alpha
[:,
:
-
1
])
log_alpha
[:,
2
:]
=
(
log_alpha
[:,
2
:]
*
(
1
-
label_mask
)
+
logaddexp
(
log_alpha
[:,
2
:],
la2
)
*
label_mask
)
log_alpha
+=
_safelog
(
match_pred
[:,
:,
t
])
ret_t
=
-
logaddexp
(
log_alpha
[
batch_idx
,
label_lengths
*
2
],
log_alpha
[
batch_idx
,
label_lengths
*
2
-
1
],
)
ret
+=
ret_t
*
equal
(
pred_lengths
-
1
,
t
)
if
reduction
==
"mean"
:
return
(
ret
/
label_lengths
).
mean
()
elif
reduction
==
"sum"
:
return
ret
.
sum
()
elif
reduction
==
"none"
:
return
ret
else
:
raise
ValueError
(
"{} is not a valid value for reduction"
.
format
(
reduction
))
imperative/python/test/unit/functional/test_elemwise.py
浏览文件 @
cebda6ff
...
...
@@ -170,6 +170,16 @@ def test_logical_oprs():
np
.
testing
.
assert_equal
(
x
^
y
,
F
.
logical_xor
(
xx
,
yy
).
numpy
())
def
test_logaddexp
():
x
=
np
.
random
.
randn
(
2
,
100
)
y
=
np
.
random
.
randn
(
2
,
100
)
xx
=
tensor
(
x
)
yy
=
tensor
(
y
)
out_np
=
np
.
log
(
np
.
exp
(
x
)
+
np
.
exp
(
y
))
out_mge
=
F
.
logaddexp
(
xx
,
yy
)
np
.
testing
.
assert_almost_equal
(
out_np
,
out_mge
.
numpy
(),
decimal
=
6
)
def
test_qadd
():
inp_scale
=
0.5
outp_scale
=
0.2
...
...
imperative/python/test/unit/functional/test_loss.py
浏览文件 @
cebda6ff
...
...
@@ -79,3 +79,128 @@ def test_cross_entropy_reduction():
with
pytest
.
raises
(
ValueError
):
F
.
nn
.
cross_entropy
(
logits
,
label
,
reduction
=
"max"
)
def
ctc_nll_naive_npy
(
pred
,
pred_lengths
,
label
,
label_lengths
,
blank
=
0
,
reduction
=
"mean"
,
time_major
=
False
,
):
"""naive :func:`ctc_nll` using numpy arrays. Used for testing and helping
our user to understand how CTC works. Only ``LABEL_COMPACT`` mode is
supported."""
pred
=
np
.
asarray
(
pred
,
dtype
=
np
.
float32
)
pred_lengths
=
np
.
asarray
(
pred_lengths
,
dtype
=
np
.
int8
)
label
=
np
.
asarray
(
label
,
dtype
=
np
.
int32
)
label_lengths
=
np
.
asarray
(
label_lengths
,
dtype
=
np
.
int32
)
if
time_major
:
pred
=
np
.
transpose
(
pred
,
(
1
,
0
,
2
))
# pred in (N, T, P) format
batch_size
,
time_len
,
nr_class
=
pred
.
shape
assert
pred_lengths
.
shape
==
(
batch_size
,)
and
pred_lengths
.
max
()
<=
pred
.
shape
[
1
]
assert
label_lengths
.
shape
==
(
batch_size
,)
assert
label
.
shape
==
(
label_lengths
.
sum
(),)
and
label
.
max
()
<
nr_class
ret
=
np
.
empty
((
batch_size
,),
dtype
=
np
.
float32
)
label_start
=
0
for
i
in
range
(
batch_size
):
label_end
=
label_start
+
label_lengths
[
i
]
ret
[
i
]
=
_ctc_npy_single_seq
(
pred
[
i
][:
pred_lengths
[
i
]],
label
[
label_start
:
label_end
],
blank
)
label_start
=
label_end
if
reduction
==
"mean"
:
return
(
ret
/
label_lengths
).
mean
()
elif
reduction
==
"sum"
:
return
ret
.
sum
()
elif
reduction
==
"none"
:
return
ret
else
:
raise
ValueError
(
"{} is not a valid value for reduction"
.
format
(
reduction
))
def
_ctc_npy_single_seq
(
pred
,
label
,
blank
):
def
safelog
(
x
):
eps
=
np
.
finfo
(
x
.
dtype
).
tiny
return
np
.
log
(
np
.
maximum
(
x
,
eps
))
def
log_sum_exp
(
x
,
y
):
x
,
y
=
np
.
maximum
(
x
,
y
),
np
.
minimum
(
x
,
y
)
return
x
+
np
.
log1p
(
np
.
exp
(
y
-
x
))
assert
np
.
abs
(
pred
.
sum
(
axis
=
1
)
-
1
).
max
()
<=
1e-3
len_pred
,
alphabet_size
=
pred
.
shape
(
len_label
,)
=
label
.
shape
len_ex_label
=
len_label
*
2
+
1
ex_label
=
(
np
.
zeros
(
len_ex_label
)).
astype
(
np
.
int32
)
+
blank
ex_label
[
1
::
2
]
=
label
prob
=
np
.
zeros
(
len_ex_label
,
dtype
=
np
.
float32
)
prob
[
0
]
=
pred
[
0
][
ex_label
[
0
]]
prob
[
1
]
=
pred
[
0
][
ex_label
[
1
]]
prob
=
safelog
(
prob
)
# compute on log scale
ex_label_pmask
=
ex_label
[
2
:]
!=
ex_label
[:
-
2
]
for
t
in
range
(
1
,
len_pred
):
# enter loop: prob[i] = log(p(pred[:t+1], label[:i+1]))
new_prob
=
prob
.
copy
()
new_prob
[
1
:]
=
log_sum_exp
(
new_prob
[
1
:],
prob
[:
-
1
])
new_prob
[
2
:]
=
(
new_prob
[
2
:]
*
(
1
-
ex_label_pmask
)
+
log_sum_exp
(
new_prob
[
2
:],
prob
[:
-
2
])
*
ex_label_pmask
)
new_prob
+=
safelog
(
pred
[
t
,
ex_label
])
prob
=
new_prob
return
-
log_sum_exp
(
prob
[
-
1
],
prob
[
-
2
])
def
test_ctc_loss
():
def
test_func
(
T
,
C
,
N
):
input
=
np
.
random
.
randn
(
T
,
N
,
C
)
input
=
F
.
softmax
(
tensor
(
input
),
axis
=-
1
).
numpy
()
input_lengths
=
np
.
ones
(
N
,
dtype
=
np
.
int32
)
*
T
target_lengths
=
np
.
random
.
randint
(
low
=
1
,
high
=
T
+
1
,
size
=
(
N
,),
dtype
=
np
.
int32
)
target
=
np
.
random
.
randint
(
low
=
1
,
high
=
C
,
size
=
(
sum
(
target_lengths
)),
dtype
=
np
.
int32
)
input_mge
=
tensor
(
input
)
input_lengths_mge
=
tensor
(
input_lengths
)
target_mge
=
tensor
(
target
)
target_lengths_mge
=
tensor
(
target_lengths
)
blank
=
np
.
random
.
randint
(
C
)
for
method
in
[
"mean"
,
"sum"
,
"none"
]:
np_out
=
ctc_nll_naive_npy
(
input
,
input_lengths
,
target
,
target_lengths
,
blank
=
blank
,
reduction
=
method
,
time_major
=
True
,
)
mge_out
=
F
.
nn
.
ctc_loss
(
input_mge
,
input_lengths_mge
,
target_mge
,
target_lengths_mge
,
blank
=
blank
,
reduction
=
method
,
)
np
.
testing
.
assert_allclose
(
mge_out
.
numpy
(),
np_out
,
rtol
=
2e-6
)
cases
=
[[
1
,
2
,
1
],
[
100
,
50
,
200
],
[
100
,
5
,
1
]]
for
case
in
cases
:
test_func
(
*
case
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录