Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9e9b705a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9e9b705a
编写于
11月 29, 2022
作者:
V
Vvsmile
提交者:
GitHub
11月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize the implementation of the argsort operator. (#47738)
Optimize the implementation of the argsort operator
上级
de443726
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
314 addition
and
110 deletion
+314
-110
paddle/phi/kernels/gpu/argsort_kernel.cu
paddle/phi/kernels/gpu/argsort_kernel.cu
+314
-110
未找到文件。
paddle/phi/kernels/gpu/argsort_kernel.cu
浏览文件 @
9e9b705a
...
@@ -64,8 +64,10 @@ struct SegmentOffsetIter {
...
@@ -64,8 +64,10 @@ struct SegmentOffsetIter {
int
num_cols_
;
int
num_cols_
;
};
};
#define PADDLE_CUDA_NUM_THREADS 1024
template
<
typename
T
>
template
<
typename
T
>
static
__global__
void
FillIndex
(
T
*
indices
,
T
num_rows
,
T
num_cols
)
{
static
__global__
void
FillIndex
(
T
*
indices
,
T
num_rows
,
T
num_cols
)
{
int
col_id
=
threadIdx
.
x
;
int
col_id
=
threadIdx
.
x
;
int
row_id
=
blockIdx
.
x
;
int
row_id
=
blockIdx
.
x
;
...
@@ -78,23 +80,246 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
...
@@ -78,23 +80,246 @@ static __global__ void FillIndex(T* indices, T num_rows, T num_cols) {
// Sort by flag descending, True: descending. False: Ascending.
// Sort by flag descending, True: descending. False: Ascending.
// Default is false.
// Default is false.
template
<
typename
T
,
typename
IndType
>
static
__global__
void
FillIndexAndSegmentKernel
(
int2
*
data
,
void
ArgFullSort
(
const
phi
::
GPUContext
&
ctx
,
int
numel
,
const
DenseTensor
*
input
,
int
nsort
)
{
DenseTensor
*
output
,
CUDA_KERNEL_LOOP
(
idx
,
numel
)
{
DenseTensor
*
indices
,
auto
segment
=
static_cast
<
int
>
(
idx
/
nsort
);
const
IndType
num_rows
,
auto
sort
=
static_cast
<
int
>
(
idx
%
nsort
);
const
IndType
num_cols
,
data
[
idx
]
=
int2
{
segment
,
sort
};
}
}
#define CUB_WRAPPER(func, ctx, ...) \
do { \
size_t temp_storage_bytes = 0; \
gpuError_t err; \
err = func(nullptr, temp_storage_bytes, __VA_ARGS__); \
PADDLE_ENFORCE_GPU_SUCCESS(err); \
DenseTensor temp_storage; \
int64_t temp_size = temp_storage_bytes; \
temp_storage.Resize({temp_size}); \
ctx.template Alloc<uint8_t>(&temp_storage); \
err = func(temp_storage.data<uint8_t>(), temp_storage_bytes, __VA_ARGS__); \
PADDLE_ENFORCE_GPU_SUCCESS(err); \
} while (false)
template
<
typename
KT
,
typename
VT
>
static
void
RadixSortPairs
(
const
phi
::
GPUContext
&
ctx
,
const
KT
*
keys_in
,
const
VT
*
values_in
,
KT
*
keys_out
,
VT
*
values_out
,
int64_t
n
,
bool
descending
=
false
,
int64_t
begin_bit
=
0
,
int64_t
end_bit
=
sizeof
(
KT
)
*
8
)
{
if
(
keys_out
==
nullptr
)
{
DenseTensor
key_out_owner
;
key_out_owner
.
Resize
({
n
});
ctx
.
template
Alloc
<
KT
>(
&
key_out_owner
);
keys_out
=
key_out_owner
.
data
<
KT
>
();
}
if
(
descending
)
{
CUB_WRAPPER
(
cub
::
DeviceRadixSort
::
SortPairsDescending
,
ctx
,
keys_in
,
keys_out
,
values_in
,
values_out
,
n
,
begin_bit
,
end_bit
,
ctx
.
stream
());
}
else
{
CUB_WRAPPER
(
cub
::
DeviceRadixSort
::
SortPairs
,
ctx
,
keys_in
,
keys_out
,
values_in
,
values_out
,
n
,
begin_bit
,
end_bit
,
ctx
.
stream
());
}
}
template
<
typename
KT
>
static
void
RadixSortKeys
(
const
phi
::
GPUContext
&
ctx
,
const
KT
*
keys_in
,
KT
*
keys_out
,
int64_t
n
,
bool
descending
,
int64_t
begin_bit
,
int64_t
end_bit
)
{
if
(
descending
)
{
CUB_WRAPPER
(
cub
::
DeviceRadixSort
::
SortKeysDescending
,
ctx
,
keys_in
,
keys_out
,
n
,
begin_bit
,
end_bit
,
ctx
.
stream
());
}
else
{
CUB_WRAPPER
(
cub
::
DeviceRadixSort
::
SortKeys
,
ctx
,
keys_in
,
keys_out
,
n
,
begin_bit
,
end_bit
,
ctx
.
stream
());
}
}
template
<
typename
T
>
static
__global__
void
SortPostprocessKernel
(
const
T
*
in
,
const
int2
*
i_s_ptr
,
T
*
out
,
int64_t
*
index
,
int
nsegments
,
int
nsort
)
{
CUDA_KERNEL_LOOP
(
i
,
nsegments
*
nsort
)
{
int
segment
=
i
/
nsort
;
// segment_id
int
j
=
i
%
nsort
;
int
offset
=
segment
*
nsort
;
const
T
*
in_
=
in
+
offset
;
T
*
out_
=
out
+
offset
;
int64_t
*
index_
=
index
+
offset
;
const
int2
*
i_s_ptr_
=
i_s_ptr
+
offset
;
int
idx
=
i_s_ptr_
[
j
].
y
;
index_
[
j
]
=
idx
;
out_
[
j
]
=
in_
[
idx
];
}
}
template
<
typename
T
>
inline
void
SegmentedSortPairsByFullSort
(
const
phi
::
GPUContext
&
ctx
,
const
T
*
const
self_ptr
,
T
*
const
values_ptr
,
int64_t
*
const
indices_ptr
,
const
int64_t
nsegments
,
const
int64_t
nsort
,
const
int64_t
n
,
const
bool
descending
)
{
int64_t
segment_bits
=
std
::
max
<
int64_t
>
(
1L
,
static_cast
<
int64_t
>
(
std
::
ceil
(
std
::
log2
(
nsegments
))));
const
auto
numel
=
nsort
*
nsegments
;
DenseTensor
indices_and_segment
;
int64_t
indices_and_segment_size
=
numel
;
indices_and_segment
.
Resize
({
indices_and_segment_size
*
2
});
ctx
.
template
Alloc
<
int64_t
>(
&
indices_and_segment
);
auto
i_s_ptr_base
=
indices_and_segment
.
data
<
int64_t
>
();
auto
i_s_ptr
=
reinterpret_cast
<
int2
*>
(
i_s_ptr_base
);
dim3
block
=
PADDLE_CUDA_NUM_THREADS
;
auto
block_num
=
(
numel
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
;
dim3
grid
=
static_cast
<
int
>
(
block_num
);
auto
cu_stream
=
ctx
.
stream
();
FillIndexAndSegmentKernel
<<<
grid
,
block
,
0
,
cu_stream
>>>
(
i_s_ptr
,
numel
,
nsort
);
DenseTensor
indices_and_segment2
;
int64_t
indices_and_segment2_size
=
numel
;
indices_and_segment2
.
Resize
({
indices_and_segment2_size
*
2
});
ctx
.
template
Alloc
<
int64_t
>(
&
indices_and_segment2
);
auto
i_s_ptr2_base
=
indices_and_segment2
.
data
<
int64_t
>
();
auto
i_s_ptr2
=
reinterpret_cast
<
int2
*>
(
i_s_ptr2_base
);
RadixSortPairs
<
T
,
int2
>
(
ctx
,
self_ptr
,
i_s_ptr
,
nullptr
,
i_s_ptr2
,
n
,
descending
);
RadixSortKeys
<
int64_t
>
(
ctx
,
reinterpret_cast
<
int64_t
*>
(
i_s_ptr2
),
reinterpret_cast
<
int64_t
*>
(
i_s_ptr
),
n
,
false
,
0
,
segment_bits
);
SortPostprocessKernel
<<<
grid
,
block
,
0
,
cu_stream
>>>
(
self_ptr
,
i_s_ptr
,
values_ptr
,
indices_ptr
,
nsegments
,
nsort
);
}
// The method is called when # of the rows of the input is less than or equal to
// 4
template
<
typename
T
,
typename
IndexType
>
void
ArgFullSortForTinyRows
(
const
phi
::
GPUContext
&
ctx
,
const
DenseTensor
*
input
,
DenseTensor
*
output
,
DenseTensor
*
indices
,
const
IndexType
num_rows
,
const
IndexType
num_cols
,
const
bool
descending
)
{
auto
gpu_stream
=
ctx
.
stream
();
size_t
temp_storage_bytes
=
-
1
;
IndexType
numel
=
num_rows
*
num_cols
;
if
(
numel
==
0
)
{
return
;
}
IndexType
numel_or_intmax
=
std
::
min
(
numel
,
static_cast
<
int64_t
>
(
std
::
numeric_limits
<
int
>::
max
()));
IndexType
nsort
=
num_cols
;
IndexType
nbatch
=
(
numel_or_intmax
/
nsort
)
*
nsort
;
T
*
sorted_out_ptr
;
IndexType
*
sorted_indices_ptr
;
const
T
*
input_data
=
input
->
data
<
T
>
();
T
*
out
=
ctx
.
template
Alloc
<
T
>(
output
);
IndexType
*
ind
=
ctx
.
template
Alloc
<
IndexType
>(
indices
);
sorted_out_ptr
=
out
;
sorted_indices_ptr
=
ind
;
int64_t
remaining
=
numel
;
while
(
remaining
>
0
)
{
int64_t
n
=
std
::
min
(
remaining
,
nbatch
);
IndexType
nsegments
=
n
/
nsort
;
SegmentedSortPairsByFullSort
(
ctx
,
input_data
,
sorted_out_ptr
,
sorted_indices_ptr
,
nsegments
,
nsort
,
n
,
descending
);
remaining
-=
n
;
input_data
+=
n
;
sorted_out_ptr
+=
n
;
sorted_indices_ptr
+=
n
;
}
}
template
<
typename
T
,
typename
IndexType
>
void
ArgFullSort
(
const
phi
::
GPUContext
&
ctx
,
const
DenseTensor
*
input
,
DenseTensor
*
output
,
DenseTensor
*
indices
,
const
IndexType
num_rows
,
const
IndexType
num_cols
,
const
bool
descending
)
{
const
bool
descending
)
{
auto
cu_stream
=
ctx
.
stream
();
auto
cu_stream
=
ctx
.
stream
();
DenseTensor
input_indices
;
DenseTensor
input_indices
;
const
std
::
vector
<
IndType
>
dims
=
{
num_rows
,
num_cols
};
const
std
::
vector
<
Ind
ex
Type
>
dims
=
{
num_rows
,
num_cols
};
auto
dim
=
phi
::
make_ddim
(
dims
);
auto
dim
=
phi
::
make_ddim
(
dims
);
input_indices
.
Resize
(
dim
);
input_indices
.
Resize
(
dim
);
ctx
.
template
Alloc
<
IndType
>(
&
input_indices
);
ctx
.
template
Alloc
<
Ind
ex
Type
>(
&
input_indices
);
size_t
temp_storage_bytes
=
-
1
;
size_t
temp_storage_bytes
=
-
1
;
auto
ComputeBlockSize
=
[](
IndType
col
)
{
auto
ComputeBlockSize
=
[](
Ind
ex
Type
col
)
{
if
(
col
>
512
)
if
(
col
>
512
)
return
1024
;
return
1024
;
else
if
(
col
>
256
&&
col
<=
512
)
else
if
(
col
>
256
&&
col
<=
512
)
...
@@ -113,111 +338,70 @@ void ArgFullSort(const phi::GPUContext& ctx,
...
@@ -113,111 +338,70 @@ void ArgFullSort(const phi::GPUContext& ctx,
int
grid_size
=
num_rows
<
maxGridDimX
?
num_rows
:
maxGridDimX
;
int
grid_size
=
num_rows
<
maxGridDimX
?
num_rows
:
maxGridDimX
;
// Init a index array
// Init a index array
FillIndex
<<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
FillIndex
<<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
input_indices
.
data
<
IndType
>
(),
num_rows
,
num_cols
);
input_indices
.
data
<
Ind
ex
Type
>
(),
num_rows
,
num_cols
);
T
*
sorted_out_ptr
;
T
*
sorted_out_ptr
;
Ind
Type
*
sorted_indices_ptr
;
Ind
exType
*
sorted_indices_ptr
;
const
T
*
inp
=
input
->
data
<
T
>
();
const
T
*
inp
=
input
->
data
<
T
>
();
T
*
out
=
ctx
.
template
Alloc
<
T
>(
output
);
T
*
out
=
ctx
.
template
Alloc
<
T
>(
output
);
Ind
Type
*
ind
=
ctx
.
template
Alloc
<
Ind
Type
>(
indices
);
Ind
exType
*
ind
=
ctx
.
template
Alloc
<
Index
Type
>(
indices
);
sorted_out_ptr
=
out
;
sorted_out_ptr
=
out
;
sorted_indices_ptr
=
ind
;
sorted_indices_ptr
=
ind
;
// create iter for counting input
// create iter for counting input
cub
::
CountingInputIterator
<
IndType
>
counting_iter
(
0
);
cub
::
CountingInputIterator
<
Ind
ex
Type
>
counting_iter
(
0
);
// segment_offset is used for move to next row
// segment_offset is used for move to next row
cub
::
TransformInputIterator
<
IndType
,
cub
::
TransformInputIterator
<
Ind
ex
Type
,
SegmentOffsetIter
,
SegmentOffsetIter
,
cub
::
CountingInputIterator
<
IndType
>>
cub
::
CountingInputIterator
<
Ind
ex
Type
>>
segment_offsets_t
(
counting_iter
,
SegmentOffsetIter
(
num_cols
));
segment_offsets_t
(
counting_iter
,
SegmentOffsetIter
(
num_cols
));
gpuError_t
err
;
gpuError_t
err
;
if
(
descending
)
{
if
(
descending
)
{
err
=
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
CUB_WRAPPER
(
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
,
nullptr
,
ctx
,
temp_storage_bytes
,
inp
,
inp
,
sorted_out_ptr
,
sorted_out_ptr
,
input_indices
.
data
<
IndexType
>
(),
input_indices
.
data
<
IndType
>
(),
sorted_indices_ptr
,
sorted_indices_ptr
,
num_cols
*
num_rows
,
num_cols
*
num_rows
,
num_rows
,
num_rows
,
segment_offsets_t
,
segment_offsets_t
,
segment_offsets_t
+
1
,
segment_offsets_t
+
1
,
0
,
0
,
sizeof
(
T
)
*
8
,
sizeof
(
T
)
*
8
,
ctx
.
stream
());
cu_stream
);
}
else
{
}
else
{
err
=
CUB_WRAPPER
(
cub
::
DeviceSegmentedRadixSort
::
SortPairs
,
cub
::
DeviceSegmentedRadixSort
::
SortPairs
(
nullptr
,
ctx
,
temp_storage_bytes
,
inp
,
inp
,
sorted_out_ptr
,
sorted_out_ptr
,
input_indices
.
data
<
IndexType
>
(),
input_indices
.
data
<
IndType
>
(),
sorted_indices_ptr
,
sorted_indices_ptr
,
num_cols
*
num_rows
,
num_cols
*
num_rows
,
num_rows
,
num_rows
,
segment_offsets_t
,
segment_offsets_t
,
segment_offsets_t
+
1
,
segment_offsets_t
+
1
,
0
,
0
,
sizeof
(
T
)
*
8
,
sizeof
(
T
)
*
8
,
ctx
.
stream
());
cu_stream
);
}
}
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
DenseTensor
temp_storage
;
int64_t
temp_size
=
temp_storage_bytes
;
temp_storage
.
Resize
({
temp_size
});
ctx
.
template
Alloc
<
uint8_t
>(
&
temp_storage
);
if
(
descending
)
{
err
=
cub
::
DeviceSegmentedRadixSort
::
SortPairsDescending
(
temp_storage
.
data
<
uint8_t
>
(),
temp_storage_bytes
,
inp
,
sorted_out_ptr
,
input_indices
.
data
<
IndType
>
(),
sorted_indices_ptr
,
num_cols
*
num_rows
,
num_rows
,
segment_offsets_t
,
segment_offsets_t
+
1
,
0
,
sizeof
(
T
)
*
8
,
cu_stream
);
}
else
{
err
=
cub
::
DeviceSegmentedRadixSort
::
SortPairs
(
temp_storage
.
data
<
uint8_t
>
(),
temp_storage_bytes
,
inp
,
sorted_out_ptr
,
input_indices
.
data
<
IndType
>
(),
sorted_indices_ptr
,
num_cols
*
num_rows
,
num_rows
,
segment_offsets_t
,
segment_offsets_t
+
1
,
0
,
sizeof
(
T
)
*
8
,
cu_stream
);
}
PADDLE_ENFORCE_GPU_SUCCESS
(
err
);
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
ArgsortKernel
(
const
Context
&
dev_ctx
,
void
ArgsortKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
int
axis
,
int
axis
,
bool
descending
,
bool
descending
,
DenseTensor
*
output
,
DenseTensor
*
output
,
DenseTensor
*
indices
)
{
DenseTensor
*
indices
)
{
auto
in_dims
=
input
.
dims
();
auto
in_dims
=
input
.
dims
();
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
axis
=
(
axis
<
0
)
?
(
in_dims
.
size
()
+
axis
)
:
axis
;
const
T
*
in_data
=
input
.
data
<
T
>
();
const
T
*
in_data
=
input
.
data
<
T
>
();
auto
size
=
input
.
numel
();
auto
size
=
input
.
numel
();
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
output
);
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
output
);
int64_t
*
ids_data
=
dev_ctx
.
template
Alloc
<
int64_t
>(
indices
);
int64_t
*
ids_data
=
dev_ctx
.
template
Alloc
<
int64_t
>(
indices
);
// Use thrust for parallel acceleration when the input size is equal to the
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
// length of the ‘axis’ dimension.
...
@@ -239,13 +423,23 @@ void ArgsortKernel(const Context& dev_ctx,
...
@@ -239,13 +423,23 @@ void ArgsortKernel(const Context& dev_ctx,
const
int64_t
input_height
=
const
int64_t
input_height
=
phi
::
product
(
phi
::
slice_ddim
(
in_dims
,
0
,
in_dims
.
size
()
-
1
));
phi
::
product
(
phi
::
slice_ddim
(
in_dims
,
0
,
in_dims
.
size
()
-
1
));
const
int64_t
input_width
=
in_dims
[
in_dims
.
size
()
-
1
];
const
int64_t
input_width
=
in_dims
[
in_dims
.
size
()
-
1
];
ArgFullSort
<
T
,
int64_t
>
(
dev_ctx
,
if
(
input_height
<=
4
)
{
&
input
,
ArgFullSortForTinyRows
<
T
,
int64_t
>
(
dev_ctx
,
output
,
&
input
,
indices
,
output
,
input_height
,
indices
,
input_width
,
input_height
,
descending
);
input_width
,
descending
);
}
else
{
ArgFullSort
<
T
,
int64_t
>
(
dev_ctx
,
&
input
,
output
,
indices
,
input_height
,
input_width
,
descending
);
}
}
else
{
}
else
{
// if not full sort, do transpose first
// if not full sort, do transpose first
std
::
vector
<
int
>
trans
;
std
::
vector
<
int
>
trans
;
...
@@ -264,7 +458,7 @@ void ArgsortKernel(const Context& dev_ctx,
...
@@ -264,7 +458,7 @@ void ArgsortKernel(const Context& dev_ctx,
DenseTensor
trans_inp
;
DenseTensor
trans_inp
;
trans_inp
.
Resize
(
trans_dims
);
trans_inp
.
Resize
(
trans_dims
);
T
*
trans_inp_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
trans_inp
);
T
*
trans_inp_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
trans_inp
);
// Do transpose
// Do transpose
TransposeKernel
<
T
,
Context
>
(
dev_ctx
,
input
,
trans
,
&
trans_inp
);
TransposeKernel
<
T
,
Context
>
(
dev_ctx
,
input
,
trans
,
&
trans_inp
);
...
@@ -282,13 +476,23 @@ void ArgsortKernel(const Context& dev_ctx,
...
@@ -282,13 +476,23 @@ void ArgsortKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
int64_t
>(
&
tmp_indices
);
dev_ctx
.
template
Alloc
<
int64_t
>(
&
tmp_indices
);
dev_ctx
.
template
Alloc
<
int64_t
>(
indices
);
dev_ctx
.
template
Alloc
<
int64_t
>(
indices
);
ArgFullSort
<
T
,
int64_t
>
(
dev_ctx
,
if
(
input_height
<=
4
)
{
&
trans_inp
,
ArgFullSortForTinyRows
<
T
,
int64_t
>
(
dev_ctx
,
&
tmp_out
,
&
trans_inp
,
&
tmp_indices
,
&
tmp_out
,
input_height
,
&
tmp_indices
,
input_width
,
input_height
,
descending
);
input_width
,
descending
);
}
else
{
ArgFullSort
<
T
,
int64_t
>
(
dev_ctx
,
&
trans_inp
,
&
tmp_out
,
&
tmp_indices
,
input_height
,
input_width
,
descending
);
}
TransposeKernel
<
int64_t
,
Context
>
(
dev_ctx
,
tmp_indices
,
trans
,
indices
);
TransposeKernel
<
int64_t
,
Context
>
(
dev_ctx
,
tmp_indices
,
trans
,
indices
);
// transpose back
// transpose back
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录