提交 482ce818 编写于 作者: V vincentXiyu 提交者: whs

Support Tensor input with padding for warpctc op (#19322)

* support tensor input with padding for warpctc op

* merge with develop

* test=develop

* modified python API examples test=develop

* nn.py is modified for code coverage test=develop

* update documents info about warpctc op in API.spec test=develop

* add test_warpctc_with_padding in test_layers test=develop

* add warning log for cuda_version back to warpctc_op.cc

* modify API.spec for warpctc op test=develop

* modify API.spec

* update warpctc test to new CompiledProgram API test=develop

* modify code examples for warpctc op test=develop

* modify API.spec for warpctc op test=develop

* modify API.spec for warpctc op test=develop
上级 bfb6ac81
......@@ -159,7 +159,7 @@ paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized',
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', 'c1df110ea65998984f564c5c10abc54a'))
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', '3720b4a386585094435993deb028b592'))
paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'e50940f3ce5a08cc477b72f517491bf3'))
paddle.fluid.layers.warpctc (ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False)), ('document', '4aa9df890b47eb67d5442f04aaf9eeec'))
paddle.fluid.layers.warpctc (ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn', 'input_length', 'label_length'], varargs=None, keywords=None, defaults=(0, False, False, None, None)), ('document', 'ba27f25141adf24706536d179fabdf17'))
paddle.fluid.layers.sequence_reshape (ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None), ('document', 'f568714a876425004aca4ea2d4a27701'))
paddle.fluid.layers.transpose (ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '8e72db173d4c082e27cb11f31d8c9bfa'))
paddle.fluid.layers.im2sequence (ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)), ('document', '33134416fc27dd65a767e5f15116ee16'))
......
......@@ -38,12 +38,19 @@ class WarpCTCOp : public framework::OperatorWithKernel {
"Output(Loss) of WarpCTCOp should not be null.");
auto logits_dims = ctx->GetInputDim("Logits");
int sequence_width =
static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
int blank = ctx->Attrs().Get<int>("blank");
int sequence_width = 0;
if (ctx->HasInput("LogitsLength")) {
sequence_width = logits_dims[2];
} else {
sequence_width =
static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
}
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).",
sequence_width);
// TODO(liuyiqun): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
}
......@@ -76,17 +83,32 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
"Tensor with LoD information. It's shape is "
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label).");
"(2-D LoDTensor<float>) or (3-D Tensor<float>), the "
"unscaled probabilities of variable-length sequences."
"When is a 2-D Tensor with LoD information, "
"it's shape is [Lp, num_classes + 1], "
"where Lp is the sum of all input sequences' length "
"and num_classes is the true number of classes "
"(not including the blank label)."
"When it is 3-D Tensor, it's shape is "
"[max_logit_length, batch_size, num_classes + 1], "
"where max_logit_length is the length of the longest "
"logit sequence.");
AddInput("Label",
"(LodTensor, default: LoDTensor<int>), the ground truth "
"of variable-length sequence, which is a 2-D Tensor with LoD "
"information. It is of the shape [Lg, 1], where Lg is th sum of "
"all labels' length.");
"(2-D LoDTensor<int>) or (2-D Tensor<int>), the "
"ground truth of variable-length sequence. "
"When it is a 2-D Tensor with LoD information, "
"it is of the shape [Lg, 1], where Lg is th sum of "
"all labels' length."
"When it is a 2-D Tensor<int>, it's shape is also [Lg, 1].");
AddInput("LogitsLength",
"1-D Tensor<int64_t>. "
"Input sequence length for Logits when Logits is a 3-D tensor.")
.AsDispensable();
AddInput("LabelLength",
"1-D Tensor<int64_t>. "
"Target sequence length for Label when Label is a 2-D tensor.")
.AsDispensable();
AddOutput("WarpCTCGrad",
"(Tensor, default: Tensor<float>), a temporary "
"output Tensor to store the gradients of warp-ctc, which is "
......@@ -143,6 +165,8 @@ class WarpCTCGradOpDescMaker : public framework::SingleGradOpDescMaker {
op->SetInput("Logits", Input("Logits"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetInput("LogitsLength", Input("LogitsLength"));
op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
op->SetAttrMap(Attrs());
......
......@@ -128,63 +128,93 @@ class WarpCTCKernel : public framework::OpKernel<T> {
auto* warpctc_grad = ctx.Output<Tensor>("WarpCTCGrad");
auto* loss = ctx.Output<Tensor>("Loss");
const size_t level = 0;
auto logits_lod = framework::ToAbsOffset(logits->lod());
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(logits_dims[0],
static_cast<int64_t>(logits_lod[level].back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
auto label_lod = framework::ToAbsOffset(label->lod());
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
const size_t num_sequences = logits_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
const size_t sequence_width = logits->numel() / logits_dims[0];
size_t num_sequences, sequence_width, max_sequence_length;
framework::Vector<size_t> logits_lod;
framework::Vector<size_t> label_lod;
if (ctx.HasInput("LogitsLength") && ctx.HasInput("LabelLength")) {
num_sequences = logits->dims()[1];
sequence_width = logits->dims()[2];
max_sequence_length = logits->dims()[0];
auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
auto* labels_length = ctx.Input<framework::Tensor>("LabelLength");
framework::Tensor logits_length_cpu;
framework::Tensor labels_length_cpu;
framework::TensorCopy(*logits_length, platform::CPUPlace(),
&logits_length_cpu);
framework::TensorCopy(*labels_length, platform::CPUPlace(),
&labels_length_cpu);
logits_lod.push_back(0);
label_lod.push_back(0);
for (auto i = 0; i < num_sequences; i++) {
logits_lod.push_back(logits_lod[i] +
logits_length_cpu.data<int64_t>()[i]);
label_lod.push_back(label_lod[i] +
labels_length_cpu.data<int64_t>()[i]);
}
} else {
logits_lod = framework::ToAbsOffset(logits->lod())[0];
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(
logits_dims[0], static_cast<int64_t>(logits_lod.back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
label_lod = framework::ToAbsOffset(label->lod())[0];
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
num_sequences = logits_lod.size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod.size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
sequence_width = logits->numel() / logits_dims[0];
max_sequence_length = math::MaximumSequenceLength(logits_lod);
}
auto loss_dims =
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
LoDTensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod[level]);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
if (ctx.HasInput("LogitsLength")) {
TensorCopySync(*logits, ctx.GetPlace(), &warpctc_logits);
} else {
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
LoDTensor cpu_pad_value;
T* pad_value_data =
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
*pad_value_data = static_cast<T>(0);
LoDTensor pad_value;
if (platform::is_cpu_place(ctx.GetPlace())) {
pad_value = cpu_pad_value;
} else {
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits,
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
math::kLengthBatchWidth);
}
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits, &warpctc_logits,
pad_value, -1, 0, false /* norm_by_times */, math::kLengthBatchWidth);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
warpctc_logits_lengths[i] =
logits_lod[level][i + 1] - logits_lod[level][i];
warpctc_label_lengths[i] = label_lod[i + 1] - label_lod[i];
warpctc_logits_lengths[i] = logits_lod[i + 1] - logits_lod[i];
}
// warpctc computes loss and gradient in one call, gradient data also stored
......@@ -199,6 +229,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
// warpctc accesses labels in CPU memory
Tensor warpctc_label;
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
Tensor warpctc_loss;
......@@ -227,14 +258,53 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
logits_grad->mutable_data<T>(ctx.GetPlace());
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss_grad_data,
logits_grad);
if (ctx.HasInput("LogitsLength")) {
size_t max_seq_length = warpctc_grad->dims()[0];
size_t num_sequences = warpctc_grad->dims()[1];
size_t seq_width = warpctc_grad->dims()[2];
LoDTensor logits_grad_with_lod;
auto logits_grad_dims =
framework::make_ddim({static_cast<int64_t>(max_seq_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(seq_width)});
T* logits_grad_cpu_data = logits_grad_with_lod.mutable_data<T>(
logits_grad_dims, platform::CPUPlace());
TensorCopySync(*warpctc_grad, platform::CPUPlace(),
&logits_grad_with_lod);
Tensor loss_grad_cpu;
loss_grad_cpu.mutable_data<T>(loss_grad->dims(), platform::CPUPlace());
TensorCopySync(*loss_grad, platform::CPUPlace(), &loss_grad_cpu);
LoDTensor scaled_logits;
T* scaled_logits_data =
scaled_logits.mutable_data<T>(logits_grad_dims, platform::CPUPlace());
const T* loss_grad_data = loss_grad_cpu.data<T>();
for (size_t i = 0; i < max_seq_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
for (size_t k = 0; k < seq_width; ++k) {
size_t idx = i * (num_sequences * seq_width) + j * seq_width + k;
scaled_logits_data[idx] =
logits_grad_cpu_data[idx] * loss_grad_data[j];
}
}
}
TensorCopySync(scaled_logits, ctx.GetPlace(), logits_grad);
} else {
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *warpctc_grad,
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
const T* loss_grad_data = loss_grad->data<T>();
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss_grad_data,
logits_grad);
}
}
};
......
......@@ -5644,7 +5644,13 @@ def ctc_greedy_decoder(input, blank, name=None):
return ctc_out
def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
def warpctc(input,
label,
blank=0,
norm_by_times=False,
use_cudnn=False,
input_length=None,
label_length=None):
"""
An operator integrating the open source Warp-CTC library
(https://github.com/baidu-research/warp-ctc)
......@@ -5655,13 +5661,18 @@ def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
Args:
input (Variable): The unscaled probabilities of variable-length sequences,
which is a 2-D Tensor with LoD information.
It's shape is [Lp, num_classes + 1], where Lp is the sum of all input
which is a 2-D Tensor with LoD information, or a 3-D Tensor without Lod
information. When it is a 2-D LodTensor, it's shape is
[Lp, num_classes + 1], where Lp is the sum of all input
sequences' length and num_classes is the true number of classes.
(not including the blank label).
(not including the blank label). When it is a 3-D Tensor, it's shape
is [max_logit_length, batch_size, num_classes + 1],
where max_logit_length is the length of the longest
input logit sequence.
label (Variable): The ground truth of variable-length sequence,
which is a 2-D Tensor with LoD information. It is of the shape [Lg, 1],
where Lg is th sum of all labels' length.
which is a 2-D Tensor with LoD information or a 2-D Tensor without
LoD information. When it is a 2-D LoDTensor or 2-D Tensor,
it is of the shape [Lg, 1], where Lg is th sum of all labels' length.
blank (int, default 0): The blank label index of Connectionist
Temporal Classification (CTC) loss, which is in the
half-opened interval [0, num_classes + 1).
......@@ -5670,30 +5681,60 @@ def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
There is no need to normalize the gradients if warpctc layer was
follewed by a mean_op.
use_cudnn (bool, default false): Whether to use cudnn.
input_length(Variable): The length for each input sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
label_length(Variable): The length for each label sequence if it is
of Tensor type, it should have shape `[batch_size]` and dtype int64.
Returns:
Variable: The Connectionist Temporal Classification (CTC) loss,
which is a 2-D Tensor of the shape [batch_size, 1].
Examples:
.. code-block:: python
# using LoDTensor
import paddle.fluid as fluid
label = fluid.layers.data(name='label', shape=[11, 8],
import numpy as np
label = fluid.layers.data(name='label', shape=[12, 1],
dtype='float32', lod_level=1)
predict = fluid.layers.data(name='predict', shape=[11, 1],
dtype='float32')
predict = fluid.layers.data(name='predict',
shape=[11, 8],
dtype='float32',lod_level=1)
cost = fluid.layers.warpctc(input=predict, label=label)
# using Tensor
input_length = fluid.layers.data(name='logits_length', shape=[11],
dtype='int64')
label_length = fluid.layers.data(name='labels_length', shape=[12],
dtype='int64')
target = fluid.layers.data(name='target', shape=[12, 1],
dtype='int32')
# length of the longest logit sequence
max_seq_length = 4
# number of logit sequences
batch_size = 4
output = fluid.layers.data(name='output',
shape=[max_seq_length, batch_size, 8],
dtype='float32')
loss = fluid.layers.warpctc(input=output,label=target,
input_length=input_length,
label_length=label_length)
"""
helper = LayerHelper('warpctc', **locals())
this_inputs = {'Logits': [input], 'Label': [label]}
if input_length and label_length:
this_inputs['LogitsLength'] = [input_length]
this_inputs['LabelLength'] = [label_length]
loss_out = helper.create_variable_for_type_inference(dtype=input.dtype)
grad_out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='warpctc',
inputs={'Logits': [input],
'Label': [label]},
inputs=this_inputs,
outputs={'WarpCTCGrad': [grad_out],
'Loss': [loss_out]},
attrs={
......
......@@ -2251,6 +2251,23 @@ class TestBook(LayerTest):
nms_eta=1.)
return (nmsed_outs)
def test_warpctc_with_padding(self):
# TODO(minqiyang): dygraph do not support lod now
with self.static_graph():
input_length = layers.data(
name='logits_length', shape=[11], dtype='int64')
label_length = layers.data(
name='labels_length', shape=[12], dtype='int64')
label = layers.data(name='label', shape=[12, 1], dtype='int32')
predict = layers.data(
name='predict', shape=[4, 4, 8], dtype='float32')
output = layers.warpctc(
input=predict,
label=label,
input_length=input_length,
label_length=label_length)
return (output)
if __name__ == '__main__':
unittest.main()
......@@ -241,6 +241,104 @@ class TestWarpCTCOpCase1(TestWarpCTCOp):
self.use_cudnn = False
class TestWarpCTCOpWithPadding(OpTest):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = self.num_classes - 1
self.norm_by_times = False
self.use_cudnn = False
def setUp(self):
self.op_type = "warpctc"
self.config()
logits = np.random.uniform(
0.1, 1.0,
[sum(self.logits_length), self.num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
# labels should not be blank
labels = np.random.randint(
0,
self.num_classes - 1, [sum(self.labels_length), 1],
dtype="int32")
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.blank, self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(self.batch_size):
max_sequence_length = max(max_sequence_length,
self.logits_length[i])
# reshape logits to T*N*S
new_logits = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes],
dtype="float32")
cur = 0
for batch_id in range(self.batch_size):
for i in range(self.logits_length[batch_id]):
for j in range(self.num_classes):
new_logits[i, batch_id, j] = logits[cur + i, j]
cur = cur + self.logits_length[batch_id]
# reshape labels to N*S
max_target_seq_length = 0
for i in range(self.batch_size):
max_target_seq_length = max(max_target_seq_length,
self.labels_length[i])
new_labels = np.zeros(
[self.batch_size, max_target_seq_length], dtype="int32")
cur = 0
for batch_id in range(self.batch_size):
for i in range(self.labels_length[batch_id]):
new_labels[batch_id, i] = labels[cur + i]
cur = cur + self.labels_length[batch_id]
self.gradient = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes],
dtype="float32")
self.inputs = {
"Logits": new_logits,
"Label": labels,
"LogitsLength": self.logits_length,
"LabelLength": self.labels_length
}
self.outputs = {"Loss": loss}
self.attrs = {
"blank": self.blank,
"norm_by_times": self.norm_by_times,
"use_cudnn": self.use_cudnn
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss", max_relative_error=0.007)
class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
def config(self):
self.batch_size = 4
self.num_classes = CUDA_BLOCK_SIZE + 2
self.logits_lod = [[4, 1, 3, 3]]
self.labels_lod = [[3, 1, 4, 4]]
self.logits_length = np.array([4, 1, 3, 3], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 4], dtype=np.int64)
self.blank = 0
self.norm_by_times = False
self.use_cudnn = False
# TODO: fix this test failed cuda9/10 manylinux images
# class TestCudnnCTCOp(TestWarpCTCOp):
# def config(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册