未验证 提交 1501a80f 编写于 作者: L Li Fuchen 提交者: GitHub

add support to float64 input of warpctc op. (#27399)

* add float64 input to ctc_loss

* modified error message of  warpctc

* update repo and tag of warpctc

* add test for warpctc with float64 input

* modified warpctc.cmake to make sure build always

* resolved sample code bug of warpctc

* add core.ops in warpctc dygraph

* fix a bug of test
上级 3f170dd8
...@@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc) ...@@ -18,7 +18,7 @@ SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc) SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc) SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git) set(WARPCTC_REPOSITORY https://github.com/baidu-research/warp-ctc.git)
set(WARPCTC_TAG fc7f226b93758216a03b1be9d24593a12819b984) set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include" SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE) CACHE PATH "Warp-ctc Directory" FORCE)
...@@ -44,8 +44,9 @@ ExternalProject_Add( ...@@ -44,8 +44,9 @@ ExternalProject_Add(
"${WARPCTC_DOWNLOAD_CMD}" "${WARPCTC_DOWNLOAD_CMD}"
PREFIX ${WARPCTC_PREFIX_DIR} PREFIX ${WARPCTC_PREFIX_DIR}
SOURCE_DIR ${WARPCTC_SOURCE_DIR} SOURCE_DIR ${WARPCTC_SOURCE_DIR}
UPDATE_COMMAND "" #UPDATE_COMMAND ""
PATCH_COMMAND "" PATCH_COMMAND ""
BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
......
...@@ -46,6 +46,7 @@ class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> { ...@@ -46,6 +46,7 @@ class ScaleLoDTensorFunctor<platform::CPUDeviceContext, T> {
}; };
template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, float>; template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, float>;
template class ScaleLoDTensorFunctor<platform::CPUDeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -52,6 +52,7 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> { ...@@ -52,6 +52,7 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
}; };
template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, float>; template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, float>;
template class ScaleLoDTensorFunctor<platform::CUDADeviceContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -103,13 +103,13 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -103,13 +103,13 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
"Target sequence length for Label when Label is a 2-D tensor.") "Target sequence length for Label when Label is a 2-D tensor.")
.AsDispensable(); .AsDispensable();
AddOutput("WarpCTCGrad", AddOutput("WarpCTCGrad",
"(Tensor, default: Tensor<float>), a temporary " "(Tensor), a temporary "
"output Tensor to store the gradients of warp-ctc, which is " "output Tensor to store the gradients of warp-ctc, which is "
"computed with loss together in one call. It is a 3-D Tensor of " "computed with loss together in one call. It is a 3-D Tensor of "
"the shape [max_sequence_length, batch_size, num_classes + 1].") "the shape [max_sequence_length, batch_size, num_classes + 1].")
.AsIntermediate(); .AsIntermediate();
AddOutput("Loss", AddOutput("Loss",
"(Tensor, default: Tensor<float>), the Connectionist " "(Tensor), the Connectionist "
"Temporal Classification (CTC) loss, which is a 2-D Tensor of " "Temporal Classification (CTC) loss, which is a 2-D Tensor of "
"the shape [batch_size, 1]"); "the shape [batch_size, 1]");
AddAttr<int>("blank", AddAttr<int>("blank",
...@@ -197,7 +197,9 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker, ...@@ -197,7 +197,9 @@ REGISTER_OPERATOR(warpctc, ops::WarpCTCOp, ops::WarpCTCOpMaker,
REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp, REGISTER_OPERATOR(warpctc_grad, ops::WarpCTCGradOp,
ops::WarpCTCGradOpNoNeedBufferVarInferer); ops::WarpCTCGradOpNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>); warpctc, ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
warpctc_grad, warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>); ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -16,7 +16,9 @@ limitations under the License. */ ...@@ -16,7 +16,9 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>); warpctc, ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, float>,
ops::WarpCTCKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
warpctc_grad, warpctc_grad,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>); ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::WarpCTCGradKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -27,7 +27,52 @@ namespace operators { ...@@ -27,7 +27,52 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class ComputeCtcLossFunctor {
public:
ctcStatus_t operator()(const T* const activations, T* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, T* costs, void* workspace,
ctcOptions options) {
return CTC_STATUS_EXECUTION_FAILED;
}
};
template <typename DeviceContext>
class ComputeCtcLossFunctor<DeviceContext, float> {
public:
ctcStatus_t operator()(const float* const activations, float* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, float* costs, void* workspace,
ctcOptions options) {
return platform::dynload::compute_ctc_loss(
activations, gradients, flat_labels, label_lengths, input_lengths,
static_cast<int>(alphabet_size), static_cast<int>(minibatch), costs,
workspace, options);
}
};
template <typename DeviceContext> template <typename DeviceContext>
class ComputeCtcLossFunctor<DeviceContext, double> {
public:
ctcStatus_t operator()(const double* const activations, double* gradients,
const int* const flat_labels,
const int* const label_lengths,
const int* const input_lengths, int alphabet_size,
int minibatch, double* costs, void* workspace,
ctcOptions options) {
return platform::dynload::compute_ctc_loss_double(
activations, gradients, flat_labels, label_lengths, input_lengths,
static_cast<int>(alphabet_size), static_cast<int>(minibatch), costs,
workspace, options);
}
};
template <typename DeviceContext, typename T>
class WarpCTCFunctor { class WarpCTCFunctor {
public: public:
/* /*
...@@ -51,21 +96,29 @@ class WarpCTCFunctor { ...@@ -51,21 +96,29 @@ class WarpCTCFunctor {
* \param blank blank label used in ctc loss function. * \param blank blank label used in ctc loss function.
* \param cpu_losss cost of each sequence in CPU memory. * \param cpu_losss cost of each sequence in CPU memory.
*/ */
void operator()(const framework::ExecutionContext& ctx, const float* input, void operator()(const framework::ExecutionContext& ctx, const T* input,
float* gradient, const int* cpu_labels, T* gradient, const int* cpu_labels,
const int* cpu_label_lengths, const int* cpu_input_lengths, const int* cpu_label_lengths, const int* cpu_input_lengths,
const size_t sequence_width, const size_t num_sequences, const size_t sequence_width, const size_t num_sequences,
const size_t blank, float* cpu_loss) { const size_t blank, T* cpu_loss) {
// Init warp-ctc options // Init warp-ctc options
init(ctx, blank); init(ctx, blank);
// Compute the required workspace size. // Compute the required workspace size.
// There is no memory allocated operations within warp-ctc. // There is no memory allocated operations within warp-ctc.
size_t workspace_bytes = 0; size_t workspace_bytes = 0;
ctcStatus_t status = platform::dynload::get_workspace_size( ctcStatus_t status = CTC_STATUS_UNKNOWN_ERROR;
cpu_label_lengths, cpu_input_lengths, static_cast<int>(sequence_width), if (sizeof(T) == 4) {
static_cast<int>(num_sequences), options_, &workspace_bytes); status = platform::dynload::get_workspace_size(
cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
options_, &workspace_bytes);
} else {
status = platform::dynload::get_workspace_size_double(
cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences),
options_, &workspace_bytes);
}
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CTC_STATUS_SUCCESS, status, CTC_STATUS_SUCCESS, status,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
...@@ -79,17 +132,17 @@ class WarpCTCFunctor { ...@@ -79,17 +132,17 @@ class WarpCTCFunctor {
workspace_bytes)); workspace_bytes));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
size_t workspace_elements = workspace_bytes / sizeof(float) + 1UL; size_t workspace_elements = workspace_bytes / sizeof(T) + 1UL;
Tensor workspace = ctx.AllocateTmpTensor<float, DeviceContext>( Tensor workspace = ctx.AllocateTmpTensor<T, DeviceContext>(
framework::make_ddim({static_cast<int64_t>(workspace_elements)}), framework::make_ddim({static_cast<int64_t>(workspace_elements)}),
dev_ctx); dev_ctx);
float* workspace_data = workspace.data<float>(); T* workspace_data = workspace.data<T>();
math::SetConstant<DeviceContext, float>()( math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), &workspace, ctx.template device_context<DeviceContext>(), &workspace,
static_cast<float>(0)); static_cast<T>(0));
// compute loss and gradient // compute loss and gradient
status = platform::dynload::compute_ctc_loss( status = ComputeCtcLossFunctor<DeviceContext, T>()(
input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths, input, gradient, cpu_labels, cpu_label_lengths, cpu_input_lengths,
static_cast<int>(sequence_width), static_cast<int>(num_sequences), static_cast<int>(sequence_width), static_cast<int>(num_sequences),
cpu_loss, workspace_data, options_); cpu_loss, workspace_data, options_);
...@@ -112,7 +165,8 @@ class WarpCTCFunctor { ...@@ -112,7 +165,8 @@ class WarpCTCFunctor {
ctx.device_context()) ctx.device_context())
.stream(); .stream();
#else #else
PADDLE_THROW("[warpctc init] GPU is not enabled."); PADDLE_THROW(platform::errors::PreconditionNotMet(
"[warpctc init] GPU is not enabled."));
#endif #endif
} else { } else {
options_.loc = CTC_CPU; options_.loc = CTC_CPU;
...@@ -292,7 +346,7 @@ class WarpCTCKernel : public framework::OpKernel<T> { ...@@ -292,7 +346,7 @@ class WarpCTCKernel : public framework::OpKernel<T> {
const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank")); const size_t blank = static_cast<size_t>(ctx.Attr<int>("blank"));
WarpCTCFunctor<DeviceContext>()( WarpCTCFunctor<DeviceContext, T>()(
ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data, ctx, warpctc_logits_data, warpctc_grad_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(), warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
sequence_width, num_sequences, blank, warpctc_loss_data); sequence_width, num_sequences, blank, warpctc_loss_data);
......
...@@ -53,7 +53,9 @@ extern void* warpctc_dso_handle; ...@@ -53,7 +53,9 @@ extern void* warpctc_dso_handle;
__macro(get_warpctc_version); \ __macro(get_warpctc_version); \
__macro(ctcGetStatusString); \ __macro(ctcGetStatusString); \
__macro(compute_ctc_loss); \ __macro(compute_ctc_loss); \
__macro(get_workspace_size) __macro(compute_ctc_loss_double); \
__macro(get_workspace_size); \
__macro(get_workspace_size_double)
WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP); WARPCTC_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_WARPCTC_WRAP);
......
...@@ -48,6 +48,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -48,6 +48,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"collect_fpn_proposals", {"collect_fpn_proposals",
{"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}},
{"distribute_fpn_proposals", {"FpnRois", "RoisNum"}}, {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}},
{"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -541,7 +541,7 @@ def warpctc(input, ...@@ -541,7 +541,7 @@ def warpctc(input,
(not including the blank label). When it is a 3-D Tensor, its shape (not including the blank label). When it is a 3-D Tensor, its shape
is `[max_logit_length, batch_size, num_classes + 1]`, is `[max_logit_length, batch_size, num_classes + 1]`,
where `max_logit_length` is the longest length of where `max_logit_length` is the longest length of
input logit sequence. The data type must be float32. input logit sequence. The data type should be float32 or float64.
label (Variable): The ground truth of variable-length sequence, label (Variable): The ground truth of variable-length sequence,
which must be a 2-D Tensor with LoD information or a 3-D Tensor without which must be a 2-D Tensor with LoD information or a 3-D Tensor without
LoD information, needs to be consistent with the coressponding input. LoD information, needs to be consistent with the coressponding input.
...@@ -571,6 +571,7 @@ def warpctc(input, ...@@ -571,6 +571,7 @@ def warpctc(input,
.. code-block:: python .. code-block:: python
# using LoDTensor # using LoDTensor
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
...@@ -581,6 +582,7 @@ def warpctc(input, ...@@ -581,6 +582,7 @@ def warpctc(input,
# class num # class num
class_num = 5 class_num = 5
paddle.enable_static()
logits = fluid.data(name='logits',shape=[None, class_num+1], logits = fluid.data(name='logits',shape=[None, class_num+1],
dtype='float32',lod_level=1) dtype='float32',lod_level=1)
label = fluid.data(name='label', shape=[None, 1], label = fluid.data(name='label', shape=[None, 1],
...@@ -602,6 +604,7 @@ def warpctc(input, ...@@ -602,6 +604,7 @@ def warpctc(input,
.. code-block:: python .. code-block:: python
# using Tensor # using Tensor
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
...@@ -613,6 +616,7 @@ def warpctc(input, ...@@ -613,6 +616,7 @@ def warpctc(input,
batch_size = 16 batch_size = 16
# class num # class num
class_num = 5 class_num = 5
paddle.enable_static()
logits = fluid.data(name='logits', logits = fluid.data(name='logits',
shape=[max_seq_length, batch_size, class_num+1], shape=[max_seq_length, batch_size, class_num+1],
dtype='float32') dtype='float32')
...@@ -637,8 +641,23 @@ def warpctc(input, ...@@ -637,8 +641,23 @@ def warpctc(input,
fetch_list=[cost.name]) fetch_list=[cost.name])
print(output) print(output)
""" """
if in_dygraph_mode():
if input_length is None or label_length is None:
raise ValueError(
"input_length and label_length must not be None in dygraph mode!"
)
grad, loss_out = core.ops.warpctc(
input,
label,
input_length,
label_length,
'blank',
blank,
'norm_by_times',
norm_by_times, )
return loss_out
helper = LayerHelper('warpctc', **locals()) helper = LayerHelper('warpctc', **locals())
check_variable_and_dtype(input, 'input', ['float32'], "warpctc") check_variable_and_dtype(input, 'input', ['float32', 'float64'], "warpctc")
check_variable_and_dtype(label, 'label', ['int32'], "warpctc") check_variable_and_dtype(label, 'label', ['int32'], "warpctc")
this_inputs = {'Logits': [input], 'Label': [label]} this_inputs = {'Logits': [input], 'Label': [label]}
if input_length is not None and label_length is not None: if input_length is not None and label_length is not None:
......
...@@ -24,7 +24,7 @@ from paddle.fluid import Program, program_guard ...@@ -24,7 +24,7 @@ from paddle.fluid import Program, program_guard
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
CUDA_BLOCK_SIZE = 512 CUDA_BLOCK_SIZE = 32
class CTCForward(object): class CTCForward(object):
...@@ -41,8 +41,8 @@ class CTCForward(object): ...@@ -41,8 +41,8 @@ class CTCForward(object):
self.num_classes = num_classes self.num_classes = num_classes
self.batch_size = batch_size self.batch_size = batch_size
self.loss = np.zeros([self.batch_size, 1], dtype="float32") self.loss = np.zeros([self.batch_size, 1], dtype=softmax.dtype)
self.gradient = np.zeros(self.softmax.shape, dtype="float32") self.gradient = np.zeros(self.softmax.shape, dtype=softmax.dtype)
# float64 # float64
self.EXP_MAX = sys.float_info.max self.EXP_MAX = sys.float_info.max
...@@ -112,13 +112,15 @@ class CTCForward(object): ...@@ -112,13 +112,15 @@ class CTCForward(object):
# calculate the forward and backward variables, # calculate the forward and backward variables,
# reference Chapter 7.3 of "Alex Grave, Supervised Sequence # reference Chapter 7.3 of "Alex Grave, Supervised Sequence
# Labelling with Recurrent Neural Networks" # Labelling with Recurrent Neural Networks"
log_acts = np.zeros([total_times, self.num_classes], dtype="float32") log_acts = np.zeros(
[total_times, self.num_classes], dtype=softmax_a_sequence.dtype)
for i in range(total_times): for i in range(total_times):
for j in range(self.num_classes): for j in range(self.num_classes):
log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j]) log_acts[i, j] = self.safe_log(softmax_a_sequence[i, j])
# calculate the forward variables # calculate the forward variables
forward_vars = np.zeros([total_times, total_segments], dtype="float32") forward_vars = np.zeros(
[total_times, total_segments], dtype=softmax_a_sequence.dtype)
for i in range(total_times): for i in range(total_times):
for j in range(total_segments): for j in range(total_segments):
forward_vars[i, j] = self.LOG_ZERO forward_vars[i, j] = self.LOG_ZERO
...@@ -219,7 +221,7 @@ class TestWarpCTCOp(OpTest): ...@@ -219,7 +221,7 @@ class TestWarpCTCOp(OpTest):
self.logits_lod[0][i]) self.logits_lod[0][i])
self.gradient = np.zeros( self.gradient = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes], [max_sequence_length, self.batch_size, self.num_classes],
dtype="float32") dtype=logits.dtype)
self.inputs = { self.inputs = {
"Logits": (logits, self.logits_lod), "Logits": (logits, self.logits_lod),
...@@ -287,7 +289,7 @@ class TestWarpCTCOpWithPadding(OpTest): ...@@ -287,7 +289,7 @@ class TestWarpCTCOpWithPadding(OpTest):
# reshape logits to T*N*S # reshape logits to T*N*S
new_logits = np.zeros( new_logits = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes], [max_sequence_length, self.batch_size, self.num_classes],
dtype="float32") dtype=logits.dtype)
cur = 0 cur = 0
for batch_id in range(self.batch_size): for batch_id in range(self.batch_size):
...@@ -312,7 +314,7 @@ class TestWarpCTCOpWithPadding(OpTest): ...@@ -312,7 +314,7 @@ class TestWarpCTCOpWithPadding(OpTest):
self.gradient = np.zeros( self.gradient = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes], [max_sequence_length, self.batch_size, self.num_classes],
dtype="float32") dtype=logits.dtype)
self.inputs = { self.inputs = {
"Logits": new_logits, "Logits": new_logits,
...@@ -347,6 +349,90 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding): ...@@ -347,6 +349,90 @@ class TestWarpCTCOpWithPaddingCase1(TestWarpCTCOpWithPadding):
self.norm_by_times = False self.norm_by_times = False
class TestWarpCTCOpFp64(OpTest):
def config(self):
self.batch_size = 4
self.num_classes = 8
self.logits_lod = [[4, 1, 5, 5]]
self.labels_lod = [[3, 1, 4, 2]]
self.logits_length = np.array([4, 1, 5, 5], dtype=np.int64)
self.labels_length = np.array([3, 1, 4, 2], dtype=np.int64)
self.blank = self.num_classes - 1
self.norm_by_times = False
def setUp(self):
self.op_type = "warpctc"
self.config()
logits = np.random.uniform(
0.1, 1.0,
[sum(self.logits_length), self.num_classes]).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 1, logits)
# labels should not be blank
labels = np.random.randint(
0,
self.num_classes - 1, [sum(self.labels_length), 1],
dtype="int32")
ctc = CTCForward(softmax, self.logits_lod, labels, self.labels_lod,
self.num_classes, self.batch_size, self.blank,
self.norm_by_times)
loss = ctc.forward()
max_sequence_length = 0
for i in range(self.batch_size):
max_sequence_length = max(max_sequence_length,
self.logits_length[i])
# reshape logits to T*N*S
new_logits = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes],
dtype=logits.dtype)
cur = 0
for batch_id in range(self.batch_size):
for i in range(self.logits_length[batch_id]):
for j in range(self.num_classes):
new_logits[i, batch_id, j] = logits[cur + i, j]
cur = cur + self.logits_length[batch_id]
# reshape labels to N*S
max_target_seq_length = 0
for i in range(self.batch_size):
max_target_seq_length = max(max_target_seq_length,
self.labels_length[i])
new_labels = np.zeros(
[self.batch_size, max_target_seq_length], dtype="int32")
cur = 0
for batch_id in range(self.batch_size):
for i in range(self.labels_length[batch_id]):
new_labels[batch_id, i] = labels[cur + i]
cur = cur + self.labels_length[batch_id]
self.gradient = np.zeros(
[max_sequence_length, self.batch_size, self.num_classes],
dtype=logits.dtype)
self.inputs = {
"Logits": new_logits,
"Label": new_labels,
"LogitsLength": self.logits_length,
"LabelLength": self.labels_length
}
self.outputs = {"Loss": loss}
self.attrs = {
"blank": self.blank,
"norm_by_times": self.norm_by_times,
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.outputs['WarpCTCGrad'] = self.gradient
self.check_grad(["Logits"], "Loss")
class TestWarpCTCOpError(unittest.TestCase): class TestWarpCTCOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -359,7 +445,7 @@ class TestWarpCTCOpError(unittest.TestCase): ...@@ -359,7 +445,7 @@ class TestWarpCTCOpError(unittest.TestCase):
name='labels_length', shape=[None], dtype='int64') name='labels_length', shape=[None], dtype='int64')
def test_logits_Variable(): def test_logits_Variable():
logits_data = np.random.rand(5, 16, 6).astype("float32") logits_data = np.random.rand(5, 16, 6).astype(logits.dtype)
fluid.layers.warpctc( fluid.layers.warpctc(
input=logits_data, input=logits_data,
label=label, label=label,
...@@ -398,6 +484,21 @@ class TestWarpCTCOpError(unittest.TestCase): ...@@ -398,6 +484,21 @@ class TestWarpCTCOpError(unittest.TestCase):
self.assertRaises(TypeError, test_label_len_Variable) self.assertRaises(TypeError, test_label_len_Variable)
def test_dygraph_errors(self):
def test_dygraph_with_lod():
logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32")
# labels should not be blank
labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
softmax = paddle.to_variable(logits)
labels = paddle.to_variable(labels)
fluid.layers.warpctc(input=softmax, label=labels)
paddle.disable_static()
self.assertRaises(ValueError, test_dygraph_with_lod)
paddle.enable_static()
class TestCTCLossAPICase(unittest.TestCase): class TestCTCLossAPICase(unittest.TestCase):
def test_functinal_api(self): def test_functinal_api(self):
......
...@@ -933,7 +933,7 @@ def ctc_loss(log_probs, ...@@ -933,7 +933,7 @@ def ctc_loss(log_probs,
is interated to the Warp-CTC library to normalize values for each row of the input tensor. is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters: Parameters:
log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32. log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32. labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64. input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64. label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
......
...@@ -773,7 +773,7 @@ class CTCLoss(fluid.dygraph.Layer): ...@@ -773,7 +773,7 @@ class CTCLoss(fluid.dygraph.Layer):
reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``. reduction (string, optional): Indicate how to average the loss, the candicates are ``'none'`` | ``'mean'`` | ``'sum'``. If :attr:`reduction` is ``'mean'``, the output loss will be divided by the label_lengths, and then return the mean of quotient; If :attr:`reduction` is ``'sum'``, return the sum of loss; If :attr:`reduction` is ``'none'``, no reduction will be applied. Default is ``'mean'``.
Shape: Shape:
log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type must be float32. log_probs (Tensor): The unscaled probability sequence with padding, which is a 3-D Tensor. The tensor shape is [max_logit_length, batch_size, num_classes + 1], where max_logit_length is the longest length of input logit sequence. The data type should be float32 or float64.
labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32. labels (Tensor): The ground truth sequence with padding, which must be a 3-D Tensor. The tensor shape is [batch_size, max_label_length], where max_label_length is the longest length of label sequence. The data type must be int32.
input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64. input_lengths (Tensor): The length for each input sequence, it should have shape [batch_size] and dtype int64.
label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64. label_lengths (Tensor): The length for each label sequence, it should have shape [batch_size] and dtype int64.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册