未验证 提交 380bc4e6 编写于 作者: H Haohongxiang 提交者: GitHub

Fix gather_op to avoid cudaErrorLaunchFailure for solov2, test=develop (#34200)

上级 661f4094
...@@ -30,20 +30,13 @@ using platform::DeviceContext; ...@@ -30,20 +30,13 @@ using platform::DeviceContext;
template <typename T, typename IndexT = int> template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices, __global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t input_size, T* output, size_t index_size,
size_t index_size, size_t slice_size) { size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) { CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size; int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i]; IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i; IndexT params_i = gather_i * slice_size + slice_i;
PADDLE_ENFORCE(
gather_i >= 0 && gather_i < input_size,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_size, gather_i);
*(output + i) = *(params + params_i); *(output + i) = *(params + params_i);
} }
} }
...@@ -108,8 +101,6 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -108,8 +101,6 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// slice size // slice size
int slice_size = 1; int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
const T* p_src = src.data<T>(); const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>(); const IndexT* p_index = index.data<IndexT>();
...@@ -122,7 +113,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src, ...@@ -122,7 +113,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
GatherCUDAKernel<T, IndexT><<< GatherCUDAKernel<T, IndexT><<<
grid, block, 0, grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, input_size, index_size, slice_size); p_src, p_index, p_output, index_size, slice_size);
} }
template <typename DeviceContext, typename T, typename IndexT = int> template <typename DeviceContext, typename T, typename IndexT = int>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册