提交 7150289b 编写于 作者: W wanghaoshuang

Refine CPU kernel

1. Allocate memory for output before compute.
2. Rename 'ctc_decode' to 'ctc_align'
上级 adcfde3e
...@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/ctc_decode_op.h" #include "paddle/operators/ctc_align_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class CTCDecodeOp : public framework::OperatorWithKernel { class CTCAlignOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCDecodeOp should not be null."); "Input of CTCAlignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"), PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCDecodeOp should not be null."); "Output of CTCAlignOp should not be null.");
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
...@@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel { ...@@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel {
} }
}; };
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is " "(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length."); "[Lp, 1], where Lp is the sum of all input sequences' length.");
AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result."); AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
AddAttr<int>("blank", AddAttr<int>("blank",
"(int, default: 0), the blank label setted in Connectionist " "(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op.") "Temporal Classification (CTC) op.")
...@@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
"merge repeated elements between two blanks. ") "merge repeated elements between two blanks. ")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CTCDecoder is used to merge repeated elements between two blanks CTCAlign op is used to merge repeated elements between two blanks
and then delete all blanks in sequence. and then delete all blanks in sequence.
Given: Given:
...@@ -86,7 +86,7 @@ Then: ...@@ -86,7 +86,7 @@ Then:
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker, REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>); ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>);
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/operators/ctc_decode_op.h" #include "paddle/operators/ctc_align_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, ...@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
} }
template <typename T> template <typename T>
class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> { class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> { ...@@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_decode, REGISTER_OP_CUDA_KERNEL(ctc_align,
paddle::operators::CTCDecodeOpCUDAKernel<int>); paddle::operators::CTCAlignOpCUDAKernel<int>);
...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor; ...@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CTCDecodeKernel : public framework::OpKernel<T> { class CTCAlignKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
...@@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
bool merge_repeated = ctx.Attr<bool>("merge_repeated"); bool merge_repeated = ctx.Attr<bool>("merge_repeated");
// merge repeated tokens and delete blank // merge repeated tokens and delete blank
std::vector<std::vector<int>> pathes(num_sequences); T* output_data = output->mutable_data<T>(ctx.GetPlace());
size_t output_idx = 0;
std::vector<size_t> output_lod0(1, 0); std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
...@@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel<T> { ...@@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
i < input_lod[level][seq_idx + 1]; ++i) { i < input_lod[level][seq_idx + 1]; ++i) {
if (input_data[i] != blank && if (input_data[i] != blank &&
!(merge_repeated && input_data[i] == prev_token)) { !(merge_repeated && input_data[i] == prev_token)) {
pathes[seq_idx].push_back(input_data[i]); output_data[output_idx] = input_data[i];
++output_idx;
} }
prev_token = input_data[i]; prev_token = input_data[i];
} }
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); output_lod0.push_back(output_idx);
} }
// set output lod // set output lod
...@@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel<T> { ...@@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
output->set_lod(output_lod); output->set_lod(output_lod);
// resize output dims // resize output dims
T* output_data = output->mutable_data<T>( output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
{static_cast<int64_t>(output_lod0.back()), 1}, ctx.GetPlace());
// copy result to output
for (int i = 0; i < num_sequences; ++i) {
memcpy(output_data + output_lod0[i], pathes[i].data(),
sizeof(int) * pathes[i].size());
}
} }
}; };
......
...@@ -5,7 +5,7 @@ from op_test import OpTest ...@@ -5,7 +5,7 @@ from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
def CTCDecode(input, lod, blank, merge_repeated): def CTCAlign(input, lod, blank, merge_repeated):
lod0 = lod[0] lod0 = lod[0]
result = [] result = []
for i in range(len(lod0) - 1): for i in range(len(lod0) - 1):
...@@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated): ...@@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated):
return result return result
class TestCTCDecodeOp(OpTest): class TestCTCAlignOp(OpTest):
def config(self): def config(self):
self.op_type = "ctc_decode" self.op_type = "ctc_align"
self.input_lod = [[0, 11, 18]] self.input_lod = [[0, 11, 18]]
self.blank = 0 self.blank = 0
self.merge_repeated = False self.merge_repeated = False
...@@ -32,8 +32,8 @@ class TestCTCDecodeOp(OpTest): ...@@ -32,8 +32,8 @@ class TestCTCDecodeOp(OpTest):
def setUp(self): def setUp(self):
self.config() self.config()
output = CTCDecode(self.input, self.input_lod, self.blank, output = CTCAlign(self.input, self.input_lod, self.blank,
self.merge_repeated) self.merge_repeated)
self.inputs = {"Input": (self.input, self.input_lod), } self.inputs = {"Input": (self.input, self.input_lod), }
self.outputs = {"Output": output} self.outputs = {"Output": output}
...@@ -47,9 +47,9 @@ class TestCTCDecodeOp(OpTest): ...@@ -47,9 +47,9 @@ class TestCTCDecodeOp(OpTest):
pass pass
class TestCTCDecodeOpCase1(TestCTCDecodeOp): class TestCTCAlignOpCase1(TestCTCAlignOp):
def config(self): def config(self):
self.op_type = "ctc_decode" self.op_type = "ctc_align"
self.input_lod = [[0, 11, 19]] self.input_lod = [[0, 11, 19]]
self.blank = 0 self.blank = 0
self.merge_repeated = True self.merge_repeated = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册