From c522530a4755f1671568467e265101a735d22a56 Mon Sep 17 00:00:00 2001 From: ShenLiang <1422485404@qq.com> Date: Thu, 1 Jul 2021 17:55:24 +0800 Subject: [PATCH] fix safe bug of scatter/scatter_nd (#33858) * fix safe bug of scatter/scatter_nd --- paddle/fluid/operators/scatter.cu.h | 25 +++++++++++++++++++++++++ paddle/fluid/operators/scatter.h | 27 +++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index b116a78891a..61e95c2b50e 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -33,6 +33,14 @@ __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; + + PADDLE_ENFORCE(scatter_i >= 0, + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be greater than or equal to 0, but received [%d]", + scatter_i); + IndexT out_i = scatter_i * slice_size + slice_i; *(output + out_i) = static_cast(0); } @@ -46,6 +54,14 @@ __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; + + PADDLE_ENFORCE(scatter_i >= 0, + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be greater than or equal to 0, but received [%d]", + scatter_i); + IndexT out_i = scatter_i * slice_size + slice_i; if (overwrite) { *(output + out_i) = *(params + i); @@ -67,6 +83,15 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, int64_t temp = slice_size; for (int64_t j = end_size - 1; j >= 0; --j) { IndexT index_value = indices[indices_i * end_size + j]; + + PADDLE_ENFORCE( + index_value >= 0 && index_value < output_dims[j], + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be less than [%d] and greater or equal to 0, but received [%d]", + output_dims[j], index_value); + gather_i += (index_value * temp); temp *= output_dims[j]; } diff --git a/paddle/fluid/operators/scatter.h b/paddle/fluid/operators/scatter.h index 864a94a4235..2589033d2fe 100644 --- a/paddle/fluid/operators/scatter.h +++ b/paddle/fluid/operators/scatter.h @@ -118,6 +118,15 @@ void ScatterAssign(const platform::DeviceContext& ctx, const Tensor& src, for (int i = 0; i < index_size; ++i) { IndexT index_ = p_index[i]; + + PADDLE_ENFORCE_GE(index_, 0, + platform::errors::OutOfRange( + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be greater than or equal to 0, but received [%d]", + index_)); + memcpy(p_output + index_ * slice_size, p_src + i * slice_size, slice_bytes); } } @@ -173,6 +182,15 @@ void ScatterAssignAdd(const framework::ExecutionContext& ctx, const Tensor& src, // if not in overwrite mode, need to init output data for (int i = 0; i < index_size; ++i) { const IndexT& index_ = p_index[i]; + + PADDLE_ENFORCE_GE(index_, 0, + platform::errors::OutOfRange( + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be greater than or equal to 0, but received [%d]", + index_)); + elementwise_inner_add(ctx, p_src, p_output, result_p_output, src, output, i, index_, slice_size, slice_bytes); @@ -233,6 +251,15 @@ void ScatterNdAdd(const framework::ExecutionContext& ctx, const Tensor& update, IndexT temp = 1; for (int64_t j = end_size - 1; j >= 0; --j) { IndexT index_value = p_index[i * end_size + j]; + PADDLE_ENFORCE_EQ( + (index_value >= 0 && index_value < output_dims[j]), true, + platform::errors::OutOfRange( + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be less than [%d] and greater or equal to 0, but received [%d]", + output_dims[j], index_value)); + index_ += (index_value * temp); temp *= output_dims[j]; } -- GitLab