提交 7150289b 编写于 作者: W wanghaoshuang

Refine CPU kernel

1. Allocate memory for output before compute.
2. Rename 'ctc_decode' to 'ctc_align'
上级 adcfde3e
......@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_decode_op.h"
#include "paddle/operators/ctc_align_op.h"
namespace paddle {
namespace operators {
class CTCDecodeOp : public framework::OperatorWithKernel {
class CTCAlignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCDecodeOp should not be null.");
"Input of CTCAlignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCDecodeOp should not be null.");
"Output of CTCAlignOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
......@@ -42,14 +42,14 @@ class CTCDecodeOp : public framework::OperatorWithKernel {
}
};
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result.");
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
AddAttr<int>("blank",
"(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op.")
......@@ -59,7 +59,7 @@ class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
"merge repeated elements between two blanks. ")
.SetDefault(true);
AddComment(R"DOC(
CTCDecoder is used to merge repeated elements between two blanks
CTCAlign op is used to merge repeated elements between two blanks
and then delete all blanks in sequence.
Given:
......@@ -86,7 +86,7 @@ Then:
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>);
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/ctc_decode_op.h"
#include "paddle/operators/ctc_align_op.h"
namespace paddle {
namespace operators {
......@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
}
template <typename T>
class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -87,5 +87,5 @@ class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_decode,
paddle::operators::CTCDecodeOpCUDAKernel<int>);
REGISTER_OP_CUDA_KERNEL(ctc_align,
paddle::operators::CTCAlignOpCUDAKernel<int>);
......@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class CTCDecodeKernel : public framework::OpKernel<T> {
class CTCAlignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
......@@ -43,7 +43,8 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
// merge repeated tokens and delete blank
std::vector<std::vector<int>> pathes(num_sequences);
T* output_data = output->mutable_data<T>(ctx.GetPlace());
size_t output_idx = 0;
std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>();
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
......@@ -52,11 +53,12 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
i < input_lod[level][seq_idx + 1]; ++i) {
if (input_data[i] != blank &&
!(merge_repeated && input_data[i] == prev_token)) {
pathes[seq_idx].push_back(input_data[i]);
output_data[output_idx] = input_data[i];
++output_idx;
}
prev_token = input_data[i];
}
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
output_lod0.push_back(output_idx);
}
// set output lod
......@@ -65,14 +67,7 @@ class CTCDecodeKernel : public framework::OpKernel<T> {
output->set_lod(output_lod);
// resize output dims
T* output_data = output->mutable_data<T>(
{static_cast<int64_t>(output_lod0.back()), 1}, ctx.GetPlace());
// copy result to output
for (int i = 0; i < num_sequences; ++i) {
memcpy(output_data + output_lod0[i], pathes[i].data(),
sizeof(int) * pathes[i].size());
}
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
}
};
......
......@@ -5,7 +5,7 @@ from op_test import OpTest
from test_softmax_op import stable_softmax
def CTCDecode(input, lod, blank, merge_repeated):
def CTCAlign(input, lod, blank, merge_repeated):
lod0 = lod[0]
result = []
for i in range(len(lod0) - 1):
......@@ -20,9 +20,9 @@ def CTCDecode(input, lod, blank, merge_repeated):
return result
class TestCTCDecodeOp(OpTest):
class TestCTCAlignOp(OpTest):
def config(self):
self.op_type = "ctc_decode"
self.op_type = "ctc_align"
self.input_lod = [[0, 11, 18]]
self.blank = 0
self.merge_repeated = False
......@@ -32,7 +32,7 @@ class TestCTCDecodeOp(OpTest):
def setUp(self):
self.config()
output = CTCDecode(self.input, self.input_lod, self.blank,
output = CTCAlign(self.input, self.input_lod, self.blank,
self.merge_repeated)
self.inputs = {"Input": (self.input, self.input_lod), }
......@@ -47,9 +47,9 @@ class TestCTCDecodeOp(OpTest):
pass
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
class TestCTCAlignOpCase1(TestCTCAlignOp):
def config(self):
self.op_type = "ctc_decode"
self.op_type = "ctc_align"
self.input_lod = [[0, 11, 19]]
self.blank = 0
self.merge_repeated = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册