未验证 提交 47753a96 编写于 作者: W whs 提交者: GitHub

Merge pull request #7527 from wanghaoshuang/ctc_greedy_decode

Add CTC align op
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_align_op.h"
namespace paddle {
namespace operators {
class CTCAlignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCAlignOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCAlignOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Output", input_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context());
}
};
class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCAlignOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is "
"[Lp, 1], where Lp is the sum of all input sequences' length.");
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
AddAttr<int>("blank",
"(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op.")
.SetDefault(0);
AddAttr<bool>("merge_repeated",
"(bool, default: true), whether to "
"merge repeated elements between two blanks. ")
.SetDefault(true);
AddComment(R"DOC(
CTCAlign op is used to merge repeated elements between two blanks
and then delete all blanks in sequence.
Given:
Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6,
6, 0, 0, 7, 7, 7, 0]
Input.dims = {18, 1}
Input.LoD = [[0, 11, 18]]
And:
blank = 0
merge_repeated = True
Then:
Output.data = [1, 2, 4, 4, 5, 6,
6, 7]
Output.dims = {8, 1}
Output.LoD = [[0, 6, 8]]
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_align, ops::CTCAlignOp, ops::CTCAlignOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_align, ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int>,
ops::CTCAlignKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/ctc_align_op.h"
namespace paddle {
namespace operators {
template <typename T>
__global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
const size_t num_seq, size_t* lod0,
const int blank, const int merge_repeated,
size_t* out_lod0, T* output) {
int ouput_idx = 0;
out_lod0[0] = 0;
for (int i = 0; i < num_seq; ++i) {
T pre_token = -1;
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
output[ouput_idx] = tokens[j];
++ouput_idx;
}
pre_token = tokens[j];
}
out_lod0[i + 1] = ouput_idx;
}
}
template <typename T>
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
const size_t level = 0;
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
auto input_lod = framework::ToAbsOffset(input->lod());
const T* tokens = input->data<T>();
const int64_t num_tokens = input->dims()[0];
const size_t num_seq = input_lod[level].size() - 1;
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
// prepare a lod to record lod information while merging elements
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
// merge elements and delete blank
T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].data(), blank,
merge_repeated, dev_out_lod0_ptr, output_data);
// set output lod
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
dev_out_lod0.end());
framework::LoD out_lod;
out_lod.push_back(host_out_lod0);
output->set_lod(out_lod);
// resize output dims
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_align, paddle::operators::CTCAlignOpCUDAKernel<int>,
paddle::operators::CTCAlignOpCUDAKernel<int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string.h>
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class CTCAlignKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod());
// check input dims and lod
auto input_dims = input->dims();
PADDLE_ENFORCE_EQ(input_dims[0],
static_cast<int64_t>(input_lod[level].back()),
"The first dimension of Input(Input) should be equal to "
"the sum of all sequences' lengths.");
const size_t num_sequences = input_lod[level].size() - 1;
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
// merge repeated tokens and delete blank
T* output_data = output->mutable_data<T>(ctx.GetPlace());
size_t output_idx = 0;
std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>();
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
T prev_token = -1;
for (size_t i = input_lod[level][seq_idx];
i < input_lod[level][seq_idx + 1]; ++i) {
if (input_data[i] != blank &&
!(merge_repeated && input_data[i] == prev_token)) {
output_data[output_idx] = input_data[i];
++output_idx;
}
prev_token = input_data[i];
}
output_lod0.push_back(output_idx);
}
// set output lod
framework::LoD output_lod;
output_lod.push_back(output_lod0);
output->set_lod(output_lod);
// resize output dims
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
}
};
} // namespace operators
} // namespace paddle
...@@ -58,7 +58,7 @@ This operator expands input(X) according to LOD of input(Y). ...@@ -58,7 +58,7 @@ This operator expands input(X) according to LOD of input(Y).
Following are cases to better explain how this works: Following are cases to better explain how this works:
Case 1: Case 1:
Given 2-level a LoDTensor input(X) Given a 2-level LoDTensor input(X)
X.lod = [[0, 2, 3], X.lod = [[0, 2, 3],
[0, 1, 3, 4]] [0, 1, 3, 4]]
X.data = [a, b, c, d] X.data = [a, b, c, d]
...@@ -75,9 +75,8 @@ then we get 2-level LoDTensor ...@@ -75,9 +75,8 @@ then we get 2-level LoDTensor
Case 2: Case 2:
Given a 0-level LoDTensor input(X) Given a common Tensor input(X)
X.data = [a, b, c] X.data = [a, b, c]
X.lod = NULL
X.dims = [3, 1] X.dims = [3, 1]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
...@@ -89,9 +88,8 @@ then we get 1-level LoDTensor ...@@ -89,9 +88,8 @@ then we get 1-level LoDTensor
Case 3: Case 3:
Given a 0-level LoDTensor input(X) Given a common Tensor input(X)
X.data = [[a, b], [c, d], [e, f]] X.data = [[a, b], [c, d], [e, f]]
X.lod = NULL
X.dims = [3, 2] X.dims = [3, 2]
and input(Y) and input(Y)
Y.lod = [[0, 2, 3, 6]] Y.lod = [[0, 2, 3, 6]]
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import sys
import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
def CTCAlign(input, lod, blank, merge_repeated):
lod0 = lod[0]
result = []
for i in range(len(lod0) - 1):
prev_token = -1
for j in range(lod0[i], lod0[i + 1]):
token = input[j][0]
if (token != blank) and not (merge_repeated and
token == prev_token):
result.append(token)
prev_token = token
result = np.array(result).reshape([len(result), 1]).astype("int32")
return result
class TestCTCAlignOp(OpTest):
def config(self):
self.op_type = "ctc_align"
self.input_lod = [[0, 11, 18]]
self.blank = 0
self.merge_repeated = False
self.input = np.array(
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
[18, 1]).astype("int32")
def setUp(self):
self.config()
output = CTCAlign(self.input, self.input_lod, self.blank,
self.merge_repeated)
self.inputs = {"Input": (self.input, self.input_lod), }
self.outputs = {"Output": output}
self.attrs = {
"blank": self.blank,
"merge_repeated": self.merge_repeated
}
def test_check_output(self):
self.check_output()
pass
class TestCTCAlignOpCase1(TestCTCAlignOp):
def config(self):
self.op_type = "ctc_align"
self.input_lod = [[0, 11, 19]]
self.blank = 0
self.merge_repeated = True
self.input = np.array(
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0, 0]).reshape(
[19, 1]).astype("int32")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册