未验证 提交 380bc4e6 编写于 作者: H Haohongxiang 提交者: GitHub

Fix gather_op to avoid cudaErrorLaunchFailure for solov2, test=develop (#34200)

上级 661f4094
......@@ -30,20 +30,13 @@ using platform::DeviceContext;
template <typename T, typename IndexT = int>
__global__ void GatherCUDAKernel(const T* params, const IndexT* indices,
T* output, size_t input_size,
size_t index_size, size_t slice_size) {
T* output, size_t index_size,
size_t slice_size) {
CUDA_KERNEL_LOOP(i, index_size * slice_size) {
int indices_i = i / slice_size;
int slice_i = i - indices_i * slice_size; // offset inside the slice
IndexT gather_i = indices[indices_i];
IndexT params_i = gather_i * slice_size + slice_i;
PADDLE_ENFORCE(
gather_i >= 0 && gather_i < input_size,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d] and greater than or equal to 0, but received [%d]",
input_size, gather_i);
*(output + i) = *(params + params_i);
}
}
......@@ -108,8 +101,6 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
// slice size
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];
// input size
int input_size = src_dims[0] * slice_size;
const T* p_src = src.data<T>();
const IndexT* p_index = index.data<IndexT>();
......@@ -122,7 +113,7 @@ void GPUGather(const platform::DeviceContext& ctx, const Tensor& src,
GatherCUDAKernel<T, IndexT><<<
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
p_src, p_index, p_output, input_size, index_size, slice_size);
p_src, p_index, p_output, index_size, slice_size);
}
template <typename DeviceContext, typename T, typename IndexT = int>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册