提交 cebda6ff 编写于 作者: M Megvii Engine Team

feat(mge/imperative): add ctc loss

GitOrigin-RevId: e29854a98e9d372c2802b073a03c0fc6f29f25ac
上级 f5cb21ed
......@@ -61,6 +61,7 @@ def _elwise(*args, mode):
_ElwMod.H_SWISH,
_ElwMod.SIGMOID,
_ElwMod.SIN,
_ElwMod.LOG_SUM_EXP,
) and (
amp._enabled or np.all([np.issubdtype(arg.dtype, np.integer) for arg in args])
):
......
......@@ -48,6 +48,7 @@ __all__ = [
"logical_not",
"logical_or",
"logical_xor",
"logaddexp",
"maximum",
"minimum",
"mod",
......@@ -406,6 +407,12 @@ def logical_xor(x, y):
return _elwise(x, y, mode=Elemwise.Mode.XOR)
def logaddexp(x: Tensor, y: Tensor) -> Tensor:
r"""Element-wise `numerically stable log(exp(x) + exp(y)`
"""
return _elwise(x, y, mode=Elemwise.Mode.LOG_SUM_EXP)
# comparison functions
......
......@@ -12,9 +12,9 @@ import numpy as np
from ..core.tensor.array_method import _reduce
from ..tensor import Tensor
from .elemwise import abs, log
from .elemwise import abs, equal, log, logaddexp, maximum
from .nn import indexing_one_hot, logsigmoid, logsumexp, relu
from .tensor import where
from .tensor import broadcast_to, cumsum, linspace, ones, where, zeros
__all__ = [
"l1_loss",
......@@ -22,6 +22,7 @@ __all__ = [
"cross_entropy",
"binary_cross_entropy",
"hinge_loss",
"ctc_loss",
]
......@@ -316,3 +317,164 @@ def hinge_loss(
return loss.sum(axis=1)
else:
return (loss ** 2).sum(axis=1)
def _gen_repeat_idx(inp: Tensor):
idx = cumsum(inp, axis=0)
ret = zeros(inp.sum(), dtype="int32")
ret[idx[:-1]] = 1
return cumsum(ret, axis=0)
def _gen_tile_idx(inp: Tensor):
idx = cumsum(inp, axis=0)
ret = ones(inp.sum(), dtype="int32")
ret[idx[:-1]] = -(inp - 1)[:-1]
return cumsum(ret, axis=0) - 1
def _expand_label(label: Tensor, label_lengths: Tensor, blank: int) -> Tensor:
N = label_lengths.shape[0]
if len(label.shape) == 1:
L = label_lengths.max()
unpack_label = zeros((N, L), dtype="int32") + blank
idx_0 = _gen_repeat_idx(label_lengths)
idx_1 = _gen_tile_idx(label_lengths)
unpack_label[idx_0, idx_1] = label
label = unpack_label
L = label.shape[1]
ex_label = zeros((N, L * 2 + 1), dtype="int32") + blank
ex_label[:, 1::2] = label
return ex_label
def _safelog(x: Tensor) -> Tensor:
eps = np.finfo(x.dtype).tiny
return log(maximum(x, eps))
def ctc_loss(
pred: Tensor,
pred_lengths: Tensor,
label: Tensor,
label_lengths: Tensor,
blank: int = 0,
reduction: str = "mean",
) -> Tensor:
r"""The Connectionist Temporal Classification loss.
Args:
pred: The probabilities of the output, shape is (T, N, C) ,
where T=input length, N=batch size, and C=number of classes (including blank).
pred_lengths: number of time steps for each sequence in ``pred``, shape is (N, )
label: groundtruth labels, containing the indices of groundtruth
symbols for each sequence at each output time step, and the blank
symbol should not be included. shape is (N, S) or (sum(label_lengths)).
label_lengths: number of time steps for each sequence in the groundtruth, shape is (N, )
blank: the blank symbol number, default 0
reduction: the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'
Returns:
loss value.
Examples:
.. testcode::
from megengine import tensor
import megengine.functional as F
pred = tensor([[[0.0614, 0.9386],[0.8812, 0.1188]],[[0.699, 0.301 ],[0.2572, 0.7428]]])
pred_length = tensor([2,2])
label = tensor([1,1])
label_lengths = tensor([1,1])
loss = F.nn.ctc_loss(pred, pred_length, label, label_lengths)
print(loss.numpy())
Outputs:
.. testoutput::
0.1504417
"""
T, N, C = pred.shape
assert (
pred_lengths.size == N
), "pred_lengths must be equal to batch_size {}, but got {}".format(
N, pred_lengths.size
)
assert (
label_lengths.size == N
), "label_lengths must be euqal to batch_size {}, but got {}".format(
N, label_lengths.size
)
assert (
blank >= 0 and blank < C
), "blank must be in label range [0, {}), but got {}".format(C, blank)
assert (
pred_lengths.min() > 0 and pred_lengths.max() <= T
), "pred_lengths must be in range ({}, {}], bug got min {}, max {}".format(
0, T, pred_lengths.min(), pred_lengths.max()
)
if label.ndim == 1: # concatenated label
assert label_lengths.min() > 0, "label lengths muse be positive"
assert (
label.size == label_lengths.sum()
), "label size must be equal to sum(label_lengths)"
else:
N, S = label.shape
assert (
label_lengths.min() > 0 and label_lengths.max() <= S
), "label_lengths must be in range ({}, {}], bug got min {}, max {}".format(
0, S, label_lengths.min(), label_lengths.max()
)
label = _expand_label(label, label_lengths, blank)
label_mask = label[:, 2:] != label[:, :-2]
L = label.shape[1]
pred = pred.transpose(1, 0, 2) # (T, N, C) -> (N, T, C)
batch_idx = linspace(0, N - 1, N).astype("int32").reshape(-1)
batch_idx_NL = broadcast_to(batch_idx.reshape(N, 1), (N, L)).reshape(-1)
match_pred = pred[batch_idx_NL, :, label.reshape(-1)].reshape(
N, L, -1
) # (N, T, C) -> (N, L, T)
log_alpha = zeros((N, L), dtype="float32")
log_alpha[:, :2] = match_pred[:, :2, 0]
log_alpha = _safelog(log_alpha)
ret = -logaddexp(
log_alpha[batch_idx, label_lengths * 2],
log_alpha[batch_idx, label_lengths * 2 - 1],
) * equal(pred_lengths - 1, 0)
for t in range(1, T):
la2 = log_alpha[:, :-2]
log_alpha[:, 1:] = logaddexp(log_alpha[:, 1:], log_alpha[:, :-1])
log_alpha[:, 2:] = (
log_alpha[:, 2:] * (1 - label_mask)
+ logaddexp(log_alpha[:, 2:], la2) * label_mask
)
log_alpha += _safelog(match_pred[:, :, t])
ret_t = -logaddexp(
log_alpha[batch_idx, label_lengths * 2],
log_alpha[batch_idx, label_lengths * 2 - 1],
)
ret += ret_t * equal(pred_lengths - 1, t)
if reduction == "mean":
return (ret / label_lengths).mean()
elif reduction == "sum":
return ret.sum()
elif reduction == "none":
return ret
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
......@@ -170,6 +170,16 @@ def test_logical_oprs():
np.testing.assert_equal(x ^ y, F.logical_xor(xx, yy).numpy())
def test_logaddexp():
x = np.random.randn(2, 100)
y = np.random.randn(2, 100)
xx = tensor(x)
yy = tensor(y)
out_np = np.log(np.exp(x) + np.exp(y))
out_mge = F.logaddexp(xx, yy)
np.testing.assert_almost_equal(out_np, out_mge.numpy(), decimal=6)
def test_qadd():
inp_scale = 0.5
outp_scale = 0.2
......
......@@ -79,3 +79,128 @@ def test_cross_entropy_reduction():
with pytest.raises(ValueError):
F.nn.cross_entropy(logits, label, reduction="max")
def ctc_nll_naive_npy(
pred,
pred_lengths,
label,
label_lengths,
blank=0,
reduction="mean",
time_major=False,
):
"""naive :func:`ctc_nll` using numpy arrays. Used for testing and helping
our user to understand how CTC works. Only ``LABEL_COMPACT`` mode is
supported."""
pred = np.asarray(pred, dtype=np.float32)
pred_lengths = np.asarray(pred_lengths, dtype=np.int8)
label = np.asarray(label, dtype=np.int32)
label_lengths = np.asarray(label_lengths, dtype=np.int32)
if time_major:
pred = np.transpose(pred, (1, 0, 2))
# pred in (N, T, P) format
batch_size, time_len, nr_class = pred.shape
assert pred_lengths.shape == (batch_size,) and pred_lengths.max() <= pred.shape[1]
assert label_lengths.shape == (batch_size,)
assert label.shape == (label_lengths.sum(),) and label.max() < nr_class
ret = np.empty((batch_size,), dtype=np.float32)
label_start = 0
for i in range(batch_size):
label_end = label_start + label_lengths[i]
ret[i] = _ctc_npy_single_seq(
pred[i][: pred_lengths[i]], label[label_start:label_end], blank
)
label_start = label_end
if reduction == "mean":
return (ret / label_lengths).mean()
elif reduction == "sum":
return ret.sum()
elif reduction == "none":
return ret
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
def _ctc_npy_single_seq(pred, label, blank):
def safelog(x):
eps = np.finfo(x.dtype).tiny
return np.log(np.maximum(x, eps))
def log_sum_exp(x, y):
x, y = np.maximum(x, y), np.minimum(x, y)
return x + np.log1p(np.exp(y - x))
assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3
len_pred, alphabet_size = pred.shape
(len_label,) = label.shape
len_ex_label = len_label * 2 + 1
ex_label = (np.zeros(len_ex_label)).astype(np.int32) + blank
ex_label[1::2] = label
prob = np.zeros(len_ex_label, dtype=np.float32)
prob[0] = pred[0][ex_label[0]]
prob[1] = pred[0][ex_label[1]]
prob = safelog(prob) # compute on log scale
ex_label_pmask = ex_label[2:] != ex_label[:-2]
for t in range(1, len_pred):
# enter loop: prob[i] = log(p(pred[:t+1], label[:i+1]))
new_prob = prob.copy()
new_prob[1:] = log_sum_exp(new_prob[1:], prob[:-1])
new_prob[2:] = (
new_prob[2:] * (1 - ex_label_pmask)
+ log_sum_exp(new_prob[2:], prob[:-2]) * ex_label_pmask
)
new_prob += safelog(pred[t, ex_label])
prob = new_prob
return -log_sum_exp(prob[-1], prob[-2])
def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(tensor(input), axis=-1).numpy()
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
)
input_mge = tensor(input)
input_lengths_mge = tensor(input_lengths)
target_mge = tensor(target)
target_lengths_mge = tensor(target_lengths)
blank = np.random.randint(C)
for method in ["mean", "sum", "none"]:
np_out = ctc_nll_naive_npy(
input,
input_lengths,
target,
target_lengths,
blank=blank,
reduction=method,
time_major=True,
)
mge_out = F.nn.ctc_loss(
input_mge,
input_lengths_mge,
target_mge,
target_lengths_mge,
blank=blank,
reduction=method,
)
np.testing.assert_allclose(mge_out.numpy(), np_out, rtol=2e-6)
cases = [[1, 2, 1], [100, 50, 200], [100, 5, 1]]
for case in cases:
test_func(*case)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册