graph_wrapper.py 28.4 KB
Newer Older
Y
yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This package provides interface to help building static computational graph
for PaddlePaddle.
"""

import warnings
import numpy as np
import paddle.fluid as fluid
Y
Yelrose 已提交
22
import paddle.fluid.layers as L
Y
yelrose 已提交
23 24 25 26 27 28 29 30

from pgl.utils import op
from pgl.utils import paddle_helper
from pgl.utils.logger import log

__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]


Y
Yelrose 已提交
31
def send(src, dst, nfeat, efeat, message_func, nfeat_src, nfeat_dst):
Y
yelrose 已提交
32 33
    """Send message from src to dst.
    """
Y
Yelrose 已提交
34 35 36 37 38 39 40 41 42 43 44 45
    for key in nfeat_src.keys():
        if key in nfeat:
            log.info("Node-Feature %s both in nfeat_src_list and nfeat_list" % key)

    for key in nfeat_dst.keys():
        if key in nfeat:
            log.info("Node-Feature %s both in nfeat_dst_list and nfeat_list" % key)

    nfeat_src.update(nfeat)
    nfeat_dst.update(nfeat)
    src_feat = op.read_rows(nfeat_src, src)
    dst_feat = op.read_rows(nfeat_dst, dst)
Y
yelrose 已提交
46 47 48 49
    msg = message_func(src_feat, dst_feat, efeat)
    return msg


W
Webbley 已提交
50 51
def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
         num_edges):
Y
yelrose 已提交
52 53 54 55 56 57 58 59
    """Recv message from given msg to dst nodes.
    """
    if reduce_function == "sum":
        if isinstance(msg, dict):
            raise TypeError("The message for build-in function"
                            " should be Tensor not dict.")

        try:
W
Webbley 已提交
60
            out_dim = msg.shape[-1]
Y
Yelrose 已提交
61
            init_output = L.fill_constant(
62
                shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
Y
yelrose 已提交
63
            init_output.stop_gradient = False
Y
Yelrose 已提交
64
            empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype)
W
Webbley 已提交
65
            msg = msg * empty_msg_flag
Y
yelrose 已提交
66 67 68 69 70 71 72
            output = paddle_helper.scatter_add(init_output, dst, msg)
            return output
        except TypeError as e:
            warnings.warn(
                "scatter_add is not supported with paddle version <= 1.5")

            def sum_func(message):
Y
Yelrose 已提交
73
                return L.sequence_pool(message, "sum")
Y
yelrose 已提交
74 75 76 77 78

            reduce_function = sum_func

    bucketed_msg = op.nested_lod_reset(msg, bucketing_index)
    output = reduce_function(bucketed_msg)
W
Webbley 已提交
79
    output_dim = output.shape[-1]
80

Y
Yelrose 已提交
81
    empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype)
W
Webbley 已提交
82
    output = output * empty_msg_flag
Y
yelrose 已提交
83

Y
Yelrose 已提交
84
    init_output = L.fill_constant(
85
        shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
W
Webbley 已提交
86
    init_output.stop_gradient = True
Y
Yelrose 已提交
87
    final_output = L.scatter(init_output, uniq_dst, output)
W
Webbley 已提交
88
    return final_output
Y
yelrose 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106


class BaseGraphWrapper(object):
    """This module implement base class for graph wrapper.

    Currently our PGL is developed based on static computational mode of
    paddle (we'll support dynamic computational model later). We need to build
    model upon a virtual data holder. BaseGraphWrapper provide a virtual
    graph structure that users can build deep learning models
    based on this virtual graph. And then feed real graph data to run
    the models. Moreover, we provide convenient message-passing interface
    (send & recv) for building graph neural networks.

    NOTICE: Don't use this BaseGraphWrapper directly. Use :code:`GraphWrapper`
    and :code:`StaticGraphWrapper` to create graph wrapper instead.
    """

    def __init__(self):
L
liweibin 已提交
107 108
        self.node_feat_tensor_dict = {}
        self.edge_feat_tensor_dict = {}
Y
yelrose 已提交
109 110 111 112 113 114 115
        self._edges_src = None
        self._edges_dst = None
        self._num_nodes = None
        self._indegree = None
        self._edge_uniq_dst = None
        self._edge_uniq_dst_count = None
        self._node_ids = None
W
Webbley 已提交
116 117
        self._graph_lod = None
        self._num_graph = None
Y
Yelrose 已提交
118
        self._num_edges = None
L
liweibin 已提交
119 120 121 122
        self._data_name_prefix = ""

    def __repr__(self):
        return self._data_name_prefix
Y
yelrose 已提交
123

Y
Yelrose 已提交
124
    def send(self, message_func, nfeat_list=None, efeat_list=None, nfeat_list_src=None, nfeat_list_dst=None):
Y
yelrose 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        """Send message from all src nodes to dst nodes.

        The UDF message function should has the following format.

        .. code-block:: python

            def message_func(src_feat, dst_feat, edge_feat):
                '''
                    Args:
                        src_feat: the node feat dict attached to the src nodes.
                        dst_feat: the node feat dict attached to the dst nodes.
                        edge_feat: the edge feat dict attached to the
                                   corresponding (src, dst) edges.

                    Return:
                        It should return a tensor or a dictionary of tensor. And each tensor
                        should have a shape of (num_edges, dims).
                '''
                pass

        Args:
            message_func: UDF function.
            nfeat_list: a list of names or tuple (name, tensor)
            efeat_list: a list of names or tuple (name, tensor)
Y
Yelrose 已提交
149 150
            nfeat_list_src: a list of names or tuple (name, tensor). The node feature only for src
            efeat_list_dst: a list of names or tuple (name, tensor). The node feature only for dst
Y
yelrose 已提交
151 152 153 154 155 156 157 158

        Return:
            A dictionary of tensor representing the message. Each of the values
            in the dictionary has a shape (num_edges, dim) which should be collected
            by :code:`recv` function.
        """
        if efeat_list is None:
            efeat_list = {}
Y
Yelrose 已提交
159

Y
yelrose 已提交
160 161 162
        if nfeat_list is None:
            nfeat_list = {}

Y
Yelrose 已提交
163 164 165 166 167 168
        if nfeat_list_src is None:
            nfeat_list_src = {}

        if nfeat_list_dst is None:
            nfeat_list_dst = {}

Y
yelrose 已提交
169 170
        src, dst = self.edges
        nfeat = {}
Y
Yelrose 已提交
171

Y
yelrose 已提交
172 173 174 175 176 177 178
        for feat in nfeat_list:
            if isinstance(feat, str):
                nfeat[feat] = self.node_feat[feat]
            else:
                name, tensor = feat
                nfeat[name] = tensor

Y
Yelrose 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
        nfeat_src = {}

        for feat in nfeat_list_src:
            if isinstance(feat, str):
                nfeat_src[feat] = self.node_feat[feat]
            else:
                name, tensor = feat
                nfeat_src[name] = tensor

        nfeat_dst = {}

        for feat in nfeat_list_dst:
            if isinstance(feat, str):
                nfeat_dst[feat] = self.node_feat[feat]
            else:
                name, tensor = feat
                nfeat_dst[name] = tensor

Y
yelrose 已提交
197 198 199 200 201 202 203 204
        efeat = {}
        for feat in efeat_list:
            if isinstance(feat, str):
                efeat[feat] = self.edge_feat[feat]
            else:
                name, tensor = feat
                efeat[name] = tensor

Y
Yelrose 已提交
205 206
        msg = send(src, dst, nfeat, efeat, message_func,
                      nfeat_src, nfeat_dst)
Y
yelrose 已提交
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        return msg

    def recv(self, msg, reduce_function):
        """Recv message and aggregate the message by reduce_fucntion

        The UDF reduce_function function should has the following format.

        .. code-block:: python

            def reduce_func(msg):
                '''
                    Args:
                        msg: A LodTensor or a dictionary of LodTensor whose batch_size
                             is equals to the number of unique dst nodes.

                    Return:
                        It should return a tensor with shape (batch_size, out_dims). The
                        batch size should be the same as msg.
                '''
                pass

        Args:
            msg: A tensor or a dictionary of tensor created by send function..

            reduce_function: UDF reduce function or strings "sum" as built-in function.
                             The built-in "sum" will use scatter_add to optimized the speed.

        Return:
            A tensor with shape (num_nodes, out_dims). The output for nodes with no message
            will be zeros.
        """
        output = recv(
            dst=self._edges_dst,
            uniq_dst=self._edge_uniq_dst,
Y
Yelrose 已提交
241
            bucketing_index=self._edge_uniq_dst_count,
Y
yelrose 已提交
242 243
            msg=msg,
            reduce_function=reduce_function,
W
Webbley 已提交
244 245
            num_edges=self._num_edges,
            num_nodes=self._num_nodes)
Y
yelrose 已提交
246 247 248 249 250 251 252 253
        return output

    @property
    def edges(self):
        """Return a tuple of edge Tensor (src, dst).

        Return:
            A tuple of Tensor (src, dst). Src and dst are both
Y
Yelrose 已提交
254
            tensor with shape (num_edges, ) and dtype int64.
Y
yelrose 已提交
255 256 257 258 259 260 261 262
        """
        return self._edges_src, self._edges_dst

    @property
    def num_nodes(self):
        """Return a variable of number of nodes

        Return:
Y
Yelrose 已提交
263
            A variable with shape (1,) as the number of nodes in int64.
Y
yelrose 已提交
264 265 266
        """
        return self._num_nodes

W
Webbley 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
    @property
    def graph_lod(self):
        """Return graph index for graphs

        Return:
            A variable with shape [None ]  as the Lod information of multiple-graph.
        """
        return self._graph_lod

    @property
    def num_graph(self):
        """Return a variable of number of graphs

        Return:
            A variable with shape (1,) as the number of Graphs in int64.
        """
        return self._num_graph

Y
yelrose 已提交
285 286 287 288 289 290 291 292
    @property
    def edge_feat(self):
        """Return a dictionary of tensor representing edge features.

        Return:
            A dictionary whose keys are the feature names and the values
            are feature tensor.
        """
L
liweibin 已提交
293
        return self.edge_feat_tensor_dict
Y
yelrose 已提交
294 295 296 297 298 299 300 301 302

    @property
    def node_feat(self):
        """Return a dictionary of tensor representing node features.

        Return:
            A dictionary whose keys are the feature names and the values
            are feature tensor.
        """
L
liweibin 已提交
303
        return self.node_feat_tensor_dict
Y
yelrose 已提交
304 305 306 307 308

    def indegree(self):
        """Return the indegree tensor for all nodes.

        Return:
Y
Yelrose 已提交
309
            A tensor of shape (num_nodes, ) in int64.
Y
yelrose 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323
        """
        return self._indegree


class StaticGraphWrapper(BaseGraphWrapper):
    """Implement a graph wrapper that the data of the graph won't
    be changed and it can be fit into the GPU or CPU memory. This
    can reduce the time of swapping large data from GPU and CPU.

    Args:
        name: The graph data prefix

        graph: The static graph that should be put into memory

W
Webbley 已提交
324
        place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
Y
yelrose 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370
               device to hold the graph data.

    Examples:

        If we have a immutable graph and it can be fit into the GPU or CPU.
        we can just use a :code:`StaticGraphWrapper` to pre-place the graph
        data into devices.

        .. code-block:: python

            import numpy as np
            import paddle.fluid as fluid
            from pgl.graph import Graph
            from pgl.graph_wrapper import StaticGraphWrapper

            place = fluid.CPUPlace()
            exe = fluid.Excecutor(place)

            num_nodes = 5
            edges = [ (0, 1), (1, 2), (3, 4)]
            feature = np.random.randn(5, 100)
            edge_feature = np.random.randn(3, 100)
            graph = Graph(num_nodes=num_nodes,
                        edges=edges,
                        node_feat={
                            "feature": feature
                        },
                        edge_feat={
                            "edge_feature": edge_feature
                        })

            graph_wrapper = StaticGraphWrapper(name="graph",
                        graph=graph,
                        place=place)

            # build your deep graph model

            # Initialize parameters for deep graph model
            exe.run(fluid.default_startup_program())

            # Initialize graph data
            graph_wrapper.initialize(place)
    """

    def __init__(self, name, graph, place):
        super(StaticGraphWrapper, self).__init__()
L
liweibin 已提交
371
        self._data_name_prefix = name
Y
yelrose 已提交
372 373 374 375 376 377 378 379 380 381 382
        self._initializers = []
        self.__create_graph_attr(graph)

    def __create_graph_attr(self, graph):
        """Create graph attributes for paddlepaddle.
        """
        src, dst, eid = graph.sorted_edges(sort_by="dst")
        indegree = graph.indegree()
        nodes = graph.nodes
        uniq_dst = nodes[indegree > 0]
        uniq_dst_count = indegree[indegree > 0]
Y
Yelrose 已提交
383 384
        uniq_dst_count = np.cumsum(uniq_dst_count, dtype='int32')
        uniq_dst_count = np.insert(uniq_dst_count, 0, 0)
W
Webbley 已提交
385 386 387 388 389 390 391 392 393 394 395
        graph_lod = graph.graph_lod
        num_graph = graph.num_graph

        num_edges = len(src)
        if num_edges == 0:
            # Fake Graph
            src = np.array([0], dtype="int64")
            dst = np.array([0], dtype="int64")
            eid = np.array([0], dtype="int64")
            uniq_dst_count = np.array([0, 1], dtype="int32")
            uniq_dst = np.array([0], dtype="int64")
Y
yelrose 已提交
396 397 398 399 400 401 402 403 404 405

        edge_feat = {}

        for key, value in graph.edge_feat.items():
            edge_feat[key] = value[eid]
        node_feat = graph.node_feat

        self.__create_graph_node_feat(node_feat, self._initializers)
        self.__create_graph_edge_feat(edge_feat, self._initializers)

W
Webbley 已提交
406 407 408 409 410 411 412 413 414 415 416 417 418 419
        self._num_edges, init = paddle_helper.constant(
            dtype="int64",
            value=np.array(
                [num_edges], dtype="int64"),
            name=self._data_name_prefix + '/num_edges')
        self._initializers.append(init)

        self._num_graph, init = paddle_helper.constant(
            dtype="int64",
            value=np.array(
                [num_graph], dtype="int64"),
            name=self._data_name_prefix + '/num_graph')
        self._initializers.append(init)

Y
yelrose 已提交
420
        self._edges_src, init = paddle_helper.constant(
Y
Yelrose 已提交
421
            dtype="int64",
Y
yelrose 已提交
422
            value=src,
L
liweibin 已提交
423
            name=self._data_name_prefix + '/edges_src')
Y
yelrose 已提交
424 425 426
        self._initializers.append(init)

        self._edges_dst, init = paddle_helper.constant(
Y
Yelrose 已提交
427
            dtype="int64",
Y
yelrose 已提交
428
            value=dst,
L
liweibin 已提交
429
            name=self._data_name_prefix + '/edges_dst')
Y
yelrose 已提交
430 431 432
        self._initializers.append(init)

        self._num_nodes, init = paddle_helper.constant(
Y
Yelrose 已提交
433
            dtype="int64",
Y
yelrose 已提交
434 435
            hide_batch_size=False,
            value=np.array([graph.num_nodes]),
L
liweibin 已提交
436
            name=self._data_name_prefix + '/num_nodes')
Y
yelrose 已提交
437 438 439
        self._initializers.append(init)

        self._edge_uniq_dst, init = paddle_helper.constant(
L
liweibin 已提交
440
            name=self._data_name_prefix + "/uniq_dst",
Y
Yelrose 已提交
441
            dtype="int64",
Y
yelrose 已提交
442 443 444 445
            value=uniq_dst)
        self._initializers.append(init)

        self._edge_uniq_dst_count, init = paddle_helper.constant(
L
liweibin 已提交
446
            name=self._data_name_prefix + "/uniq_dst_count",
Y
yelrose 已提交
447 448 449 450
            dtype="int32",
            value=uniq_dst_count)
        self._initializers.append(init)

W
Webbley 已提交
451 452 453 454 455 456
        self._graph_lod, init = paddle_helper.constant(
            name=self._data_name_prefix + "/graph_lod",
            dtype="int32",
            value=graph_lod)
        self._initializers.append(init)

Y
Yelrose 已提交
457
        node_ids_value = np.arange(0, graph.num_nodes, dtype="int64")
Y
yelrose 已提交
458
        self._node_ids, init = paddle_helper.constant(
L
liweibin 已提交
459
            name=self._data_name_prefix + "/node_ids",
Y
Yelrose 已提交
460
            dtype="int64",
Y
yelrose 已提交
461 462 463 464
            value=node_ids_value)
        self._initializers.append(init)

        self._indegree, init = paddle_helper.constant(
L
liweibin 已提交
465
            name=self._data_name_prefix + "/indegree",
Y
Yelrose 已提交
466
            dtype="int64",
Y
yelrose 已提交
467 468 469 470 471 472 473 474 475
            value=indegree)
        self._initializers.append(init)

    def __create_graph_node_feat(self, node_feat, collector):
        """Convert node features into paddlepaddle tensor.
        """
        for node_feat_name, node_feat_value in node_feat.items():
            node_feat_shape = node_feat_value.shape
            node_feat_dtype = node_feat_value.dtype
L
liweibin 已提交
476
            self.node_feat_tensor_dict[
Y
yelrose 已提交
477
                node_feat_name], init = paddle_helper.constant(
L
liweibin 已提交
478
                    name=self._data_name_prefix + '/node_feat/' +
Y
Yelrose 已提交
479
                    node_feat_name,
Y
yelrose 已提交
480 481 482 483 484 485 486 487 488 489
                    dtype=node_feat_dtype,
                    value=node_feat_value)
            collector.append(init)

    def __create_graph_edge_feat(self, edge_feat, collector):
        """Convert edge features into paddlepaddle tensor.
        """
        for edge_feat_name, edge_feat_value in edge_feat.items():
            edge_feat_shape = edge_feat_value.shape
            edge_feat_dtype = edge_feat_value.dtype
L
liweibin 已提交
490
            self.edge_feat_tensor_dict[
Y
yelrose 已提交
491
                edge_feat_name], init = paddle_helper.constant(
L
liweibin 已提交
492
                    name=self._data_name_prefix + '/edge_feat/' +
Y
Yelrose 已提交
493
                    edge_feat_name,
Y
yelrose 已提交
494 495 496 497 498 499 500 501
                    dtype=edge_feat_dtype,
                    value=edge_feat_value)
            collector.append(init)

    def initialize(self, place):
        """Placing the graph data into the devices.

        Args:
W
Webbley 已提交
502
            place: fluid.CPUPlace or fluid.CUDAPlace(n) indicating the
Y
yelrose 已提交
503 504 505 506 507 508 509 510 511 512 513
                   device to hold the graph data.
        """
        log.info(
            "StaticGraphWrapper.initialize must be called after startup program"
        )
        for init_func in self._initializers:
            init_func(place)


class GraphWrapper(BaseGraphWrapper):
    """Implement a graph wrapper that creates a graph data holders
Y
Yelrose 已提交
514
    that attributes and features in the graph are :code:`L.data`.
Y
yelrose 已提交
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
    And we provide interface :code:`to_feed` to help converting :code:`Graph`
    data into :code:`feed_dict`.

    Args:
        name: The graph data prefix

        node_feat: A list of tuples that decribe the details of node
                   feature tenosr. Each tuple mush be (name, shape, dtype)
                   and the first dimension of the shape must be set unknown
                   (-1 or None) or we can easily use :code:`Graph.node_feat_info()`
                   to get the node_feat settings.

        edge_feat: A list of tuples that decribe the details of edge
                   feature tenosr. Each tuple mush be (name, shape, dtype)
                   and the first dimension of the shape must be set unknown
                   (-1 or None) or we can easily use :code:`Graph.edge_feat_info()`
                   to get the edge_feat settings.

    Examples:

        .. code-block:: python

            import numpy as np
            import paddle.fluid as fluid
            from pgl.graph import Graph
            from pgl.graph_wrapper import GraphWrapper

            place = fluid.CPUPlace()
            exe = fluid.Excecutor(place)

            num_nodes = 5
            edges = [ (0, 1), (1, 2), (3, 4)]
            feature = np.random.randn(5, 100)
            edge_feature = np.random.randn(3, 100)
            graph = Graph(num_nodes=num_nodes,
                        edges=edges,
                        node_feat={
                            "feature": feature
                        },
                        edge_feat={
                            "edge_feature": edge_feature
                        })

            graph_wrapper = GraphWrapper(name="graph",
                        node_feat=graph.node_feat_info(),
                        edge_feat=graph.edge_feat_info())

            # build your deep graph model
            ...

            # Initialize parameters for deep graph model
            exe.run(fluid.default_startup_program())

            for i in range(10):
                feed_dict = graph_wrapper.to_feed(graph)
                ret = exe.run(fetch_list=[...], feed=feed_dict )
    """

Y
yelrose 已提交
573
    def __init__(self, name, node_feat=[], edge_feat=[], **kwargs):
Y
yelrose 已提交
574
        super(GraphWrapper, self).__init__()
Y
Yelrose 已提交
575
        # collect holders for PyReader
L
liweibin 已提交
576
        self._data_name_prefix = name
Y
Yelrose 已提交
577
        self._holder_list = []
Y
yelrose 已提交
578 579 580 581 582 583 584 585 586 587 588 589
        self.__create_graph_attr_holders()
        for node_feat_name, node_feat_shape, node_feat_dtype in node_feat:
            self.__create_graph_node_feat_holders(
                node_feat_name, node_feat_shape, node_feat_dtype)

        for edge_feat_name, edge_feat_shape, edge_feat_dtype in edge_feat:
            self.__create_graph_edge_feat_holders(
                edge_feat_name, edge_feat_shape, edge_feat_dtype)

    def __create_graph_attr_holders(self):
        """Create data holders for graph attributes.
        """
Y
Yelrose 已提交
590
        self._num_edges = L.data(
W
Webbley 已提交
591 592 593 594 595
            self._data_name_prefix + '/num_edges',
            shape=[1],
            append_batch_size=False,
            dtype="int64",
            stop_gradient=True)
Y
Yelrose 已提交
596
        self._num_graph = L.data(
W
Webbley 已提交
597 598 599 600 601
            self._data_name_prefix + '/num_graph',
            shape=[1],
            append_batch_size=False,
            dtype="int64",
            stop_gradient=True)
Y
Yelrose 已提交
602
        self._edges_src = L.data(
L
liweibin 已提交
603
            self._data_name_prefix + '/edges_src',
Y
yelrose 已提交
604 605
            shape=[None],
            append_batch_size=False,
Y
Yelrose 已提交
606
            dtype="int64",
Y
yelrose 已提交
607
            stop_gradient=True)
Y
Yelrose 已提交
608
        self._edges_dst = L.data(
L
liweibin 已提交
609
            self._data_name_prefix + '/edges_dst',
Y
yelrose 已提交
610 611
            shape=[None],
            append_batch_size=False,
Y
Yelrose 已提交
612
            dtype="int64",
Y
yelrose 已提交
613
            stop_gradient=True)
Y
Yelrose 已提交
614
        self._num_nodes = L.data(
L
liweibin 已提交
615
            self._data_name_prefix + '/num_nodes',
Y
yelrose 已提交
616 617
            shape=[1],
            append_batch_size=False,
Y
Yelrose 已提交
618
            dtype='int64',
Y
yelrose 已提交
619
            stop_gradient=True)
W
Webbley 已提交
620

Y
Yelrose 已提交
621
        self._edge_uniq_dst = L.data(
L
liweibin 已提交
622
            self._data_name_prefix + "/uniq_dst",
Y
yelrose 已提交
623 624
            shape=[None],
            append_batch_size=False,
Y
Yelrose 已提交
625
            dtype="int64",
Y
yelrose 已提交
626
            stop_gradient=True)
W
Webbley 已提交
627

Y
Yelrose 已提交
628
        self._graph_lod = L.data(
W
Webbley 已提交
629 630 631 632 633 634
            self._data_name_prefix + "/graph_lod",
            shape=[None],
            append_batch_size=False,
            dtype="int32",
            stop_gradient=True)

Y
Yelrose 已提交
635
        self._edge_uniq_dst_count = L.data(
L
liweibin 已提交
636
            self._data_name_prefix + "/uniq_dst_count",
Y
yelrose 已提交
637 638 639 640
            shape=[None],
            append_batch_size=False,
            dtype="int32",
            stop_gradient=True)
W
Webbley 已提交
641

Y
Yelrose 已提交
642
        self._node_ids = L.data(
L
liweibin 已提交
643
            self._data_name_prefix + "/node_ids",
Y
yelrose 已提交
644 645
            shape=[None],
            append_batch_size=False,
Y
Yelrose 已提交
646
            dtype="int64",
Y
yelrose 已提交
647
            stop_gradient=True)
Y
Yelrose 已提交
648
        self._indegree = L.data(
L
liweibin 已提交
649
            self._data_name_prefix + "/indegree",
Y
yelrose 已提交
650 651
            shape=[None],
            append_batch_size=False,
Y
Yelrose 已提交
652
            dtype="int64",
Y
yelrose 已提交
653
            stop_gradient=True)
Y
Yelrose 已提交
654
        self._holder_list.extend([
W
Webbley 已提交
655 656 657 658 659 660 661 662 663
            self._edges_src,
            self._edges_dst,
            self._num_nodes,
            self._edge_uniq_dst,
            self._edge_uniq_dst_count,
            self._node_ids,
            self._indegree,
            self._graph_lod,
            self._num_graph,
W
Webbley 已提交
664
            self._num_edges,
Y
Yelrose 已提交
665
        ])
Y
yelrose 已提交
666 667 668 669 670

    def __create_graph_node_feat_holders(self, node_feat_name, node_feat_shape,
                                         node_feat_dtype):
        """Create data holders for node features.
        """
Y
Yelrose 已提交
671
        feat_holder = L.data(
L
liweibin 已提交
672
            self._data_name_prefix + '/node_feat/' + node_feat_name,
Y
yelrose 已提交
673 674 675 676
            shape=node_feat_shape,
            append_batch_size=False,
            dtype=node_feat_dtype,
            stop_gradient=True)
L
liweibin 已提交
677
        self.node_feat_tensor_dict[node_feat_name] = feat_holder
Y
Yelrose 已提交
678
        self._holder_list.append(feat_holder)
Y
yelrose 已提交
679 680 681 682 683

    def __create_graph_edge_feat_holders(self, edge_feat_name, edge_feat_shape,
                                         edge_feat_dtype):
        """Create edge holders for edge features.
        """
Y
Yelrose 已提交
684
        feat_holder = L.data(
L
liweibin 已提交
685
            self._data_name_prefix + '/edge_feat/' + edge_feat_name,
Y
yelrose 已提交
686 687 688 689
            shape=edge_feat_shape,
            append_batch_size=False,
            dtype=edge_feat_dtype,
            stop_gradient=True)
L
liweibin 已提交
690
        self.edge_feat_tensor_dict[edge_feat_name] = feat_holder
Y
Yelrose 已提交
691
        self._holder_list.append(feat_holder)
Y
yelrose 已提交
692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709

    def to_feed(self, graph):
        """Convert the graph into feed_dict.

        This function helps to convert graph data into feed dict
        for :code:`fluid.Excecutor` to run the model.

        Args:
            graph: the :code:`Graph` data object

        Return:
            A dictionary contains data holder names and its corresponding
            data.
        """
        feed_dict = {}
        src, dst, eid = graph.sorted_edges(sort_by="dst")
        indegree = graph.indegree()
        nodes = graph.nodes
W
Webbley 已提交
710
        num_edges = len(src)
Y
yelrose 已提交
711 712
        uniq_dst = nodes[indegree > 0]
        uniq_dst_count = indegree[indegree > 0]
Y
Yelrose 已提交
713 714
        uniq_dst_count = np.cumsum(uniq_dst_count, dtype='int32')
        uniq_dst_count = np.insert(uniq_dst_count, 0, 0)
W
Webbley 已提交
715 716 717 718 719 720 721 722 723 724 725
        num_graph = graph.num_graph
        graph_lod = graph.graph_lod

        if num_edges == 0:
            # Fake Graph
            src = np.array([0], dtype="int64")
            dst = np.array([0], dtype="int64")
            eid = np.array([0], dtype="int64")

            uniq_dst_count = np.array([0, 1], dtype="int32")
            uniq_dst = np.array([0], dtype="int64")
Y
yelrose 已提交
726 727 728 729 730 731 732

        edge_feat = {}

        for key, value in graph.edge_feat.items():
            edge_feat[key] = value[eid]
        node_feat = graph.node_feat

W
Webbley 已提交
733 734
        feed_dict[self._data_name_prefix + '/num_edges'] = np.array(
            [num_edges], dtype="int64")
L
liweibin 已提交
735 736 737
        feed_dict[self._data_name_prefix + '/edges_src'] = src
        feed_dict[self._data_name_prefix + '/edges_dst'] = dst
        feed_dict[self._data_name_prefix + '/num_nodes'] = np.array(
W
Webbley 已提交
738
            [graph.num_nodes], dtype="int64")
L
liweibin 已提交
739 740 741 742
        feed_dict[self._data_name_prefix + '/uniq_dst'] = uniq_dst
        feed_dict[self._data_name_prefix + '/uniq_dst_count'] = uniq_dst_count
        feed_dict[self._data_name_prefix + '/node_ids'] = graph.nodes
        feed_dict[self._data_name_prefix + '/indegree'] = indegree
W
Webbley 已提交
743 744 745 746
        feed_dict[self._data_name_prefix + '/graph_lod'] = graph_lod
        feed_dict[self._data_name_prefix + '/num_graph'] = np.array(
            [num_graph], dtype="int64")
        feed_dict[self._data_name_prefix + '/indegree'] = indegree
L
liweibin 已提交
747 748 749

        for key in self.node_feat_tensor_dict:
            feed_dict[self._data_name_prefix + '/node_feat/' +
Y
Yelrose 已提交
750
                      key] = node_feat[key]
Y
yelrose 已提交
751

L
liweibin 已提交
752 753
        for key in self.edge_feat_tensor_dict:
            feed_dict[self._data_name_prefix + '/edge_feat/' +
Y
Yelrose 已提交
754
                      key] = edge_feat[key]
Y
yelrose 已提交
755 756

        return feed_dict
Y
Yelrose 已提交
757 758 759 760 761 762

    @property
    def holder_list(self):
        """Return the holder list.
        """
        return self._holder_list
Y
Yelrose 已提交
763 764 765 766 767 768 769 770 771 772 773 774 775 776


def get_degree(edge, num_nodes):
    init_output = L.fill_constant(
        shape=[num_nodes], value=0, dtype="float32")
    init_output.stop_gradient = True
    final_output = L.scatter(init_output,
                       edge,
                       L.full_like(edge, 1, dtype="float32"),
                       overwrite=False)
    return final_output

class DropEdgeWrapper(BaseGraphWrapper):
    """Implement of Edge Drop """
Y
Yelrose 已提交
777
    def __init__(self, graph_wrapper, dropout, keep_self_loop=True):
Y
Yelrose 已提交
778 779 780 781 782 783 784 785 786 787 788 789 790 791
        super(DropEdgeWrapper, self).__init__()

        # Copy Node's information
        for key, value in graph_wrapper.node_feat.items():
            self.node_feat_tensor_dict[key] = value

        self._num_nodes = graph_wrapper.num_nodes 
        self._graph_lod = graph_wrapper.graph_lod
        self._num_graph = graph_wrapper.num_graph
        self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") 
     
        # Dropout Edges
        src, dst = graph_wrapper.edges
        u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.)
Y
Yelrose 已提交
792
        
Y
Yelrose 已提交
793 794 795 796 797 798

        # Avoid Empty Edges
        keeped = L.cast(u > dropout, dtype="float32")
        self._num_edges = L.reduce_sum(L.cast(keeped, "int32"))
        keeped = keeped + L.cast(self._num_edges == 0, dtype="float32")

Y
Yelrose 已提交
799 800 801 802
        if keep_self_loop:
            self_loop = L.cast(src == dst, dtype="float32")
            keeped = keeped + self_loop

Y
Yelrose 已提交
803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820
        keeped = (keeped > 0.5)
        src = paddle_helper.masked_select(src, keeped)
        dst = paddle_helper.masked_select(dst, keeped)
        src.stop_gradient=True
        dst.stop_gradient=True
        self._edges_src = src 
        self._edges_dst = dst 

        for key, value in graph_wrapper.edge_feat.items():
            self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped)
        
        self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32")
        self._edge_uniq_dst.stop_gradient=True
        last = L.reduce_sum(uniq_count, keep_dim=True)
        uniq_count = L.cumsum(uniq_count, exclusive=True)
        self._edge_uniq_dst_count = L.concat([uniq_count, last])
        self._edge_uniq_dst_count.stop_gradient=True
        self._indegree = get_degree(self._edges_dst, self._num_nodes)