Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
03ef0bdc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
03ef0bdc
编写于
8月 23, 2022
作者:
S
Siming Dai
提交者:
GitHub
8月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Geometric] Fix cuda configuration error for message_passing api (#45315)
上级
0e384ade
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
48 addition
and
11 deletion
+48
-11
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
+1
-1
paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h
paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h
+37
-1
paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu
paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu
+4
-4
paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu
paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu
+2
-1
paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu
paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu
+3
-3
paddle/phi/kernels/gpu/graph_send_uv_kernel.cu
paddle/phi/kernels/gpu/graph_send_uv_kernel.cu
+1
-1
未找到文件。
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
浏览文件 @
03ef0bdc
...
@@ -54,7 +54,7 @@ struct MaxFunctor {
...
@@ -54,7 +54,7 @@ struct MaxFunctor {
if
(
x
>
cap
)
{
if
(
x
>
cap
)
{
return
cap
;
return
cap
;
}
}
return
x
;
return
x
>=
0
?
x
:
0
;
}
}
};
};
...
...
paddle/phi/kernels/gpu/graph_send_ue_recv_funcs.h
浏览文件 @
03ef0bdc
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Copyright 2022 The DGL team for some useful functions.
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// You may obtain a copy of the License at
...
@@ -23,6 +24,10 @@
...
@@ -23,6 +24,10 @@
namespace
phi
{
namespace
phi
{
#define CUDA_MAX_NUM_BLOCKS_X 0x7FFFFFFF
#define CUDA_MAX_NUM_BLOCKS_Y 0xFFFF
#define CUDA_MAX_NUM_BLOCKS_Z 0xFFFF
inline
void
CopyBCastOff
(
const
BroadCastInfo
&
bcast_info
,
inline
void
CopyBCastOff
(
const
BroadCastInfo
&
bcast_info
,
thrust
::
device_vector
<
int64_t
>&
l_bcastoff
,
thrust
::
device_vector
<
int64_t
>&
l_bcastoff
,
thrust
::
device_vector
<
int64_t
>&
r_bcastoff
)
{
thrust
::
device_vector
<
int64_t
>&
r_bcastoff
)
{
...
@@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) {
...
@@ -63,6 +68,37 @@ inline int FindNumThreads(int dim, int max_num_threads) {
return
res
;
return
res
;
}
}
inline
int
FindNumBlocks
(
char
axis
,
int
nblocks
,
int
max_num_blocks
=
-
1
)
{
int
default_max_num_blocks
=
-
1
;
switch
(
axis
)
{
case
'x'
:
default_max_num_blocks
=
CUDA_MAX_NUM_BLOCKS_X
;
break
;
case
'y'
:
default_max_num_blocks
=
CUDA_MAX_NUM_BLOCKS_Y
;
break
;
case
'z'
:
default_max_num_blocks
=
CUDA_MAX_NUM_BLOCKS_Z
;
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"%c axis is not recognized"
,
axis
));
}
if
(
max_num_blocks
==
-
1
)
{
max_num_blocks
=
default_max_num_blocks
;
}
PADDLE_ENFORCE_GT
(
max_num_blocks
,
0
,
phi
::
errors
::
InvalidArgument
(
"max_num_blocks should be larger than 0, "
"but received %d"
,
max_num_blocks
));
if
(
nblocks
<
max_num_blocks
)
{
return
nblocks
;
}
return
max_num_blocks
;
}
template
<
typename
T
>
template
<
typename
T
>
struct
GraphSendUERecvSumCUDAFunctor
{
struct
GraphSendUERecvSumCUDAFunctor
{
DEVICE
inline
void
operator
()(
T
*
output
,
T
val
)
{
DEVICE
inline
void
operator
()(
T
*
output
,
T
val
)
{
...
...
paddle/phi/kernels/gpu/graph_send_ue_recv_grad_kernel.cu
浏览文件 @
03ef0bdc
...
@@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx,
...
@@ -52,7 +52,7 @@ void CalculateXEGradForMinMax(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid
(
nbx
,
nby
);
const
dim3
grid
(
nbx
,
nby
);
const
dim3
block
(
ntx
,
nty
);
const
dim3
block
(
ntx
,
nty
);
...
@@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx,
...
@@ -183,7 +183,7 @@ void CalculateXGrad(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
block_
(
ntx
,
nty
);
const
dim3
block_
(
ntx
,
nty
);
funcs
::
MultiplyFunctor
<
T
>
mul_functor
;
funcs
::
MultiplyFunctor
<
T
>
mul_functor
;
...
@@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx,
...
@@ -306,7 +306,7 @@ void CalculateXGrad(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
block_
(
ntx
,
nty
);
const
dim3
block_
(
ntx
,
nty
);
if
(
!
reduce
)
{
if
(
!
reduce
)
{
...
@@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx,
...
@@ -392,7 +392,7 @@ void CalculateEGrad(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid
(
nbx
,
nby
);
const
dim3
grid
(
nbx
,
nby
);
const
dim3
block
(
ntx
,
nty
);
const
dim3
block
(
ntx
,
nty
);
if
(
reduce_op
==
"SUM"
)
{
if
(
reduce_op
==
"SUM"
)
{
...
...
paddle/phi/kernels/gpu/graph_send_ue_recv_kernel.cu
浏览文件 @
03ef0bdc
...
@@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
...
@@ -81,6 +81,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
if
(
index_size
==
0
)
return
;
if
(
index_size
==
0
)
return
;
const
auto
&
bcast_info
=
phi
::
CalcBCastInfo
(
x
.
dims
(),
e
.
dims
());
const
auto
&
bcast_info
=
phi
::
CalcBCastInfo
(
x
.
dims
(),
e
.
dims
());
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
e_data
=
e
.
data
<
T
>
();
const
T
*
e_data
=
e
.
data
<
T
>
();
const
IndexT
*
s_index
=
src_index
.
data
<
IndexT
>
();
const
IndexT
*
s_index
=
src_index
.
data
<
IndexT
>
();
...
@@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
...
@@ -95,7 +96,7 @@ void GraphSendUERecvOpCUDAKernelLaunchHelper(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid
(
nbx
,
nby
);
const
dim3
grid
(
nbx
,
nby
);
const
dim3
block
(
ntx
,
nty
);
const
dim3
block
(
ntx
,
nty
);
int64_t
input_size
=
x
.
dims
()[
0
];
int64_t
input_size
=
x
.
dims
()[
0
];
...
...
paddle/phi/kernels/gpu/graph_send_uv_grad_kernel.cu
浏览文件 @
03ef0bdc
...
@@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx,
...
@@ -73,7 +73,7 @@ void CalculateGrad(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
slice_size
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
slice_size
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
slice_size
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
slice_size
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid_tmp
(
nbx
,
nby
);
const
dim3
grid_tmp
(
nbx
,
nby
);
const
dim3
block_tmp
(
ntx
,
nty
);
const
dim3
block_tmp
(
ntx
,
nty
);
GraphSendUVGradCUDAKernel
<
T
,
IndexT
>
GraphSendUVGradCUDAKernel
<
T
,
IndexT
>
...
@@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx,
...
@@ -93,7 +93,7 @@ void CalculateGrad(const Context& ctx,
FindNumThreads
(
bcast_info
.
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
FindNumThreads
(
bcast_info
.
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
bcast_info
.
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
bcast_info
.
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid_tmp
(
nbx
,
nby
);
const
dim3
grid_tmp
(
nbx
,
nby
);
const
dim3
block_tmp
(
ntx
,
nty
);
const
dim3
block_tmp
(
ntx
,
nty
);
GraphSendUVGradCUDAKernel
<
T
,
IndexT
>
GraphSendUVGradCUDAKernel
<
T
,
IndexT
>
...
@@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx,
...
@@ -133,7 +133,7 @@ void CalculateGrad(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
grid_
(
nbx
,
nby
);
const
dim3
block_
(
ntx
,
nty
);
const
dim3
block_
(
ntx
,
nty
);
funcs
::
MultiplyFunctor
<
T
>
mul_functor
;
funcs
::
MultiplyFunctor
<
T
>
mul_functor
;
...
...
paddle/phi/kernels/gpu/graph_send_uv_kernel.cu
浏览文件 @
03ef0bdc
...
@@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
...
@@ -101,7 +101,7 @@ void GraphSendUVOpCUDAKernelLaunchHelper(const Context& ctx,
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
ntx
=
FindNumThreads
(
out_len
,
ctx
.
GetMaxThreadsPerBlock
());
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nty
=
ctx
.
GetMaxThreadsPerBlock
()
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nbx
=
(
out_len
+
ntx
-
1
)
/
ntx
;
const
int
nby
=
(
index_size
+
nty
-
1
)
/
nty
;
const
int
nby
=
FindNumBlocks
(
'y'
,
(
index_size
+
nty
-
1
)
/
nty
)
;
const
dim3
grid
(
nbx
,
nby
);
const
dim3
grid
(
nbx
,
nby
);
const
dim3
block
(
ntx
,
nty
);
const
dim3
block
(
ntx
,
nty
);
if
(
message_op
==
"ADD"
)
{
if
(
message_op
==
"ADD"
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录