未验证 提交 b5fda272 编写于 作者: Y Yiqun Liu 提交者: GitHub

Port WarpCTC Operator (#5107)

* Add Seq2BatchFunctor, which will be used in WarpCTCOp.

* Implement WrapCTCFunctor and WrapCTCKernel.

* Add unittest of warpctc_op.

* Modify the check_output inferface in python unittest framework to allow check a subset of outputs.

* Use absolute offset lod in warpctc_op and related functors.

* Refine the comments of warpctc_op.

* The new python unittest supports checking a subset of the outputs, so revoke the previous change.

* Rename the transform from LoDTensor to Tensor with shape [max_sequence_length, num_sequences, sequence_width] to PaddingSequenceFunctor.

* Update to the newest codes.

* Rename the PaddingSequenceFunctor to PaddingLoDTensorFunctor and remove the computation of dimensions out of the functos.
上级 fe341bac
...@@ -63,7 +63,7 @@ ExternalProject_Add( ...@@ -63,7 +63,7 @@ ExternalProject_Add(
MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}") MESSAGE(STATUS "warp-ctc library: ${WARPCTC_LIBRARIES}")
INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR}) INCLUDE_DIRECTORIES(${WARPCTC_INCLUDE_DIR})
ADD_LIBRARY(warpctc STATIC IMPORTED GLOBAL) ADD_LIBRARY(warpctc SHARED IMPORTED GLOBAL)
SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES}) SET_PROPERTY(TARGET warpctc PROPERTY IMPORTED_LOCATION ${WARPCTC_LIBRARIES})
ADD_DEPENDENCIES(warpctc extern_warpctc) ADD_DEPENDENCIES(warpctc extern_warpctc)
......
...@@ -151,6 +151,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) ...@@ -151,6 +151,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col) op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor) op_library(recurrent_op DEPS executor)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding math_function)
op_library(cos_sim_op DEPS cos_sim_functor) op_library(cos_sim_op DEPS cos_sim_functor)
op_library(parallel_do_op DEPS executor) op_library(parallel_do_op DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
......
...@@ -230,7 +230,6 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const { ...@@ -230,7 +230,6 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad, REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
namespace ops = paddle::operators;
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad, REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad); ops::ConvOpGrad);
......
...@@ -12,6 +12,7 @@ if(WITH_GPU) ...@@ -12,6 +12,7 @@ if(WITH_GPU)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor)
nv_library(sequence_padding SRCS sequence_padding.cc sequence_padding.cu DEPS lod_tensor device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context) nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
...@@ -27,6 +28,7 @@ else() ...@@ -27,6 +28,7 @@ else()
cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor) cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor)
cc_library(context_project SRCS context_project.cc DEPS device_context math_function) cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor)
cc_library(sequence_padding SRCS sequence_padding.cc DEPS lod_tensor device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context) cc_library(unpooling SRCS unpooling.cc DEPS device_context)
...@@ -38,3 +40,4 @@ cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) ...@@ -38,3 +40,4 @@ cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
...@@ -14,7 +14,6 @@ limitations under the License. */ ...@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream>
template <typename DeviceContext, typename Place> template <typename DeviceContext, typename Place>
void testIm2col() { void testIm2col() {
...@@ -102,6 +101,7 @@ void testIm2col() { ...@@ -102,6 +101,7 @@ void testIm2col() {
Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp); Copy(output_ocf, paddle::platform::CPUPlace(), *context, &output_tmp);
out_ocf_ptr = output_tmp.data<float>(); out_ocf_ptr = output_tmp.data<float>();
} }
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]); EXPECT_EQ(out_ocf_ptr[i], out_ocf_data[i]);
} }
...@@ -154,6 +154,9 @@ void testIm2col() { ...@@ -154,6 +154,9 @@ void testIm2col() {
for (int i = 0; i < 6; ++i) { for (int i = 0; i < 6; ++i) {
EXPECT_EQ(in_ptr[i], col2im_data[i]); EXPECT_EQ(in_ptr[i], col2im_data[i]);
} }
delete place;
delete context;
} }
TEST(math, im2col) { TEST(math, im2col) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class PaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
for (size_t i = 0; i < max_sequence_length; ++i) {
for (size_t j = 0; j < num_sequences; ++j) {
size_t start_pos = abs_offset_lod[level][j];
size_t sequence_length = abs_offset_lod[level][j + 1] - start_pos;
if (i < sequence_length) {
// i > 0 => sequence_length > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
padding_data[(i * num_sequences + j) * sequence_width + k] =
seq_data[(start_pos + i) * sequence_width + k] * scale;
}
} else {
memset(padding_data + (i * num_sequences + j) * sequence_width, 0,
sequence_width * sizeof(T));
}
}
}
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The LoD of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
const size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
for (size_t i = 0; i < num_sequences; ++i) {
size_t start_pos = abs_offset_lod[level][i];
size_t sequence_length = abs_offset_lod[level][i + 1] - start_pos;
for (size_t j = 0; j < sequence_length; ++j) {
// sequence_width > j > 0
T scale =
norm_by_times ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
for (size_t k = 0; k < sequence_width; ++k) {
seq_data[(start_pos + j) * sequence_width + k] =
padding_data[(j * num_sequences + i) * sequence_width + k] *
scale;
}
}
}
}
};
template class PaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CPUDeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool NormByTimes, bool Padding>
__global__ void SequencePaddingKernel(T* padding, T* sequence,
const size_t* sequence_start_positions,
const size_t sequence_width,
const size_t max_sequence_length,
const size_t num_sequences) {
size_t padding_idx = blockIdx.y;
size_t start_pos = sequence_start_positions[padding_idx];
size_t sequence_length =
sequence_start_positions[padding_idx + 1] - start_pos;
size_t sequence_idx = blockIdx.x * blockDim.y + threadIdx.y;
size_t padding_base_idx =
(sequence_idx * num_sequences + padding_idx) * sequence_width;
size_t sequence_base_idx = (start_pos + sequence_idx) * sequence_width;
if (sequence_idx < sequence_length) {
T scale = NormByTimes ? (1.0f / static_cast<T>(sequence_length)) : 1.0f;
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = scale * sequence[sequence_base_idx + i];
}
} else {
/* padding -> sequence */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
sequence[sequence_base_idx + i] = scale * padding[padding_base_idx + i];
}
}
} else if (sequence_idx < max_sequence_length) {
if (Padding) {
/* sequence -> padding */
for (size_t i = threadIdx.x; i < sequence_width; i += blockDim.x) {
padding[padding_base_idx + i] = 0;
}
}
}
}
template <typename T>
class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor& seq, framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequence_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be the "
"maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be the "
"number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
Copy(seq, context.GetPlace(), context, &padding);
padding.Resize(padding_dims);
return;
}
const size_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
dim3 grid(grid_dim_x, grid_dim_y);
const T* seq_data = seq.data<T>();
T* padding_data = padding.data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 1><<<grid, threads, 0, context.stream()>>>(
padding_data, const_cast<T*>(seq_data), abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
}
}
};
template <typename T>
class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
framework::LoDTensor& seq, const framework::Tensor& padding,
bool norm_by_times) {
auto lod = seq.lod();
PADDLE_ENFORCE_GT(lod.size(), 0UL,
"The lod of LoDTensor seq should not be null.");
const size_t level = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
auto seq_dims = seq.dims();
PADDLE_ENFORCE_EQ(seq_dims[0], abs_offset_lod[level].back(),
"The first dimension of LoDTensor seq should be "
"equal to the sum of all sequences's length.");
auto padding_dims = padding.dims();
PADDLE_ENFORCE_EQ(padding_dims.size(), 3UL,
"The input padding should be a 3-D Tensor of shape "
"[max_sequnece_length, num_sequences, sequence_width].");
size_t max_sequence_length = MaximumSequenceLength(lod, level);
PADDLE_ENFORCE_EQ(padding_dims[0], max_sequence_length,
"The first dimension of Tensor padding should be "
"the maximum length of all sequences in LoDTensor seq.");
const size_t num_sequences = abs_offset_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(padding_dims[1], num_sequences,
"The second dimension of Tensor padding should be "
"the number of sequences in LoDTensor seq.");
const size_t sequence_width = seq.numel() / seq_dims[0];
PADDLE_ENFORCE_EQ(padding_dims[2], sequence_width,
"The third dimension of Tensor padding should be the "
"width of sequence in LoDTensor seq.");
if (!norm_by_times && num_sequences == 1UL) {
Copy(padding, context.GetPlace(), context, &seq);
seq.Resize(seq_dims);
return;
}
const size_t kBlockSize = 512;
/* At least use 32 threads to copy sequence_width elements,
* and at least 8 elements for each thread.
*/
size_t block_dim_x =
std::min(((((sequence_width + 7) >> 3) + 31) >> 5) << 5, kBlockSize);
size_t block_dim_y = kBlockSize / block_dim_x;
dim3 threads(block_dim_x, block_dim_y);
size_t grid_dim_x = (max_sequence_length + block_dim_y - 1) / block_dim_y;
size_t grid_dim_y = num_sequences;
dim3 grid(grid_dim_x, grid_dim_y);
const T* padding_data = padding.data<T>();
T* seq_data = seq.data<T>();
if (norm_by_times) {
SequencePaddingKernel<T, 1, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
} else {
SequencePaddingKernel<T, 0, 0><<<grid, threads, 0, context.stream()>>>(
const_cast<T*>(padding_data), seq_data, abs_offset_lod[level].data(),
sequence_width, max_sequence_length, num_sequences);
}
}
};
template class PaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, float>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
inline static size_t MaximumSequenceLength(const framework::LoD& lod,
const size_t level) {
const size_t num_sequences = lod[level].size() - 1;
size_t max_sequence_length = 0;
framework::LoD abs_offset_lod = framework::ToAbsOffset(lod);
for (size_t i = 0; i < num_sequences; ++i) {
max_sequence_length =
std::max(max_sequence_length,
abs_offset_lod[level][i + 1] - abs_offset_lod[level][i]);
}
return max_sequence_length;
}
/*
* \brief Padding/Unpadding LoDTensor to/from normal Tensor of the shape
* [max_sequence_length, num_sequences, sequence_width].
*
* Padding sequence:
* padding[i] = seq[lod[level][i]]
* Unpadding sequence:
* seq[lod[level][i]] = padding[i]
*
* All sequences will be padded to the same length and stored in a transposed
* shape.
* Example:
* seq (s0, s0, s0, s0; s1, s1; s2, s2, s2; s3)
* padding (s0, s1, s2, s3; s0, s1, s2, 0; s0, 0, s2, 0; s0, 0, 0, 0)
*
* \param context device context of this functor.
* \param seq LoDTensor which is stored in sequence format, the shape
* is [total_sequence_length, sequence_width] where
* total_sequence_length is the sum of all sequences'
* length.
* \param padding Tensor which is padded to the same length, the shape is
* [max_sequence_length, num_sequences, sequence_width].
* \param norm_by_times whether dividing sequence's length.
*
* \note transposition is also done in this functor.
*/
template <typename DeviceContext, typename T>
class PaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, const framework::LoDTensor& seq,
framework::Tensor& padding, bool norm_by_times);
};
template <typename DeviceContext, typename T>
class UnpaddingLoDTensorFunctor {
public:
void operator()(const DeviceContext& context, framework::LoDTensor& seq,
const framework::Tensor& padding, bool norm_by_times);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_padding.h"
#include <gtest/gtest.h>
template <typename DeviceContext, typename Place, typename T>
void TestSequencePadding(const paddle::framework::LoD& lod,
const size_t sequence_width) {
paddle::framework::LoDTensor cpu_seq;
paddle::framework::LoDTensor cpu_seq_back;
paddle::framework::LoDTensor seq;
paddle::framework::LoDTensor seq_back;
paddle::framework::Tensor padding;
const size_t level = lod.size() - 1;
auto seq_dims =
paddle::framework::make_ddim({static_cast<int64_t>(lod[level].back()),
static_cast<int64_t>(sequence_width)});
cpu_seq.set_lod(lod);
cpu_seq.mutable_data<T>(seq_dims, paddle::platform::CPUPlace());
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
cpu_seq.data<T>()[i] = static_cast<T>(i);
}
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
seq = cpu_seq;
} else {
Copy(cpu_seq, *place, *context, &seq);
seq.set_lod(lod);
}
const size_t max_sequence_length =
paddle::operators::math::MaximumSequenceLength(lod, level);
const size_t num_sequences = lod[level].size() - 1;
auto padding_dims =
paddle::framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
padding.mutable_data<T>(padding_dims, *place);
paddle::operators::math::PaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq, padding, false);
seq_back.set_lod(lod);
seq_back.mutable_data<T>(seq_dims, *place);
paddle::operators::math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
*context, seq_back, padding, false);
if (paddle::platform::is_cpu_place(*place)) {
cpu_seq_back = seq_back;
} else {
Copy(seq_back, paddle::platform::CPUPlace(), *context, &cpu_seq_back);
cpu_seq_back.set_lod(lod);
}
EXPECT_EQ(cpu_seq.numel(), cpu_seq_back.numel());
EXPECT_EQ(cpu_seq.dims(), cpu_seq_back.dims());
for (size_t i = 0; i < cpu_seq.numel(); ++i) {
EXPECT_EQ(cpu_seq.data<T>()[i], cpu_seq_back.data<T>()[i]);
}
delete place;
delete context;
};
TEST(Seq2BatchPadding, CPU) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePadding<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod1, 16);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePadding<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2, 128);
}
#ifdef PADDLE_WITH_CUDA
TEST(SequencePadding, CUDA) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePadding<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1, 16);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePadding<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2, 128);
}
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/warpctc_op.h"
namespace paddle {
namespace operators {
class WarpCTCOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("WarpCTCGrad"),
"Output(WarpCTCGrad) of WarpCTCOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of WarpCTCOp should not be null.");
auto logits_dims = ctx->GetInputDim("Logits");
int sequence_width =
static_cast<int>(framework::product(logits_dims) / logits_dims[0]);
int blank = ctx->Attrs().Get<int>("blank");
PADDLE_ENFORCE((blank >= 0) && (blank < sequence_width),
"The value of Attr(blank) should be in interval [0, %d).",
sequence_width);
// TODO(liuyiqun): it is tricky to set the wrong dimension here.
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WarpCTCOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Logits",
"(LodTensor, default: LoDTensor<float>), the unscaled "
"probabilities of variable-length sequences, which is a 2-D "
"Tensor with LoD information. It's shape is "
"[Lp, num_classes + 1], where Lp is the sum of all input "
"sequences' length and num_classes is the true number of classes "
"(not including the blank label).");
AddInput("Label",
"(LodTensor, default: LoDTensor<int>), the ground truth "
"of variable-length sequence, which is a 2-D Tensor with LoD "
"information. It is of the shape [Lg, 1], where Lg is th sum of "
"all labels' length.");
AddOutput("WarpCTCGrad",
"(Tensor, default: Tensor<float>), a temporary "
"output Tensor to store the gradients of warp-ctc, which is "
"computed with loss together in one call. It is a 3-D Tensor of "
"the shape [max_sequence_length, batch_size, num_classes + 1].")
.AsIntermediate();
AddOutput("Loss",
"(Tensor, default: Tensor<float>), the Connectionist "
"Temporal Classification (CTC) loss, which is a 2-D Tensor of "
"the shape [batch_size, 1]");
AddAttr<int>("blank",
"(int, default: 0), the blank label of Connectionist "
"Temporal Classification (CTC) loss, which is in the "
"half-opened interval [0, num_classes + 1).")
.SetDefault(0);
AddAttr<bool>("norm_by_times",
"(bool, default: false), whether to "
"normalize the gradients by the number of time-step, "
"which is also the sequence's length.")
.SetDefault(false);
AddComment(R"DOC(
An operator integrating the open-source
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in
[Deep Speech 2: End-toEnd Speech Recognition in English and Mandarin](
https://arxiv.org/pdf/1512.02595v1.pdf),
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with ctc, since a native softmax activation is
interated to the warp-ctc library, to to normlize values for each row of the
input tensor.
More detail of CTC loss can be found by refering to
[Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with
Recurrent Neural Networks](
http://machinelearning.wustl.edu/mlpapers/paper_files/icml2006_GravesFGS06.pdf).
)DOC");
}
};
class WarpCTCGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("WarpCTCGrad"),
"Input(WarpCTCGrad) of WarpCTCGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Logits")),
"Output(Logits@GRAD) of WarpCTCGradOp should not be null.");
ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Logits"));
ctx->ShareLoD("Logits", /*->*/ framework::GradVarName("Logits"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, warpctc_grad,
ops::WarpCTCGradOp);
REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/warpctc_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_padding.h"
#include "paddle/platform/dynload/warpctc.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext>
class WarpCTCFunctor {
public:
/*
* \brief Compute the connectionist temporal classification loss,
* and optionally compute the gradient with respect to the inputs.
*
* If gradient is nullptr, it only computes the ctc loss,
* or computes both ctc loss and gradient.
*
* \param ctx execution context of this functor
* \param input batch matrix of input probabilities, in
* max_sequence_length x num_sequences x
* sequence_width, (row-major) format
* \param gradient batch matrix of gradient, with the same shape as
* input.
* \param cpu_labels labels always in CPU memory.
* \param cpu_label_lengths length of all labels in CPU memory.
* \param cpu_input_lengths length of all sequences in CPU memory.
* \param sequence_width number of possible output symbols.
* \param num_sequences number of sequence.
* \param blank blank label used in ctc loss function.
* \param cpu_losss cost of each sequence in CPU memory.
*/
void operator()(const framework::ExecutionContext& ctx, const float* input,
float* gradient, const int* cpu_labels,
const int* cpu_label_lengths, const int* cpu_input_lengths,
const size_t sequence_width, const size_t num_sequences,
const size_t blank, float* cpu_loss) {
// Init warp-ctc options
init(ctx, blank);
// Compute the required workspace size.
// There is no memory allocated operations within warp-ctc.
size_t workspace_bytes = 0;
ctcStatus_t status = platform::dynload::get_workspace_size(
cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width),
static_cast<int>(num_sequences), options_, &workspace_bytes);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in get_workspace_size: ",
warpctc_version_,
platform::dynload::ctcGetStatusString(status));
PADDLE_ENFORCE_GT(workspace_bytes, 0UL,
"Bytes of workspace got by warp-ctc function, "
"get_workspace_size(), should be larger than 0.");
Tensor workspace;
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL;
float* workspace_data = workspace.mutable_data<float>(
framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
ctx.GetPlace());
math::SetConstant<DeviceContext, float>()(
ctx.template device_context<DeviceContext>(), &workspace,
static_cast<float>(0));
// compute loss and gradient
status = platform::dynload::compute_ctc_loss(
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
cpu_loss, workspace_data, options_);
PADDLE_ENFORCE_EQ(CTC_STATUS_SUCCESS, status,
"warp-ctc [version %d] Error in compute_ctc_loss: ",
warpctc_version_,
platform::dynload::ctcGetStatusString(status));
}
protected:
void init(const framework::ExecutionContext& ctx, const size_t blank) {
warpctc_version_ = platform::dynload::get_warpctc_version();
if (platform::is_gpu_place(ctx.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
options_.loc = CTC_GPU;
options_.stream = reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream();
#else
PADDLE_THROW("[warpctc init] GPU is not enabled.");
#endif
} else {
options_.loc = CTC_CPU;
options_.num_threads = 1;
}
options_.blank_label = blank;
}
private:
int warpctc_version_;
ctcOptions options_;
};
template <typename DeviceContext, typename T>
class WarpCTCKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* logits = ctx.Input<LoDTensor>("Logits");
auto* label = ctx.Input<LoDTensor>("Label");
auto* warpctc_grad = ctx.Output<Tensor>("WarpCTCGrad");
auto* loss = ctx.Output<Tensor>("Loss");
const size_t level = 0;
auto logits_lod = framework::ToAbsOffset(logits->lod());
auto logits_dims = logits->dims();
PADDLE_ENFORCE_EQ(logits_dims[0],
static_cast<int64_t>(logits_lod[level].back()),
"The first dimension of Input(Logits) should be equal to "
"the sum of all sequences' lengths.");
auto label_lod = framework::ToAbsOffset(label->lod());
auto label_dims = label->dims();
PADDLE_ENFORCE_EQ(
label_dims[0], label->numel(),
"The width of each timestep in Input(Label) should be 1.");
const size_t num_sequences = logits_lod[level].size() - 1;
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
"The number of sequences of Input(Logits) should be "
"equal to that of Input(Label).");
const size_t sequence_width = logits->numel() / logits_dims[0];
auto loss_dims =
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
// warpctc needs sequences data stored in transposed padding format
Tensor warpctc_logits;
const size_t max_sequence_length =
math::MaximumSequenceLength(logits_lod, level);
auto warpctc_logits_dims =
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
static_cast<int64_t>(num_sequences),
static_cast<int64_t>(sequence_width)});
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits, warpctc_logits,
false);
const T* warpctc_logits_data = warpctc_logits.data<T>();
std::vector<int> warpctc_label_lengths(num_sequences);
std::vector<int> warpctc_logits_lengths(num_sequences);
for (size_t i = 0; i < num_sequences; ++i) {
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
warpctc_logits_lengths[i] =
logits_lod[level][i + 1] - logits_lod[level][i];
}
// warpctc computes loss and gradient in one call, gradient data also stored
// in batch format
T* warpctc_grad_data =
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
// warpctc accesses labels in CPU memory
Tensor warpctc_label;
Copy(*label, platform::CPUPlace(), ctx.device_context(), &warpctc_label);
const int* warpctc_label_data = warpctc_label.data<int>();
// warpctc stores loss in CPU memory
Tensor warpctc_loss;
T* warpctc_loss_data =
warpctc_loss.mutable_data<T>(loss_dims, platform::CPUPlace());
const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
WarpCTCFunctor<DeviceContext>()(
ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
sequence_width, num_sequences, blank, warpctc_loss_data);
// Copy the loss back
Copy(warpctc_loss, ctx.GetPlace(), ctx.device_context(), loss);
}
};
template <typename DeviceContext, typename T>
class WarpCTCGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* warpctc_grad = ctx.Input<Tensor>("WarpCTCGrad");
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), *logits_grad,
*warpctc_grad, norm_by_times);
}
};
} // namespace operators
} // namespace paddle
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce)
nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc nv_library(dynload_cuda SRCS cublas.cc cudnn.cc curand.cc nccl.cc
DEPS dynamic_loader nccl) DEPS dynamic_loader nccl)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <paddle/platform/dynload/cublas.h> #include "paddle/platform/dynload/cublas.h"
namespace paddle { namespace paddle {
namespace platform { namespace platform {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/dynload/warpctc.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag warpctc_dso_flag;
void* warpctc_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
WARPCTC_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <dlfcn.h>
#include <mutex>
#include "ctc.h"
#include "paddle/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag warpctc_dso_flag;
extern void* warpctc_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load warpctc routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using warpctcFunc = decltype(__name(args...)) (*)(Args...); \
std::call_once(warpctc_dso_flag, \
paddle::platform::dynload::GetWarpCTCDsoHandle, \
&warpctc_dso_handle); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP(__name) \
DYNAMIC_LOAD_WARPCTC_WRAP(__name)
#define WARPCTC_ROUTINE_EACH(__macro) \
__macro(get_warpctc_version); \
__macro(ctcGetStatusString); \
__macro(compute_ctc_loss); \
__macro(get_workspace_size)
WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP);
#undef DYNAMIC_LOAD_WARPCTC_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
from test_softmax_op import stable_softmax
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
class TestSequenceSoftmaxOp(OpTest): class TestSequenceSoftmaxOp(OpTest):
......
import sys
import unittest
import numpy as np
from op_test import OpTest
from test_softmax_op import stable_softmax
class CTCForward(object):
def __init__(self, softmax, softmax_lod, labels, labels_lod, blank,
norm_by_times):
self.softmax = softmax
self.softmax_lod = softmax_lod
assert labels.shape[1] == 1
self.labels = labels
self.labels_lod = labels_lod
self.blank = blank
self.norm_by_times = norm_by_times
self.level = 0
self.num_classes = softmax.shape[1]
self.batch_size = len(softmax_lod[self.level]) - 1
assert self.batch_size == len(labels_lod[self.level]) - 1
self.loss = np.zeros([self.batch_size, 1], dtype="float32")
self.gradient = np.zeros(self.softmax.shape, dtype="float32")
# float64
self.EXP_MAX = sys.float_info.max
self.EXP_MIN = sys.float_info.min
self.LOG_ZERO = np.log(self.EXP_MIN)
self.LOG_INFINITY = np.log(self.EXP_MAX)
def safe_exp(self, x):
if x <= self.LOG_ZERO:
return 0.0
if x >= self.LOG_INFINITY:
return self.EXP_MAX
return np.exp(x)
def safe_log(self, x):
if x <= self.EXP_MIN:
return self.LOG_ZERO
return np.log(x)
# x = lna and y = lnb are in log scale, ln(a / b) = lna - lnb
def log_div(self, x, y):
res = x - y
if res <= self.LOG_ZERO:
return self.LOG_ZERO
if res >= self.LOG_INFINITY:
return self.LOG_INFINITY
return res
# x = lna and y = lnb are in log scale, ln(a * b) = lna + lnb
def log_mul(self, x, y):
res = x + y
if res <= self.LOG_ZERO:
return self.LOG_ZERO
if res >= self.LOG_INFINITY:
return self.LOG_INFINITY
return res
# x = lna and y = lnb are in log scale,
# ln(a + b) = lna + ln(1 + exp(lnb - lna)), where b > a
def log_add(self, x, y):
if x < y:
t = y
y = x
x = t
return x + self.safe_log(1 + self.safe_exp(y - x))
def segment_range(self, time, total_times, total_segments):
start = max(0, total_segments - (2 * (total_times - time)))
end = min(total_segments, 2 * (time + 1))
return start, end
def forward_a_sequence(self, softmax_a_sequence, labels_a_sequence):
total_times = softmax_a_sequence.shape[0]
total_segments = labels_a_sequence.shape[0] * 2 + 1
required_times = labels_a_sequence.shape[0]
old_label = -1
for i in range(labels_a_sequence.shape[0]):
# two contingous labels with the same value
if labels_a_sequence[i, 0] == old_label:
required_times = required_times + 1
old_label = labels_a_sequence[i, 0]
if total_times < required_times:
return 0
# calculate the forward and backward variables,
# reference Chapter 7.3 of "Alex Grave, Supervised Sequence
# Labelling with Recurrent Neural Networks"
log_acts = np.zeros([total_times, self.num_classes], dtype="float32")
for i in range(total_times):
for j in range(self.num_classes):
log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j])
# calculate the forward variables
forward_vars = np.zeros([total_times, total_segments], dtype="float32")
for i in range(total_times):
for j in range(total_segments):
forward_vars[i, j] = self.LOG_ZERO
for i in range(total_times):
# dp initialization at t0
if i == 0:
forward_vars[i, 0] = log_acts[0, self.blank]
if total_segments > 1:
forward_vars[i, 1] = log_acts[0, labels_a_sequence[i, 0]]
continue
# dp from t1
start, end = self.segment_range(i, total_times, total_segments)
for k in range(end - start):
j = k + start
if j & 1 == 1:
label_idx = j / 2
label_val = labels_a_sequence[label_idx, 0]
fv = self.log_add(forward_vars[i - 1, j],
forward_vars[i - 1, j - 1])
if j > 1 and label_val != labels_a_sequence[label_idx - 1,
0]:
fv = self.log_add(fv, forward_vars[i - 1, j - 2])
fv = self.log_mul(fv, log_acts[i, label_val])
else:
fv = forward_vars[i - 1, j]
if j > 0:
fv = self.log_add(fv, forward_vars[i - 1, j - 1])
fv = self.log_mul(fv, log_acts[i, self.blank])
forward_vars[i, j] = fv
# sum the last two value as log_prob
log_prob = forward_vars[total_times - 1, total_segments - 1]
if total_segments > 1:
log_prob = self.log_add(
log_prob, forward_vars[total_times - 1, total_segments - 2])
return -log_prob
def forward(self):
for i in range(self.batch_size):
softmax_start_i = self.softmax_lod[self.level][i]
softmax_end_i = self.softmax_lod[self.level][i + 1]
labels_start_i = self.labels_lod[self.level][i]
labels_end_i = self.labels_lod[self.level][i + 1]
softmax_a_sequence = self.softmax[softmax_start_i:softmax_end_i, :]
labels_a_sequence = self.labels[labels_start_i:labels_end_i, :]
self.loss[i] = self.forward_a_sequence(softmax_a_sequence,
labels_a_sequence)
return self.loss
class TestWarpCTCOp(OpTest):
def setUp(self):
self.op_type = "warpctc"
batch_size = 4
num_classes = 8
logits_lod = [[0, 4, 5, 8, 11]]
logits = np.random.uniform(0.1, 1.0,
[11, num_classes]).astype("float32")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
labels_lod = [[0, 3, 4, 8, 12]]
# labels should not be blank
labels = np.random.randint(0, num_classes - 1, [12, 1], dtype="int32")
blank = num_classes - 1
norm_by_times = False
ctc = CTCForward(softmax, logits_lod, labels, labels_lod, blank,
norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(batch_size):
max_sequence_length = max(max_sequence_length,
logits_lod[0][i + 1] - logits_lod[0][i])
gradient = np.zeros(
[max_sequence_length, batch_size, num_classes], dtype="float32")
self.inputs = {
"Logits": (logits, logits_lod),
"Label": (labels, labels_lod)
}
self.outputs = {"Loss": loss}
self.attrs = {"blank": blank, "norm_by_times": norm_by_times}
def test_check_output(self):
self.check_output()
# def test_check_grad(self):
# self.outputs["WarpCTCGrad"] = None
# self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册