未验证 提交 ceb20406 编写于 作者: S Siming Dai 提交者: GitHub

Support hetergraph reindex (#43128)

* support heter reindex

* add unittest, fix bug

* add comment

* delete empty line

* refine example

* fix codestyle

* add disable static
上级 2bfe8b2c
......@@ -59,11 +59,15 @@ void GraphReindexKernel(const Context& dev_ctx,
src[i] = node_map[node];
}
// Reindex Dst
// Add support for multi-type edges reindex
int num_edge_types = count.dims()[0] / bs;
int cnt = 0;
for (int i = 0; i < bs; i++) {
for (int j = 0; j < count_data[i]; j++) {
T node = x_data[i];
dst[cnt++] = node_map[node];
for (int i = 0; i < num_edge_types; i++) {
for (int j = 0; j < bs; j++) {
for (int k = 0; k < count_data[i * bs + j]; k++) {
T node = x_data[j];
dst[cnt++] = node_map[node];
}
}
}
......
......@@ -331,26 +331,37 @@ void GraphReindexKernel(const Context& dev_ctx,
}
// Get reindex dst edge.
// Add support for multi-type edges reindex.
int num_ac_count = count.dims()[0];
int num_edge_types = num_ac_count / bs;
thrust::device_vector<int> unique_dst_reindex(bs);
thrust::sequence(unique_dst_reindex.begin(), unique_dst_reindex.end());
thrust::device_vector<int> dst_ptr(bs);
thrust::exclusive_scan(count_data, count_data + bs, dst_ptr.begin());
constexpr int BLOCK_WARPS = 128 / WARP_SIZE;
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((bs + TILE_SIZE - 1) / TILE_SIZE);
reindex_dst->Resize({num_edges});
T* reindex_dst_data = dev_ctx.template Alloc<T>(reindex_dst);
GetDstEdgeCUDAKernel<T,
BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
bs,
thrust::raw_pointer_cast(unique_dst_reindex.data()),
count_data,
thrust::raw_pointer_cast(dst_ptr.data()),
reindex_dst_data);
int begin = 0;
for (int i = 0; i < num_edge_types; i++) {
thrust::device_vector<int> dst_ptr(bs);
thrust::exclusive_scan(
count_data + i * bs, count_data + (i + 1) * bs, dst_ptr.begin());
GetDstEdgeCUDAKernel<T,
BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, dev_ctx.stream()>>>(
bs,
thrust::raw_pointer_cast(unique_dst_reindex.data()),
count_data + i * bs,
thrust::raw_pointer_cast(dst_ptr.data()),
reindex_dst_data + begin);
int count_i =
thrust::reduce(thrust::device_pointer_cast(count_data) + i * bs,
thrust::device_pointer_cast(count_data) + (i + 1) * bs);
begin += count_i;
}
out_nodes->Resize({static_cast<int>(unique_nodes.size())});
T* out_nodes_data = dev_ctx.template Alloc<T>(out_nodes);
......
......@@ -62,6 +62,63 @@ class TestGraphReindex(unittest.TestCase):
self.assertTrue(np.allclose(self.reindex_dst, reindex_dst))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
def test_heter_reindex_result(self):
paddle.disable_static()
x = paddle.to_tensor(self.x)
neighbors = paddle.to_tensor(self.neighbors)
neighbors = paddle.concat([neighbors, neighbors])
count = paddle.to_tensor(self.count)
count = paddle.concat([count, count])
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
self.assertTrue(
np.allclose(self.reindex_src, reindex_src[:self.neighbors.shape[
0]]))
self.assertTrue(
np.allclose(self.reindex_src, reindex_src[self.neighbors.shape[
0]:]))
self.assertTrue(
np.allclose(self.reindex_dst, reindex_dst[:self.neighbors.shape[
0]]))
self.assertTrue(
np.allclose(self.reindex_dst, reindex_dst[self.neighbors.shape[
0]:]))
self.assertTrue(np.allclose(self.out_nodes, out_nodes))
def test_heter_reindex_result_v2(self):
paddle.disable_static()
x = np.arange(5).astype("int64")
neighbors1 = np.random.randint(100, size=20).astype("int64")
count1 = np.array([2, 8, 4, 3, 3], dtype="int32")
neighbors2 = np.random.randint(100, size=20).astype("int64")
count2 = np.array([4, 5, 1, 6, 4], dtype="int32")
neighbors = np.concatenate([neighbors1, neighbors2])
counts = np.concatenate([count1, count2])
# Get numpy result.
out_nodes = list(x)
for neighbor in neighbors:
if neighbor not in out_nodes:
out_nodes.append(neighbor)
out_nodes = np.array(out_nodes, dtype="int64")
reindex_dict = {node: ind for ind, node in enumerate(out_nodes)}
reindex_src = np.array([reindex_dict[node] for node in neighbors])
reindex_dst = []
for count in [count1, count2]:
for node, c in zip(x, count):
for i in range(c):
reindex_dst.append(reindex_dict[node])
reindex_dst = np.array(reindex_dst, dtype="int64")
reindex_src_, reindex_dst_, out_nodes_ = \
paddle.incubate.graph_reindex(paddle.to_tensor(x),
paddle.to_tensor(neighbors),
paddle.to_tensor(counts))
self.assertTrue(np.allclose(reindex_src, reindex_src_))
self.assertTrue(np.allclose(reindex_dst, reindex_dst_))
self.assertTrue(np.allclose(out_nodes, out_nodes_))
def test_reindex_result_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
......
......@@ -38,10 +38,6 @@ def graph_khop_sampler(row,
and `sample_sizes` means the number of neighbors and number of layers we want
to sample.
**Note**:
Currently the API will reindex the output edges after finishing sampling. We
will add a choice or a new API for whether to reindex the edges in the near future.
Args:
row (Tensor): One of the components of the CSC format of the input graph, and
the shape should be [num_edges, 1] or [num_edges]. The available
......
......@@ -35,6 +35,12 @@ def graph_reindex(x,
is to reindex the ids information of the input nodes, and return the
corresponding graph edges after reindex.
**Notes**:
The number in x should be unique, otherwise it would cause potential errors.
Besides, we also support multi-edge-types neighbors reindexing. If we have different
edge_type neighbors for x, we should concatenate all the neighbors and count of x.
We will reindex all the nodes from 0.
Take input nodes x = [0, 1, 2] as an example.
If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
then we know that the neighbors of 0 is [8, 9], the neighbors of 1
......@@ -70,18 +76,31 @@ def graph_reindex(x,
import paddle
x = [0, 1, 2]
neighbors = [8, 9, 0, 4, 7, 6, 7]
count = [2, 3, 2]
neighbors_e1 = [8, 9, 0, 4, 7, 6, 7]
count_e1 = [2, 3, 2]
x = paddle.to_tensor(x, dtype="int64")
neighbors = paddle.to_tensor(neighbors, dtype="int64")
count = paddle.to_tensor(count, dtype="int32")
neighbors_e1 = paddle.to_tensor(neighbors_e1, dtype="int64")
count_e1 = paddle.to_tensor(count_e1, dtype="int32")
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
paddle.incubate.graph_reindex(x, neighbors_e1, count_e1)
# reindex_src: [3, 4, 0, 5, 6, 7, 6]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]
neighbors_e2 = [0, 2, 3, 5, 1]
count_e2 = [1, 3, 1]
neighbors_e2 = paddle.to_tensor(neighbors_e2, dtype="int64")
count_e2 = paddle.to_tensor(count_e2, dtype="int32")
neighbors = paddle.concat([neighbors_e1, neighbors_e2])
count = paddle.concat([count_e1, count_e2])
reindex_src, reindex_dst, out_nodes = \
paddle.incubate.graph_reindex(x, neighbors, count)
# reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1]
# reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2]
# out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]
"""
if flag_buffer_hashtable:
if value_buffer is None or index_buffer is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册