From d5e73f079c758b3ca756f9cfd6b54932a5b71dda Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 23 Nov 2021 11:48:58 +0800 Subject: [PATCH] [Cherry-pick 2.2]Enhance error message of scatter op (#37431) * enhance scatter err msg check * fix ci error --- paddle/fluid/operators/scatter.h | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 802d9ae76d3..c4f9c628dbb 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -169,15 +169,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, const size_t& slice_bytes = slice_size * sizeof(T); // if not in overwrite mode, need to init output data + auto max_index = dst_dims[0]; for (int64_t i = 0; i < index_size; ++i) { const IndexT& index_val = p_index[i]; - memset(p_output + slice_size * index_val, 0, slice_bytes); - } - - // if not in overwrite mode, need to init output data - for (int64_t i = 0; i < index_size; ++i) { - const IndexT& index_val = p_index[i]; - PADDLE_ENFORCE_GE(index_val, 0, platform::errors::OutOfRange( "The index is out of bounds, " @@ -185,7 +179,19 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, "input meet the requirements. It should " "be greater than or equal to 0, but received [%d]", index_val)); + PADDLE_ENFORCE_LT(index_val, max_index, + platform::errors::OutOfRange( + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be less than %d, but received %d", + max_index, index_val)); + memset(p_output + slice_size * index_val, 0, slice_bytes); + } + // if not in overwrite mode, need to init output data + for (int64_t i = 0; i < index_size; ++i) { + const IndexT& index_val = p_index[i]; elementwise_inner_add(ctx, p_src, p_output, i, index_val, slice_size); } -- GitLab