提交 01d568e5 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into ctc_evaluator_py

......@@ -32,7 +32,8 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
......@@ -41,32 +42,32 @@ class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
const int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
for (size_t i = 0; i < max_sequence_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
size_t start_pos = abs_offset_lod[level][j];
size_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
for (int64_t i = 0; i < max_sequence_length; ++i) {
for (int64_t j = 0; j < num_sequences; ++j) {
int64_t start_pos = abs_offset_lod[level][j];
int64_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
if (i < sequence_length) {
// i > 0 => sequence_length > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
for (int64_t k = 0; k < sequence_width; ++k) {
padding_data[(i * num_sequences + j) * sequence_width + k] =
seq_data[(start_pos + i) * sequence_width + k] * scale;
}
......@@ -93,7 +94,8 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
......@@ -102,31 +104,31 @@ class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
const int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
for (size_t i = 0; i < num_sequences; ++i) {
size_t start_pos = abs_offset_lod[level][i];
size_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
for (size_t j = 0; j < sequence_length; ++j) {
for (int64_t i = 0; i < num_sequences; ++i) {
int64_t start_pos = abs_offset_lod[level][i];
int64_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
for (int64_t j = 0; j < sequence_length; ++j) {
// sequence_width > j > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
for (int64_t k = 0; k < sequence_width; ++k) {
seq_data[(start_pos + j) * sequence_width + k] =
padding_data[(j * num_sequences + i) * sequence_width + k] *
scale;
......
......@@ -71,7 +71,8 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
......@@ -80,17 +81,17 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
......@@ -101,7 +102,7 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
return;
}
const size_t kBlockSize = 512;
const int64_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
......@@ -143,7 +144,8 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
PADDLE_ENFORCE_EQ(seq_dims[0],
static_cast<int64_t>(abs_offset_lod[level].back()),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
......@@ -152,17 +154,17 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
int64_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
const int64_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
const int64_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
......@@ -173,7 +175,7 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
return;
}
const size_t kBlockSize = 512;
const int64_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
......
......@@ -22,38 +22,14 @@ from ..param_attr import ParamAttr
from tensor import concat
__all__ = [
'fc',
'embedding',
'dynamic_lstm',
'gru_unit',
'linear_chain_crf',
'crf_decoding',
'cos_sim',
'cross_entropy',
'square_error_cost',
'accuracy',
'chunk_eval',
'sequence_conv',
'conv2d',
'sequence_pool',
'pool2d',
'batch_norm',
'beam_search_decode',
'conv2d_transpose',
'sequence_expand',
'lstm_unit',
'reduce_sum',
'reduce_mean',
'reduce_max',
'reduce_min',
'sequence_first_step',
'sequence_last_step',
'dropout',
'split',
'ctc_greedy_decoder',
'edit_distance_error',
'l2_normalize',
'matmul',
'fc', 'embedding', 'dynamic_lstm', 'gru_unit', 'linear_chain_crf',
'crf_decoding', 'cos_sim', 'cross_entropy', 'square_error_cost', 'accuracy',
'chunk_eval', 'sequence_conv', 'conv2d', 'sequence_pool', 'pool2d',
'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
'ctc_greedy_decoder', 'edit_distance_error', 'l2_normalize', 'matmul',
'warpctc'
]
......@@ -1903,3 +1879,56 @@ def ctc_greedy_decoder(input, blank, name=None):
attrs={"merge_repeated": True,
"blank": blank})
return ctc_out
def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
"""
An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation is
interated to the Warp-CTC library, to to normlize values for each row of the
input tensor.
Args:
input(Variable): (LodTensor, default: LoDTensor<float>),
the unscaled probabilities of variable-length sequences,
which is a 2-D Tensor with LoD information.
It's shape is [Lp, num_classes + 1], where Lp is the sum of all input
sequences' length and num_classes is the true number of classes.
(not including the blank label).
label(Variable): (LodTensor, default: LoDTensor<int>), the ground truth
of variable-length sequence, which is a 2-D Tensor with LoD
information. It is of the shape [Lg, 1], where Lg is th sum of
all labels' length.
blank: (int, default: 0), the blank label index of Connectionist
Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1).
norm_by_times: (bool, default: false), whether to normalize
the gradients by the number of time-step,which is also the
sequence's length. There is no need to normalize the gradients
if warpctc layer was follewed by a mean_op.
Returns:
Variable: The Connectionist Temporal Classification (CTC) loss,
which is a 2-D Tensor of the shape [batch_size, 1].
Examples:
.. code-block:: python
y = layers.data(name='y', shape=[11, 8], dtype='float32', lod_level=1)
y_predict = layers.data(name='y_predict', shape=[11, 1], dtype='float32')
cost = layers.warpctc(input=y_predict, label=y)
"""
helper = LayerHelper('warpctc', **kwargs)
loss_out = helper.create_tmp_variable(dtype=input.dtype)
grad_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='warpctc',
inputs={'Logits': [input],
'Label': [label]},
outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]},
attrs={'blank': blank,
'norm_by_times': norm_by_times})
return loss_out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册