diff --git a/paddle/operators/ctc_decode_op.cc b/paddle/operators/ctc_decode_op.cc index b290b11d1d1e80ec403429211ab6ea2bde9fc934..480c9ae133ce1d7853652d9f527bcf51159d8f1e 100644 --- a/paddle/operators/ctc_decode_op.cc +++ b/paddle/operators/ctc_decode_op.cc @@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { -class CTCGreedyDecodeOp : public framework::OperatorWithKernel { +class CTCDecodeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), - "Input of CTCGreedyDecodeOp should not be null."); + "Input of CTCDecodeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Output"), - "Output of CTCGreedyDecodeOp should not be null."); + "Output of CTCDecodeOp should not be null."); auto input_dims = ctx->GetInputDim("Input"); @@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { } }; -class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { +class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { public: - CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) + CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(LodTensor, default: LoDTensor), Its shape is " @@ -86,9 +86,7 @@ Then: } // namespace paddle namespace ops = paddle::operators; -REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, - ops::CTCGreedyDecodeOpMaker, +REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL( - ctc_greedy_decode, - ops::CTCGreedyDecodeKernel); + ctc_decode, ops::CTCDecodeKernel); diff --git a/paddle/operators/ctc_decode_op.cu b/paddle/operators/ctc_decode_op.cu index e9cdad7c26b17999c08cae4fe46122cf534fb891..b10db100f7fb71ccef1e93fd7dcb651e70815b43 100644 --- a/paddle/operators/ctc_decode_op.cu +++ b/paddle/operators/ctc_decode_op.cu @@ -15,7 +15,7 @@ limitations under the License. */ #include #include #include -#include "paddle/operators/ctc_greedy_decode_op.h" +#include "paddle/operators/ctc_decode_op.h" namespace paddle { namespace operators { @@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, } template -class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { +class CTCDecodeOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), @@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, - paddle::operators::CTCGreedyDecodeOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(ctc_decode, + paddle::operators::CTCDecodeOpCUDAKernel); diff --git a/paddle/operators/ctc_decode_op.h b/paddle/operators/ctc_decode_op.h index 30bb53e157f19e1c98d9945deee01a5cb2595952..bc8dfab9f621531273beda313dd3af79efd7fd1e 100644 --- a/paddle/operators/ctc_decode_op.h +++ b/paddle/operators/ctc_decode_op.h @@ -23,7 +23,7 @@ using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; template -class CTCGreedyDecodeKernel : public framework::OpKernel { +class CTCDecodeKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); diff --git a/python/paddle/v2/fluid/tests/test_ctc_decode.py b/python/paddle/v2/fluid/tests/test_ctc_decode.py index 3b7486cfb98098facf05106eb2ff668959bb35d5..6e798a8465c36affcdc222f3cdd12cec73526d07 100644 --- a/python/paddle/v2/fluid/tests/test_ctc_decode.py +++ b/python/paddle/v2/fluid/tests/test_ctc_decode.py @@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated): class TestCTCDecodeOp(OpTest): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = False @@ -49,7 +49,7 @@ class TestCTCDecodeOp(OpTest): class TestCTCDecodeOpCase1(TestCTCDecodeOp): def config(self): - self.op_type = "ctc_greedy_decode" + self.op_type = "ctc_decode" self.input_lod = [[0, 11, 18]] self.blank = 0 self.merge_repeated = True