From 24b26ee1a8a096c950f03cb1534ec1378f423e0c Mon Sep 17 00:00:00 2001 From: leonwanghui Date: Tue, 21 Apr 2020 10:20:09 +0800 Subject: [PATCH] Move args_type_check function to _checkparam.py --- mindspore/_checkparam.py | 53 ++++++++++++++----- mindspore/_extends/__init__.py | 2 +- mindspore/_extends/pynative_helper.py | 44 --------------- mindspore/context.py | 17 +++--- mindspore/parallel/_auto_parallel_context.py | 2 +- mindspore/parallel/_cost_model_context.py | 2 +- mindspore/parallel/algo_parameter_config.py | 2 +- tests/ut/python/pynative_mode/test_backend.py | 10 ++-- 8 files changed, 60 insertions(+), 72 deletions(-) delete mode 100644 mindspore/_extends/pynative_helper.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index e9a928461..7b8c89351 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -14,8 +14,9 @@ # ============================================================================ """Check parameters.""" import re +import inspect from enum import Enum -from functools import reduce +from functools import reduce, wraps from itertools import repeat from collections.abc import Iterable @@ -181,7 +182,7 @@ class Validator: @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): - """Checks whether some type is sublcass of another type""" + """Checks whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -240,7 +241,6 @@ class Validator: elem_types = map(_check_tensor_type, args.items()) reduce(_check_types_same, elem_types) - @staticmethod def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): """ @@ -261,7 +261,7 @@ class Validator: def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 - excp_flag = False + except_flag = False if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): arg1_type = arg1_type.element_type() arg2_type = arg2_type.element_type() @@ -271,9 +271,9 @@ class Validator: arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: - excp_flag = True + except_flag = True - if excp_flag or arg1_type != arg2_type: + if except_flag or arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') return arg1 @@ -283,11 +283,12 @@ class Validator: def check_value_type(arg_name, arg_value, valid_types, prim_name): """Checks whether a value is instance of some types.""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') @@ -303,6 +304,7 @@ class Validator: def check_type_name(arg_name, arg_type, valid_types, prim_name): """Checks whether a type in some specified types""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def get_typename(t): return t.__name__ if hasattr(t, '__name__') else str(t) @@ -368,9 +370,9 @@ class ParamValidator: @staticmethod def check_isinstance(arg_name, arg_value, classes): - """Check arg isintance of classes""" + """Check arg isinstance of classes""" if not isinstance(arg_value, classes): - raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') return arg_value @staticmethod @@ -384,7 +386,7 @@ class ParamValidator: @staticmethod def check_subclass(arg_name, type_, template_type, with_type_of=True): - """Check whether some type is sublcass of another type""" + """Check whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -402,9 +404,9 @@ class ParamValidator: @staticmethod def check_bool(arg_name, arg_value): - """Check arg isintance of bool""" + """Check arg isinstance of bool""" if not isinstance(arg_value, bool): - raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') return arg_value @staticmethod @@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): if re.match(reg, target, flag) is None: raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) return True + + +def args_type_check(*type_args, **type_kwargs): + """Check whether input data type is correct.""" + + def type_check(func): + sig = inspect.signature(func) + bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_types + bound_values = sig.bind(*args, **kwargs) + argument_dict = bound_values.arguments + if "kwargs" in bound_types: + bound_types = bound_types["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in bound_types: + if value is not None and not isinstance(value, bound_types[name]): + raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) + return func(*args, **kwargs) + + return wrapper + + return type_check diff --git a/mindspore/_extends/__init__.py b/mindspore/_extends/__init__.py index 5eabfcd97..91e1192e7 100644 --- a/mindspore/_extends/__init__.py +++ b/mindspore/_extends/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -Extension functions. +Extension functions. Python functions that will be called in the c++ parts of MindSpore. """ diff --git a/mindspore/_extends/pynative_helper.py b/mindspore/_extends/pynative_helper.py deleted file mode 100644 index 0b93ab926..000000000 --- a/mindspore/_extends/pynative_helper.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Pynative mode help module.""" -from inspect import signature -from functools import wraps - - -def args_type_check(*type_args, **type_kwargs): - """Check whether input data type is correct.""" - - def type_check(func): - sig = signature(func) - bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments - - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal bound_types - bound_values = sig.bind(*args, **kwargs) - argument_dict = bound_values.arguments - if "kwargs" in bound_types: - bound_types = bound_types["kwargs"] - if "kwargs" in argument_dict: - argument_dict = argument_dict["kwargs"] - for name, value in argument_dict.items(): - if name in bound_types: - if value is not None and not isinstance(value, bound_types[name]): - raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) - return func(*args, **kwargs) - - return wrapper - - return type_check diff --git a/mindspore/context.py b/mindspore/context.py index 311ca745f..f6fe8705f 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -14,7 +14,7 @@ # ============================================================================ """ The context of mindspore, used to configure the current execution environment, -including execution mode, execution backend and other feature switchs. +including execution mode, execution backend and other feature switches. """ import os import threading @@ -22,7 +22,7 @@ from collections import namedtuple from types import FunctionType from mindspore import log as logger from mindspore._c_expression import MSContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context @@ -38,7 +38,7 @@ def _make_directory(path: str): """Make directory.""" real_path = None if path is None or not isinstance(path, str) or path.strip() == "": - raise ValueError(f"Input path `{path}` is invaild type") + raise ValueError(f"Input path `{path}` is invalid type") # convert the relative paths path = os.path.realpath(path) @@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. """ + def __init__(self): super(_ThreadLocalInfo, self).__init__() self._reserve_class_name_in_scope = True @@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local): Args: is_pynative (bool): Whether to adopt the PyNative mode. """ + def __init__(self, is_pynative): super(_ContextSwitchInfo, self).__init__() self.context_stack = [] @@ -209,7 +211,7 @@ class _Context: def device_target(self, target): success = self._context_handle.set_device_target(target) if not success: - raise ValueError("target device name is invalid!!!") + raise ValueError("Target device name is invalid!!!") @property def device_id(self): @@ -335,7 +337,7 @@ class _Context: @graph_memory_max_size.setter def graph_memory_max_size(self, graph_memory_max_size): - if check_input_fotmat(graph_memory_max_size): + if check_input_format(graph_memory_max_size): graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) else: @@ -347,7 +349,7 @@ class _Context: @variable_memory_max_size.setter def variable_memory_max_size(self, variable_memory_max_size): - if check_input_fotmat(variable_memory_max_size): + if check_input_format(variable_memory_max_size): variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) else: @@ -367,12 +369,13 @@ class _Context: thread_info.debug_runtime = enable -def check_input_fotmat(x): +def check_input_format(x): import re pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' result = re.match(pattern, x) return result is not None + _k_context = None diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index c99ac4a3c..bf4b99085 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -17,7 +17,7 @@ import threading import mindspore.context as context from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore._c_expression import AutoParallelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _AutoParallelContext: diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 0920d66f4..54cca5516 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -15,7 +15,7 @@ """Context of cost_model in auto_parallel""" import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _CostModelContext: diff --git a/mindspore/parallel/algo_parameter_config.py b/mindspore/parallel/algo_parameter_config.py index d1e4aa87a..244156da3 100644 --- a/mindspore/parallel/algo_parameter_config.py +++ b/mindspore/parallel/algo_parameter_config.py @@ -16,7 +16,7 @@ import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/tests/ut/python/pynative_mode/test_backend.py b/tests/ut/python/pynative_mode/test_backend.py index 7258b6948..fae197485 100644 --- a/tests/ut/python/pynative_mode/test_backend.py +++ b/tests/ut/python/pynative_mode/test_backend.py @@ -14,16 +14,13 @@ # ============================================================================ """ test_backend """ import os -import numpy as np import pytest from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore import context +from mindspore import context, ms_function from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter -from mindspore._extends.pynative_helper import args_type_check -from mindspore.common.tensor import Tensor -from mindspore.common.api import ms_function +from mindspore._checkparam import args_type_check def setup_module(module): @@ -32,6 +29,7 @@ def setup_module(module): class Net(nn.Cell): """ Net definition """ + def __init__(self): super(Net, self).__init__() self.add = P.TensorAdd() @@ -50,6 +48,7 @@ def test_vm_backend(): output = add() assert output.asnumpy().shape == (1, 3, 3, 4) + def test_vm_set_context(): """ test_vm_set_context """ context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) @@ -59,6 +58,7 @@ def test_vm_set_context(): assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 context.set_context(mode=context.PYNATIVE_MODE) + @args_type_check(v_str=str, v_int=int, v_tuple=tuple) def check_input(v_str, v_int, v_tuple): """ check_input """ -- GitLab