You need to sign in or sign up before continuing.
提交 10dd6326 编写于 作者: W wanghaoshuang

Rename 'ctc_greedy_decode' to 'ctc_decode'

上级 281e93bc
...@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/ctc_greedy_decode_op.h" #include "paddle/operators/ctc_decode_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CTCGreedyDecodeOp : public framework::OperatorWithKernel { class CTCDecodeOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCGreedyDecodeOp should not be null."); "Input of CTCDecodeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"), PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCGreedyDecodeOp should not be null."); "Output of CTCDecodeOp should not be null.");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
...@@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { ...@@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
} }
}; };
class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is " "(LodTensor, default: LoDTensor<int>), Its shape is "
...@@ -86,9 +86,7 @@ Then: ...@@ -86,9 +86,7 @@ Then:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
ops::CTCGreedyDecodeOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ctc_greedy_decode, ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, int>);
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/operators/ctc_greedy_decode_op.h" #include "paddle/operators/ctc_decode_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, ...@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
} }
template <typename T> template <typename T>
class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> { class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> { ...@@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, REGISTER_OP_CUDA_KERNEL(ctc_decode,
paddle::operators::CTCGreedyDecodeOpCUDAKernel<int>); paddle::operators::CTCDecodeOpCUDAKernel<int>);
...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor; ...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CTCGreedyDecodeKernel : public framework::OpKernel<T> { class CTCDecodeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
......
...@@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated): ...@@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated):
class TestCTCDecodeOp(OpTest): class TestCTCDecodeOp(OpTest):
def config(self): def config(self):
self.op_type = "ctc_greedy_decode" self.op_type = "ctc_decode"
self.input_lod = [[0, 11, 18]] self.input_lod = [[0, 11, 18]]
self.blank = 0 self.blank = 0
self.merge_repeated = False self.merge_repeated = False
...@@ -49,7 +49,7 @@ class TestCTCDecodeOp(OpTest): ...@@ -49,7 +49,7 @@ class TestCTCDecodeOp(OpTest):
class TestCTCDecodeOpCase1(TestCTCDecodeOp): class TestCTCDecodeOpCase1(TestCTCDecodeOp):
def config(self): def config(self):
self.op_type = "ctc_greedy_decode" self.op_type = "ctc_decode"
self.input_lod = [[0, 11, 18]] self.input_lod = [[0, 11, 18]]
self.blank = 0 self.blank = 0
self.merge_repeated = True self.merge_repeated = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册