提交 faf6890b 编写于 作者: L Liufang Sang 提交者: whs

support tensor input for ctc align op (#18887)

* test=develop support Tensor input for ctc_align_op

* test=develop add some comment
上级 c97ea53c
...@@ -45,7 +45,7 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -45,7 +45,7 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<int>), Its shape is " "2-D Tensor or LodTensor with shape "
"[Lp, 1], where Lp is the sum of all input sequences' length."); "[Lp, 1], where Lp is the sum of all input sequences' length.");
AddOutput("Output", "(Tensor, default: Tensor<int>), The align result."); AddOutput("Output", "(Tensor, default: Tensor<int>), The align result.");
AddAttr<int>("blank", AddAttr<int>("blank",
...@@ -56,6 +56,11 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,6 +56,11 @@ class CTCAlignOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, default: true), whether to " "(bool, default: true), whether to "
"merge repeated elements between two blanks. ") "merge repeated elements between two blanks. ")
.SetDefault(true); .SetDefault(true);
// add attr padding number for tensor input
AddAttr<int>("padding_num",
"(int, default: 0), padding number "
"use to padding tensor. ")
.SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CTCAlign op is used to merge repeated elements between two blanks CTCAlign op is used to merge repeated elements between two blanks
and then delete all blanks in sequence. and then delete all blanks in sequence.
...@@ -75,7 +80,23 @@ Then: ...@@ -75,7 +80,23 @@ Then:
6, 7] 6, 7]
Output.dims = {8, 1} Output.dims = {8, 1}
Output.LoD = [[0, 6, 8]] Output.LoD = [[0, 6, 8]]
or Given:
Input.data = [[0, 1, 2, 2, 0, 4],
[0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]
Input.dims = {3, 6},
Input.Lod = []
And:
blank = 0
merge_repeated = True
padding_num = 0
Then:
Output.data = [[1, 2, 4, 0, 0, 0],
[4, 5, 6, 0, 0, 0],
[7, 0, 0, 0, 0, 0]]
Output.dims = {3, 6},
Output.Lod = []
)DOC"); )DOC");
} }
}; };
......
...@@ -42,25 +42,61 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens, ...@@ -42,25 +42,61 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
} }
} }
template <typename T>
__global__ void PaddingMergeAndDelCudaKernel(const int64_t num_token,
const T* tokens, const int blank,
const int merge_repeated,
const int padding_num,
const int64_t batch_size,
T* output) {
int ind = blockIdx.x * blockDim.x + threadIdx.x;
if (ind >= batch_size) return;
int output_idx = ind * num_token;
T prev_token = -1;
for (int i = ind * num_token; i < ind * num_token + num_token; i++) {
if ((unsigned)tokens[i] != blank &&
!(merge_repeated && tokens[i] == prev_token)) {
output[output_idx] = tokens[i];
++output_idx;
}
prev_token = tokens[i];
}
for (int i = output_idx; i < ind * num_token + num_token; i++) {
output[i] = padding_num;
}
}
template <typename T> template <typename T>
class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace."); "It must use CUDAPlace.");
const size_t level = 0;
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output"); auto* output = ctx.Output<LoDTensor>("Output");
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
const T* tokens = input->data<T>();
auto stream = ctx.cuda_device_context().stream();
// tensor input which has no lod
if (input->lod().empty()) {
const int padding_num = ctx.Attr<int>("padding_num");
auto input_dims = input->dims();
T* output_data = output->mutable_data<T>({input_dims[0], input_dims[1]},
ctx.GetPlace());
PaddingMergeAndDelCudaKernel<
T><<<32, (input_dims[0] + 32 - 1) / 32, 0, stream>>>(
input_dims[1], tokens, blank, merge_repeated, padding_num,
input_dims[0], output_data);
} else {
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod()); auto input_lod = framework::ToAbsOffset(input->lod());
const T* tokens = input->data<T>();
const int64_t num_tokens = input->dims()[0]; const int64_t num_tokens = input->dims()[0];
const size_t num_seq = input_lod[level].size() - 1; const size_t num_seq = input_lod[level].size() - 1;
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
// prepare a lod to record lod information while merging elements // prepare a lod to record lod information while merging elements
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size()); thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
...@@ -68,14 +104,14 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { ...@@ -68,14 +104,14 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
// merge elements and delete blank // merge elements and delete blank
T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace()); T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>( MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, num_tokens, tokens, num_seq,
input_lod[level].CUDAMutableData(ctx.GetPlace()), blank, merge_repeated, input_lod[level].CUDAMutableData(ctx.GetPlace()), blank,
dev_out_lod0_ptr, output_data); merge_repeated, dev_out_lod0_ptr, output_data);
// set output lod // set output lod
std::vector<size_t> host_out_lod0(dev_out_lod0.begin(), dev_out_lod0.end()); std::vector<size_t> host_out_lod0(dev_out_lod0.begin(),
dev_out_lod0.end());
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(host_out_lod0); out_lod.push_back(host_out_lod0);
output->set_lod(out_lod); output->set_lod(out_lod);
...@@ -91,6 +127,7 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> { ...@@ -91,6 +127,7 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
output, -1); output, -1);
} }
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -31,25 +31,47 @@ class CTCAlignKernel : public framework::OpKernel<T> { ...@@ -31,25 +31,47 @@ class CTCAlignKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output"); auto* output = ctx.Output<LoDTensor>("Output");
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto input_dims = input->dims();
const T* input_data = input->data<T>();
// support tensor input, no lod information
if (input->lod().empty()) {
size_t padding_num = static_cast<size_t>(ctx.Attr<int>("padding_num"));
for (size_t batch_id = 0; batch_id < (unsigned)input_dims[0];
batch_id++) {
T prev_token = -1;
size_t output_idx = 0;
for (size_t i = 0; i < (unsigned)input_dims[1]; i++) {
size_t input_ind = batch_id * input_dims[1] + i;
if ((unsigned)input_data[input_ind] != blank &&
!(merge_repeated && input_data[input_ind] == prev_token)) {
output_data[batch_id * input_dims[1] + output_idx] =
input_data[input_ind];
++output_idx;
}
prev_token = input_data[input_ind];
}
for (size_t j = output_idx; j < (unsigned)input_dims[1]; j++)
output_data[batch_id * input_dims[1] + j] = padding_num;
}
} else {
const size_t level = 0; const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod()); auto input_lod = framework::ToAbsOffset(input->lod());
// check input dims and lod // check input dims and lod
auto input_dims = input->dims(); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(input_dims[0], input_dims[0], static_cast<int64_t>(input_lod[level].back()),
static_cast<int64_t>(input_lod[level].back()),
"The first dimension of Input(Input) should be equal to " "The first dimension of Input(Input) should be equal to "
"the sum of all sequences' lengths."); "the sum of all sequences' lengths.");
const size_t num_sequences = input_lod[level].size() - 1; const size_t num_sequences = input_lod[level].size() - 1;
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
// merge repeated tokens and delete blank // merge repeated tokens and delete blank
T* output_data = output->mutable_data<T>(ctx.GetPlace());
size_t output_idx = 0; size_t output_idx = 0;
std::vector<size_t> output_lod0(1, 0); std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>();
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
T prev_token = -1; T prev_token = -1;
for (size_t i = input_lod[level][seq_idx]; for (size_t i = input_lod[level][seq_idx];
...@@ -77,6 +99,7 @@ class CTCAlignKernel : public framework::OpKernel<T> { ...@@ -77,6 +99,7 @@ class CTCAlignKernel : public framework::OpKernel<T> {
output_data[0] = -1; output_data[0] = -1;
} }
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -21,7 +21,8 @@ from op_test import OpTest ...@@ -21,7 +21,8 @@ from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
def CTCAlign(input, lod, blank, merge_repeated): def CTCAlign(input, lod, blank, merge_repeated, padding=0):
if lod is not None and len(lod) > 0:
lod0 = lod[0] lod0 = lod[0]
result = [] result = []
cur_offset = 0 cur_offset = 0
...@@ -37,6 +38,22 @@ def CTCAlign(input, lod, blank, merge_repeated): ...@@ -37,6 +38,22 @@ def CTCAlign(input, lod, blank, merge_repeated):
result = np.array(result).reshape([len(result), 1]).astype("int32") result = np.array(result).reshape([len(result), 1]).astype("int32")
if len(result) == 0: if len(result) == 0:
result = np.array([-1]) result = np.array([-1])
else:
result = [[] for i in range(len(input))]
for i in range(len(input)):
prev_token = -1
for j in range(len(input[i])):
token = input[i][j]
if (token != blank) and not (merge_repeated and
token == prev_token):
result[i].append(token)
prev_token = token
start = len(result[i])
for j in range(start, len(input[i])):
result[i].append(padding)
result = np.array(result).reshape(
[len(input), len(input[0])]).astype("int32")
return result return result
...@@ -87,5 +104,73 @@ class TestCTCAlignOpCase2(TestCTCAlignOp): ...@@ -87,5 +104,73 @@ class TestCTCAlignOpCase2(TestCTCAlignOp):
self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32") self.input = np.array([0, 0, 0, 0]).reshape([4, 1]).astype("int32")
class TestCTCAlignPaddingOp(OpTest):
def config(self):
self.op_type = "ctc_align"
self.input_lod = []
self.blank = 0
self.padding_num = 0
self.merge_repeated = True
self.input = np.array([[0, 2, 4, 4, 0, 6, 3, 6, 6, 0, 0],
[1, 1, 3, 0, 0, 4, 5, 6, 0, 0, 0]]).reshape(
[2, 11]).astype("int32")
def setUp(self):
self.config()
output = CTCAlign(self.input, self.input_lod, self.blank,
self.merge_repeated, self.padding_num)
self.inputs = {"Input": (self.input, self.input_lod), }
self.outputs = {"Output": output}
self.attrs = {
"blank": self.blank,
"merge_repeated": self.merge_repeated,
"padding_num": self.padding_num
}
def test_check_output(self):
self.check_output()
pass
class TestCTCAlignOpCase3(TestCTCAlignPaddingOp):
def config(self):
self.op_type = "ctc_align"
self.blank = 0
self.input_lod = []
self.merge_repeated = True
self.padding_num = 0
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
class TestCTCAlignOpCase4(TestCTCAlignPaddingOp):
'''
# test tensor input which has attr input padding_num
'''
def config(self):
self.op_type = "ctc_align"
self.blank = 0
self.input_lod = []
self.merge_repeated = False
self.padding_num = 0
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 7, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
class TestCTCAlignOpCase5(TestCTCAlignPaddingOp):
def config(self):
self.op_type = "ctc_align"
self.blank = 0
self.input_lod = []
self.merge_repeated = False
self.padding_num = 1
self.input = np.array([[0, 1, 2, 2, 0, 4], [0, 4, 5, 0, 6, 0],
[0, 7, 1, 7, 0, 0]]).reshape(
[3, 6]).astype("int32")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册