未验证 提交 c522530a 编写于 作者: S ShenLiang 提交者: GitHub

fix safe bug of scatter/scatter_nd (#33858)

* fix safe bug of scatter/scatter_nd
上级 57aabbab
......@@ -33,6 +33,14 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output,
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i];
PADDLE_ENFORCE(scatter_i >= 0,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
scatter_i);
IndexT out_i = scatter_i * slice_size + slice_i;
*(output + out_i) = static_cast<T>(0);
}
......@@ -46,6 +54,14 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices,
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT scatter_i = indices[indices_i];
PADDLE_ENFORCE(scatter_i >= 0,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
scatter_i);
IndexT out_i = scatter_i * slice_size + slice_i;
if (overwrite) {
*(output + out_i) = *(params + i);
......@@ -67,6 +83,15 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices,
int64_t temp = slice_size;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = indices[indices_i * end_size + j];
PADDLE_ENFORCE(
index_value >= 0 && index_value < output_dims[j],
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]",
output_dims[j], index_value);
gather_i += (index_value * temp);
temp *= output_dims[j];
}
......
......@@ -118,6 +118,15 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src,
for (int i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i];
PADDLE_ENFORCE_GE(index_, 0,
platform::errors::OutOfRange(
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
index_));
memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes);
}
}
......@@ -173,6 +182,15 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src,
// if not in overwrite mode, need to init output data
for (int i = 0; i < index_size; ++i) {
const IndexT& index_ = p_index[i];
PADDLE_ENFORCE_GE(index_, 0,
platform::errors::OutOfRange(
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be greater than or equal to 0, but received [%d]",
index_));
elementwise_inner_add<T, IndexT>(ctx, p_src, p_output, result_p_output, src,
output, i, index_, slice_size,
slice_bytes);
......@@ -233,6 +251,15 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update,
IndexT temp = 1;
for (int64_t j = end_size - 1; j >= 0; --j) {
IndexT index_value = p_index[i * end_size + j];
PADDLE_ENFORCE_EQ(
(index_value >= 0 && index_value < output_dims[j]), true,
platform::errors::OutOfRange(
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]",
output_dims[j], index_value));
index_ += (index_value * temp);
temp *= output_dims[j];
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册