未验证 提交 11b17c88 编写于 作者: S sneaxiy 提交者: GitHub

Enhance the error message of scatter op (#37429)

* enhance scatter err msg check

* fix ci error
上级 32d9beef
......@@ -169,15 +169,9 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
const size_t& slice_bytes = slice_size * sizeof(T);
// if not in overwrite mode, need to init output data
auto max_index = dst_dims[0];
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i];
memset(p_output + slice_size * index_val, 0, slice_bytes);
}
// if not in overwrite mode, need to init output data
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i];
PADDLE_ENFORCE_GE(index_val, 0,
platform::errors::OutOfRange(
"The index is out of bounds, "
......@@ -185,7 +179,19 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
index_val));
PADDLE_ENFORCE_LT(index_val, max_index,
platform::errors::OutOfRange(
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than %d, but received %d",
max_index, index_val));
memset(p_output + slice_size * index_val, 0, slice_bytes);
}
// if not in overwrite mode, need to init output data
for (int64_t i = 0; i < index_size; ++i) {
const IndexT& index_val = p_index[i];
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, i, index_val,
slice_size);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册