未验证 提交 1c67cf0c 编写于 作者: Z zlsh80826 提交者: GitHub

run radix sort of proposals layer on context stream (#31631)

上级 e429deb0
......@@ -66,7 +66,8 @@ static void SortDescending(const platform::CUDADeviceContext &ctx,
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num);
nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num, 0,
sizeof(T) * 8, ctx.stream());
// Allocate temporary storage
auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
......@@ -74,7 +75,7 @@ static void SortDescending(const platform::CUDADeviceContext &ctx,
// Run sorting operation
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in,
idx_out, num);
idx_out, num, 0, sizeof(T) * 8, ctx.stream());
}
template <typename T>
......
......@@ -144,7 +144,7 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairsDescending<T, int>(
nullptr, temp_storage_bytes, concat_scores.data<T>(), keys_out, idx_in,
idx_out, total_roi_num);
idx_out, total_roi_num, 0, sizeof(T) * 8, dev_ctx.stream());
// Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
......@@ -152,7 +152,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort score to get corresponding index
cub::DeviceRadixSort::SortPairsDescending<T, int>(
d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data<T>(),
keys_out, idx_in, idx_out, total_roi_num);
keys_out, idx_in, idx_out, total_roi_num, 0, sizeof(T) * 8,
dev_ctx.stream());
index_out_t.Resize({real_post_num});
Tensor sorted_rois;
sorted_rois.mutable_data<T>({real_post_num, kBBoxSize}, dev_ctx.GetPlace());
......@@ -176,7 +177,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(
nullptr, temp_storage_bytes, sorted_batch_id.data<int>(), out_id_data,
batch_idx_in, index_out_t.data<int>(), real_post_num);
batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
sizeof(int) * 8, dev_ctx.stream());
// Allocate temporary storage
d_temp_storage = memory::Alloc(place, temp_storage_bytes);
......@@ -184,7 +186,8 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort batch_id to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data<int>(),
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num);
out_id_data, batch_idx_in, index_out_t.data<int>(), real_post_num, 0,
sizeof(int) * 8, dev_ctx.stream());
GPUGather<T>(dev_ctx, sorted_rois, index_out_t, fpn_rois);
......
......@@ -149,9 +149,9 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// Determine temporary device storage requirements
size_t temp_storage_bytes = 0;
cub::DeviceRadixSort::SortPairs<int, int>(nullptr, temp_storage_bytes,
target_lvls_data, keys_out,
idx_in, idx_out, roi_num);
cub::DeviceRadixSort::SortPairs<int, int>(
nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in,
idx_out, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
// Allocate temporary storage
auto d_temp_storage = memory::Alloc(place, temp_storage_bytes);
......@@ -159,14 +159,14 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel<T> {
// sort target level to get corresponding index
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out,
idx_in, idx_out, roi_num);
idx_in, idx_out, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
int* restore_idx_data =
restore_index->mutable_data<int>({roi_num, 1}, dev_ctx.GetPlace());
// sort current index to get restore index
cub::DeviceRadixSort::SortPairs<int, int>(
d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in,
restore_idx_data, roi_num);
restore_idx_data, roi_num, 0, sizeof(int) * 8, dev_ctx.stream());
int start = 0;
auto multi_rois_num = ctx.MultiOutput<Tensor>("MultiLevelRoIsNum");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册