From d86c26dc457d183e94d81a84008424dce5aff7c4 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Mon, 9 Aug 2021 22:08:20 -0500 Subject: [PATCH] fix for div zero (#34724) * fix for div zero * fix err;test=develop * fix lod --- paddle/fluid/operators/warpctc_op.h | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/paddle/fluid/operators/warpctc_op.h b/paddle/fluid/operators/warpctc_op.h index e90eefd72d4..f5b51da3d85 100644 --- a/paddle/fluid/operators/warpctc_op.h +++ b/paddle/fluid/operators/warpctc_op.h @@ -199,6 +199,27 @@ class WarpCTCKernel : public framework::OpKernel { sequence_width = logits->dims()[2]; max_sequence_length = logits->dims()[0]; + PADDLE_ENFORCE_GT(max_sequence_length, 0, + platform::errors::InvalidArgument( + "The first dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + max_sequence_length)); + + PADDLE_ENFORCE_GT(num_sequences, 0, + platform::errors::InvalidArgument( + "The second dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + num_sequences)); + + PADDLE_ENFORCE_GT(sequence_width, 0, + platform::errors::InvalidArgument( + "The third dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + sequence_width)); + auto* logits_length = ctx.Input("LogitsLength"); auto* labels_length = ctx.Input("LabelLength"); framework::Tensor logits_length_cpu; @@ -229,6 +250,13 @@ class WarpCTCKernel : public framework::OpKernel { logits_lod = framework::ToAbsOffset(logits->lod())[0]; auto logits_dims = logits->dims(); + PADDLE_ENFORCE_GT(logits_dims[0], 0, + platform::errors::InvalidArgument( + "The first dimension of Input(Logits) should be " + "greater than zero " + "but received %d. ", + logits_dims[0])); + PADDLE_ENFORCE_EQ( logits_dims[0], static_cast(logits_lod.back()), platform::errors::InvalidArgument( -- GitLab