reindex.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode, Variable
from paddle.fluid.data_feeder import check_variable_and_dtype
19
from paddle import _legacy_C_ops
20 21 22 23

__all__ = []


24 25 26
def reindex_graph(
    x, neighbors, count, value_buffer=None, index_buffer=None, name=None
):
27
    """
28

29 30 31
    Reindex Graph API.

    This API is mainly used in Graph Learning domain, which should be used
S
Siming Dai 已提交
32
    in conjunction with `paddle.geometric.sample_neighbors` API. And the main purpose
33
    is to reindex the ids information of the input nodes, and return the
34 35
    corresponding graph edges after reindex.

S
Siming Dai 已提交
36 37 38 39
    Take input nodes x = [0, 1, 2] as an example. If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
    then we know that the neighbors of 0 is [8, 9], the neighbors of 1 is [0, 4, 7], and the neighbors of 2 is [6, 7].
    Then after graph_reindex, we will have 3 different outputs: reindex_src: [3, 4, 0, 5, 6, 7, 6], reindex_dst: [0, 0, 1, 1, 1, 2, 2]
    and out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]. We can see that the numbers in `reindex_src` and `reindex_dst` is the corresponding index
40 41
    of nodes in `out_nodes`.

S
Siming Dai 已提交
42 43 44
    Note:
        The number in x should be unique, otherwise it would cause potential errors. We will reindex all the nodes from 0.

45 46 47 48 49
    Args:
        x (Tensor): The input nodes which we sample neighbors for. The available
                    data type is int32, int64.
        neighbors (Tensor): The neighbors of the input nodes `x`. The data type
                            should be the same with `x`.
50
        count (Tensor): The neighbor count of the input nodes `x`. And the
51
                        data type should be int32.
52 53 54
        value_buffer (Tensor, optional): Value buffer for hashtable. The data type should be int32,
                                    and should be filled with -1. Only useful for gpu version. Default is None.
        index_buffer (Tensor, optional): Index buffer for hashtable. The data type should be int32,
55
                                    and should be filled with -1. Only useful for gpu version.
56
                                    `value_buffer` and `index_buffer` should be both not None
57
                                    if you want to speed up by using hashtable buffer. Default is None.
58 59
        name (str, optional): Name for the operation (optional, default is None).
                              For more information, please refer to :ref:`api_guide_Name`.
60

61
    Returns:
S
Siming Dai 已提交
62
        - reindex_src (Tensor), the source node index of graph edges after reindex.
63

S
Siming Dai 已提交
64
        - reindex_dst (Tensor), the destination node index of graph edges after reindex.
65

S
Siming Dai 已提交
66
        - out_nodes (Tensor), the index of unique input nodes and neighbors before reindex, where we put the input nodes `x` in the front, and put neighbor nodes in the back.
67

S
Siming Dai 已提交
68 69
    Examples:
        .. code-block:: python
70

S
Siming Dai 已提交
71
            import paddle
72

S
Siming Dai 已提交
73 74 75 76 77 78 79 80 81 82
            x = [0, 1, 2]
            neighbors = [8, 9, 0, 4, 7, 6, 7]
            count = [2, 3, 2]
            x = paddle.to_tensor(x, dtype="int64")
            neighbors = paddle.to_tensor(neighbors, dtype="int64")
            count = paddle.to_tensor(count, dtype="int32")
            reindex_src, reindex_dst, out_nodes = paddle.geometric.reindex_graph(x, neighbors, count)
            # reindex_src: [3, 4, 0, 5, 6, 7, 6]
            # reindex_dst: [0, 0, 1, 1, 1, 2, 2]
            # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]
83 84

    """
85 86 87
    use_buffer_hashtable = (
        True if value_buffer is not None and index_buffer is not None else False
    )
88 89

    if _non_static_mode():
90 91 92 93 94 95 96 97 98
        reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex(
            x,
            neighbors,
            count,
            value_buffer,
            index_buffer,
            "flag_buffer_hashtable",
            use_buffer_hashtable,
        )
99 100 101
        return reindex_src, reindex_dst, out_nodes

    check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex")
102 103 104
    check_variable_and_dtype(
        neighbors, "Neighbors", ("int32", "int64"), "graph_reindex"
    )
105 106 107
    check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")

    if use_buffer_hashtable:
108 109 110 111 112 113
        check_variable_and_dtype(
            value_buffer, "HashTable_Value", ("int32"), "graph_reindex"
        )
        check_variable_and_dtype(
            index_buffer, "HashTable_Index", ("int32"), "graph_reindex"
        )
114 115 116 117 118

    helper = LayerHelper("reindex_graph", **locals())
    reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
    reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
    out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    helper.append_op(
        type="graph_reindex",
        inputs={
            "X": x,
            "Neighbors": neighbors,
            "Count": count,
            "HashTable_Value": value_buffer if use_buffer_hashtable else None,
            "HashTable_Index": index_buffer if use_buffer_hashtable else None,
        },
        outputs={
            "Reindex_Src": reindex_src,
            "Reindex_Dst": reindex_dst,
            "Out_Nodes": out_nodes,
        },
        attrs={"flag_buffer_hashtable": use_buffer_hashtable},
    )
135 136 137
    return reindex_src, reindex_dst, out_nodes


138 139 140
def reindex_heter_graph(
    x, neighbors, count, value_buffer=None, index_buffer=None, name=None
):
141
    """
142

143 144 145
    Reindex HeterGraph API.

    This API is mainly used in Graph Learning domain, which should be used
S
Siming Dai 已提交
146
    in conjunction with `paddle.geometric.sample_neighbors` API. And the main purpose
147 148 149
    is to reindex the ids information of the input nodes, and return the
    corresponding graph edges after reindex.

S
Siming Dai 已提交
150 151 152 153 154 155 156 157
    Take input nodes x = [0, 1, 2] as an example. For graph A, suppose we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
    then we know that the neighbors of 0 is [8, 9], the neighbors of 1 is [0, 4, 7], and the neighbors of 2 is [6, 7]. For graph B,
    suppose we have neighbors = [0, 2, 3, 5, 1], and count = [1, 3, 1], then we know that the neighbors of 0 is [0], the neighbors of 1 is [2, 3, 5],
    and the neighbors of 3 is [1]. We will get following outputs: reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1], reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2]
    and out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5].

    Note:
        The number in x should be unique, otherwise it would cause potential errors. We support multi-edge-types neighbors reindexing in reindex_heter_graph api. We will reindex all the nodes from 0.
158 159 160 161

    Args:
        x (Tensor): The input nodes which we sample neighbors for. The available
                    data type is int32, int64.
162
        neighbors (list|tuple): The neighbors of the input nodes `x` from different graphs.
163
                                The data type should be the same with `x`.
164
        count (list|tuple): The neighbor counts of the input nodes `x` from different graphs.
165
                            And the data type should be int32.
166 167 168
        value_buffer (Tensor, optional): Value buffer for hashtable. The data type should be int32,
                                    and should be filled with -1. Only useful for gpu version. Default is None.
        index_buffer (Tensor, optional): Index buffer for hashtable. The data type should be int32,
169 170
                                    and should be filled with -1. Only useful for gpu version.
                                    `value_buffer` and `index_buffer` should be both not None
171
                                    if you want to speed up by using hashtable buffer. Default is None.
172 173 174 175
        name (str, optional): Name for the operation (optional, default is None).
                              For more information, please refer to :ref:`api_guide_Name`.

    Returns:
S
Siming Dai 已提交
176
        - reindex_src (Tensor), the source node index of graph edges after reindex.
177

S
Siming Dai 已提交
178
        - reindex_dst (Tensor), the destination node index of graph edges after reindex.
179

S
Siming Dai 已提交
180 181 182 183 184
        - out_nodes (Tensor), the index of unique input nodes and neighbors before reindex,
                              where we put the input nodes `x` in the front, and put neighbor
                              nodes in the back.

    Examples:
185 186
        .. code-block:: python

S
Siming Dai 已提交
187
            import paddle
188

S
Siming Dai 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
            x = [0, 1, 2]
            neighbors_a = [8, 9, 0, 4, 7, 6, 7]
            count_a = [2, 3, 2]
            x = paddle.to_tensor(x, dtype="int64")
            neighbors_a = paddle.to_tensor(neighbors_a, dtype="int64")
            count_a = paddle.to_tensor(count_a, dtype="int32")
            neighbors_b = [0, 2, 3, 5, 1]
            count_b = [1, 3, 1]
            neighbors_b = paddle.to_tensor(neighbors_b, dtype="int64")
            count_b = paddle.to_tensor(count_b, dtype="int32")
            neighbors = [neighbors_a, neighbors_b]
            count = [count_a, count_b]
            reindex_src, reindex_dst, out_nodes = paddle.geometric.reindex_heter_graph(x, neighbors, count)
            # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1]
            # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2]
            # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]
205 206

    """
207 208 209
    use_buffer_hashtable = (
        True if value_buffer is not None and index_buffer is not None else False
    )
210 211 212 213

    if _non_static_mode():
        neighbors = paddle.concat(neighbors, axis=0)
        count = paddle.concat(count, axis=0)
214 215 216 217 218 219 220 221 222
        reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex(
            x,
            neighbors,
            count,
            value_buffer,
            index_buffer,
            "flag_buffer_hashtable",
            use_buffer_hashtable,
        )
223 224 225 226 227 228 229 230 231 232 233
        return reindex_src, reindex_dst, out_nodes

    if isinstance(neighbors, Variable):
        neighbors = [neighbors]
    if isinstance(count, Variable):
        count = [count]

    neighbors = paddle.concat(neighbors, axis=0)
    count = paddle.concat(count, axis=0)

    check_variable_and_dtype(x, "X", ("int32", "int64"), "heter_graph_reindex")
234 235 236
    check_variable_and_dtype(
        neighbors, "Neighbors", ("int32", "int64"), "graph_reindex"
    )
237 238 239
    check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")

    if use_buffer_hashtable:
240 241 242 243 244 245
        check_variable_and_dtype(
            value_buffer, "HashTable_Value", ("int32"), "graph_reindex"
        )
        check_variable_and_dtype(
            index_buffer, "HashTable_Index", ("int32"), "graph_reindex"
        )
246 247 248 249 250 251 252

    helper = LayerHelper("reindex_heter_graph", **locals())
    reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
    reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
    out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
    neighbors = paddle.concat(neighbors, axis=0)
    count = paddle.concat(count, axis=0)
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    helper.append_op(
        type="graph_reindex",
        inputs={
            "X": x,
            "Neighbors": neighbors,
            "Count": count,
            "HashTable_Value": value_buffer if use_buffer_hashtable else None,
            "HashTable_Index": index_buffer if use_buffer_hashtable else None,
        },
        outputs={
            "Reindex_Src": reindex_src,
            "Reindex_Dst": reindex_dst,
            "Out_Nodes": out_nodes,
        },
        attrs={"flag_buffer_hashtable": use_buffer_hashtable},
    )
269
    return reindex_src, reindex_dst, out_nodes