未验证 提交 dbf232a9 编写于 作者: L Li Fuchen 提交者: GitHub

add functional ctc_loss and CTCLoss class. (#26384)

* add functional ctc_loss and CTCLoss class.

* modified docstring of ctc_loss and CTCLoss
上级 b6eb37f5
......@@ -21,25 +21,25 @@ from op_test import OpTest
from test_softmax_op import stable_softmax
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
import paddle.nn.functional as F
CUDA_BLOCK_SIZE = 512
class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
norm_by_times):
def __init__(self, softmax, softmax_lod, labels, labels_lod, num_classes,
batch_size, blank, norm_by_times):
self.softmax = softmax
self.softmax_lod = softmax_lod
assert labels.shape[1] == 1
self.labels = labels
self.labels_lod = labels_lod
self.blank = blank
self.norm_by_times = norm_by_times
self.level = 0
self.num_classes = softmax.shape[1]
self.batch_size = len(softmax_lod[self.level])
assert self.batch_size == len(labels_lod[self.level])
self.num_classes = num_classes
self.batch_size = batch_size
self.loss = np.zeros([self.batch_size, 1], dtype="float32")
self.gradient = np.zeros(self.softmax.shape, dtype="float32")
......@@ -163,17 +163,25 @@ class CTCForward(object):
softmax_offset = 0
labels_offset = 0
for i in range(self.batch_size):
softmax_start_i = softmax_offset
softmax_end_i = softmax_offset + self.softmax_lod[self.level][i]
labels_start_i = labels_offset
labels_end_i = labels_offset + self.labels_lod[self.level][i]
softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
softmax_offset += self.softmax_lod[self.level][i]
labels_offset += self.labels_lod[self.level][i]
if self.labels.shape[1] == 1:
softmax_start_i = softmax_offset
softmax_end_i = softmax_offset + self.softmax_lod[self.level][i]
labels_start_i = labels_offset
labels_end_i = labels_offset + self.labels_lod[self.level][i]
softmax_a_sequence = self.softmax[softmax_start_i:
softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
softmax_offset += self.softmax_lod[self.level][i]
labels_offset += self.labels_lod[self.level][i]
else:
softmax_a_sequence = self.softmax[:self.softmax_lod[i], i, :]
labels_a_sequence = self.labels[:self.labels_lod[i], :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
return self.loss
......@@ -201,7 +209,8 @@ class TestWarpCTCOp(OpTest):
dtype="int32")
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
self.num_classes, self.batch_size, self.blank,
self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
......@@ -223,7 +232,7 @@ class TestWarpCTCOp(OpTest):
}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
......@@ -237,7 +246,7 @@ class TestWarpCTCOpCase1(TestWarpCTCOp):
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.blank = 0
self.blank = self.num_classes - 1
self.norm_by_times = False
......@@ -267,7 +276,8 @@ class TestWarpCTCOpWithPadding(OpTest):
dtype="int32")
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
self.num_classes, self.batch_size, self.blank,
self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
......@@ -317,7 +327,7 @@ class TestWarpCTCOpWithPadding(OpTest):
}
def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
......@@ -333,7 +343,7 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
self.labels_lod = [[3, 1, 4, 4]]
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = 0
self.blank = self.num_classes - 1
self.norm_by_times = False
......@@ -389,5 +399,97 @@ class TestWarpCTCOpError(unittest.TestCase):
self.assertRaises(TypeError, test_label_len_Variable)
class TestCTCLossAPICase(unittest.TestCase):
def test_functinal_api(self):
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = self.num_classes - 1
self.norm_by_times = False
logits = np.random.uniform(0.1, 1.0, [
max(self.logits_length), self.batch_size, self.num_classes
]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, -1, logits)
# labels should not be blank
labels = np.random.randint(
0,
self.num_classes - 1, [self.batch_size, max(self.labels_length)],
dtype="int32")
ctc = CTCForward(softmax, self.logits_length, labels,
self.labels_length, self.num_classes, self.batch_size,
self.blank, self.norm_by_times)
loss_np = ctc.forward()
paddle.disable_static()
softmax = paddle.to_variable(logits)
labels = paddle.to_variable(labels)
logits_length = paddle.to_variable(self.logits_length)
labels_length = paddle.to_variable(self.labels_length)
loss_pd_mean = F.ctc_loss(
softmax,
labels,
logits_length,
labels_length,
blank=self.blank,
reduction='mean')
loss_pd_mean = loss_pd_mean.numpy()
loss_pd_sum = F.ctc_loss(
softmax,
labels,
logits_length,
labels_length,
blank=self.blank,
reduction='sum')
loss_pd_sum = loss_pd_sum.numpy()
paddle.enable_static()
loss_np = np.squeeze(loss_np, axis=-1)
loss_np_mean = (loss_np / labels_length.numpy()).mean()
loss_np_sum = loss_np.sum()
self.assertTrue(np.allclose(loss_pd_mean, loss_np_mean, atol=1))
self.assertTrue(np.allclose(loss_pd_sum, loss_np_sum, atol=1))
def test_class_api(self):
self.batch_size = 3
self.num_classes = 15
self.logits_length = np.array([3, 3, 3], dtype=np.int64)
self.labels_length = np.array([0, 1, 2], dtype=np.int64)
self.blank = 0
self.norm_by_times = False
logits = np.random.uniform(0.1, 1.0, [
max(self.logits_length), self.batch_size, self.num_classes
]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, -1, logits)
# labels should not be blank
labels = np.random.randint(
1,
self.num_classes, [self.batch_size, max(self.labels_length)],
dtype="int32")
ctc = CTCForward(softmax, self.logits_length, labels,
self.labels_length, self.num_classes, self.batch_size,
self.blank, self.norm_by_times)
loss_np = ctc.forward()
paddle.disable_static()
softmax = paddle.to_variable(logits)
labels = paddle.to_variable(labels)
logits_length = paddle.to_variable(self.logits_length)
labels_length = paddle.to_variable(self.labels_length)
loss_pd = paddle.nn.CTCLoss(self.blank, 'none')(
softmax, labels, logits_length, labels_length)
loss_pd = loss_pd.numpy()
paddle.enable_static()
loss_np = np.squeeze(loss_np, axis=-1)
self.assertTrue(np.allclose(loss_pd, loss_np, atol=1))
if __name__ == "__main__":
unittest.main()
......@@ -111,6 +111,7 @@ from .layer.loss import NLLLoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.loss import KLDivLoss #DEFINE_ALIAS
from .layer.loss import MarginRankingLoss #DEFINE_ALIAS
from .layer.loss import CTCLoss #DEFINE_ALIAS
from .layer.loss import SmoothL1Loss #DEFINE_ALIAS
from .layer.norm import BatchNorm #DEFINE_ALIAS
from .layer.norm import SyncBatchNorm #DEFINE_ALIAS
......
......@@ -25,6 +25,8 @@ from . import extension
__all__ += extension.__all__
from . import common
__all__ += common.__all__
from . import loss
__all__ += loss.__all__
from .activation import brelu #DEFINE_ALIAS
from .activation import elu #DEFINE_ALIAS
from .activation import erf #DEFINE_ALIAS
......@@ -147,6 +149,7 @@ from .loss import softmax_with_cross_entropy #DEFINE_ALIAS
from .loss import square_error_cost #DEFINE_ALIAS
from .loss import ssd_loss #DEFINE_ALIAS
from .loss import teacher_student_sigmoid_loss #DEFINE_ALIAS
from .loss import ctc_loss #DEFINE_ALIAS
# from .norm import batch_norm #DEFINE_ALIAS
# from .norm import data_norm #DEFINE_ALIAS
# from .norm import group_norm #DEFINE_ALIAS
......
......@@ -13,6 +13,9 @@
# limitations under the License.
import paddle
from ...fluid.layer_helper import LayerHelper
from ...fluid.data_feeder import check_variable_and_dtype
import paddle.fluid as fluid
# TODO: define loss functions of neural network
import numpy as np
......@@ -70,7 +73,8 @@ __all__ = [
'softmax_with_cross_entropy',
'square_error_cost',
'ssd_loss',
'teacher_student_sigmoid_loss'
'teacher_student_sigmoid_loss',
'ctc_loss',
]
......@@ -791,6 +795,102 @@ def mse_loss(input, label, reduction='mean', name=None):
name=name)
def ctc_loss(log_probs,
labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean'):
"""
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters:
log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32.
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Examples:
.. code-block:: python
# declarative mode
import paddle.nn.functional as F
import numpy as np
import paddle
# length of the longest logit sequence
max_seq_length = 4
#length of the longest label sequence
max_label_length = 3
# number of logit sequences
batch_size = 2
# class num
class_num = 3
np.random.seed(1)
log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
[3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],
[[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
[5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],
[[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
[6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],
[[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
[9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],
[[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
[3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
labels = np.array([[1, 2, 2],
[1, 2, 2]]).astype("int32")
input_lengths = np.array([5, 5]).astype("int64")
label_lengths = np.array([3, 3]).astype("int64")
paddle.disable_static()
log_probs = paddle.to_variable(log_probs)
labels = paddle.to_variable(labels)
input_lengths = paddle.to_variable(input_lengths)
label_lengths = paddle.to_variable(label_lengths)
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='none')
print(loss.numpy()) #[3.9179852 2.9076521]
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean')
print(loss.numpy()) #[1.1376063]
"""
loss_out = fluid.layers.warpctc(log_probs, labels, blank, False,
input_lengths, label_lengths)
loss_out = fluid.layers.squeeze(loss_out, [-1])
assert reduction in ['mean', 'sum', 'none']
if reduction == 'mean':
loss_out = paddle.mean(loss_out / label_lengths)
elif reduction == 'sum':
loss_out = paddle.sum(loss_out)
return loss_out
def cross_entropy(input,
label,
weight=None,
......
......@@ -76,6 +76,7 @@ from .loss import NLLLoss #DEFINE_ALIAS
from .loss import BCELoss #DEFINE_ALIAS
from .loss import KLDivLoss #DEFINE_ALIAS
from .loss import MarginRankingLoss #DEFINE_ALIAS
from .loss import CTCLoss #DEFINE_ALIAS
from .loss import SmoothL1Loss #DEFINE_ALIAS
from .norm import BatchNorm #DEFINE_ALIAS
from .norm import SyncBatchNorm #DEFINE_ALIAS
......
......@@ -28,6 +28,7 @@ __all__ = [
'BCELoss',
'KLDivLoss',
'MarginRankingLoss',
'CTCLoss',
'SmoothL1Loss',
]
......@@ -672,6 +673,94 @@ class MarginRankingLoss(fluid.dygraph.Layer):
return out
class CTCLoss(fluid.dygraph.Layer):
"""
:alias_main: paddle.nn.CTCLoss
:alias: paddle.nn.CTCLoss, paddle.nn.layer.CTCLoss, paddle.nn.layer.loss.CTCLoss
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters:
blank (int, optional): The blank label index of Connectionist Temporal Classification (CTC) loss, which is in the half-opened interval [0, num_classes + 1). The data type must be int32. Default is 0.
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
Shape:
log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32.
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Examples:
.. code-block:: python
# declarative mode
import numpy as np
import paddle
# length of the longest logit sequence
max_seq_length = 4
#length of the longest label sequence
max_label_length = 3
# number of logit sequences
batch_size = 2
# class num
class_num = 3
np.random.seed(1)
log_probs = np.array([[[4.17021990e-01, 7.20324516e-01, 1.14374816e-04],
[3.02332580e-01, 1.46755889e-01, 9.23385918e-02]],
[[1.86260208e-01, 3.45560730e-01, 3.96767467e-01],
[5.38816750e-01, 4.19194520e-01, 6.85219526e-01]],
[[2.04452246e-01, 8.78117442e-01, 2.73875929e-02],
[6.70467496e-01, 4.17304814e-01, 5.58689833e-01]],
[[1.40386939e-01, 1.98101491e-01, 8.00744593e-01],
[9.68261600e-01, 3.13424170e-01, 6.92322612e-01]],
[[8.76389146e-01, 8.94606650e-01, 8.50442126e-02],
[3.90547849e-02, 1.69830427e-01, 8.78142476e-01]]]).astype("float32")
labels = np.array([[1, 2, 2],
[1, 2, 2]]).astype("int32")
input_lengths = np.array([5, 5]).astype("int64")
label_lengths = np.array([3, 3]).astype("int64")
paddle.disable_static()
log_probs = paddle.to_variable(log_probs)
labels = paddle.to_variable(labels)
input_lengths = paddle.to_variable(input_lengths)
label_lengths = paddle.to_variable(label_lengths)
loss = paddle.nn.CTCLoss(blank=0, reduction='none')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[3.9179852 2.9076521]
loss = paddle.nn.CTCLoss(blank=0, reduction='mean')(log_probs, labels,
input_lengths,
label_lengths)
print(loss.numpy()) #[1.1376063]
"""
def __init__(self, blank=0, reduction='mean'):
super(CTCLoss, self).__init__()
self.blank = blank
self.reduction = reduction
def forward(self, log_probs, labels, input_lengths, label_lengths):
return paddle.nn.functional.ctc_loss(log_probs, labels, input_lengths,
label_lengths, self.blank,
self.reduction)
class SmoothL1Loss(fluid.dygraph.Layer):
"""
This operator calculates smooth_l1_loss. Creates a criterion that uses a squared
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册