提交 281e93bc 编写于 作者: W wanghaoshuang

Remove 'top 1' from CPU and GPU kernel

1. Remove 'top 1'(or argmax) from CPU and GPU kernel
2. Add a new test case
3. Refine doc
上级 579f6846
...@@ -29,14 +29,8 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel { ...@@ -29,14 +29,8 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
auto input_dims = ctx->GetInputDim("Input"); auto input_dims = ctx->GetInputDim("Input");
int sequence_width =
static_cast<int>(framework::product(input_dims) / input_dims[0]);
int blank = ctx->Attrs().Get<int>("blank");
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).",
sequence_width);
// TODO(wanghaoshuang): it is tricky to set the wrong dimension here. // TODO(wanghaoshuang): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Output", {input_dims[0], 1}); ctx->SetOutputDim("Output", input_dims);
} }
protected: protected:
...@@ -53,25 +47,37 @@ class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -53,25 +47,37 @@ class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker) CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input", AddInput("Input",
"(LodTensor, default: LoDTensor<float>), the unscaled " "(LodTensor, default: LoDTensor<int>), Its shape is "
"probabilities of variable-length sequences, which is a 2-D " "[Lp, 1], where Lp is the sum of all input sequences' length.");
"Tensor with LoD information. It's shape is " AddOutput("Output", "(Tensor, default: Tensor<int>), The decode result.");
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label).");
AddOutput("Output", "(Tensor, default: Tensor<int>), the decode result ");
AddAttr<int>("blank", AddAttr<int>("blank",
"(int, default: 0), the blank label setted in Connectionist " "(int, default: 0), the blank label setted in Connectionist "
"Temporal Classification (CTC) op, and it is in the " "Temporal Classification (CTC) op.")
"half-opened interval [0, num_classes + 1).")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("merge_repeated", AddAttr<bool>("merge_repeated",
"(bool, default: true), whether to " "(bool, default: true), whether to "
"merge repeated elements between two blanks. ") "merge repeated elements between two blanks. ")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
CTCGreedyDecoder is an implementation of the simple best path decoding CTCDecoder is used to merge repeated elements between two blanks
algorithm, selecting at each timestep the most likely class at each timestep. and then delete all blanks in sequence.
Given:
Input.data = [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6,
6, 0, 0, 7, 7, 7, 0]
Input.dims = {18, 1}
Input.LoD = [[0, 11, 18]]
And:
blank = 0
merge_repeated = True
Then:
Output.data = [1, 2, 4, 4, 5, 6,
6, 7]
Output.dims = {8, 1}
Output.LoD = [[0, 6, 8]]
)DOC"); )DOC");
} }
}; };
...@@ -85,4 +91,4 @@ REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp, ...@@ -85,4 +91,4 @@ REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
ctc_greedy_decode, ctc_greedy_decode,
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, float>); ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, int>);
...@@ -16,62 +16,20 @@ limitations under the License. */ ...@@ -16,62 +16,20 @@ limitations under the License. */
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "paddle/operators/ctc_greedy_decode_op.h" #include "paddle/operators/ctc_greedy_decode_op.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/gpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
__device__ static float atomicMaxF(float* address, float val) {
int* address_as_i = (int*)address;
int old = *address_as_i, assumed;
do {
assumed = old;
old = ::atomicCAS(address_as_i, assumed,
__float_as_int(::fmaxf(val, __int_as_float(assumed))));
} while (assumed != old);
return __int_as_float(old);
}
template <typename T, int BlockSize>
__global__ void ArgmaxCudaKernel(const size_t seq_width, const T* logits,
int* output) {
T local_max_value = 0;
int local_max_index = 0;
__shared__ T max_value;
if (threadIdx.x == 0) {
max_value = 0;
}
__syncthreads();
for (int i = threadIdx.x; i < seq_width; i += BlockSize) {
T value = logits[blockIdx.x * seq_width + i];
if (value > local_max_value) {
local_max_value = value;
local_max_index = i;
}
}
atomicMaxF(&max_value, local_max_value);
__syncthreads();
if (local_max_value == max_value) {
output[blockIdx.x] = local_max_index;
}
}
template <typename T> template <typename T>
__global__ void MergeAndDelCudaKernel(const int64_t num_token, int* tokens, __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
const size_t num_seq, size_t* lod0, const size_t num_seq, size_t* lod0,
const int blank, const int merge_repeated, const int blank, const int merge_repeated,
size_t* out_lod0, int* output) { size_t* out_lod0, T* output) {
int ouput_idx = 0; int ouput_idx = 0;
out_lod0[0] = 0; out_lod0[0] = 0;
for (int i = 0; i < num_seq; ++i) { for (int i = 0; i < num_seq; ++i) {
int pre_token = -1; T pre_token = -1;
for (int j = lod0[i]; j < lod0[i + 1]; ++j) { for (int j = lod0[i]; j < lod0[i + 1]; ++j) {
if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) { if (tokens[j] != blank && !(merge_repeated && tokens[j] == pre_token)) {
output[ouput_idx] = tokens[j]; output[ouput_idx] = tokens[j];
...@@ -89,44 +47,39 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> { ...@@ -89,44 +47,39 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace."); "It must use CUDAPlace.");
const size_t level = 0;
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output"); auto* output = ctx.Output<LoDTensor>("Output");
auto input_lod = framework::ToAbsOffset(input->lod());
const T* tokens = input->data<T>();
const int64_t num_tokens = input->dims()[0]; const int64_t num_tokens = input->dims()[0];
const size_t seq_width = input->numel() / num_tokens;
const T* logits = input->data<T>();
Tensor tmp;
int* tokens = tmp.mutable_data<int>({num_tokens, 1}, ctx.GetPlace());
// get argmax
// platform::GpuMemsetAsync(args, 0, sizeof(float), stream);
auto stream = ctx.cuda_device_context().stream();
ArgmaxCudaKernel<T, PADDLE_CUDA_NUM_THREADS><<<
num_tokens, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(seq_width, logits,
tokens);
const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod());
const size_t num_seq = input_lod[level].size() - 1; const size_t num_seq = input_lod[level].size() - 1;
const int blank = ctx.Attr<int>("blank"); const int blank = ctx.Attr<int>("blank");
const int merge_repeated = const int merge_repeated =
static_cast<int>(ctx.Attr<bool>("merge_repeated")); static_cast<int>(ctx.Attr<bool>("merge_repeated"));
// prepare a lod to record lod information while merging elements
thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size()); thrust::device_vector<size_t> dev_out_lod0(input_lod[level].size());
size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data()); size_t* dev_out_lod0_ptr = thrust::raw_pointer_cast(dev_out_lod0.data());
int* output_data = // merge elements and delete blank
output->mutable_data<int>({num_tokens, 1}, ctx.GetPlace()); T* output_data = output->mutable_data<T>({num_tokens, 1}, ctx.GetPlace());
auto stream = ctx.cuda_device_context().stream();
MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>( MergeAndDelCudaKernel<T><<<1, 1, 0, stream>>>(
num_tokens, tokens, num_seq, input_lod[level].data(), blank, num_tokens, tokens, num_seq, input_lod[level].data(), blank,
merge_repeated, dev_out_lod0_ptr, output_data); merge_repeated, dev_out_lod0_ptr, output_data);
// set output lod
thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(), thrust::host_vector<size_t> host_out_lod0(dev_out_lod0.begin(),
dev_out_lod0.end()); dev_out_lod0.end());
framework::LoD out_lod; framework::LoD out_lod;
out_lod.push_back(host_out_lod0); out_lod.push_back(host_out_lod0);
output->set_lod(out_lod); output->set_lod(out_lod);
// resize output dims
output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1}); output->Resize({static_cast<int64_t>(host_out_lod0.back()), 1});
} }
}; };
...@@ -135,4 +88,4 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> { ...@@ -135,4 +88,4 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode, REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
paddle::operators::CTCGreedyDecodeOpCUDAKernel<float>); paddle::operators::CTCGreedyDecodeOpCUDAKernel<int>);
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <string.h> #include <string.h>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,8 +29,9 @@ class CTCGreedyDecodeKernel : public framework::OpKernel<T> { ...@@ -30,8 +29,9 @@ class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* output = ctx.Output<LoDTensor>("Output"); auto* output = ctx.Output<LoDTensor>("Output");
const size_t level = 0; const size_t level = 0;
auto input_lod = framework::ToAbsOffset(input->lod()); auto input_lod = framework::ToAbsOffset(input->lod());
// check input dims and lod
auto input_dims = input->dims(); auto input_dims = input->dims();
PADDLE_ENFORCE_EQ(input_dims[0], PADDLE_ENFORCE_EQ(input_dims[0],
static_cast<int64_t>(input_lod[level].back()), static_cast<int64_t>(input_lod[level].back()),
...@@ -39,38 +39,36 @@ class CTCGreedyDecodeKernel : public framework::OpKernel<T> { ...@@ -39,38 +39,36 @@ class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
"the sum of all sequences' lengths."); "the sum of all sequences' lengths.");
const size_t num_sequences = input_lod[level].size() - 1; const size_t num_sequences = input_lod[level].size() - 1;
const size_t sequence_width = input->numel() / input_dims[0];
size_t blank = static_cast<size_t>(ctx.Attr<int>("blank")); size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
bool merge_repeated = ctx.Attr<bool>("merge_repeated"); bool merge_repeated = ctx.Attr<bool>("merge_repeated");
// merge repeated tokens and delete blank
std::vector<std::vector<int>> pathes(num_sequences); std::vector<std::vector<int>> pathes(num_sequences);
std::vector<size_t> output_lod0(1, 0); std::vector<size_t> output_lod0(1, 0);
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
input_mat(const_cast<T*>(input_data), input->numel() / sequence_width,
sequence_width);
size_t max_class_idx;
size_t prev_class_idx = -1;
for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) { for (size_t seq_idx = 0; seq_idx < num_sequences; ++seq_idx) {
T prev_token = -1;
for (size_t i = input_lod[level][seq_idx]; for (size_t i = input_lod[level][seq_idx];
i < input_lod[level][seq_idx + 1]; ++i) { i < input_lod[level][seq_idx + 1]; ++i) {
input_mat.row(i).maxCoeff(&max_class_idx); if (input_data[i] != blank &&
if (max_class_idx != blank && !(merge_repeated && input_data[i] == prev_token)) {
!(merge_repeated && max_class_idx == prev_class_idx)) { pathes[seq_idx].push_back(input_data[i]);
pathes[seq_idx].push_back(max_class_idx);
} }
prev_class_idx = max_class_idx; prev_token = input_data[i];
} }
output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size()); output_lod0.push_back(output_lod0.back() + pathes[seq_idx].size());
} }
// set output lod
framework::LoD output_lod; framework::LoD output_lod;
output_lod.push_back(output_lod0); output_lod.push_back(output_lod0);
output->set_lod(output_lod); output->set_lod(output_lod);
int64_t num_step = static_cast<int64_t>(output_lod0.back());
int* output_data = output->mutable_data<int>({num_step, 1}, ctx.GetPlace());
// resize output dims
T* output_data = output->mutable_data<T>(
{static_cast<int64_t>(output_lod0.back()), 1}, ctx.GetPlace());
// copy result to output
for (int i = 0; i < num_sequences; ++i) { for (int i = 0; i < num_sequences; ++i) {
memcpy(output_data + output_lod0[i], pathes[i].data(), memcpy(output_data + output_lod0[i], pathes[i].data(),
sizeof(int) * pathes[i].size()); sizeof(int) * pathes[i].size());
......
...@@ -5,33 +5,37 @@ from op_test import OpTest ...@@ -5,33 +5,37 @@ from op_test import OpTest
from test_softmax_op import stable_softmax from test_softmax_op import stable_softmax
def CTCGreedyDecode(softmax, blank, merge_repeated): def CTCDecode(input, lod, blank, merge_repeated):
prev_token = -1 lod0 = lod[0]
result = [] result = []
for token in np.argmax(softmax, axis=1): for i in range(len(lod0) - 1):
if (token != blank) and not (merge_repeated and token == prev_token): prev_token = -1
result.append(token) for j in range(lod0[i], lod0[i + 1]):
return np.array(result).reshape([len(result), 1]) token = input[j][0]
if (token != blank) and not (merge_repeated and
token == prev_token):
class TestCTCGreedyDecodeOp(OpTest): result.append(token)
prev_token = token
result = np.array(result).reshape([len(result), 1]).astype("int32")
return result
class TestCTCDecodeOp(OpTest):
def config(self): def config(self):
self.op_type = "ctc_greedy_decode" self.op_type = "ctc_greedy_decode"
self.batch_size = 4 self.input_lod = [[0, 11, 18]]
self.num_classes = 8 self.blank = 0
self.input_lod = [[0, 4, 5, 8, 11]] self.merge_repeated = False
self.blank = 7 self.input = np.array(
self.merge_repeated = True [0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
[18, 1]).astype("int32")
def setUp(self): def setUp(self):
self.config() self.config()
input = np.random.uniform( output = CTCDecode(self.input, self.input_lod, self.blank,
0.1, 1.0, self.merge_repeated)
[self.input_lod[0][-1], self.num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, input)
output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated)
self.inputs = {"Input": (softmax, self.input_lod), } self.inputs = {"Input": (self.input, self.input_lod), }
self.outputs = {"Output": output} self.outputs = {"Output": output}
self.attrs = { self.attrs = {
"blank": self.blank, "blank": self.blank,
...@@ -40,16 +44,18 @@ class TestCTCGreedyDecodeOp(OpTest): ...@@ -40,16 +44,18 @@ class TestCTCGreedyDecodeOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
pass
class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp): class TestCTCDecodeOpCase1(TestCTCDecodeOp):
def config(self): def config(self):
self.op_type = "ctc_greedy_decode" self.op_type = "ctc_greedy_decode"
self.batch_size = 4 self.input_lod = [[0, 11, 18]]
self.num_classes = 1025
self.input_lod = [[0, 4, 5, 8, 11]]
self.blank = 0 self.blank = 0
self.merge_repeated = True self.merge_repeated = True
self.input = np.array(
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
[18, 1]).astype("int32")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册