graph_reindex.py 6.9 KB
Newer Older
S
Siming Dai 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid import core
20
from paddle import _C_ops, _legacy_C_ops
21
import paddle.utils.deprecated as deprecated
S
Siming Dai 已提交
22 23


24 25 26 27
@deprecated(since="2.4.0",
            update_to="paddle.geometric.reindex_graph",
            level=1,
            reason="paddle.incubate.graph_reindex will be removed in future")
S
Siming Dai 已提交
28 29 30 31 32 33 34 35 36 37 38 39
def graph_reindex(x,
                  neighbors,
                  count,
                  value_buffer=None,
                  index_buffer=None,
                  flag_buffer_hashtable=False,
                  name=None):
    """
    Graph Reindex API.

    This API is mainly used in Graph Learning domain, which should be used
    in conjunction with `graph_sample_neighbors` API. And the main purpose
40
    is to reindex the ids information of the input nodes, and return the
S
Siming Dai 已提交
41 42
    corresponding graph edges after reindex.

43
    **Notes**:
S
Siming Dai 已提交
44 45
        The number in x should be unique, otherwise it would cause potential errors.
    Besides, we also support multi-edge-types neighbors reindexing. If we have different
46 47
    edge_type neighbors for x, we should concatenate all the neighbors and count of x.
    We will reindex all the nodes from 0.
S
Siming Dai 已提交
48

49 50
    Take input nodes x = [0, 1, 2] as an example.
    If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2],
S
Siming Dai 已提交
51 52 53 54 55 56 57 58
    then we know that the neighbors of 0 is [8, 9], the neighbors of 1
    is [0, 4, 7], and the neighbors of 2 is [6, 7].

    Args:
        x (Tensor): The input nodes which we sample neighbors for. The available
                    data type is int32, int64.
        neighbors (Tensor): The neighbors of the input nodes `x`. The data type
                            should be the same with `x`.
59
        count (Tensor): The neighbor count of the input nodes `x`. And the
S
Siming Dai 已提交
60
                        data type should be int32.
61
        value_buffer (Tensor|None): Value buffer for hashtable. The data type should
S
Siming Dai 已提交
62
                                    be int32, and should be filled with -1.
63
        index_buffer (Tensor|None): Index buffer for hashtable. The data type should
S
Siming Dai 已提交
64 65 66 67 68
                                    be int32, and should be filled with -1.
        flag_buffer_hashtable (bool): Whether to use buffer for hashtable to speed up.
                                      Default is False. Only useful for gpu version currently.
        name (str, optional): Name for the operation (optional, default is None).
                              For more information, please refer to :ref:`api_guide_Name`.
69

S
Siming Dai 已提交
70 71 72 73 74 75 76 77
    Returns:
        reindex_src (Tensor): The source node index of graph edges after reindex.
        reindex_dst (Tensor): The destination node index of graph edges after reindex.
        out_nodes (Tensor): The index of unique input nodes and neighbors before reindex,
                            where we put the input nodes `x` in the front, and put neighbor
                            nodes in the back.

    Examples:
78

S
Siming Dai 已提交
79 80 81 82 83
        .. code-block:: python

        import paddle

        x = [0, 1, 2]
S
Siming Dai 已提交
84 85
        neighbors_e1 = [8, 9, 0, 4, 7, 6, 7]
        count_e1 = [2, 3, 2]
S
Siming Dai 已提交
86
        x = paddle.to_tensor(x, dtype="int64")
S
Siming Dai 已提交
87 88
        neighbors_e1 = paddle.to_tensor(neighbors_e1, dtype="int64")
        count_e1 = paddle.to_tensor(count_e1, dtype="int32")
S
Siming Dai 已提交
89 90

        reindex_src, reindex_dst, out_nodes = \
S
Siming Dai 已提交
91
             paddle.incubate.graph_reindex(x, neighbors_e1, count_e1)
S
Siming Dai 已提交
92 93 94 95
        # reindex_src: [3, 4, 0, 5, 6, 7, 6]
        # reindex_dst: [0, 0, 1, 1, 1, 2, 2]
        # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]

S
Siming Dai 已提交
96 97 98 99
        neighbors_e2 = [0, 2, 3, 5, 1]
        count_e2 = [1, 3, 1]
        neighbors_e2 = paddle.to_tensor(neighbors_e2, dtype="int64")
        count_e2 = paddle.to_tensor(count_e2, dtype="int32")
100

S
Siming Dai 已提交
101 102 103 104 105 106 107 108
        neighbors = paddle.concat([neighbors_e1, neighbors_e2])
        count = paddle.concat([count_e1, count_e2])
        reindex_src, reindex_dst, out_nodes = \
             paddle.incubate.graph_reindex(x, neighbors, count)
        # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1]
        # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2]
        # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]

S
Siming Dai 已提交
109 110 111 112 113 114 115 116
    """
    if flag_buffer_hashtable:
        if value_buffer is None or index_buffer is None:
            raise ValueError(f"`value_buffer` and `index_buffer` should not"
                             "be None if `flag_buffer_hashtable` is True.")

    if _non_static_mode():
        reindex_src, reindex_dst, out_nodes = \
117
            _legacy_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer,
S
Siming Dai 已提交
118 119 120 121 122 123 124 125 126 127 128
                                 "flag_buffer_hashtable", flag_buffer_hashtable)
        return reindex_src, reindex_dst, out_nodes

    check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex")
    check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"),
                             "graph_reindex")
    check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex")

    if flag_buffer_hashtable:
        check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"),
                                 "graph_reindex")
129
        check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"),
S
Siming Dai 已提交
130 131 132 133 134 135
                                 "graph_reindex")

    helper = LayerHelper("graph_reindex", **locals())
    reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype)
    reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype)
    out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype)
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    helper.append_op(type="graph_reindex",
                     inputs={
                         "X":
                         x,
                         "Neighbors":
                         neighbors,
                         "Count":
                         count,
                         "HashTable_Value":
                         value_buffer if flag_buffer_hashtable else None,
                         "HashTable_Index":
                         index_buffer if flag_buffer_hashtable else None,
                     },
                     outputs={
                         "Reindex_Src": reindex_src,
                         "Reindex_Dst": reindex_dst,
                         "Out_Nodes": out_nodes
                     },
                     attrs={"flag_buffer_hashtable": flag_buffer_hashtable})
S
Siming Dai 已提交
155
    return reindex_src, reindex_dst, out_nodes