提交 579f6846 编写于 作者: W wanghaoshuang

Add ctc_greedy_decode_op

上级 8d253e49
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/ctc_greedy_decode_op.h"
namespace paddle {
namespace operators {
class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input of CTCGreedyDecodeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output of CTCGreedyDecodeOp should not be null.");
auto input_dims = ctx->GetInputDim("Input");
int sequence_width =
static_cast<int>(framework::product(input_dims) / input_dims[0]);
int blank = ctx->Attrs().Get<int>("blank");
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).",
sequence_width);
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Output", {input_dims[0], 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
ctx.device_context());
}
};
class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
"Tensor with LoD information. It's shape is "
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label).");
AddOutput("Output", "(Tensor, default: Tensor<int>), the decode result ");
AddAttr<int>("blank",
"(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op, and it is in the "
"half-opened interval [0, num_classes + 1).")
.SetDefault(0);
AddAttr<bool>("merge_repeated",
"(bool, default: true), whether to "
"merge repeated elements between two blanks. ")
.SetDefault(true);
AddComment(R"DOC(
CTCGreedyDecoder is an implementation of the simple best path decoding
algorithm, selecting at each timestep the most likely class at each timestep.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
ops::CTCGreedyDecodeOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(
ctc_greedy_decode,
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "paddle/operators/ctc_greedy_decode_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
__device__ static float atomicMaxF(float* address, float val) {
int* address_as_i = (int*)address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
template <typename T, int BlockSize>
__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits,
int* output) {
T local_max_value = 0;
int local_max_index = 0;
__shared__ T max_value;
if (threadIdx.x == 0) {
max_value = 0;
}
__syncthreads();
for (int i = threadIdx.x; i < seq_width; i += BlockSize) {
T value = logits[blockIdx.x * seq_width + i];
if (value > local_max_value) {
local_max_value = value;
local_max_index = i;
}
}
atomicMaxF(&max_value, local_max_value);
__syncthreads();
if (local_max_value == max_value) {
output[blockIdx.x] = local_max_index;
}
}
template <typename T>
__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens,
const size_t num_seq, size_t* lod0,
const int blank, const int merge_repeated,
size_t* out_lod0, int* output) {
int ouput_idx = 0;
out_lod0[0] = 0;
for (int i = 0; i < num_seq; ++i) {
int pre_token = -1;
for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
output[ouput_idx] = tokens[j];
++ouput_idx;
}
pre_token = tokens[j];
}
out_lod0[i + 1] = ouput_idx;
}
}
template <typename T>
class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
const int64_t num_tokens = input->dims()[0];
const size_t seq_width = input->numel() / num_tokens;
const T* logits = input->data<T>();
Tensor tmp;
int* tokens = tmp.mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
// get argmax
// platform::GpuMemsetAsync(args, 0, sizeof(float), stream);
auto stream = ctx.cuda_device_context().stream();
ArgmaxCudaKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits,
tokens);
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod());
const size_t num_seq = input_lod[level].size() - 1;
const int blank = ctx.Attr<int>("blank");
const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated"));
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
int* output_data =
output->mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].data(), blank,
merge_repeated, dev_out_lod0_ptr, output_data);
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
dev_out_lod0.end());
framework::LoD out_lod;
out_lod.push_back(host_out_lod0);
output->set_lod(out_lod);
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
paddle::operators::CTCGreedyDecodeOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string.h>
#include "paddle/framework/op_registry.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output");
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod());
auto input_dims = input->dims();
PADDLE_ENFORCE_EQ(input_dims[0],
static_cast<int64_t>(input_lod[level].back()),
"The first dimension of Input(Input) should be equal to "
"the sum of all sequences' lengths.");
const size_t num_sequences = input_lod[level].size() - 1;
const size_t sequence_width = input->numel() / input_dims[0];
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated");
std::vector<std::vector<int>> pathes(num_sequences);
std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>();
Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
input_mat(const_cast<T*>(input_data), input->numel() / sequence_width,
sequence_width);
size_t max_class_idx;
size_t prev_class_idx = -1;
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
for (size_t i = input_lod[level][seq_idx];
i < input_lod[level][seq_idx + 1]; ++i) {
input_mat.row(i).maxCoeff(&max_class_idx);
if (max_class_idx != blank &&
!(merge_repeated && max_class_idx == prev_class_idx)) {
pathes[seq_idx].push_back(max_class_idx);
}
prev_class_idx = max_class_idx;
}
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
}
framework::LoD output_lod;
output_lod.push_back(output_lod0);
output->set_lod(output_lod);
int64_t num_step = static_cast<int64_t>(output_lod0.back());
int* output_data = output->mutable_data<int>({num_step, 1}, ctx.GetPlace());
for (int i = 0; i < num_sequences; ++i) {
memcpy(output_data + output_lod0[i], pathes[i].data(),
sizeof(int) * pathes[i].size());
}
}
};
} // namespace operators
} // namespace paddle
import sys
import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
def CTCGreedyDecode(softmax, blank, merge_repeated):
prev_token = -1
result = []
for token in np.argmax(softmax, axis=1):
if (token != blank) and not (merge_repeated and token == prev_token):
result.append(token)
return np.array(result).reshape([len(result), 1])
class TestCTCGreedyDecodeOp(OpTest):
def config(self):
self.op_type = "ctc_greedy_decode"
self.batch_size = 4
self.num_classes = 8
self.input_lod = [[0, 4, 5, 8, 11]]
self.blank = 7
self.merge_repeated = True
def setUp(self):
self.config()
input = np.random.uniform(
0.1, 1.0,
[self.input_lod[0][-1], self.num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, input)
output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated)
self.inputs = {"Input": (softmax, self.input_lod), }
self.outputs = {"Output": output}
self.attrs = {
"blank": self.blank,
"merge_repeated": self.merge_repeated
}
def test_check_output(self):
self.check_output()
class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp):
def config(self):
self.op_type = "ctc_greedy_decode"
self.batch_size = 4
self.num_classes = 1025
self.input_lod = [[0, 4, 5, 8, 11]]
self.blank = 0
self.merge_repeated = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册