diff --git a/paddle/operators/ctc_align_op.cc b/paddle/operators/ctc_align_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eeecbd32127d2cf9756432817fc5d36673685aa7 --- /dev/null +++ b/paddle/operators/ctc_align_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +class CTCAlignOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input of CTCAlignOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Output"), + "Output of CTCAlignOp should not be null."); + + auto input_dims = ctx->GetInputDim("Input"); + + // TODO(wanghaoshuang): it is tricky to set the wrong dimension here. + ctx->SetOutputDim("Output", input_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } +}; + +class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LodTensor, default: LoDTensor), Its shape is " + "[Lp, 1], where Lp is the sum of all input sequences' length."); + AddOutput("Output", "(Tensor, default: Tensor), The align result."); + AddAttr("blank", + "(int, default: 0), the blank label setted in Connectionist " + "Temporal Classification (CTC) op.") + .SetDefault(0); + AddAttr("merge_repeated", + "(bool, default: true), whether to " + "merge repeated elements between two blanks. ") + .SetDefault(true); + AddComment(R"DOC( +CTCAlign op is used to merge repeated elements between two blanks +and then delete all blanks in sequence. + +Given: + Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, + 6, 0, 0, 7, 7, 7, 0] + Input.dims = {18, 1} + Input.LoD = [[0, 11, 18]] + +And: + blank = 0 + merge_repeated = True + +Then: + Output.data = [1, 2, 4, 4, 5, 6, + 6, 7] + Output.dims = {8, 1} + Output.LoD = [[0, 6, 8]] + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + ctc_align, ops::CTCAlignKernel, + ops::CTCAlignKernel); diff --git a/paddle/operators/ctc_align_op.cu b/paddle/operators/ctc_align_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..45635f16745346b08f7e31db2f25905bdbc3aeeb --- /dev/null +++ b/paddle/operators/ctc_align_op.cu @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/operators/ctc_align_op.h" + +namespace paddle { +namespace operators { + +template +__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, + const size_t num_seq, size_t* lod0, + const int blank, const int merge_repeated, + size_t* out_lod0, T* output) { + int ouput_idx = 0; + out_lod0[0] = 0; + + for (int i = 0; i < num_seq; ++i) { + T pre_token = -1; + for (int j = lod0[i]; j < lod0[i + 1]; ++j) { + if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { + output[ouput_idx] = tokens[j]; + ++ouput_idx; + } + pre_token = tokens[j]; + } + out_lod0[i + 1] = ouput_idx; + } +} + +template +class CTCAlignOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "It must use CUDAPlace."); + const size_t level = 0; + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + auto input_lod = framework::ToAbsOffset(input->lod()); + + const T* tokens = input->data(); + const int64_t num_tokens = input->dims()[0]; + const size_t num_seq = input_lod[level].size() - 1; + + const int blank = ctx.Attr("blank"); + const int merge_repeated = + static_cast(ctx.Attr("merge_repeated")); + + // prepare a lod to record lod information while merging elements + thrust::device_vector dev_out_lod0(input_lod[level].size()); + size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); + + // merge elements and delete blank + T* output_data = output->mutable_data({num_tokens, 1}, ctx.GetPlace()); + + auto stream = ctx.cuda_device_context().stream(); + MergeAndDelCudaKernel<<<1, 1, 0, stream>>>( + num_tokens, tokens, num_seq, input_lod[level].data(), blank, + merge_repeated, dev_out_lod0_ptr, output_data); + + // set output lod + thrust::host_vector host_out_lod0(dev_out_lod0.begin(), + dev_out_lod0.end()); + framework::LoD out_lod; + out_lod.push_back(host_out_lod0); + output->set_lod(out_lod); + + // resize output dims + output->Resize({static_cast(host_out_lod0.back()), 1}); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_CUDA_KERNEL(ctc_align, paddle::operators::CTCAlignOpCUDAKernel, + paddle::operators::CTCAlignOpCUDAKernel); diff --git a/paddle/operators/ctc_align_op.h b/paddle/operators/ctc_align_op.h new file mode 100644 index 0000000000000000000000000000000000000000..589413feb3dcbb7fea1f0a878b35d4bf714b5318 --- /dev/null +++ b/paddle/operators/ctc_align_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/op_registry.h" +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +class CTCAlignKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("Input"); + auto* output = ctx.Output("Output"); + const size_t level = 0; + auto input_lod = framework::ToAbsOffset(input->lod()); + + // check input dims and lod + auto input_dims = input->dims(); + PADDLE_ENFORCE_EQ(input_dims[0], + static_cast(input_lod[level].back()), + "The first dimension of Input(Input) should be equal to " + "the sum of all sequences' lengths."); + + const size_t num_sequences = input_lod[level].size() - 1; + size_t blank = static_cast(ctx.Attr("blank")); + bool merge_repeated = ctx.Attr("merge_repeated"); + + // merge repeated tokens and delete blank + T* output_data = output->mutable_data(ctx.GetPlace()); + size_t output_idx = 0; + std::vector output_lod0(1, 0); + const T* input_data = input->data(); + for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { + T prev_token = -1; + for (size_t i = input_lod[level][seq_idx]; + i < input_lod[level][seq_idx + 1]; ++i) { + if (input_data[i] != blank && + !(merge_repeated && input_data[i] == prev_token)) { + output_data[output_idx] = input_data[i]; + ++output_idx; + } + prev_token = input_data[i]; + } + output_lod0.push_back(output_idx); + } + + // set output lod + framework::LoD output_lod; + output_lod.push_back(output_lod0); + output->set_lod(output_lod); + + // resize output dims + output->Resize({static_cast(output_lod0.back()), 1}); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sequence_expand_op.cc b/paddle/operators/sequence_expand_op.cc index b40ec617e42110e0ab5168a8ac675adaf760fb3c..d34dbd35b6df2dac275fbe2c41f99b8549217d5b 100644 --- a/paddle/operators/sequence_expand_op.cc +++ b/paddle/operators/sequence_expand_op.cc @@ -58,7 +58,7 @@ This operator expands input(X) according to LOD of input(Y). Following are cases to better explain how this works: Case 1: -Given 2-level a LoDTensor input(X) +Given a 2-level LoDTensor input(X) X.lod = [[0, 2, 3], [0, 1, 3, 4]] X.data = [a, b, c, d] @@ -75,9 +75,8 @@ then we get 2-level LoDTensor Case 2: -Given a 0-level LoDTensor input(X) +Given a common Tensor input(X) X.data = [a, b, c] - X.lod = NULL X.dims = [3, 1] and input(Y) Y.lod = [[0, 2, 3, 6]] @@ -89,9 +88,8 @@ then we get 1-level LoDTensor Case 3: -Given a 0-level LoDTensor input(X) +Given a common Tensor input(X) X.data = [[a, b], [c, d], [e, f]] - X.lod = NULL X.dims = [3, 2] and input(Y) Y.lod = [[0, 2, 3, 6]] diff --git a/python/paddle/v2/fluid/tests/test_ctc_align.py b/python/paddle/v2/fluid/tests/test_ctc_align.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7c16997c19fed1ed231ef1cb76875eb48cb1c1 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_ctc_align.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import sys +import unittest +import numpy as np +from op_test import OpTest +from test_softmax_op import stable_softmax + + +def CTCAlign(input, lod, blank, merge_repeated): + lod0 = lod[0] + result = [] + for i in range(len(lod0) - 1): + prev_token = -1 + for j in range(lod0[i], lod0[i + 1]): + token = input[j][0] + if (token != blank) and not (merge_repeated and + token == prev_token): + result.append(token) + prev_token = token + result = np.array(result).reshape([len(result), 1]).astype("int32") + return result + + +class TestCTCAlignOp(OpTest): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[0, 11, 18]] + self.blank = 0 + self.merge_repeated = False + self.input = np.array( + [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape( + [18, 1]).astype("int32") + + def setUp(self): + self.config() + output = CTCAlign(self.input, self.input_lod, self.blank, + self.merge_repeated) + + self.inputs = {"Input": (self.input, self.input_lod), } + self.outputs = {"Output": output} + self.attrs = { + "blank": self.blank, + "merge_repeated": self.merge_repeated + } + + def test_check_output(self): + self.check_output() + pass + + +class TestCTCAlignOpCase1(TestCTCAlignOp): + def config(self): + self.op_type = "ctc_align" + self.input_lod = [[0, 11, 19]] + self.blank = 0 + self.merge_repeated = True + self.input = np.array( + [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]).reshape( + [19, 1]).astype("int32") + + +if __name__ == "__main__": + unittest.main()