未验证 提交 d86c26dc 编写于 作者: H Hui Zhang 提交者: GitHub

fix for div zero (#34724)

* fix for div zero

* fix err;test=develop

* fix lod
上级 84eb6757
......@@ -199,6 +199,27 @@ class WarpCTCKernel : public framework::OpKernel<T> {
sequence_width = logits->dims()[2];
max_sequence_length = logits->dims()[0];
PADDLE_ENFORCE_GT(max_sequence_length, 0,
platform::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
max_sequence_length));
PADDLE_ENFORCE_GT(num_sequences, 0,
platform::errors::InvalidArgument(
"The second dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
num_sequences));
PADDLE_ENFORCE_GT(sequence_width, 0,
platform::errors::InvalidArgument(
"The third dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
sequence_width));
auto* logits_length = ctx.Input<framework::Tensor>("LogitsLength");
auto* labels_length = ctx.Input<framework::Tensor>("LabelLength");
framework::Tensor logits_length_cpu;
......@@ -229,6 +250,13 @@ class WarpCTCKernel : public framework::OpKernel<T> {
logits_lod = framework::ToAbsOffset(logits->lod())[0];
auto logits_dims = logits->dims();
PADDLE_ENFORCE_GT(logits_dims[0], 0,
platform::errors::InvalidArgument(
"The first dimension of Input(Logits) should be "
"greater than zero "
"but received %d. ",
logits_dims[0]));
PADDLE_ENFORCE_EQ(
logits_dims[0], static_cast<int64_t>(logits_lod.back()),
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册