From 10dd632659012374f827ae0208c05b0eb5c17fb6 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 16 Jan 2018 15:56:52 +0800 Subject: [PATCH] Rename 'ctc_greedy_decode' to 'ctc_decode' --- paddle/operators/ctc_decode_op.cc | 18 ++++++++---------- paddle/operators/ctc_decode_op.cu | 8 ++++---- paddle/operators/ctc_decode_op.h | 2 +- .../paddle/v2/fluid/tests/test_ctc_decode.py | 4 ++-- 4 files changed, 15 insertions(+), 17 deletions(-) diff --git a/paddle/operators/ctc_decode_op.cc b/paddle/operators/ctc_decode_op.cc index b290b11d1d..480c9ae133 100644 --- a/paddle/operators/ctc_decode_op.cc +++ b/paddle/operators/ctc_decode_op.cc @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { -class CTCGreedyDecodeOp : public framework::OperatorWithKernel { +class CTCDecodeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input of CTCGreedyDecodeOp should not be null."); + "Input of CTCDecodeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output of CTCGreedyDecodeOp should not be null."); + "Output of CTCDecodeOp should not be null."); auto input_dims = ctx->GetInputDim("Input"); @@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { } }; -class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { +class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(LodTensor, default: LoDTensor), Its shape is " @@ -86,9 +86,7 @@ Then: } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, - ops::CTCGreedyDecodeOpMaker, +REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_greedy_decode, - ops::CTCGreedyDecodeKernel); + ctc_decode, ops::CTCDecodeKernel); diff --git a/paddle/operators/ctc_decode_op.cu b/paddle/operators/ctc_decode_op.cu index e9cdad7c26..b10db100f7 100644 --- a/paddle/operators/ctc_decode_op.cu +++ b/paddle/operators/ctc_decode_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { @@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } template -class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { +class CTCDecodeOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, - paddle::operators::CTCGreedyDecodeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_decode, + paddle::operators::CTCDecodeOpCUDAKernel); diff --git a/paddle/operators/ctc_decode_op.h b/paddle/operators/ctc_decode_op.h index 30bb53e157..bc8dfab9f6 100644 --- a/paddle/operators/ctc_decode_op.h +++ b/paddle/operators/ctc_decode_op.h @@ -23,7 +23,7 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -class CTCGreedyDecodeKernel : public framework::OpKernel { +class CTCDecodeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_decode.py index 3b7486cfb9..6e798a8465 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_decode.py @@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated): class TestCTCDecodeOp(OpTest): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = False @@ -49,7 +49,7 @@ class TestCTCDecodeOp(OpTest): class TestCTCDecodeOpCase1(TestCTCDecodeOp): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = True -- GitLab