From faf6890b6c167e1f29c0acb8ed5f9914e9c8984f Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Mon, 5 Aug 2019 10:47:51 +0800 Subject: [PATCH] support tensor input for ctc align op (#18887) * test=develop support Tensor input for ctc_align_op * test=develop add some comment --- paddle/fluid/operators/ctc_align_op.cc | 23 +++- paddle/fluid/operators/ctc_align_op.cu | 107 ++++++++++------ paddle/fluid/operators/ctc_align_op.h | 97 +++++++++------ .../fluid/tests/unittests/test_ctc_align.py | 117 +++++++++++++++--- 4 files changed, 255 insertions(+), 89 deletions(-) diff --git a/paddle/fluid/operators/ctc_align_op.cc b/paddle/fluid/operators/ctc_align_op.cc index e7c472f8c0..1b49cf3ce9 100644 --- a/paddle/fluid/operators/ctc_align_op.cc +++ b/paddle/fluid/operators/ctc_align_op.cc @@ -45,7 +45,7 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("Input", - "(LodTensor, default: LoDTensor), Its shape is " + "2-D Tensor or LodTensor with shape " "[Lp, 1], where Lp is the sum of all input sequences' length."); AddOutput("Output", "(Tensor, default: Tensor), The align result."); AddAttr("blank", @@ -56,6 +56,11 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { "(bool, default: true), whether to " "merge repeated elements between two blanks. ") .SetDefault(true); + // add attr padding number for tensor input + AddAttr("padding_num", + "(int, default: 0), padding number " + "use to padding tensor. ") + .SetDefault(0); AddComment(R"DOC( CTCAlign op is used to merge repeated elements between two blanks and then delete all blanks in sequence. @@ -75,7 +80,23 @@ Then: 6, 7] Output.dims = {8, 1} Output.LoD = [[0, 6, 8]] +or Given: + Input.data = [[0, 1, 2, 2, 0, 4], + [0, 4, 5, 0, 6, 0], + [0, 7, 7, 7, 0, 0]] + Input.dims = {3, 6}, + Input.Lod = [] +And: + blank = 0 + merge_repeated = True + padding_num = 0 +Then: + Output.data = [[1, 2, 4, 0, 0, 0], + [4, 5, 6, 0, 0, 0], + [7, 0, 0, 0, 0, 0]] + Output.dims = {3, 6}, + Output.Lod = [] )DOC"); } }; diff --git a/paddle/fluid/operators/ctc_align_op.cu b/paddle/fluid/operators/ctc_align_op.cu index bbad74e96d..4f58668485 100644 --- a/paddle/fluid/operators/ctc_align_op.cu +++ b/paddle/fluid/operators/ctc_align_op.cu @@ -42,53 +42,90 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } } +template +__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token, + const T* tokens, const int blank, + const int merge_repeated, + const int padding_num, + const int64_t batch_size, + T* output) { + int ind = blockIdx.x * blockDim.x + threadIdx.x; + if (ind >= batch_size) return; + int output_idx = ind * num_token; + T prev_token = -1; + for (int i = ind * num_token; i < ind * num_token + num_token; i++) { + if ((unsigned)tokens[i] != blank && + !(merge_repeated && tokens[i] == prev_token)) { + output[output_idx] = tokens[i]; + ++output_idx; + } + prev_token = tokens[i]; + } + for (int i = output_idx; i < ind * num_token + num_token; i++) { + output[i] = padding_num; + } +} + template class CTCAlignOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "It must use CUDAPlace."); - const size_t level = 0; auto* input = ctx.Input("Input"); auto* output = ctx.Output("Output"); - auto input_lod = framework::ToAbsOffset(input->lod()); - - const T* tokens = input->data(); - const int64_t num_tokens = input->dims()[0]; - const size_t num_seq = input_lod[level].size() - 1; - const int blank = ctx.Attr("blank"); const int merge_repeated = static_cast(ctx.Attr("merge_repeated")); - - // prepare a lod to record lod information while merging elements - thrust::device_vector dev_out_lod0(input_lod[level].size()); - size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); - - // merge elements and delete blank - T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); - + const T* tokens = input->data(); auto stream = ctx.cuda_device_context().stream(); - MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( - num_tokens, tokens, num_seq, - input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated, - dev_out_lod0_ptr, output_data); - - // set output lod - std::vector host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); - framework::LoD out_lod; - out_lod.push_back(host_out_lod0); - output->set_lod(out_lod); - - // resize output dims - output->Resize({static_cast(host_out_lod0.back()), 1}); - - if (host_out_lod0.back() == 0) { - output->Resize({1, 1}); - output->mutable_data(ctx.GetPlace()); - math::SetConstant set_constant; - set_constant(ctx.template device_context(), - output, -1); + + // tensor input which has no lod + if (input->lod().empty()) { + const int padding_num = ctx.Attr("padding_num"); + auto input_dims = input->dims(); + T* output_data = output->mutable_data({input_dims[0], input_dims[1]}, + ctx.GetPlace()); + PaddingMergeAndDelCudaKernel< + T><<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>( + input_dims[1], tokens, blank, merge_repeated, padding_num, + input_dims[0], output_data); + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + const int64_t num_tokens = input->dims()[0]; + const size_t num_seq = input_lod[level].size() - 1; + + // prepare a lod to record lod information while merging elements + thrust::device_vector dev_out_lod0(input_lod[level].size()); + size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); + + // merge elements and delete blank + T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + + MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( + num_tokens, tokens, num_seq, + input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, + merge_repeated, dev_out_lod0_ptr, output_data); + + // set output lod + std::vector host_out_lod0(dev_out_lod0.begin(), + dev_out_lod0.end()); + framework::LoD out_lod; + out_lod.push_back(host_out_lod0); + output->set_lod(out_lod); + + // resize output dims + output->Resize({static_cast(host_out_lod0.back()), 1}); + + if (host_out_lod0.back() == 0) { + output->Resize({1, 1}); + output->mutable_data(ctx.GetPlace()); + math::SetConstant set_constant; + set_constant(ctx.template device_context(), + output, -1); + } } } }; diff --git a/paddle/fluid/operators/ctc_align_op.h b/paddle/fluid/operators/ctc_align_op.h index 9c5c6f5aa0..1b5bf32d76 100644 --- a/paddle/fluid/operators/ctc_align_op.h +++ b/paddle/fluid/operators/ctc_align_op.h @@ -31,50 +31,73 @@ class CTCAlignKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* output = ctx.Output("Output"); - const size_t level = 0; - auto input_lod = framework::ToAbsOffset(input->lod()); - - // check input dims and lod - auto input_dims = input->dims(); - PADDLE_ENFORCE_EQ(input_dims[0], - static_cast(input_lod[level].back()), - "The first dimension of Input(Input) should be equal to " - "the sum of all sequences' lengths."); - - const size_t num_sequences = input_lod[level].size() - 1; size_t blank = static_cast(ctx.Attr("blank")); bool merge_repeated = ctx.Attr("merge_repeated"); - - // merge repeated tokens and delete blank T* output_data = output->mutable_data(ctx.GetPlace()); - size_t output_idx = 0; - std::vector output_lod0(1, 0); + auto input_dims = input->dims(); const T* input_data = input->data(); - for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { - T prev_token = -1; - for (size_t i = input_lod[level][seq_idx]; - i < input_lod[level][seq_idx + 1]; ++i) { - if ((unsigned)input_data[i] != blank && - !(merge_repeated && input_data[i] == prev_token)) { - output_data[output_idx] = input_data[i]; - ++output_idx; + + // support tensor input, no lod information + if (input->lod().empty()) { + size_t padding_num = static_cast(ctx.Attr("padding_num")); + for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0]; + batch_id++) { + T prev_token = -1; + size_t output_idx = 0; + for (size_t i = 0; i < (unsigned)input_dims[1]; i++) { + size_t input_ind = batch_id * input_dims[1] + i; + if ((unsigned)input_data[input_ind] != blank && + !(merge_repeated && input_data[input_ind] == prev_token)) { + output_data[batch_id * input_dims[1] + output_idx] = + input_data[input_ind]; + ++output_idx; + } + prev_token = input_data[input_ind]; } - prev_token = input_data[i]; + for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++) + output_data[batch_id * input_dims[1] + j] = padding_num; } - output_lod0.push_back(output_idx); - } + } else { + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); - // set output lod - framework::LoD output_lod; - output_lod.push_back(output_lod0); - output->set_lod(output_lod); - // resize output dims - output->Resize({static_cast(output_lod0.back()), 1}); - // for empty sequence - if (output_lod0.back() == 0) { - output->Resize({1, 1}); - output_data = output->mutable_data(ctx.GetPlace()); - output_data[0] = -1; + // check input dims and lod + PADDLE_ENFORCE_EQ( + input_dims[0], static_cast(input_lod[level].back()), + "The first dimension of Input(Input) should be equal to " + "the sum of all sequences' lengths."); + + const size_t num_sequences = input_lod[level].size() - 1; + + // merge repeated tokens and delete blank + size_t output_idx = 0; + std::vector output_lod0(1, 0); + for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + T prev_token = -1; + for (size_t i = input_lod[level][seq_idx]; + i < input_lod[level][seq_idx + 1]; ++i) { + if ((unsigned)input_data[i] != blank && + !(merge_repeated && input_data[i] == prev_token)) { + output_data[output_idx] = input_data[i]; + ++output_idx; + } + prev_token = input_data[i]; + } + output_lod0.push_back(output_idx); + } + + // set output lod + framework::LoD output_lod; + output_lod.push_back(output_lod0); + output->set_lod(output_lod); + // resize output dims + output->Resize({static_cast(output_lod0.back()), 1}); + // for empty sequence + if (output_lod0.back() == 0) { + output->Resize({1, 1}); + output_data = output->mutable_data(ctx.GetPlace()); + output_data[0] = -1; + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_ctc_align.py b/python/paddle/fluid/tests/unittests/test_ctc_align.py index 5f17d2d407..cddb2f0c3c 100644 --- a/python/paddle/fluid/tests/unittests/test_ctc_align.py +++ b/python/paddle/fluid/tests/unittests/test_ctc_align.py @@ -21,22 +21,39 @@ from op_test import OpTest from test_softmax_op import stable_softmax -def CTCAlign(input, lod, blank, merge_repeated): - lod0 = lod[0] - result = [] - cur_offset = 0 - for i in range(len(lod0)): - prev_token = -1 - for j in range(cur_offset, cur_offset + lod0[i]): - token = input[j][0] - if (token != blank) and not (merge_repeated and - token == prev_token): - result.append(token) - prev_token = token - cur_offset += lod0[i] - result = np.array(result).reshape([len(result), 1]).astype("int32") - if len(result) == 0: - result = np.array([-1]) +def CTCAlign(input, lod, blank, merge_repeated, padding=0): + if lod is not None and len(lod) > 0: + lod0 = lod[0] + result = [] + cur_offset = 0 + for i in range(len(lod0)): + prev_token = -1 + for j in range(cur_offset, cur_offset + lod0[i]): + token = input[j][0] + if (token != blank) and not (merge_repeated and + token == prev_token): + result.append(token) + prev_token = token + cur_offset += lod0[i] + result = np.array(result).reshape([len(result), 1]).astype("int32") + if len(result) == 0: + result = np.array([-1]) + else: + result = [[] for i in range(len(input))] + for i in range(len(input)): + prev_token = -1 + for j in range(len(input[i])): + token = input[i][j] + if (token != blank) and not (merge_repeated and + token == prev_token): + result[i].append(token) + prev_token = token + start = len(result[i]) + for j in range(start, len(input[i])): + result[i].append(padding) + result = np.array(result).reshape( + [len(input), len(input[0])]).astype("int32") + return result @@ -87,5 +104,73 @@ class TestCTCAlignOpCase2(TestCTCAlignOp): self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32") +class TestCTCAlignPaddingOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [] + self.blank = 0 + self.padding_num = 0 + self.merge_repeated = True + self.input = np.array([[0, 2, 4, 4, 0, 6, 3, 6, 6, 0, 0], + [1, 1, 3, 0, 0, 4, 5, 6, 0, 0, 0]]).reshape( + [2, 11]).astype("int32") + + def setUp(self): + self.config() + output = CTCAlign(self.input, self.input_lod, self.blank, + self.merge_repeated, self.padding_num) + self.inputs = {"Input": (self.input, self.input_lod), } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated, + "padding_num": self.padding_num + } + + def test_check_output(self): + self.check_output() + pass + + +class TestCTCAlignOpCase3(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = True + self.padding_num = 0 + self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], + [0, 7, 7, 7, 0, 0]]).reshape( + [3, 6]).astype("int32") + + +class TestCTCAlignOpCase4(TestCTCAlignPaddingOp): + ''' + # test tensor input which has attr input padding_num + ''' + + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_num = 0 + self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], + [0, 7, 7, 7, 0, 0]]).reshape( + [3, 6]).astype("int32") + + +class TestCTCAlignOpCase5(TestCTCAlignPaddingOp): + def config(self): + self.op_type = "ctc_align" + self.blank = 0 + self.input_lod = [] + self.merge_repeated = False + self.padding_num = 1 + self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0], + [0, 7, 1, 7, 0, 0]]).reshape( + [3, 6]).astype("int32") + + if __name__ == "__main__": unittest.main() -- GitLab