diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 802d9ae76d3b5894100b8f86c979f6b4fb12a161..c4f9c628dbb04ebba860931ad9f2ca3b11a3473e 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -169,15 +169,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, const size_t& slice_bytes = slice_size * sizeof(T); // if not in overwrite mode, need to init output data + auto max_index = dst_dims[0]; for (int64_t i = 0; i < index_size; ++i) { const IndexT& index_val = p_index[i]; - memset(p_output + slice_size * index_val, 0, slice_bytes); - } - - // if not in overwrite mode, need to init output data - for (int64_t i = 0; i < index_size; ++i) { - const IndexT& index_val = p_index[i]; - PADDLE_ENFORCE_GE(index_val, 0, platform::errors::OutOfRange( "The index is out of bounds, " @@ -185,7 +179,19 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, "input meet the requirements. It should " "be greater than or equal to 0, but received [%d]", index_val)); + PADDLE_ENFORCE_LT(index_val, max_index, + platform::errors::OutOfRange( + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be less than %d, but received %d", + max_index, index_val)); + memset(p_output + slice_size * index_val, 0, slice_bytes); + } + // if not in overwrite mode, need to init output data + for (int64_t i = 0; i < index_size; ++i) { + const IndexT& index_val = p_index[i]; elementwise_inner_add(ctx, p_src, p_output, i, index_val, slice_size); }