Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f4e74887
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f4e74887
编写于
2月 22, 2022
作者:
N
niuliling123
提交者:
GitHub
2月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Sort API for Kernel Primitive API (#39734)
* Add Sort API for Kernel Primitive API * update & -> ptr
上级
de760d2c
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
123 addition
and
0 deletion
+123
-0
paddle/phi/kernels/primitive/compute_primitives.h
paddle/phi/kernels/primitive/compute_primitives.h
+123
-0
未找到文件。
paddle/phi/kernels/primitive/compute_primitives.h
浏览文件 @
f4e74887
...
...
@@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
return
shared_memory
[
threadIdx
.
x
];
}
// Swap data
template
<
typename
T
>
__device__
__forceinline__
void
Swap
(
T
*
first_value
,
T
*
second_value
)
{
T
t_value
;
t_value
=
(
*
first_value
);
(
*
first_value
)
=
(
*
second_value
);
(
*
second_value
)
=
t_value
;
}
// swap with monotonic_type
template
<
typename
T
>
__device__
__forceinline__
void
Comparator
(
T
*
first_value
,
T
*
second_value
,
int
monotonic_type
)
{
if
(((
*
first_value
)
>
(
*
second_value
))
==
monotonic_type
)
{
Swap
<
T
>
(
first_value
,
second_value
);
}
}
template
<
typename
T
,
typename
IndexType
>
__device__
__forceinline__
void
ComparatorWithIndex
(
T
*
first_value
,
T
*
second_value
,
IndexType
*
first_index
,
IndexType
*
second_index
,
int
monotonic_type
)
{
if
((
*
first_value
>
(
*
second_value
))
==
monotonic_type
)
{
// swap value
Swap
<
T
>
(
first_value
,
second_value
);
// swap index
Swap
<
IndexType
>
(
first_index
,
second_index
);
}
}
}
// namespace details
/**
...
...
@@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out,
static_cast
<
OutT
>
(
temp
[
tidx
+
shared_size
+
(
tidx
+
shared_size
)
/
32
]);
}
#define SHARED_SIZE_LIMIT \
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
// if monotonic_type = 1 then increase
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
// == 1 the increase
template
<
typename
T
>
__device__
__forceinline__
void
Sort
(
T
*
dst
,
const
T
*
src_data
,
int
num
,
int
monotonic_type
)
{
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__
T
value
[
SHARED_SIZE_LIMIT
];
// shareMem's size must larger than
// blockDim * 2
// Copy value and index from src and src_index
value
[
threadIdx
.
x
]
=
src_data
[
0
];
value
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)]
=
src_data
[
1
];
// make bitonicSort
for
(
int
size
=
2
;
size
<
num
;
size
<<=
1
)
{
int
bitonic_type
=
(
threadIdx
.
x
&
(
size
/
2
))
!=
0
;
for
(
int
stride
=
size
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
int
pos
=
2
*
threadIdx
.
x
-
(
threadIdx
.
x
&
(
stride
-
1
));
details
::
Comparator
<
T
>
(
&
value
[
pos
],
&
value
[
pos
+
stride
],
bitonic_type
);
}
}
// last sort
for
(
int
stride
=
SHARED_SIZE_LIMIT
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
int
pos
=
2
*
threadIdx
.
x
-
(
threadIdx
.
x
&
(
stride
-
1
));
// last sort when monotonic_type = 1 then increase
details
::
Comparator
<
T
>
(
&
value
[
pos
],
&
value
[
pos
+
stride
],
monotonic_type
);
}
__syncthreads
();
dst
[
0
]
=
value
[
threadIdx
.
x
];
dst
[
1
]
=
value
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)];
}
template
<
typename
T
,
typename
IndexType
>
__device__
__forceinline__
void
Sort
(
T
*
dst
,
IndexType
*
dst_index
,
const
T
*
src_data
,
IndexType
*
src_index
,
int
num
,
int
monotonic_type
)
{
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__
T
value
[
SHARED_SIZE_LIMIT
];
// shareMem's size must larger than
// blockDim * 2
__shared__
IndexType
index
[
SHARED_SIZE_LIMIT
];
// Copy value and index from src and src_index
value
[
threadIdx
.
x
]
=
src_data
[
0
];
value
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)]
=
src_data
[
1
];
// index
index
[
threadIdx
.
x
]
=
src_index
[
0
];
index
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)]
=
src_index
[
1
];
// make bitonicSort
for
(
int
size
=
2
;
size
<
num
;
size
<<=
1
)
{
int
bitonic_type
=
(
threadIdx
.
x
&
(
size
/
2
))
!=
0
;
for
(
int
stride
=
size
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
int
pos
=
2
*
threadIdx
.
x
-
(
threadIdx
.
x
&
(
stride
-
1
));
details
::
ComparatorWithIndex
<
T
,
IndexType
>
(
&
value
[
pos
],
&
value
[
pos
+
stride
],
&
index
[
pos
],
&
index
[
pos
+
stride
],
bitonic_type
);
}
}
for
(
int
stride
=
SHARED_SIZE_LIMIT
/
2
;
stride
>
0
;
stride
>>=
1
)
{
__syncthreads
();
int
pos
=
2
*
threadIdx
.
x
-
(
threadIdx
.
x
&
(
stride
-
1
));
// last sort when monotonic_type = 1 then increase
details
::
ComparatorWithIndex
<
T
,
IndexType
>
(
&
value
[
pos
],
&
value
[
pos
+
stride
],
&
index
[
pos
],
&
index
[
pos
+
stride
],
monotonic_type
);
}
__syncthreads
();
dst
[
0
]
=
value
[
threadIdx
.
x
];
dst
[
1
]
=
value
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)];
dst_index
[
0
]
=
index
[
threadIdx
.
x
];
dst_index
[
1
]
=
index
[
threadIdx
.
x
+
(
SHARED_SIZE_LIMIT
/
2
)];
}
}
// namespace kps
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录