diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index a4c77ef3f8da31ff58697070d202425888deca54..9f70204e2f726d50b0083387c84da73d723a0d82 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -20,12 +20,13 @@ import numpy as np from mindspore._c_dataengine import Graph from mindspore._c_dataengine import Tensor -from .validators import check_gnn_get_all_nodes, check_gnn_get_all_neighbors, check_gnn_get_node_feature +from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_neighbors, \ + check_gnn_get_node_feature class GraphData: """ - Reads th graph dataset used for GNN training from the shared file and database. + Reads the graph dataset used for GNN training from the shared file and database. Args: dataset_file (str): One of file names in dataset. @@ -33,6 +34,7 @@ class GraphData: (default=None). """ + @check_gnn_graphdata def __init__(self, dataset_file, num_parallel_workers=None): self._dataset_file = dataset_file if num_parallel_workers is None: @@ -45,7 +47,7 @@ class GraphData: Get all nodes in the graph. Args: - node_type (int): Specify the tpye of node. + node_type (int): Specify the type of node. Returns: numpy.ndarray: array of nodes. @@ -67,7 +69,7 @@ class GraphData: Args: node_list (list or numpy.ndarray): The given list of nodes. - neighbor_type (int): Specify the tpye of neighbor. + neighbor_type (int): Specify the type of neighbor. Returns: numpy.ndarray: array of nodes. diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index c9c06e559ca0a0e13c80e906654ea54b16ec3319..abbc15f0fec92bd15db59a9f24ca3d924e4bfee8 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -19,6 +19,7 @@ import inspect as ins import os from functools import wraps from multiprocessing import cpu_count +import numpy as np from mindspore._c_expression import typing from . import samplers from . import datasets @@ -1075,14 +1076,48 @@ def check_split(method): return new_method -def check_list_or_ndarray(param, param_name): - if (not isinstance(param, list)) and (not hasattr(param, 'tolist')): - raise TypeError("Wrong input type for {0}, should be list, got {1}".format( +def check_gnn_graphdata(method): + """check the input arguments of graphdata.""" + + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + # check dataset_file; required argument + dataset_file = param_dict.get('dataset_file') + if dataset_file is None: + raise ValueError("dataset_file is not provided.") + check_dataset_file(dataset_file) + + nreq_param_int = ['num_parallel_workers'] + + check_param_type(nreq_param_int, param_dict, int) + + return method(*args, **kwargs) + + return new_method + + +def check_gnn_list_or_ndarray(param, param_name): + """Check if the input parameter is list or numpy.ndarray.""" + + if isinstance(param, list): + for m in param: + if not isinstance(m, int): + raise TypeError( + "Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m))) + elif isinstance(param, np.ndarray): + if not param.dtype == np.int32: + raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( + param_name, param.dtype)) + else: + raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( param_name, type(param))) def check_gnn_get_all_nodes(method): """A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) @@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method): param_dict = make_param_dict(method, args, kwargs) # check node_list; required argument - check_list_or_ndarray(param_dict.get("node_list"), 'node_list') + check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') # check neighbor_type; required argument check_type(param_dict.get("neighbor_type"), 'neighbor_type', int) @@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method): return new_method -def check_aligned_list(param, param_name): +def check_aligned_list(param, param_name, membor_type): """Check whether the structure of each member of the list is the same.""" + if not isinstance(param, list): raise TypeError("Parameter {0} is not a list".format(param_name)) membor_have_list = None list_len = None for membor in param: if isinstance(membor, list): - check_aligned_list(membor, param_name) + check_aligned_list(membor, param_name, membor_type) if membor_have_list not in (None, True): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name): membor_have_list = True list_len = len(membor) else: + if not isinstance(membor, membor_type): + raise TypeError("Each membor in {0} should be of type int. Got {1}.".format( + param_name, type(membor))) if membor_have_list not in (None, False): raise TypeError("The type of each member of the parameter {0} is inconsistent".format( param_name)) @@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name): def check_gnn_get_node_feature(method): """A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" + @wraps(method) def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) # check node_list; required argument node_list = param_dict.get("node_list") - check_list_or_ndarray(node_list, 'node_list') if isinstance(node_list, list): - check_aligned_list(node_list, 'node_list') + check_aligned_list(node_list, 'node_list', int) + elif isinstance(node_list, np.ndarray): + if not node_list.dtype == np.int32: + raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( + node_list, node_list.dtype)) + else: + raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( + 'node_list', type(node_list))) # check feature_types; required argument - check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types') + check_gnn_list_or_ndarray(param_dict.get( + "feature_types"), 'feature_types') return method(*args, **kwargs) diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 35c6d02fc7e8cf940fcc430ff6f4a35cf962357c..4aa4fc89ee206e51393648e4b44f2acbfbd297eb 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import pytest +import numpy as np import mindspore.dataset as ds from mindspore import log as logger @@ -23,8 +24,7 @@ def test_graphdata_getfullneighbor(): g = ds.GraphData(DATASET_FILE, 2) nodes = g.get_all_nodes(1) assert len(nodes) == 10 - nodes_list = nodes.tolist() - neighbor = g.get_all_neighbors(nodes_list, 2) + neighbor = g.get_all_neighbors(nodes, 2) assert neighbor.shape == (10, 6) row_tensor = g.get_node_feature(neighbor.tolist(), [2, 3]) assert row_tensor[0].shape == (10, 6) @@ -60,6 +60,14 @@ def test_graphdata_getnodefeature_input_check(): input_list = [[1, 1], [1, 1]] g.get_node_feature(input_list, 1) + with pytest.raises(TypeError): + input_list = [[1, 0.1], [1, 1]] + g.get_node_feature(input_list, 1) + + with pytest.raises(TypeError): + input_list = np.array([[1, 0.1], [1, 1]]) + g.get_node_feature(input_list, 1) + with pytest.raises(TypeError): input_list = [[1, 1], [1, 1]] g.get_node_feature(input_list, ["a"])