提交 10dd6326 编写于 作者: W wanghaoshuang

Rename 'ctc_greedy_decode' to 'ctc_decode'

上级 281e93bc
......@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_greedy_decode_op.h"
#include "paddle/operators/ctc_decode_op.h"
namespace paddle {
namespace operators {
class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
class CTCDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCGreedyDecodeOp should not be null.");
"Input of CTCDecodeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCGreedyDecodeOp should not be null.");
"Output of CTCDecodeOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
......@@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
}
};
class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
......@@ -86,9 +86,7 @@ Then:
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
ops::CTCGreedyDecodeOpMaker,
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_greedy_decode,
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, int>);
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/ctc_greedy_decode_op.h"
#include "paddle/operators/ctc_decode_op.h"
namespace paddle {
namespace operators {
......@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
}
template <typename T>
class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
paddle::operators::CTCGreedyDecodeOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(ctc_decode,
paddle::operators::CTCDecodeOpCUDAKernel<int>);
......@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
class CTCDecodeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
......
......@@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated):
class TestCTCDecodeOp(OpTest):
def config(self):
self.op_type = "ctc_greedy_decode"
self.op_type = "ctc_decode"
self.input_lod = [[0, 11, 18]]
self.blank = 0
self.merge_repeated = False
......@@ -49,7 +49,7 @@ class TestCTCDecodeOp(OpTest):
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
def config(self):
self.op_type = "ctc_greedy_decode"
self.op_type = "ctc_decode"
self.input_lod = [[0, 11, 18]]
self.blank = 0
self.merge_repeated = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册