未验证 提交 64b9065d 编写于 作者: H Haohongxiang 提交者: GitHub

Fix gather_op by adding OurOfRangeCheck for param[Index], test=develop (#34096)

* Fix gather_op by adding OurOfRangeCheck for param[Index]

* Code Optimization
上级 d8343f45
...@@ -30,13 +30,20 @@ using platform::DeviceContext; ...@@ -30,13 +30,20 @@ using platform::DeviceContext;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices, __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t index_size, T* output, size_t input_size,
size_t slice_size) { size_t index_size, size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) { CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size; int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i]; IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i; IndexT params_i = gather_i * slice_size + slice_i;
PADDLE_ENFORCE(
gather_i >= 0 && gather_i < input_size,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_size, gather_i);
*(output + i) = *(params + params_i); *(output + i) = *(params + params_i);
} }
} }
...@@ -58,7 +65,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int* input_dims, ...@@ -58,7 +65,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int* input_dims,
"The index is out of bounds, " "The index is out of bounds, "
"please check whether the dimensions of index and " "please check whether the dimensions of index and "
"input meet the requirements. It should " "input meet the requirements. It should "
"be less than [%d] and greater or equal to 0, but received [%d]", "be less than [%d] and greater than or equal to 0, but received [%d]",
input_dims[j], index_value); input_dims[j], index_value);
gather_i += (index_value * temp); gather_i += (index_value * temp);
temp *= input_dims[j]; temp *= input_dims[j];
...@@ -91,6 +98,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -91,6 +98,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
" the second dimension should be 1.")); " the second dimension should be 1."));
} }
// index size
int index_size = index.dims()[0]; int index_size = index.dims()[0];
auto src_dims = src.dims(); auto src_dims = src.dims();
...@@ -100,6 +108,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -100,6 +108,8 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
...@@ -112,7 +122,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -112,7 +122,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
GatherCUDAKernel<T, IndexT><<< GatherCUDAKernel<T, IndexT><<<
grid, block, 0, grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, index_size, slice_size); p_src, p_index, p_output, input_size, index_size, slice_size);
} }
template <typename DeviceContext, typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
...@@ -177,6 +187,15 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out, ...@@ -177,6 +187,15 @@ __global__ void GatherGPUKernel(const T* input, const U* index, T* out,
int next_idx = idx - outer_size * inner_dim_index; int next_idx = idx - outer_size * inner_dim_index;
int index_dim_index = next_idx / outer_dim_size; int index_dim_index = next_idx / outer_dim_size;
int index_val = index[index_dim_index]; int index_val = index[index_dim_index];
PADDLE_ENFORCE(
index_val >= 0 && index_val < input_index_dim_size,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_index_dim_size, index_val);
int out_dim_index = next_idx - outer_dim_size * index_dim_index; int out_dim_index = next_idx - outer_dim_size * index_dim_index;
int input_index = int input_index =
inner_dim_index * (outer_dim_size * input_index_dim_size) + inner_dim_index * (outer_dim_size * input_index_dim_size) +
......
...@@ -67,11 +67,25 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -67,11 +67,25 @@ void CPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
const size_t slice_bytes = slice_size * sizeof(T); const size_t slice_bytes = slice_size * sizeof(T);
for (int64_t i = 0; i < index_size; ++i) { for (int64_t i = 0; i < index_size; ++i) {
IndexT index_ = p_index[i]; IndexT index_ = p_index[i];
PADDLE_ENFORCE_LT(p_index[i], input_size,
platform::errors::OutOfRange(
"The element of Index must be less than the size of "
"input dim size of axis which is %d, but received "
"index element which is %d in the %d index.",
input_size, p_index[i], i));
PADDLE_ENFORCE_GE(p_index[i], 0UL,
platform::errors::OutOfRange(
"The element of Index must be greater than or equal "
"to 0, but received index element which is %d in the "
"%d index.",
p_index[i], i));
memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes); memcpy(p_output + i * slice_size, p_src + index_ * slice_size, slice_bytes);
} }
} }
...@@ -141,11 +155,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis, ...@@ -141,11 +155,17 @@ void GatherV2Function(const Tensor* input, const Tensor* index, int axis,
int input_index_dim_size = input_dim[axis_index]; int input_index_dim_size = input_dim[axis_index];
for (int i = 0; i < index_size; i++) { for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size, PADDLE_ENFORCE_LT(index_data[i], input_index_dim_size,
platform::errors::InvalidArgument( platform::errors::OutOfRange(
"The element of Index must be less than the size of " "The element of Index must be less than the size of "
"input dim size of axis which is %d, but received " "input dim size of axis which is %d, but received "
"index element which is %d in the %d index.", "index element which is %d in the %d index.",
input_index_dim_size, index_data[i], i)); input_index_dim_size, index_data[i], i));
PADDLE_ENFORCE_GE(index_data[i], 0UL,
platform::errors::OutOfRange(
"The element of Index must be greater than or equal "
"to 0, but received index element which is %d in the "
"%d index.",
index_data[i], i));
} }
int inner_dim_size = 1; int inner_dim_size = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册