Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7a1cf277
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7a1cf277
编写于
11月 02, 2022
作者:
S
Siming Dai
提交者:
GitHub
11月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[geometric] Optimize graph sample speed (#47531) (#47548)
上级
61953b90
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
36 addition
and
39 deletion
+36
-39
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
+36
-39
未找到文件。
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
浏览文件 @
7a1cf277
...
...
@@ -58,7 +58,7 @@ struct MaxFunctor {
}
};
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
SampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
const
int64_t
num_nodes
,
...
...
@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
T
*
output_eids
,
int
*
output_ptr
,
bool
return_eids
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
...
...
@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#ifdef PADDLE_WITH_HIP
hiprandState
rng
;
hiprand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
*
WARP
_SIZE
+
threadIdx
.
x
,
threadIdx
.
y
*
CTA
_SIZE
+
threadIdx
.
x
,
0
,
&
rng
);
#else
curandState
rng
;
curandState
Philox4_32_10_t
rng
;
curand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
*
WARP
_SIZE
+
threadIdx
.
x
,
threadIdx
.
y
*
CTA
_SIZE
+
threadIdx
.
x
,
0
,
&
rng
);
#endif
...
...
@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
T
node
=
nodes
[
out_row
];
if
(
node
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
continue
;
}
T
in_row_start
=
col_ptr
[
node
];
...
...
@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed,
int
out_row_start
=
output_ptr
[
out_row
];
if
(
deg
<=
k
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
output
[
out_row_start
+
idx
]
=
row
[
in_row_start
+
idx
];
if
(
return_eids
)
{
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
}
}
}
else
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
CTA
_SIZE
)
{
output
[
out_row_start
+
idx
]
=
idx
;
}
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
for
(
int
idx
=
k
+
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
k
+
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
#ifdef PADDLE_WITH_HIP
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
#else
...
...
@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
CTA
_SIZE
)
{
T
perm_idx
=
output
[
out_row_start
+
idx
]
+
in_row_start
;
output
[
out_row_start
+
idx
]
=
row
[
perm_idx
];
if
(
return_eids
)
{
...
...
@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
...
...
@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
thrust
::
exclusive_scan
(
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
BLOCK_
WARPS
=
128
/
WARP
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
WARPS
*
16
;
const
dim3
block
(
WARP_SIZE
,
BLOCK_WARP
S
);
constexpr
int
CTA_SIZE
=
128
;
constexpr
int
BLOCK_
CTAS
=
128
/
CTA
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
CTAS
;
const
dim3
block
(
CTA_SIZE
,
BLOCK_CTA
S
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
SampleKernel
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
SampleKernel
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
sample_size
,
...
...
@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids
);
}
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
FisherYatesSampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
const
int64_t
num_rows
,
...
...
@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const
T
*
in_rows
,
T
*
src
,
const
T
*
dst_count
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
...
...
@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
hiprand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
#else
curandState
rng
;
curandState
Philox4_32_10_t
rng
;
curand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
#endif
...
...
@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
const
T
row
=
in_rows
[
out_row
];
if
(
row
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
continue
;
}
const
T
in_row_start
=
dst_count
[
row
];
...
...
@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
}
else
{
split
=
deg
-
k
;
}
for
(
int
idx
=
split
+
threadIdx
.
x
;
idx
<=
deg
-
1
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
split
+
threadIdx
.
x
;
idx
<=
deg
-
1
;
idx
+=
CTA
_SIZE
)
{
#ifdef PADDLE_WITH_HIP
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
#else
...
...
@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
src
[
in_row_start
+
idx
])));
}
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
GatherEdge
(
int
k
,
int64_t
num_rows
,
const
T
*
in_rows
,
...
...
@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
int
*
output_ptr
,
T
*
perm_data
,
bool
return_eids
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
...
...
@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
const
T
out_row_start
=
output_ptr
[
out_row
];
if
(
deg
<=
k
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
outputs
[
out_row_start
+
idx
]
=
src
[
in_row_start
+
idx
];
if
(
return_eids
)
{
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
...
...
@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
end
=
deg
;
}
for
(
int
idx
=
begin
+
threadIdx
.
x
;
idx
<
end
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
begin
+
threadIdx
.
x
;
idx
<
end
;
idx
+=
CTA
_SIZE
)
{
outputs
[
out_row_start
+
idx
-
begin
]
=
src
[
perm_data
[
in_row_start
+
idx
]];
if
(
return_eids
)
{
...
...
@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
}
}
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
...
...
@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
thrust
::
exclusive_scan
(
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
BLOCK_
WARPS
=
128
/
WARP
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
WARPS
*
16
;
const
dim3
block
(
WARP_SIZE
,
BLOCK_WARP
S
);
constexpr
int
CTA_SIZE
=
128
;
constexpr
int
BLOCK_
CTAS
=
128
/
CTA
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
CTAS
;
const
dim3
block
(
CTA_SIZE
,
BLOCK_CTA
S
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
FisherYatesSampleKernel
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
FisherYatesSampleKernel
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
sample_size
,
bs
,
...
...
@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
perm_data
,
col_ptr
);
GatherEdge
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
GatherEdge
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
sample_size
,
bs
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录