From edb9aff59e6325f1a043441b732e7d65d427cf6f Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Fri, 23 Jul 2021 15:49:32 +0800 Subject: [PATCH] update gather tree error msg (#34322) --- paddle/fluid/operators/gather_tree_op.cu | 8 ++++++++ paddle/fluid/operators/gather_tree_op.h | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/paddle/fluid/operators/gather_tree_op.cu b/paddle/fluid/operators/gather_tree_op.cu index c53f1e81ce..829682764a 100644 --- a/paddle/fluid/operators/gather_tree_op.cu +++ b/paddle/fluid/operators/gather_tree_op.cu @@ -50,6 +50,14 @@ class GatherTreeOpCUDAKernel : public framework::OpKernel { const auto *parents_data = parents->data(); auto *out_data = out->mutable_data(ctx.GetPlace()); + PADDLE_ENFORCE_NOT_NULL( + ids_data, platform::errors::InvalidArgument( + "Input(Ids) of gather_tree should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + parents_data, platform::errors::InvalidArgument( + "Input(Parents) of gather_tree should not be null.")); + auto &ids_dims = ids->dims(); int64_t max_length = ids_dims[0]; int64_t batch_size = ids_dims[1]; diff --git a/paddle/fluid/operators/gather_tree_op.h b/paddle/fluid/operators/gather_tree_op.h index 742a7ffcaa..e035a30e79 100644 --- a/paddle/fluid/operators/gather_tree_op.h +++ b/paddle/fluid/operators/gather_tree_op.h @@ -38,6 +38,14 @@ class GatherTreeOpKernel : public framework::OpKernel { auto batch_size = ids_dims[1]; auto beam_size = ids_dims[2]; + PADDLE_ENFORCE_NOT_NULL( + ids_data, platform::errors::InvalidArgument( + "Input(Ids) of gather_tree should not be null.")); + + PADDLE_ENFORCE_NOT_NULL( + parents_data, platform::errors::InvalidArgument( + "Input(Parents) of gather_tree should not be null.")); + for (int batch = 0; batch < batch_size; batch++) { for (int beam = 0; beam < beam_size; beam++) { auto idx = (max_length - 1) * batch_size * beam_size + -- GitLab