Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2a932e55
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2a932e55
编写于
11月 01, 2022
作者:
S
Siming Dai
提交者:
GitHub
11月 01, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[geometric] Optimize graph sample speed (#47531)
上级
32efda3d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
36 addition
and
39 deletion
+36
-39
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
+36
-39
未找到文件。
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
浏览文件 @
2a932e55
...
@@ -58,7 +58,7 @@ struct MaxFunctor {
...
@@ -58,7 +58,7 @@ struct MaxFunctor {
}
}
};
};
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
SampleKernel
(
const
uint64_t
rand_seed
,
__global__
void
SampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
int
k
,
const
int64_t
num_nodes
,
const
int64_t
num_nodes
,
...
@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -71,8 +71,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
T
*
output_eids
,
T
*
output_eids
,
int
*
output_ptr
,
int
*
output_ptr
,
bool
return_eids
)
{
bool
return_eids
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
const
int64_t
last_row
=
...
@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -80,13 +79,13 @@ __global__ void SampleKernel(const uint64_t rand_seed,
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
hiprandState
rng
;
hiprandState
rng
;
hiprand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
hiprand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
*
WARP
_SIZE
+
threadIdx
.
x
,
threadIdx
.
y
*
CTA
_SIZE
+
threadIdx
.
x
,
0
,
0
,
&
rng
);
&
rng
);
#else
#else
curandState
rng
;
curandState
Philox4_32_10_t
rng
;
curand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
curand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
*
WARP
_SIZE
+
threadIdx
.
x
,
threadIdx
.
y
*
CTA
_SIZE
+
threadIdx
.
x
,
0
,
0
,
&
rng
);
&
rng
);
#endif
#endif
...
@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -94,7 +93,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
while
(
out_row
<
last_row
)
{
T
node
=
nodes
[
out_row
];
T
node
=
nodes
[
out_row
];
if
(
node
>
len_col_ptr
-
1
)
{
if
(
node
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
continue
;
continue
;
}
}
T
in_row_start
=
col_ptr
[
node
];
T
in_row_start
=
col_ptr
[
node
];
...
@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -102,21 +101,21 @@ __global__ void SampleKernel(const uint64_t rand_seed,
int
out_row_start
=
output_ptr
[
out_row
];
int
out_row_start
=
output_ptr
[
out_row
];
if
(
deg
<=
k
)
{
if
(
deg
<=
k
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
output
[
out_row_start
+
idx
]
=
row
[
in_row_start
+
idx
];
output
[
out_row_start
+
idx
]
=
row
[
in_row_start
+
idx
];
if
(
return_eids
)
{
if
(
return_eids
)
{
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
}
}
}
}
}
else
{
}
else
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
CTA
_SIZE
)
{
output
[
out_row_start
+
idx
]
=
idx
;
output
[
out_row_start
+
idx
]
=
idx
;
}
}
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
#endif
for
(
int
idx
=
k
+
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
k
+
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
#else
#else
...
@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -129,10 +128,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
}
}
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
#endif
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
k
;
idx
+=
CTA
_SIZE
)
{
T
perm_idx
=
output
[
out_row_start
+
idx
]
+
in_row_start
;
T
perm_idx
=
output
[
out_row_start
+
idx
]
+
in_row_start
;
output
[
out_row_start
+
idx
]
=
row
[
perm_idx
];
output
[
out_row_start
+
idx
]
=
row
[
perm_idx
];
if
(
return_eids
)
{
if
(
return_eids
)
{
...
@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -141,7 +140,7 @@ __global__ void SampleKernel(const uint64_t rand_seed,
}
}
}
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
}
}
...
@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
...
@@ -181,12 +180,12 @@ void SampleNeighbors(const Context& dev_ctx,
thrust
::
exclusive_scan
(
thrust
::
exclusive_scan
(
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
CTA_SIZE
=
128
;
constexpr
int
BLOCK_
WARPS
=
128
/
WARP
_SIZE
;
constexpr
int
BLOCK_
CTAS
=
128
/
CTA
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
WARPS
*
16
;
constexpr
int
TILE_SIZE
=
BLOCK_
CTAS
;
const
dim3
block
(
WARP_SIZE
,
BLOCK_WARP
S
);
const
dim3
block
(
CTA_SIZE
,
BLOCK_CTA
S
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
SampleKernel
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
SampleKernel
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
0
,
sample_size
,
sample_size
,
...
@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
...
@@ -202,7 +201,7 @@ void SampleNeighbors(const Context& dev_ctx,
return_eids
);
return_eids
);
}
}
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
FisherYatesSampleKernel
(
const
uint64_t
rand_seed
,
__global__
void
FisherYatesSampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
int
k
,
const
int64_t
num_rows
,
const
int64_t
num_rows
,
...
@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -210,8 +209,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
const
T
*
in_rows
,
const
T
*
in_rows
,
T
*
src
,
T
*
src
,
const
T
*
dst_count
)
{
const
T
*
dst_count
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
const
int64_t
last_row
=
...
@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -221,7 +219,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
hiprand_init
(
hiprand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
#else
#else
curandState
rng
;
curandState
Philox4_32_10_t
rng
;
curand_init
(
curand_init
(
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
rand_seed
*
gridDim
.
x
+
blockIdx
.
x
,
threadIdx
.
y
+
threadIdx
.
x
,
0
,
&
rng
);
#endif
#endif
...
@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -229,7 +227,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
while
(
out_row
<
last_row
)
{
const
T
row
=
in_rows
[
out_row
];
const
T
row
=
in_rows
[
out_row
];
if
(
row
>
len_col_ptr
-
1
)
{
if
(
row
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
continue
;
continue
;
}
}
const
T
in_row_start
=
dst_count
[
row
];
const
T
in_row_start
=
dst_count
[
row
];
...
@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -241,7 +239,7 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
}
else
{
}
else
{
split
=
deg
-
k
;
split
=
deg
-
k
;
}
}
for
(
int
idx
=
split
+
threadIdx
.
x
;
idx
<=
deg
-
1
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
split
+
threadIdx
.
x
;
idx
<=
deg
-
1
;
idx
+=
CTA
_SIZE
)
{
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
const
int
num
=
hiprand
(
&
rng
)
%
(
idx
+
1
);
#else
#else
...
@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -254,14 +252,14 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
src
[
in_row_start
+
idx
])));
src
[
in_row_start
+
idx
])));
}
}
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
__sync
warp
();
__sync
threads
();
#endif
#endif
}
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
}
}
template
<
typename
T
,
int
WARP_SIZE
,
int
BLOCK_WARP
S
,
int
TILE_SIZE
>
template
<
typename
T
,
int
CTA_SIZE
,
int
BLOCK_CTA
S
,
int
TILE_SIZE
>
__global__
void
GatherEdge
(
int
k
,
__global__
void
GatherEdge
(
int
k
,
int64_t
num_rows
,
int64_t
num_rows
,
const
T
*
in_rows
,
const
T
*
in_rows
,
...
@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
...
@@ -273,8 +271,7 @@ __global__ void GatherEdge(int k,
int
*
output_ptr
,
int
*
output_ptr
,
T
*
perm_data
,
T
*
perm_data
,
bool
return_eids
)
{
bool
return_eids
)
{
assert
(
blockDim
.
x
==
WARP_SIZE
);
assert
(
blockDim
.
x
==
CTA_SIZE
);
assert
(
blockDim
.
y
==
BLOCK_WARPS
);
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
int64_t
out_row
=
blockIdx
.
x
*
TILE_SIZE
+
threadIdx
.
y
;
const
int64_t
last_row
=
const
int64_t
last_row
=
...
@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
...
@@ -287,7 +284,7 @@ __global__ void GatherEdge(int k,
const
T
out_row_start
=
output_ptr
[
out_row
];
const
T
out_row_start
=
output_ptr
[
out_row
];
if
(
deg
<=
k
)
{
if
(
deg
<=
k
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
deg
;
idx
+=
CTA
_SIZE
)
{
outputs
[
out_row_start
+
idx
]
=
src
[
in_row_start
+
idx
];
outputs
[
out_row_start
+
idx
]
=
src
[
in_row_start
+
idx
];
if
(
return_eids
)
{
if
(
return_eids
)
{
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
output_eids
[
out_row_start
+
idx
]
=
eids
[
in_row_start
+
idx
];
...
@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
...
@@ -304,7 +301,7 @@ __global__ void GatherEdge(int k,
end
=
deg
;
end
=
deg
;
}
}
for
(
int
idx
=
begin
+
threadIdx
.
x
;
idx
<
end
;
idx
+=
WARP
_SIZE
)
{
for
(
int
idx
=
begin
+
threadIdx
.
x
;
idx
<
end
;
idx
+=
CTA
_SIZE
)
{
outputs
[
out_row_start
+
idx
-
begin
]
=
outputs
[
out_row_start
+
idx
-
begin
]
=
src
[
perm_data
[
in_row_start
+
idx
]];
src
[
perm_data
[
in_row_start
+
idx
]];
if
(
return_eids
)
{
if
(
return_eids
)
{
...
@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
...
@@ -313,7 +310,7 @@ __global__ void GatherEdge(int k,
}
}
}
}
}
}
out_row
+=
BLOCK_
WARP
S
;
out_row
+=
BLOCK_
CTA
S
;
}
}
}
}
...
@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
...
@@ -337,13 +334,13 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
thrust
::
exclusive_scan
(
thrust
::
exclusive_scan
(
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
output_count
,
output_count
+
bs
,
output_ptr
.
begin
(),
0
);
constexpr
int
WARP_SIZE
=
32
;
constexpr
int
CTA_SIZE
=
128
;
constexpr
int
BLOCK_
WARPS
=
128
/
WARP
_SIZE
;
constexpr
int
BLOCK_
CTAS
=
128
/
CTA
_SIZE
;
constexpr
int
TILE_SIZE
=
BLOCK_
WARPS
*
16
;
constexpr
int
TILE_SIZE
=
BLOCK_
CTAS
;
const
dim3
block
(
WARP_SIZE
,
BLOCK_WARP
S
);
const
dim3
block
(
CTA_SIZE
,
BLOCK_CTA
S
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
const
dim3
grid
((
bs
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
FisherYatesSampleKernel
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
FisherYatesSampleKernel
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
sample_size
,
sample_size
,
bs
,
bs
,
...
@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
...
@@ -352,7 +349,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
perm_data
,
perm_data
,
col_ptr
);
col_ptr
);
GatherEdge
<
T
,
WARP_SIZE
,
BLOCK_WARP
S
,
TILE_SIZE
>
GatherEdge
<
T
,
CTA_SIZE
,
BLOCK_CTA
S
,
TILE_SIZE
>
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
sample_size
,
sample_size
,
bs
,
bs
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录