From 7150289b5cad76d3347a268b54c31e13a0e49f42 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 17 Jan 2018 16:34:38 +0800 Subject: [PATCH] Refine CPU kernel 1. Allocate memory for output before compute. 2. Rename 'ctc_decode' to 'ctc_align' --- .../{ctc_decode_op.cc => ctc_align_op.cc} | 20 +++++++++---------- .../{ctc_decode_op.cu => ctc_align_op.cu} | 8 ++++---- .../{ctc_decode_op.h => ctc_align_op.h} | 19 +++++++----------- .../{test_ctc_decode.py => test_ctc_align.py} | 14 ++++++------- 4 files changed, 28 insertions(+), 33 deletions(-) rename paddle/operators/{ctc_decode_op.cc => ctc_align_op.cc} (78%) rename paddle/operators/{ctc_decode_op.cu => ctc_align_op.cu} (93%) rename paddle/operators/{ctc_decode_op.h => ctc_align_op.h} (80%) rename python/paddle/v2/fluid/tests/{test_ctc_decode.py => test_ctc_align.py} (82%) diff --git a/paddle/operators/ctc_decode_op.cc b/paddle/operators/ctc_align_op.cc similarity index 78% rename from paddle/operators/ctc_decode_op.cc rename to paddle/operators/ctc_align_op.cc index 480c9ae133c..3fa8d2af742 100644 --- a/paddle/operators/ctc_decode_op.cc +++ b/paddle/operators/ctc_align_op.cc @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_decode_op.h" +#include "paddle/operators/ctc_align_op.h" namespace paddle { namespace operators { -class CTCDecodeOp : public framework::OperatorWithKernel { +class CTCAlignOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input of CTCDecodeOp should not be null."); + "Input of CTCAlignOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output of CTCDecodeOp should not be null."); + "Output of CTCAlignOp should not be null."); auto input_dims = ctx->GetInputDim("Input"); @@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel { } }; -class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { +class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(LodTensor, default: LoDTensor), Its shape is " "[Lp, 1], where Lp is the sum of all input sequences' length."); - AddOutput("Output", "(Tensor, default: Tensor), The decode result."); + AddOutput("Output", "(Tensor, default: Tensor), The align result."); AddAttr("blank", "(int, default: 0), the blank label setted in Connectionist " "Temporal Classification (CTC) op.") @@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { "merge repeated elements between two blanks. ") .SetDefault(true); AddComment(R"DOC( -CTCDecoder is used to merge repeated elements between two blanks +CTCAlign op is used to merge repeated elements between two blanks and then delete all blanks in sequence. Given: @@ -86,7 +86,7 @@ Then: } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, +REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_decode, ops::CTCDecodeKernel); + ctc_align, ops::CTCAlignKernel); diff --git a/paddle/operators/ctc_decode_op.cu b/paddle/operators/ctc_align_op.cu similarity index 93% rename from paddle/operators/ctc_decode_op.cu rename to paddle/operators/ctc_align_op.cu index b10db100f7f..99e716e989f 100644 --- a/paddle/operators/ctc_decode_op.cu +++ b/paddle/operators/ctc_align_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include "paddle/operators/ctc_decode_op.h" +#include "paddle/operators/ctc_align_op.h" namespace paddle { namespace operators { @@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } template -class CTCDecodeOpCUDAKernel : public framework::OpKernel { +class CTCAlignOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_decode, - paddle::operators::CTCDecodeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_align, + paddle::operators::CTCAlignOpCUDAKernel); diff --git a/paddle/operators/ctc_decode_op.h b/paddle/operators/ctc_align_op.h similarity index 80% rename from paddle/operators/ctc_decode_op.h rename to paddle/operators/ctc_align_op.h index bc8dfab9f62..589413feb3d 100644 --- a/paddle/operators/ctc_decode_op.h +++ b/paddle/operators/ctc_align_op.h @@ -23,7 +23,7 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -class CTCDecodeKernel : public framework::OpKernel { +class CTCAlignKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); @@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel { bool merge_repeated = ctx.Attr("merge_repeated"); // merge repeated tokens and delete blank - std::vector> pathes(num_sequences); + T* output_data = output->mutable_data(ctx.GetPlace()); + size_t output_idx = 0; std::vector output_lod0(1, 0); const T* input_data = input->data(); for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { @@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel { i < input_lod[level][seq_idx + 1]; ++i) { if (input_data[i] != blank && !(merge_repeated && input_data[i] == prev_token)) { - pathes[seq_idx].push_back(input_data[i]); + output_data[output_idx] = input_data[i]; + ++output_idx; } prev_token = input_data[i]; } - output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); + output_lod0.push_back(output_idx); } // set output lod @@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel { output->set_lod(output_lod); // resize output dims - T* output_data = output->mutable_data( - {static_cast(output_lod0.back()), 1}, ctx.GetPlace()); - - // copy result to output - for (int i = 0; i < num_sequences; ++i) { - memcpy(output_data + output_lod0[i], pathes[i].data(), - sizeof(int) * pathes[i].size()); - } + output->Resize({static_cast(output_lod0.back()), 1}); } }; diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_align.py similarity index 82% rename from python/paddle/v2/fluid/tests/test_ctc_decode.py rename to python/paddle/v2/fluid/tests/test_ctc_align.py index 1efacab4b3b..96f45890ee9 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_align.py @@ -5,7 +5,7 @@ from op_test import OpTest from test_softmax_op import stable_softmax -def CTCDecode(input, lod, blank, merge_repeated): +def CTCAlign(input, lod, blank, merge_repeated): lod0 = lod[0] result = [] for i in range(len(lod0) - 1): @@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated): return result -class TestCTCDecodeOp(OpTest): +class TestCTCAlignOp(OpTest): def config(self): - self.op_type = "ctc_decode" + self.op_type = "ctc_align" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = False @@ -32,8 +32,8 @@ class TestCTCDecodeOp(OpTest): def setUp(self): self.config() - output = CTCDecode(self.input, self.input_lod, self.blank, - self.merge_repeated) + output = CTCAlign(self.input, self.input_lod, self.blank, + self.merge_repeated) self.inputs = {"Input": (self.input, self.input_lod), } self.outputs = {"Output": output} @@ -47,9 +47,9 @@ class TestCTCDecodeOp(OpTest): pass -class TestCTCDecodeOpCase1(TestCTCDecodeOp): +class TestCTCAlignOpCase1(TestCTCAlignOp): def config(self): - self.op_type = "ctc_decode" + self.op_type = "ctc_align" self.input_lod = [[0, 11, 19]] self.blank = 0 self.merge_repeated = True -- GitLab