Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
59be2f3b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
59be2f3b
编写于
8月 09, 2022
作者:
S
Siming Dai
提交者:
GitHub
8月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[GNN] Fix graph sample and data type bug (#45001)
上级
125e48c3
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
35 addition
and
10 deletion
+35
-10
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+1
-1
paddle/fluid/pybind/slice_utils.h
paddle/fluid/pybind/slice_utils.h
+2
-2
paddle/fluid/pybind/tensor_py.h
paddle/fluid/pybind/tensor_py.h
+3
-3
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
+29
-4
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
59be2f3b
...
@@ -670,7 +670,7 @@ void BindImperative(py::module *m_ptr) {
...
@@ -670,7 +670,7 @@ void BindImperative(py::module *m_ptr) {
.
def
(
"__init__"
,
.
def
(
"__init__"
,
[](
imperative
::
VarBase
&
self
,
[](
imperative
::
VarBase
&
self
,
framework
::
proto
::
VarType
::
Type
dtype
,
framework
::
proto
::
VarType
::
Type
dtype
,
const
std
::
vector
<
int
>
&
dims
,
const
std
::
vector
<
int
64_t
>
&
dims
,
const
py
::
handle
&
name
,
const
py
::
handle
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
framework
::
proto
::
VarType
::
Type
type
,
bool
persistable
)
{
bool
persistable
)
{
...
...
paddle/fluid/pybind/slice_utils.h
浏览文件 @
59be2f3b
...
@@ -191,10 +191,10 @@ static void ParseIndexingSlice(framework::LoDTensor* tensor,
...
@@ -191,10 +191,10 @@ static void ParseIndexingSlice(framework::LoDTensor* tensor,
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
PyObject
*
slice_item
=
PyTuple_GetItem
(
index
,
i
);
infer_flags
->
push_back
(
1
);
infer_flags
->
push_back
(
1
);
int
dim_len
=
shape
[
dim
];
int
64_t
dim_len
=
shape
[
dim
];
if
(
PyCheckInteger
(
slice_item
)
||
IsNumpyType
(
slice_item
))
{
if
(
PyCheckInteger
(
slice_item
)
||
IsNumpyType
(
slice_item
))
{
// integer, PyLong_AsLong supports both int and long
// integer, PyLong_AsLong supports both int and long
int
start
=
static_cast
<
in
t
>
(
PyLong_AsLong
(
slice_item
));
int
64_t
start
=
static_cast
<
int64_
t
>
(
PyLong_AsLong
(
slice_item
));
auto
s_t
=
start
;
auto
s_t
=
start
;
start
=
start
<
0
?
start
+
dim_len
:
start
;
start
=
start
<
0
?
start
+
dim_len
:
start
;
...
...
paddle/fluid/pybind/tensor_py.h
浏览文件 @
59be2f3b
...
@@ -368,7 +368,7 @@ void SetTensorFromPyArrayT(
...
@@ -368,7 +368,7 @@ void SetTensorFromPyArrayT(
std
::
vector
<
int64_t
>
dims
;
std
::
vector
<
int64_t
>
dims
;
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
push_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
push_back
(
static_cast
<
int
64_t
>
(
array
.
shape
()[
i
]));
}
}
self
->
Resize
(
phi
::
make_ddim
(
dims
));
self
->
Resize
(
phi
::
make_ddim
(
dims
));
...
@@ -612,8 +612,8 @@ void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor,
...
@@ -612,8 +612,8 @@ void SetUVATensorFromPyArrayImpl(framework::LoDTensor *self_tensor,
dims
.
reserve
(
array
.
ndim
());
dims
.
reserve
(
array
.
ndim
());
int64_t
numel
=
1
;
int64_t
numel
=
1
;
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
for
(
decltype
(
array
.
ndim
())
i
=
0
;
i
<
array
.
ndim
();
++
i
)
{
dims
.
emplace_back
(
static_cast
<
int
>
(
array
.
shape
()[
i
]));
dims
.
emplace_back
(
static_cast
<
int
64_t
>
(
array
.
shape
()[
i
]));
numel
*=
static_cast
<
int
>
(
array
.
shape
()[
i
]);
numel
*=
static_cast
<
int
64_t
>
(
array
.
shape
()[
i
]);
}
}
self_tensor
->
Resize
(
phi
::
make_ddim
(
dims
));
self_tensor
->
Resize
(
phi
::
make_ddim
(
dims
));
...
...
paddle/phi/kernels/gpu/graph_sample_neighbors_kernel.cu
浏览文件 @
59be2f3b
...
@@ -37,9 +37,13 @@ namespace phi {
...
@@ -37,9 +37,13 @@ namespace phi {
template
<
typename
T
>
template
<
typename
T
>
struct
DegreeFunctor
{
struct
DegreeFunctor
{
const
T
*
col_ptr
;
const
T
*
col_ptr
;
HOSTDEVICE
explicit
inline
DegreeFunctor
(
const
T
*
x
)
{
this
->
col_ptr
=
x
;
}
int64_t
len_col_ptr
;
HOSTDEVICE
explicit
inline
DegreeFunctor
(
const
T
*
x
,
int64_t
len_col_ptr
)
{
this
->
col_ptr
=
x
;
this
->
len_col_ptr
=
len_col_ptr
;
}
HOSTDEVICE
inline
int
operator
()(
T
i
)
const
{
HOSTDEVICE
inline
int
operator
()(
T
i
)
const
{
return
col_ptr
[
i
+
1
]
-
col_ptr
[
i
];
return
i
>
len_col_ptr
-
1
?
0
:
col_ptr
[
i
+
1
]
-
col_ptr
[
i
];
}
}
};
};
...
@@ -58,6 +62,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
...
@@ -58,6 +62,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__
void
SampleKernel
(
const
uint64_t
rand_seed
,
__global__
void
SampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
int
k
,
const
int64_t
num_nodes
,
const
int64_t
num_nodes
,
const
int64_t
len_col_ptr
,
const
T
*
nodes
,
const
T
*
nodes
,
const
T
*
row
,
const
T
*
row
,
const
T
*
col_ptr
,
const
T
*
col_ptr
,
...
@@ -88,6 +93,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -88,6 +93,10 @@ __global__ void SampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
while
(
out_row
<
last_row
)
{
T
node
=
nodes
[
out_row
];
T
node
=
nodes
[
out_row
];
if
(
node
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_WARPS
;
continue
;
}
T
in_row_start
=
col_ptr
[
node
];
T
in_row_start
=
col_ptr
[
node
];
int
deg
=
col_ptr
[
node
+
1
]
-
in_row_start
;
int
deg
=
col_ptr
[
node
+
1
]
-
in_row_start
;
int
out_row_start
=
output_ptr
[
out_row
];
int
out_row_start
=
output_ptr
[
out_row
];
...
@@ -139,10 +148,12 @@ __global__ void SampleKernel(const uint64_t rand_seed,
...
@@ -139,10 +148,12 @@ __global__ void SampleKernel(const uint64_t rand_seed,
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
int
GetTotalSampleNum
(
const
thrust
::
device_ptr
<
const
T
>
input
,
int
GetTotalSampleNum
(
const
thrust
::
device_ptr
<
const
T
>
input
,
const
T
*
col_ptr
,
const
T
*
col_ptr
,
int64_t
len_col_ptr
,
thrust
::
device_ptr
<
int
>
output_count
,
thrust
::
device_ptr
<
int
>
output_count
,
int
sample_size
,
int
sample_size
,
int
bs
)
{
int
bs
)
{
thrust
::
transform
(
input
,
input
+
bs
,
output_count
,
DegreeFunctor
<
T
>
(
col_ptr
));
thrust
::
transform
(
input
,
input
+
bs
,
output_count
,
DegreeFunctor
<
T
>
(
col_ptr
,
len_col_ptr
));
if
(
sample_size
>=
0
)
{
if
(
sample_size
>=
0
)
{
thrust
::
transform
(
thrust
::
transform
(
output_count
,
output_count
+
bs
,
output_count
,
MaxFunctor
(
sample_size
));
output_count
,
output_count
+
bs
,
output_count
,
MaxFunctor
(
sample_size
));
...
@@ -163,6 +174,7 @@ void SampleNeighbors(const Context& dev_ctx,
...
@@ -163,6 +174,7 @@ void SampleNeighbors(const Context& dev_ctx,
int
sample_size
,
int
sample_size
,
int
bs
,
int
bs
,
int
total_sample_num
,
int
total_sample_num
,
int64_t
len_col_ptr
,
bool
return_eids
)
{
bool
return_eids
)
{
thrust
::
device_vector
<
int
>
output_ptr
;
thrust
::
device_vector
<
int
>
output_ptr
;
output_ptr
.
resize
(
bs
);
output_ptr
.
resize
(
bs
);
...
@@ -179,6 +191,7 @@ void SampleNeighbors(const Context& dev_ctx,
...
@@ -179,6 +191,7 @@ void SampleNeighbors(const Context& dev_ctx,
0
,
0
,
sample_size
,
sample_size
,
bs
,
bs
,
len_col_ptr
,
thrust
::
raw_pointer_cast
(
input
),
thrust
::
raw_pointer_cast
(
input
),
row
,
row
,
col_ptr
,
col_ptr
,
...
@@ -193,6 +206,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
...
@@ -193,6 +206,7 @@ template <typename T, int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__
void
FisherYatesSampleKernel
(
const
uint64_t
rand_seed
,
__global__
void
FisherYatesSampleKernel
(
const
uint64_t
rand_seed
,
int
k
,
int
k
,
const
int64_t
num_rows
,
const
int64_t
num_rows
,
const
int64_t
len_col_ptr
,
const
T
*
in_rows
,
const
T
*
in_rows
,
T
*
src
,
T
*
src
,
const
T
*
dst_count
)
{
const
T
*
dst_count
)
{
...
@@ -214,6 +228,10 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
...
@@ -214,6 +228,10 @@ __global__ void FisherYatesSampleKernel(const uint64_t rand_seed,
while
(
out_row
<
last_row
)
{
while
(
out_row
<
last_row
)
{
const
T
row
=
in_rows
[
out_row
];
const
T
row
=
in_rows
[
out_row
];
if
(
row
>
len_col_ptr
-
1
)
{
out_row
+=
BLOCK_WARPS
;
continue
;
}
const
T
in_row_start
=
dst_count
[
row
];
const
T
in_row_start
=
dst_count
[
row
];
const
int
deg
=
dst_count
[
row
+
1
]
-
in_row_start
;
const
int
deg
=
dst_count
[
row
+
1
]
-
in_row_start
;
int
split
;
int
split
;
...
@@ -312,6 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
...
@@ -312,6 +330,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
int
sample_size
,
int
sample_size
,
int
bs
,
int
bs
,
int
total_sample_num
,
int
total_sample_num
,
int64_t
len_col_ptr
,
bool
return_eids
)
{
bool
return_eids
)
{
thrust
::
device_vector
<
int
>
output_ptr
;
thrust
::
device_vector
<
int
>
output_ptr
;
output_ptr
.
resize
(
bs
);
output_ptr
.
resize
(
bs
);
...
@@ -328,6 +347,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
...
@@ -328,6 +347,7 @@ void FisherYatesSampleNeighbors(const Context& dev_ctx,
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
<<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
0
,
sample_size
,
sample_size
,
bs
,
bs
,
len_col_ptr
,
thrust
::
raw_pointer_cast
(
input
),
thrust
::
raw_pointer_cast
(
input
),
perm_data
,
perm_data
,
col_ptr
);
col_ptr
);
...
@@ -365,6 +385,7 @@ void GraphSampleNeighborsKernel(
...
@@ -365,6 +385,7 @@ void GraphSampleNeighborsKernel(
auto
*
col_ptr_data
=
col_ptr
.
data
<
T
>
();
auto
*
col_ptr_data
=
col_ptr
.
data
<
T
>
();
auto
*
x_data
=
x
.
data
<
T
>
();
auto
*
x_data
=
x
.
data
<
T
>
();
int
bs
=
x
.
dims
()[
0
];
int
bs
=
x
.
dims
()[
0
];
int64_t
len_col_ptr
=
col_ptr
.
dims
()[
0
];
const
thrust
::
device_ptr
<
const
T
>
input
(
x_data
);
const
thrust
::
device_ptr
<
const
T
>
input
(
x_data
);
...
@@ -373,7 +394,7 @@ void GraphSampleNeighborsKernel(
...
@@ -373,7 +394,7 @@ void GraphSampleNeighborsKernel(
thrust
::
device_ptr
<
int
>
output_count
(
out_count_data
);
thrust
::
device_ptr
<
int
>
output_count
(
out_count_data
);
int
total_sample_size
=
GetTotalSampleNum
<
T
,
Context
>
(
int
total_sample_size
=
GetTotalSampleNum
<
T
,
Context
>
(
input
,
col_ptr_data
,
output_count
,
sample_size
,
bs
);
input
,
col_ptr_data
,
len_col_ptr
,
output_count
,
sample_size
,
bs
);
out
->
Resize
({
static_cast
<
int
>
(
total_sample_size
)});
out
->
Resize
({
static_cast
<
int
>
(
total_sample_size
)});
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
T
*
out_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
...
@@ -396,6 +417,7 @@ void GraphSampleNeighborsKernel(
...
@@ -396,6 +417,7 @@ void GraphSampleNeighborsKernel(
sample_size
,
sample_size
,
bs
,
bs
,
total_sample_size
,
total_sample_size
,
len_col_ptr
,
return_eids
);
return_eids
);
}
else
{
}
else
{
DenseTensor
perm_buffer_out
(
perm_buffer
->
type
());
DenseTensor
perm_buffer_out
(
perm_buffer
->
type
());
...
@@ -414,6 +436,7 @@ void GraphSampleNeighborsKernel(
...
@@ -414,6 +436,7 @@ void GraphSampleNeighborsKernel(
sample_size
,
sample_size
,
bs
,
bs
,
total_sample_size
,
total_sample_size
,
len_col_ptr
,
return_eids
);
return_eids
);
}
}
}
else
{
}
else
{
...
@@ -431,6 +454,7 @@ void GraphSampleNeighborsKernel(
...
@@ -431,6 +454,7 @@ void GraphSampleNeighborsKernel(
sample_size
,
sample_size
,
bs
,
bs
,
total_sample_size
,
total_sample_size
,
len_col_ptr
,
return_eids
);
return_eids
);
}
else
{
}
else
{
DenseTensor
perm_buffer_out
(
perm_buffer
->
type
());
DenseTensor
perm_buffer_out
(
perm_buffer
->
type
());
...
@@ -449,6 +473,7 @@ void GraphSampleNeighborsKernel(
...
@@ -449,6 +473,7 @@ void GraphSampleNeighborsKernel(
sample_size
,
sample_size
,
bs
,
bs
,
total_sample_size
,
total_sample_size
,
len_col_ptr
,
return_eids
);
return_eids
);
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录