diff --git a/python/paddle/geometric/math.py b/python/paddle/geometric/math.py index 7a6db7d10aa9919548eca00e196fc529698dd8f8..22b0f32114a2213cb4e5768fe5db26cecac7fed1 100644 --- a/python/paddle/geometric/math.py +++ b/python/paddle/geometric/math.py @@ -32,16 +32,15 @@ def segment_sum(data, segment_ids, name=None): Args: data (Tensor): A tensor, available data type float32, float64, int32, int64, float16. segment_ids (Tensor): A 1-D tensor, which have the same size - with the first dimension of input data. + with the first dimension of input data. Available data type is int32, int64. - name (str, optional): Name for the operation (optional, default is None). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): the reduced result. + - output (Tensor), the reduced result. Examples: - .. code-block:: python import paddle @@ -54,29 +53,30 @@ def segment_sum(data, segment_ids, name=None): if in_dygraph_mode(): return _C_ops.segment_pool(data, segment_ids, "SUM")[0] if _in_legacy_dygraph(): - out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', - "SUM") + out, tmp = _legacy_C_ops.segment_pool( + data, segment_ids, 'pooltype', "SUM" + ) return out check_variable_and_dtype( - data, "X", ("float32", "float64", "int32", "int64", "float16"), - "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") + data, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "segment_pool", + ) + check_variable_and_dtype( + segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool" + ) helper = LayerHelper("segment_sum", **locals()) out = helper.create_variable_for_type_inference(dtype=data.dtype) summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op(type="segment_pool", - inputs={ - "X": data, - "SegmentIds": segment_ids - }, - outputs={ - "Out": out, - "SummedIds": summed_ids - }, - attrs={"pooltype": "SUM"}) + helper.append_op( + type="segment_pool", + inputs={"X": data, "SegmentIds": segment_ids}, + outputs={"Out": out, "SummedIds": summed_ids}, + attrs={"pooltype": "SUM"}, + ) return out @@ -84,7 +84,7 @@ def segment_mean(data, segment_ids, name=None): r""" Segment mean Operator. - Ihis operator calculate the mean value of input `data` which + This operator calculate the mean value of input `data` which with the same index in `segment_ids`. It computes a tensor such that $out_i = \\frac{1}{n_i} \\sum_{j} data[j]$ where sum is over j such that 'segment_ids[j] == i' and $n_i$ is the number @@ -92,17 +92,16 @@ def segment_mean(data, segment_ids, name=None): Args: data (tensor): a tensor, available data type float32, float64, int32, int64, float16. - segment_ids (tensor): a 1-d tensor, which have the same size - with the first dimension of input data. + segment_ids (tensor): a 1-d tensor, which have the same size + with the first dimension of input data. available data type is int32, int64. - name (str, optional): Name for the operation (optional, default is None). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): the reduced result. + - output (Tensor), the reduced result. Examples: - .. code-block:: python import paddle @@ -116,29 +115,30 @@ def segment_mean(data, segment_ids, name=None): if in_dygraph_mode(): return _C_ops.segment_pool(data, segment_ids, "MEAN")[0] if _in_legacy_dygraph(): - out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', - "MEAN") + out, tmp = _legacy_C_ops.segment_pool( + data, segment_ids, 'pooltype', "MEAN" + ) return out check_variable_and_dtype( - data, "X", ("float32", "float64", "int32", "int64", "float16"), - "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") + data, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "segment_pool", + ) + check_variable_and_dtype( + segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool" + ) helper = LayerHelper("segment_mean", **locals()) out = helper.create_variable_for_type_inference(dtype=data.dtype) summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op(type="segment_pool", - inputs={ - "X": data, - "SegmentIds": segment_ids - }, - outputs={ - "Out": out, - "SummedIds": summed_ids - }, - attrs={"pooltype": "MEAN"}) + helper.append_op( + type="segment_pool", + inputs={"X": data, "SegmentIds": segment_ids}, + outputs={"Out": out, "SummedIds": summed_ids}, + attrs={"pooltype": "MEAN"}, + ) return out @@ -154,16 +154,15 @@ def segment_min(data, segment_ids, name=None): Args: data (tensor): a tensor, available data type float32, float64, int32, int64, float16. segment_ids (tensor): a 1-d tensor, which have the same size - with the first dimension of input data. + with the first dimension of input data. available data type is int32, int64. - name (str, optional): Name for the operation (optional, default is None). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): the reduced result. + - output (Tensor), the reduced result. Examples: - .. code-block:: python import paddle @@ -177,29 +176,30 @@ def segment_min(data, segment_ids, name=None): if in_dygraph_mode(): return _C_ops.segment_pool(data, segment_ids, "MIN")[0] if _in_legacy_dygraph(): - out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', - "MIN") + out, tmp = _legacy_C_ops.segment_pool( + data, segment_ids, 'pooltype', "MIN" + ) return out check_variable_and_dtype( - data, "X", ("float32", "float64", "int32", "int64", "float16"), - "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") + data, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "segment_pool", + ) + check_variable_and_dtype( + segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool" + ) helper = LayerHelper("segment_min", **locals()) out = helper.create_variable_for_type_inference(dtype=data.dtype) summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op(type="segment_pool", - inputs={ - "X": data, - "SegmentIds": segment_ids - }, - outputs={ - "Out": out, - "SummedIds": summed_ids - }, - attrs={"pooltype": "MIN"}) + helper.append_op( + type="segment_pool", + inputs={"X": data, "SegmentIds": segment_ids}, + outputs={"Out": out, "SummedIds": summed_ids}, + attrs={"pooltype": "MIN"}, + ) return out @@ -215,16 +215,15 @@ def segment_max(data, segment_ids, name=None): Args: data (tensor): a tensor, available data type float32, float64, int32, int64, float16. segment_ids (tensor): a 1-d tensor, which have the same size - with the first dimension of input data. + with the first dimension of input data. available data type is int32, int64. - name (str, optional): Name for the operation (optional, default is None). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - output (Tensor): the reduced result. + - output (Tensor), the reduced result. Examples: - .. code-block:: python import paddle @@ -238,27 +237,28 @@ def segment_max(data, segment_ids, name=None): if in_dygraph_mode(): return _C_ops.segment_pool(data, segment_ids, "MAX")[0] if _in_legacy_dygraph(): - out, tmp = _legacy_C_ops.segment_pool(data, segment_ids, 'pooltype', - "MAX") + out, tmp = _legacy_C_ops.segment_pool( + data, segment_ids, 'pooltype', "MAX" + ) return out check_variable_and_dtype( - data, "X", ("float32", "float64", "int32", "int64", "float16"), - "segment_pool") - check_variable_and_dtype(segment_ids, "SegmentIds", ("int32", "int64"), - "segment_pool") + data, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "segment_pool", + ) + check_variable_and_dtype( + segment_ids, "SegmentIds", ("int32", "int64"), "segment_pool" + ) helper = LayerHelper("segment_max", **locals()) out = helper.create_variable_for_type_inference(dtype=data.dtype) summed_ids = helper.create_variable_for_type_inference(dtype=data.dtype) - helper.append_op(type="segment_pool", - inputs={ - "X": data, - "SegmentIds": segment_ids - }, - outputs={ - "Out": out, - "SummedIds": summed_ids - }, - attrs={"pooltype": "MAX"}) + helper.append_op( + type="segment_pool", + inputs={"X": data, "SegmentIds": segment_ids}, + outputs={"Out": out, "SummedIds": summed_ids}, + attrs={"pooltype": "MAX"}, + ) return out diff --git a/python/paddle/geometric/message_passing/send_recv.py b/python/paddle/geometric/message_passing/send_recv.py index 03a272aa6af08fc9d8e92eca54882901251e1a3b..b397a863702b7a94fc59c49df1832eb50b134b2c 100644 --- a/python/paddle/geometric/message_passing/send_recv.py +++ b/python/paddle/geometric/message_passing/send_recv.py @@ -14,29 +14,38 @@ import numpy as np from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.framework import _non_static_mode, _in_legacy_dygraph, in_dygraph_mode +from paddle.fluid.framework import ( + _non_static_mode, + _in_legacy_dygraph, + in_dygraph_mode, +) from paddle.fluid.framework import Variable -from paddle.fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, convert_dtype +from paddle.fluid.data_feeder import ( + check_variable_and_dtype, + check_type, + check_dtype, + convert_dtype, +) from paddle import _C_ops, _legacy_C_ops -from .utils import convert_out_size_to_list, get_out_size_tensor_inputs, reshape_lhs_rhs +from .utils import ( + convert_out_size_to_list, + get_out_size_tensor_inputs, + reshape_lhs_rhs, +) __all__ = [] -def send_u_recv(x, - src_index, - dst_index, - reduce_op="sum", - out_size=None, - name=None): +def send_u_recv( + x, src_index, dst_index, reduce_op="sum", out_size=None, name=None +): """ - Graph Learning message passing api. - This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` - to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor + to gather the corresponding data, and then use `dst_index` to update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. .. code-block:: text @@ -65,21 +74,20 @@ def send_u_recv(x, x (Tensor): The input tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. - dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. - The available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`. Default value is `sum`. - out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or + out_size (int|Tensor|None): We can set `out_size` to get necessary output shape. If not set or out_size is smaller or equal to 0, then this input will not be used. - Otherwise, `out_size` should be equal with or larger than + Otherwise, `out_size` should be equal with or larger than max(dst_index) + 1. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. - If `out_size` is set correctly, then it should have the same shape as `x` except - the 0th dimension. + - out (Tensor), the output tensor, should have the same shape and same dtype as input tensor `x`. + If `out_size` is set correctly, then it should have the same shape as `x` except the 0th dimension. Examples: .. code-block:: python @@ -110,74 +118,93 @@ def send_u_recv(x, if reduce_op not in ["sum", "mean", "max", "min"]: raise ValueError( "reduce_op should be `sum`, `mean`, `max` or `min`, but received %s" - % reduce_op) + % reduce_op + ) # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) - out, tmp = _legacy_C_ops.graph_send_recv(x, src_index, dst_index, - None, 'reduce_op', - reduce_op.upper(), 'out_size', - out_size) + out, tmp = _legacy_C_ops.graph_send_recv( + x, + src_index, + dst_index, + None, + 'reduce_op', + reduce_op.upper(), + 'out_size', + out_size, + ) return out if in_dygraph_mode(): out_size = convert_out_size_to_list(out_size) - return _C_ops.graph_send_recv(x, src_index, dst_index, - reduce_op.upper(), out_size) + return _C_ops.graph_send_recv( + x, src_index, dst_index, reduce_op.upper(), out_size + ) check_variable_and_dtype( - x, "X", ("float32", "float64", "int32", "int64", "float16"), - "graph_send_recv") - check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), - "graph_send_recv") - check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), - "graph_send_recv") + x, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "graph_send_recv", + ) + check_variable_and_dtype( + src_index, "Src_index", ("int32", "int64"), "graph_send_recv" + ) + check_variable_and_dtype( + dst_index, "Dst_index", ("int32", "int64"), "graph_send_recv" + ) if out_size: - check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), - 'graph_send_recv') + check_type( + out_size, + 'out_size', + (int, np.int32, np.int64, Variable), + 'graph_send_recv', + ) if isinstance(out_size, Variable): - check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], - 'graph_send_recv') + check_dtype( + out_size.dtype, 'out_size', ['int32', 'int64'], 'graph_send_recv' + ) helper = LayerHelper("send_u_recv", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) - dst_count = helper.create_variable_for_type_inference(dtype="int32", - stop_gradient=True) + dst_count = helper.create_variable_for_type_inference( + dtype="int32", stop_gradient=True + ) inputs = {"X": x, "Src_index": src_index, "Dst_index": dst_index} attrs = {"reduce_op": reduce_op.upper()} - get_out_size_tensor_inputs(inputs=inputs, - attrs=attrs, - out_size=out_size, - op_type='graph_send_recv') - - helper.append_op(type="graph_send_recv", - inputs=inputs, - outputs={ - "Out": out, - "Dst_count": dst_count - }, - attrs=attrs) + get_out_size_tensor_inputs( + inputs=inputs, attrs=attrs, out_size=out_size, op_type='graph_send_recv' + ) + + helper.append_op( + type="graph_send_recv", + inputs=inputs, + outputs={"Out": out, "Dst_count": dst_count}, + attrs=attrs, + ) return out -def send_ue_recv(x, - y, - src_index, - dst_index, - message_op="add", - reduce_op="sum", - out_size=None, - name=None): +def send_ue_recv( + x, + y, + src_index, + dst_index, + message_op="add", + reduce_op="sum", + out_size=None, + name=None, +): """ Graph Learning message passing api. - This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory consumption in the process of message passing. Take `x` as the input tensor, we first use `src_index` - to gather the corresponding data, after computing with `y` in different message ops like add/sub/mul/div, then use `dst_index` to - update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min. + to gather the corresponding data, after computing with `y` in different message ops like add/sub/mul/div, then use `dst_index` to + update the corresponding position of output tensor in different reduce ops, like sum, mean, max, or min. Besides, we can use `out_size` to set necessary output shape. .. code-block:: text @@ -205,13 +232,14 @@ def send_ue_recv(x, out = [[1, 3, 4], [4, 10, 12], [2, 5, 6]] + Args: x (Tensor): The input node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. y (Tensor): The input edge feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. - dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. The available data type is int32, int64. message_op (str): Different message ops for x and e, including `add`, `sub`, `mul`, `div`. reduce_op (str): Different reduce ops, including `sum`, `mean`, `max`, `min`. @@ -224,9 +252,8 @@ def send_ue_recv(x, For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor): The output tensor, should have the same shape and same dtype as input tensor `x`. - If `out_size` is set correctly, then it should have the same shape as `x` except - the 0th dimension. + - out (Tensor), the output tensor, should have the same shape and same dtype as input tensor `x`. + If `out_size` is set correctly, then it should have the same shape as `x` except the 0th dimension. Examples: .. code-block:: python @@ -259,13 +286,15 @@ def send_ue_recv(x, if message_op not in ["add", "sub", "mul", "div"]: raise ValueError( - "message_op should be `add`, `sub`, `mul`, `div`, but received %s" % - message_op) + "message_op should be `add`, `sub`, `mul`, `div`, but received %s" + % message_op + ) if reduce_op not in ["sum", "mean", "max", "min"]: raise ValueError( "reduce_op should be `sum`, `mean`, `max` or `min`, but received %s" - % reduce_op) + % reduce_op + ) x, y = reshape_lhs_rhs(x, y) @@ -274,61 +303,89 @@ def send_ue_recv(x, y = -y if message_op == "div": message_op = 'mul' - y = 1. / (y + 1e-12) + y = 1.0 / (y + 1e-12) # TODO(daisiming): Should we add judgement for out_size: max(dst_index) + 1. if _in_legacy_dygraph(): out_size = convert_out_size_to_list(out_size) - out, tmp = _legacy_C_ops.graph_send_ue_recv(x, y, src_index, dst_index, - None, 'message_op', - message_op.upper(), - 'reduce_op', - reduce_op.upper(), - 'out_size', out_size) + out, tmp = _legacy_C_ops.graph_send_ue_recv( + x, + y, + src_index, + dst_index, + None, + 'message_op', + message_op.upper(), + 'reduce_op', + reduce_op.upper(), + 'out_size', + out_size, + ) return out if in_dygraph_mode(): out_size = convert_out_size_to_list(out_size) - return _C_ops.graph_send_ue_recv(x, y, src_index, dst_index, - message_op.upper(), reduce_op.upper(), - out_size) + return _C_ops.graph_send_ue_recv( + x, + y, + src_index, + dst_index, + message_op.upper(), + reduce_op.upper(), + out_size, + ) check_variable_and_dtype( - x, "X", ("float32", "float64", "int32", "int64", "float16"), - "graph_send_ue_recv") + x, + "X", + ("float32", "float64", "int32", "int64", "float16"), + "graph_send_ue_recv", + ) check_variable_and_dtype( - y, "Y", ("float32", "float64", "int32", "int64", "float16"), - "graph_send_ue_recv") - check_variable_and_dtype(src_index, "Src_index", ("int32", "int64"), - "graph_send_ue_recv") - check_variable_and_dtype(dst_index, "Dst_index", ("int32", "int64"), - "graph_send_ue_recv") + y, + "Y", + ("float32", "float64", "int32", "int64", "float16"), + "graph_send_ue_recv", + ) + check_variable_and_dtype( + src_index, "Src_index", ("int32", "int64"), "graph_send_ue_recv" + ) + check_variable_and_dtype( + dst_index, "Dst_index", ("int32", "int64"), "graph_send_ue_recv" + ) if out_size: - check_type(out_size, 'out_size', (int, np.int32, np.int64, Variable), - 'graph_send_ue_recv') + check_type( + out_size, + 'out_size', + (int, np.int32, np.int64, Variable), + 'graph_send_ue_recv', + ) if isinstance(out_size, Variable): - check_dtype(out_size.dtype, 'out_size', ['int32', 'int64'], - 'graph_send_ue_recv') + check_dtype( + out_size.dtype, 'out_size', ['int32', 'int64'], 'graph_send_ue_recv' + ) helper = LayerHelper("send_ue_recv", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) - dst_count = helper.create_variable_for_type_inference(dtype="int32", - stop_gradient=True) + dst_count = helper.create_variable_for_type_inference( + dtype="int32", stop_gradient=True + ) inputs = {"X": x, "Y": y, "Src_index": src_index, "Dst_index": dst_index} attrs = {"message_op": message_op.upper(), "reduce_op": reduce_op.upper()} - get_out_size_tensor_inputs(inputs=inputs, - attrs=attrs, - out_size=out_size, - op_type='graph_send_ue_recv') - - helper.append_op(type="graph_send_ue_recv", - inputs=inputs, - outputs={ - "Out": out, - "Dst_count": dst_count - }, - attrs=attrs) + get_out_size_tensor_inputs( + inputs=inputs, + attrs=attrs, + out_size=out_size, + op_type='graph_send_ue_recv', + ) + + helper.append_op( + type="graph_send_ue_recv", + inputs=inputs, + outputs={"Out": out, "Dst_count": dst_count}, + attrs=attrs, + ) return out @@ -337,8 +394,8 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): Graph Learning message passing api. - This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory - consumption in the process of message passing. Take `x` as the source node feature tensor, take `y` as + This api is mainly used in Graph Learning domain, and the main purpose is to reduce intermediate memory + consumption in the process of message passing. Take `x` as the source node feature tensor, take `y` as the destination node feature tensor. Then we use `src_index` and `dst_index` to gather the corresponding data, and then compute the edge features in different message_ops like `add`, `sub`, `mul`, `div`. @@ -371,16 +428,17 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): x (Tensor): The source node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. y (Tensor): The destination node feature tensor, and the available data type is float32, float64, int32, int64. And we support float16 in gpu version. src_index (Tensor): An 1-D tensor, and the available data type is int32, int64. - dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. - The available data type is int32, int64. + dst_index (Tensor): An 1-D tensor, and should have the same shape as `src_index`. + The available data type is int32, int64. message_op (str): Different message ops for x and y, including `add`, `sub`, `mul` and `div`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: - out (Tensor): The output tensor. + - out (Tensor), the output tensor. Examples: + .. code-block:: python import paddle @@ -397,8 +455,9 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): if message_op not in ['add', 'sub', 'mul', 'div']: raise ValueError( - "message_op should be `add`, `sub`, `mul`, `div`, but received %s" % - message_op) + "message_op should be `add`, `sub`, `mul`, `div`, but received %s" + % message_op + ) x, y = reshape_lhs_rhs(x, y) @@ -407,38 +466,50 @@ def send_uv(x, y, src_index, dst_index, message_op="add", name=None): y = -y if message_op == 'div': message_op = 'mul' - y = 1. / (y + 1e-12) + y = 1.0 / (y + 1e-12) if in_dygraph_mode(): - return _C_ops.graph_send_uv(x, y, src_index, dst_index, - message_op.upper()) + return _C_ops.graph_send_uv( + x, y, src_index, dst_index, message_op.upper() + ) else: if _in_legacy_dygraph(): - return _legacy_C_ops.graph_send_uv(x, y, src_index, dst_index, - "message_op", message_op.upper()) + return _legacy_C_ops.graph_send_uv( + x, y, src_index, dst_index, "message_op", message_op.upper() + ) else: helper = LayerHelper("send_uv", **locals()) check_variable_and_dtype( - x, 'x', ['int32', 'int64', 'float32', 'float64', 'float16'], - 'graph_send_uv') + x, + 'x', + ['int32', 'int64', 'float32', 'float64', 'float16'], + 'graph_send_uv', + ) + check_variable_and_dtype( + y, + 'y', + ['int32', 'int64', 'float32', 'float64', 'float16'], + 'graph_send_uv', + ) + check_variable_and_dtype( + src_index, 'src_index', ['int32', 'int64'], 'graph_send_uv' + ) check_variable_and_dtype( - y, 'y', ['int32', 'int64', 'float32', 'float64', 'float16'], - 'graph_send_uv') - check_variable_and_dtype(src_index, 'src_index', ['int32', 'int64'], - 'graph_send_uv') - check_variable_and_dtype(dst_index, 'dst_index', ['int32', 'int64'], - 'graph_send_uv') + dst_index, 'dst_index', ['int32', 'int64'], 'graph_send_uv' + ) out = helper.create_variable_for_type_inference(dtype=x.dtype) inputs = { 'x': x, 'y': y, 'src_index': src_index, - 'dst_index': dst_index + 'dst_index': dst_index, } attrs = {'message_op': message_op.upper()} - helper.append_op(type="graph_send_uv", - inputs=inputs, - attrs=attrs, - outputs={"out": out}) + helper.append_op( + type="graph_send_uv", + inputs=inputs, + attrs=attrs, + outputs={"out": out}, + ) return out diff --git a/python/paddle/geometric/reindex.py b/python/paddle/geometric/reindex.py index 9580ff5c4ee1f2b93a850b5691a728c1a5cfd7cd..94e9dbec4a5b2296cd034a2467b2f7fff6a2b64c 100644 --- a/python/paddle/geometric/reindex.py +++ b/python/paddle/geometric/reindex.py @@ -22,161 +22,144 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [] -def reindex_graph(x, - neighbors, - count, - value_buffer=None, - index_buffer=None, - name=None): +def reindex_graph( + x, neighbors, count, value_buffer=None, index_buffer=None, name=None +): """ Reindex Graph API. This API is mainly used in Graph Learning domain, which should be used - in conjunction with `graph_sample_neighbors` API. And the main purpose - is to reindex the ids information of the input nodes, and return the + in conjunction with `paddle.geometric.sample_neighbors` API. And the main purpose + is to reindex the ids information of the input nodes, and return the corresponding graph edges after reindex. - **Notes**: - The number in x should be unique, otherwise it would cause potential errors. - We will reindex all the nodes from 0. - - Take input nodes x = [0, 1, 2] as an example. - If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], - then we know that the neighbors of 0 is [8, 9], the neighbors of 1 - is [0, 4, 7], and the neighbors of 2 is [6, 7]. - Then after graph_reindex, we will have 3 different outputs: - 1. reindex_src: [3, 4, 0, 5, 6, 7, 6] - 2. reindex_dst: [0, 0, 1, 1, 1, 2, 2] - 3. out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] - We can see that the numbers in `reindex_src` and `reindex_dst` is the corresponding index + Take input nodes x = [0, 1, 2] as an example. If we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], + then we know that the neighbors of 0 is [8, 9], the neighbors of 1 is [0, 4, 7], and the neighbors of 2 is [6, 7]. + Then after graph_reindex, we will have 3 different outputs: reindex_src: [3, 4, 0, 5, 6, 7, 6], reindex_dst: [0, 0, 1, 1, 1, 2, 2] + and out_nodes: [0, 1, 2, 8, 9, 4, 7, 6]. We can see that the numbers in `reindex_src` and `reindex_dst` is the corresponding index of nodes in `out_nodes`. + Note: + The number in x should be unique, otherwise it would cause potential errors. We will reindex all the nodes from 0. + Args: x (Tensor): The input nodes which we sample neighbors for. The available data type is int32, int64. neighbors (Tensor): The neighbors of the input nodes `x`. The data type should be the same with `x`. - count (Tensor): The neighbor count of the input nodes `x`. And the + count (Tensor): The neighbor count of the input nodes `x`. And the data type should be int32. value_buffer (Tensor|None): Value buffer for hashtable. The data type should be int32, and should be filled with -1. Only useful for gpu version. index_buffer (Tensor|None): Index buffer for hashtable. The data type should be int32, and should be filled with -1. Only useful for gpu version. - `value_buffer` and `index_buffer` should be both not None + `value_buffer` and `index_buffer` should be both not None if you want to speed up by using hashtable buffer. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. - + Returns: - reindex_src (Tensor): The source node index of graph edges after reindex. - reindex_dst (Tensor): The destination node index of graph edges after reindex. - out_nodes (Tensor): The index of unique input nodes and neighbors before reindex, - where we put the input nodes `x` in the front, and put neighbor - nodes in the back. + - reindex_src (Tensor), the source node index of graph edges after reindex. - Examples: - - .. code-block:: python + - reindex_dst (Tensor), the destination node index of graph edges after reindex. - import paddle + - out_nodes (Tensor), the index of unique input nodes and neighbors before reindex, where we put the input nodes `x` in the front, and put neighbor nodes in the back. - x = [0, 1, 2] - neighbors = [8, 9, 0, 4, 7, 6, 7] - count = [2, 3, 2] - x = paddle.to_tensor(x, dtype="int64") - neighbors = paddle.to_tensor(neighbors, dtype="int64") - count = paddle.to_tensor(count, dtype="int32") + Examples: + .. code-block:: python - reindex_src, reindex_dst, out_nodes = \ - paddle.geometric.reindex_graph(x, neighbors, count) - # reindex_src: [3, 4, 0, 5, 6, 7, 6] - # reindex_dst: [0, 0, 1, 1, 1, 2, 2] - # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] + import paddle + x = [0, 1, 2] + neighbors = [8, 9, 0, 4, 7, 6, 7] + count = [2, 3, 2] + x = paddle.to_tensor(x, dtype="int64") + neighbors = paddle.to_tensor(neighbors, dtype="int64") + count = paddle.to_tensor(count, dtype="int32") + reindex_src, reindex_dst, out_nodes = paddle.geometric.reindex_graph(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6] """ - use_buffer_hashtable = True if value_buffer is not None \ - and index_buffer is not None else False + use_buffer_hashtable = ( + True if value_buffer is not None and index_buffer is not None else False + ) if _non_static_mode(): - reindex_src, reindex_dst, out_nodes = \ - _legacy_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer, - "flag_buffer_hashtable", use_buffer_hashtable) + reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex( + x, + neighbors, + count, + value_buffer, + index_buffer, + "flag_buffer_hashtable", + use_buffer_hashtable, + ) return reindex_src, reindex_dst, out_nodes check_variable_and_dtype(x, "X", ("int32", "int64"), "graph_reindex") - check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"), - "graph_reindex") + check_variable_and_dtype( + neighbors, "Neighbors", ("int32", "int64"), "graph_reindex" + ) check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex") if use_buffer_hashtable: - check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), - "graph_reindex") - check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"), - "graph_reindex") + check_variable_and_dtype( + value_buffer, "HashTable_Value", ("int32"), "graph_reindex" + ) + check_variable_and_dtype( + index_buffer, "HashTable_Index", ("int32"), "graph_reindex" + ) helper = LayerHelper("reindex_graph", **locals()) reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype) reindex_dst = helper.create_variable_for_type_inference(dtype=x.dtype) out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype) - helper.append_op(type="graph_reindex", - inputs={ - "X": - x, - "Neighbors": - neighbors, - "Count": - count, - "HashTable_Value": - value_buffer if use_buffer_hashtable else None, - "HashTable_Index": - index_buffer if use_buffer_hashtable else None, - }, - outputs={ - "Reindex_Src": reindex_src, - "Reindex_Dst": reindex_dst, - "Out_Nodes": out_nodes - }, - attrs={"flag_buffer_hashtable": use_buffer_hashtable}) + helper.append_op( + type="graph_reindex", + inputs={ + "X": x, + "Neighbors": neighbors, + "Count": count, + "HashTable_Value": value_buffer if use_buffer_hashtable else None, + "HashTable_Index": index_buffer if use_buffer_hashtable else None, + }, + outputs={ + "Reindex_Src": reindex_src, + "Reindex_Dst": reindex_dst, + "Out_Nodes": out_nodes, + }, + attrs={"flag_buffer_hashtable": use_buffer_hashtable}, + ) return reindex_src, reindex_dst, out_nodes -def reindex_heter_graph(x, - neighbors, - count, - value_buffer=None, - index_buffer=None, - name=None): +def reindex_heter_graph( + x, neighbors, count, value_buffer=None, index_buffer=None, name=None +): """ Reindex HeterGraph API. This API is mainly used in Graph Learning domain, which should be used - in conjunction with `graph_sample_neighbors` API. And the main purpose + in conjunction with `paddle.geometric.sample_neighbors` API. And the main purpose is to reindex the ids information of the input nodes, and return the corresponding graph edges after reindex. - **Notes**: - The number in x should be unique, otherwise it would cause potential errors. - We support multi-edge-types neighbors reindexing in reindex_heter_graph api. - We will reindex all the nodes from 0. - - Take input nodes x = [0, 1, 2] as an example. - For graph A, suppose we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], - then we know that the neighbors of 0 is [8, 9], the neighbors of 1 - is [0, 4, 7], and the neighbors of 2 is [6, 7]. - For graph B, suppose we have neighbors = [0, 2, 3, 5, 1], and count = [1, 3, 1], - then we know that the neighbors of 0 is [0], the neighbors of 1 is [2, 3, 5], - and the neighbors of 3 is [1]. - We will get following outputs: - 1. reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] - 2. reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] - 3. out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] + Take input nodes x = [0, 1, 2] as an example. For graph A, suppose we have neighbors = [8, 9, 0, 4, 7, 6, 7], and count = [2, 3, 2], + then we know that the neighbors of 0 is [8, 9], the neighbors of 1 is [0, 4, 7], and the neighbors of 2 is [6, 7]. For graph B, + suppose we have neighbors = [0, 2, 3, 5, 1], and count = [1, 3, 1], then we know that the neighbors of 0 is [0], the neighbors of 1 is [2, 3, 5], + and the neighbors of 3 is [1]. We will get following outputs: reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1], reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] + and out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5]. + + Note: + The number in x should be unique, otherwise it would cause potential errors. We support multi-edge-types neighbors reindexing in reindex_heter_graph api. We will reindex all the nodes from 0. Args: x (Tensor): The input nodes which we sample neighbors for. The available data type is int32, int64. - neighbors (list|tuple): The neighbors of the input nodes `x` from different graphs. + neighbors (list|tuple): The neighbors of the input nodes `x` from different graphs. The data type should be the same with `x`. - count (list|tuple): The neighbor counts of the input nodes `x` from different graphs. + count (list|tuple): The neighbor counts of the input nodes `x` from different graphs. And the data type should be int32. value_buffer (Tensor|None): Value buffer for hashtable. The data type should be int32, and should be filled with -1. Only useful for gpu version. @@ -188,48 +171,52 @@ def reindex_heter_graph(x, For more information, please refer to :ref:`api_guide_Name`. Returns: - reindex_src (Tensor): The source node index of graph edges after reindex. - reindex_dst (Tensor): The destination node index of graph edges after reindex. - out_nodes (Tensor): The index of unique input nodes and neighbors before reindex, - where we put the input nodes `x` in the front, and put neighbor - nodes in the back. - - Examples: + - reindex_src (Tensor), the source node index of graph edges after reindex. - .. code-block:: python + - reindex_dst (Tensor), the destination node index of graph edges after reindex. - import paddle + - out_nodes (Tensor), the index of unique input nodes and neighbors before reindex, + where we put the input nodes `x` in the front, and put neighbor + nodes in the back. - x = [0, 1, 2] - neighbors_a = [8, 9, 0, 4, 7, 6, 7] - count_a = [2, 3, 2] - x = paddle.to_tensor(x, dtype="int64") - neighbors_a = paddle.to_tensor(neighbors_a, dtype="int64") - count_a = paddle.to_tensor(count_a, dtype="int32") - - neighbors_b = [0, 2, 3, 5, 1] - count_b = [1, 3, 1] - neighbors_b = paddle.to_tensor(neighbors_b, dtype="int64") - count_b = paddle.to_tensor(count_b, dtype="int32") + Examples: + .. code-block:: python - neighbors = [neighbors_a, neighbors_b] - count = [count_a, count_b] - reindex_src, reindex_dst, out_nodes = \ - paddle.geometric.reindex_heter_graph(x, neighbors, count) - # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] - # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] - # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] + import paddle + x = [0, 1, 2] + neighbors_a = [8, 9, 0, 4, 7, 6, 7] + count_a = [2, 3, 2] + x = paddle.to_tensor(x, dtype="int64") + neighbors_a = paddle.to_tensor(neighbors_a, dtype="int64") + count_a = paddle.to_tensor(count_a, dtype="int32") + neighbors_b = [0, 2, 3, 5, 1] + count_b = [1, 3, 1] + neighbors_b = paddle.to_tensor(neighbors_b, dtype="int64") + count_b = paddle.to_tensor(count_b, dtype="int32") + neighbors = [neighbors_a, neighbors_b] + count = [count_a, count_b] + reindex_src, reindex_dst, out_nodes = paddle.geometric.reindex_heter_graph(x, neighbors, count) + # reindex_src: [3, 4, 0, 5, 6, 7, 6, 0, 2, 8, 9, 1] + # reindex_dst: [0, 0, 1, 1, 1, 2, 2, 0, 1, 1, 1, 2] + # out_nodes: [0, 1, 2, 8, 9, 4, 7, 6, 3, 5] """ - use_buffer_hashtable = True if value_buffer is not None \ - and index_buffer is not None else False + use_buffer_hashtable = ( + True if value_buffer is not None and index_buffer is not None else False + ) if _non_static_mode(): neighbors = paddle.concat(neighbors, axis=0) count = paddle.concat(count, axis=0) - reindex_src, reindex_dst, out_nodes = \ - _legacy_C_ops.graph_reindex(x, neighbors, count, value_buffer, index_buffer, - "flag_buffer_hashtable", use_buffer_hashtable) + reindex_src, reindex_dst, out_nodes = _legacy_C_ops.graph_reindex( + x, + neighbors, + count, + value_buffer, + index_buffer, + "flag_buffer_hashtable", + use_buffer_hashtable, + ) return reindex_src, reindex_dst, out_nodes if isinstance(neighbors, Variable): @@ -241,15 +228,18 @@ def reindex_heter_graph(x, count = paddle.concat(count, axis=0) check_variable_and_dtype(x, "X", ("int32", "int64"), "heter_graph_reindex") - check_variable_and_dtype(neighbors, "Neighbors", ("int32", "int64"), - "graph_reindex") + check_variable_and_dtype( + neighbors, "Neighbors", ("int32", "int64"), "graph_reindex" + ) check_variable_and_dtype(count, "Count", ("int32"), "graph_reindex") if use_buffer_hashtable: - check_variable_and_dtype(value_buffer, "HashTable_Value", ("int32"), - "graph_reindex") - check_variable_and_dtype(index_buffer, "HashTable_Index", ("int32"), - "graph_reindex") + check_variable_and_dtype( + value_buffer, "HashTable_Value", ("int32"), "graph_reindex" + ) + check_variable_and_dtype( + index_buffer, "HashTable_Index", ("int32"), "graph_reindex" + ) helper = LayerHelper("reindex_heter_graph", **locals()) reindex_src = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -257,23 +247,20 @@ def reindex_heter_graph(x, out_nodes = helper.create_variable_for_type_inference(dtype=x.dtype) neighbors = paddle.concat(neighbors, axis=0) count = paddle.concat(count, axis=0) - helper.append_op(type="graph_reindex", - inputs={ - "X": - x, - "Neighbors": - neighbors, - "Count": - count, - "HashTable_Value": - value_buffer if use_buffer_hashtable else None, - "HashTable_Index": - index_buffer if use_buffer_hashtable else None, - }, - outputs={ - "Reindex_Src": reindex_src, - "Reindex_Dst": reindex_dst, - "Out_Nodes": out_nodes - }, - attrs={"flag_buffer_hashtable": use_buffer_hashtable}) + helper.append_op( + type="graph_reindex", + inputs={ + "X": x, + "Neighbors": neighbors, + "Count": count, + "HashTable_Value": value_buffer if use_buffer_hashtable else None, + "HashTable_Index": index_buffer if use_buffer_hashtable else None, + }, + outputs={ + "Reindex_Src": reindex_src, + "Reindex_Dst": reindex_dst, + "Out_Nodes": out_nodes, + }, + attrs={"flag_buffer_hashtable": use_buffer_hashtable}, + ) return reindex_src, reindex_dst, out_nodes diff --git a/python/paddle/geometric/sampling/neighbors.py b/python/paddle/geometric/sampling/neighbors.py index a9619d54a852ede17729735d1f3f7b0aac5665e0..2dd5a9fb27c6715d7064e9fd80f5f2271fe5ff69 100644 --- a/python/paddle/geometric/sampling/neighbors.py +++ b/python/paddle/geometric/sampling/neighbors.py @@ -21,25 +21,27 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [] -def sample_neighbors(row, - colptr, - input_nodes, - sample_size=-1, - eids=None, - return_eids=False, - perm_buffer=None, - name=None): +def sample_neighbors( + row, + colptr, + input_nodes, + sample_size=-1, + eids=None, + return_eids=False, + perm_buffer=None, + name=None, +): """ Graph Sample Neighbors API. This API is mainly used in Graph Learning domain, and the main purpose is to - provide high performance of graph sampling method. For example, we get the - CSC(Compressed Sparse Column) format of the input graph edges as `row` and + provide high performance of graph sampling method. For example, we get the + CSC(Compressed Sparse Column) format of the input graph edges as `row` and `colptr`, so as to convert graph data into a suitable format for sampling. - `input_nodes` means the nodes we need to sample neighbors, and `sample_sizes` + `input_nodes` means the nodes we need to sample neighbors, and `sample_sizes` means the number of neighbors and number of layers we want to sample. - Besides, we support fisher-yates sampling in GPU version. + Besides, we support fisher-yates sampling in GPU version. Args: row (Tensor): One of the components of the CSC format of the input graph, and @@ -50,10 +52,10 @@ def sample_neighbors(row, The data type should be the same with `row`. input_nodes (Tensor): The input nodes we need to sample neighbors for, and the data type should be the same with `row`. - sample_size (int): The number of neighbors we need to sample. Default value is -1, + sample_size (int): The number of neighbors we need to sample. Default value is -1, which means returning all the neighbors of the input nodes. eids (Tensor): The eid information of the input graph. If return_eids is True, - then `eids` should not be None. The data type should be the + then `eids` should not be None. The data type should be the same with `row`. Default is None. return_eids (bool): Whether to return eid information of sample edges. Default is False. perm_buffer (Tensor): Permutation buffer for fisher-yates sampling. If `use_perm_buffer` @@ -64,81 +66,106 @@ def sample_neighbors(row, For more information, please refer to :ref:`api_guide_Name`. Returns: - out_neighbors (Tensor): The sample neighbors of the input nodes. - out_count (Tensor): The number of sampling neighbors of each input node, and the shape - should be the same with `input_nodes`. - out_eids (Tensor): If `return_eids` is True, we will return the eid information of the - sample edges. + - out_neighbors (Tensor), the sample neighbors of the input nodes. + + - out_count (Tensor), the number of sampling neighbors of each input node, and the shape + should be the same with `input_nodes`. + + - out_eids (Tensor), if `return_eids` is True, we will return the eid information of the + sample edges. Examples: .. code-block:: python - import paddle - # edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4), - # (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8) - row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] - colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] - nodes = [0, 8, 1, 2] - sample_size = 2 - row = paddle.to_tensor(row, dtype="int64") - colptr = paddle.to_tensor(colptr, dtype="int64") - nodes = paddle.to_tensor(nodes, dtype="int64") - out_neighbors, out_count = \ - paddle.geometric.sample_neighbors(row, colptr, nodes, - sample_size=sample_size) + + import paddle + # edges: (3, 0), (7, 0), (0, 1), (9, 1), (1, 2), (4, 3), (2, 4), + # (9, 5), (3, 5), (9, 6), (1, 6), (9, 8), (7, 8) + row = [3, 7, 0, 9, 1, 4, 2, 9, 3, 9, 1, 9, 7] + colptr = [0, 2, 4, 5, 6, 7, 9, 11, 11, 13, 13] + nodes = [0, 8, 1, 2] + sample_size = 2 + row = paddle.to_tensor(row, dtype="int64") + colptr = paddle.to_tensor(colptr, dtype="int64") + nodes = paddle.to_tensor(nodes, dtype="int64") + out_neighbors, out_count = paddle.geometric.sample_neighbors(row, colptr, nodes, sample_size=sample_size) """ if return_eids: if eids is None: raise ValueError( - f"`eids` should not be None if `return_eids` is True.") + f"`eids` should not be None if `return_eids` is True." + ) use_perm_buffer = True if perm_buffer is not None else False if _non_static_mode(): - out_neighbors, out_count, out_eids = _legacy_C_ops.graph_sample_neighbors( - row, colptr, input_nodes, eids, perm_buffer, "sample_size", - sample_size, "return_eids", return_eids, "flag_perm_buffer", - use_perm_buffer) + ( + out_neighbors, + out_count, + out_eids, + ) = _legacy_C_ops.graph_sample_neighbors( + row, + colptr, + input_nodes, + eids, + perm_buffer, + "sample_size", + sample_size, + "return_eids", + return_eids, + "flag_perm_buffer", + use_perm_buffer, + ) if return_eids: return out_neighbors, out_count, out_eids return out_neighbors, out_count - check_variable_and_dtype(row, "Row", ("int32", "int64"), - "graph_sample_neighbors") - check_variable_and_dtype(colptr, "Col_Ptr", ("int32", "int64"), - "graph_sample_neighbors") - check_variable_and_dtype(input_nodes, "X", ("int32", "int64"), - "graph_sample_neighbors") + check_variable_and_dtype( + row, "Row", ("int32", "int64"), "graph_sample_neighbors" + ) + check_variable_and_dtype( + colptr, "Col_Ptr", ("int32", "int64"), "graph_sample_neighbors" + ) + check_variable_and_dtype( + input_nodes, "X", ("int32", "int64"), "graph_sample_neighbors" + ) if return_eids: - check_variable_and_dtype(eids, "Eids", ("int32", "int64"), - "graph_sample_neighbors") + check_variable_and_dtype( + eids, "Eids", ("int32", "int64"), "graph_sample_neighbors" + ) if use_perm_buffer: - check_variable_and_dtype(perm_buffer, "Perm_Buffer", ("int32", "int64"), - "graph_sample_neighbors") + check_variable_and_dtype( + perm_buffer, + "Perm_Buffer", + ("int32", "int64"), + "graph_sample_neighbors", + ) helper = LayerHelper("sample_neighbors", **locals()) out_neighbors = helper.create_variable_for_type_inference(dtype=row.dtype) out_count = helper.create_variable_for_type_inference(dtype=row.dtype) out_eids = helper.create_variable_for_type_inference(dtype=row.dtype) - helper.append_op(type="graph_sample_neighbors", - inputs={ - "Row": row, - "Col_Ptr": colptr, - "X": input_nodes, - "Eids": eids if return_eids else None, - "Perm_Buffer": perm_buffer if use_perm_buffer else None - }, - outputs={ - "Out": out_neighbors, - "Out_Count": out_count, - "Out_Eids": out_eids - }, - attrs={ - "sample_size": sample_size, - "return_eids": return_eids, - "flag_perm_buffer": use_perm_buffer - }) + helper.append_op( + type="graph_sample_neighbors", + inputs={ + "Row": row, + "Col_Ptr": colptr, + "X": input_nodes, + "Eids": eids if return_eids else None, + "Perm_Buffer": perm_buffer if use_perm_buffer else None, + }, + outputs={ + "Out": out_neighbors, + "Out_Count": out_count, + "Out_Eids": out_eids, + }, + attrs={ + "sample_size": sample_size, + "return_eids": return_eids, + "flag_perm_buffer": use_perm_buffer, + }, + ) if return_eids: return out_neighbors, out_count, out_eids return out_neighbors, out_count