提交 24b26ee1 编写于 作者: L leonwanghui

Move args_type_check function to _checkparam.py

上级 5d467874
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# ============================================================================ # ============================================================================
"""Check parameters.""" """Check parameters."""
import re import re
import inspect
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce, wraps
from itertools import repeat from itertools import repeat
from collections.abc import Iterable from collections.abc import Iterable
...@@ -181,7 +182,7 @@ class Validator: ...@@ -181,7 +182,7 @@ class Validator:
@staticmethod @staticmethod
def check_subclass(arg_name, type_, template_type, prim_name): def check_subclass(arg_name, type_, template_type, prim_name):
"""Checks whether some type is sublcass of another type""" """Checks whether some type is subclass of another type"""
if not isinstance(template_type, Iterable): if not isinstance(template_type, Iterable):
template_type = (template_type,) template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]): if not any([mstype.issubclass_(type_, x) for x in template_type]):
...@@ -240,7 +241,6 @@ class Validator: ...@@ -240,7 +241,6 @@ class Validator:
elem_types = map(_check_tensor_type, args.items()) elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types) reduce(_check_types_same, elem_types)
@staticmethod @staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
""" """
...@@ -261,7 +261,7 @@ class Validator: ...@@ -261,7 +261,7 @@ class Validator:
def _check_types_same(arg1, arg2): def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1 arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2 arg2_name, arg2_type = arg2
excp_flag = False except_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type() arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type() arg2_type = arg2_type.element_type()
...@@ -271,9 +271,9 @@ class Validator: ...@@ -271,9 +271,9 @@ class Validator:
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
else: else:
excp_flag = True except_flag = True
if excp_flag or arg1_type != arg2_type: if except_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1 return arg1
...@@ -283,11 +283,12 @@ class Validator: ...@@ -283,11 +283,12 @@ class Validator:
def check_value_type(arg_name, arg_value, valid_types, prim_name): def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Checks whether a value is instance of some types.""" """Checks whether a value is instance of some types."""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def raise_error_msg(): def raise_error_msg():
"""func for raising error message when check failed""" """func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types] type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types) num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
...@@ -303,6 +304,7 @@ class Validator: ...@@ -303,6 +304,7 @@ class Validator:
def check_type_name(arg_name, arg_type, valid_types, prim_name): def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types""" """Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t): def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t) return t.__name__ if hasattr(t, '__name__') else str(t)
...@@ -368,9 +370,9 @@ class ParamValidator: ...@@ -368,9 +370,9 @@ class ParamValidator:
@staticmethod @staticmethod
def check_isinstance(arg_name, arg_value, classes): def check_isinstance(arg_name, arg_value, classes):
"""Check arg isintance of classes""" """Check arg isinstance of classes"""
if not isinstance(arg_value, classes): if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value return arg_value
@staticmethod @staticmethod
...@@ -384,7 +386,7 @@ class ParamValidator: ...@@ -384,7 +386,7 @@ class ParamValidator:
@staticmethod @staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True): def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is sublcass of another type""" """Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable): if not isinstance(template_type, Iterable):
template_type = (template_type,) template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]): if not any([mstype.issubclass_(type_, x) for x in template_type]):
...@@ -402,9 +404,9 @@ class ParamValidator: ...@@ -402,9 +404,9 @@ class ParamValidator:
@staticmethod @staticmethod
def check_bool(arg_name, arg_value): def check_bool(arg_name, arg_value):
"""Check arg isintance of bool""" """Check arg isinstance of bool"""
if not isinstance(arg_value, bool): if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value return arg_value
@staticmethod @staticmethod
...@@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): ...@@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if re.match(reg, target, flag) is None: if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True return True
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""
def type_check(func):
sig = inspect.signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Extension functions. Extension functions.
Python functions that will be called in the c++ parts of MindSpore. Python functions that will be called in the c++ parts of MindSpore.
""" """
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pynative mode help module."""
from inspect import signature
from functools import wraps
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""
def type_check(func):
sig = signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
""" """
The context of mindspore, used to configure the current execution environment, The context of mindspore, used to configure the current execution environment,
including execution mode, execution backend and other feature switchs. including execution mode, execution backend and other feature switches.
""" """
import os import os
import threading import threading
...@@ -22,7 +22,7 @@ from collections import namedtuple ...@@ -22,7 +22,7 @@ from collections import namedtuple
from types import FunctionType from types import FunctionType
from mindspore import log as logger from mindspore import log as logger
from mindspore._c_expression import MSContext from mindspore._c_expression import MSContext
from mindspore._extends.pynative_helper import args_type_check from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context _reset_auto_parallel_context
...@@ -38,7 +38,7 @@ def _make_directory(path: str): ...@@ -38,7 +38,7 @@ def _make_directory(path: str):
"""Make directory.""" """Make directory."""
real_path = None real_path = None
if path is None or not isinstance(path, str) or path.strip() == "": if path is None or not isinstance(path, str) or path.strip() == "":
raise ValueError(f"Input path `{path}` is invaild type") raise ValueError(f"Input path `{path}` is invalid type")
# convert the relative paths # convert the relative paths
path = os.path.realpath(path) path = os.path.realpath(path)
...@@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local): ...@@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local):
""" """
Thread local Info used for store thread local attributes. Thread local Info used for store thread local attributes.
""" """
def __init__(self): def __init__(self):
super(_ThreadLocalInfo, self).__init__() super(_ThreadLocalInfo, self).__init__()
self._reserve_class_name_in_scope = True self._reserve_class_name_in_scope = True
...@@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local): ...@@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local):
Args: Args:
is_pynative (bool): Whether to adopt the PyNative mode. is_pynative (bool): Whether to adopt the PyNative mode.
""" """
def __init__(self, is_pynative): def __init__(self, is_pynative):
super(_ContextSwitchInfo, self).__init__() super(_ContextSwitchInfo, self).__init__()
self.context_stack = [] self.context_stack = []
...@@ -209,7 +211,7 @@ class _Context: ...@@ -209,7 +211,7 @@ class _Context:
def device_target(self, target): def device_target(self, target):
success = self._context_handle.set_device_target(target) success = self._context_handle.set_device_target(target)
if not success: if not success:
raise ValueError("target device name is invalid!!!") raise ValueError("Target device name is invalid!!!")
@property @property
def device_id(self): def device_id(self):
...@@ -335,7 +337,7 @@ class _Context: ...@@ -335,7 +337,7 @@ class _Context:
@graph_memory_max_size.setter @graph_memory_max_size.setter
def graph_memory_max_size(self, graph_memory_max_size): def graph_memory_max_size(self, graph_memory_max_size):
if check_input_fotmat(graph_memory_max_size): if check_input_format(graph_memory_max_size):
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
else: else:
...@@ -347,7 +349,7 @@ class _Context: ...@@ -347,7 +349,7 @@ class _Context:
@variable_memory_max_size.setter @variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size): def variable_memory_max_size(self, variable_memory_max_size):
if check_input_fotmat(variable_memory_max_size): if check_input_format(variable_memory_max_size):
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
else: else:
...@@ -367,12 +369,13 @@ class _Context: ...@@ -367,12 +369,13 @@ class _Context:
thread_info.debug_runtime = enable thread_info.debug_runtime = enable
def check_input_fotmat(x): def check_input_format(x):
import re import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x) result = re.match(pattern, x)
return result is not None return result is not None
_k_context = None _k_context = None
......
...@@ -17,7 +17,7 @@ import threading ...@@ -17,7 +17,7 @@ import threading
import mindspore.context as context import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore._c_expression import AutoParallelContext from mindspore._c_expression import AutoParallelContext
from mindspore._extends.pynative_helper import args_type_check from mindspore._checkparam import args_type_check
class _AutoParallelContext: class _AutoParallelContext:
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"""Context of cost_model in auto_parallel""" """Context of cost_model in auto_parallel"""
import threading import threading
from mindspore._c_expression import CostModelContext from mindspore._c_expression import CostModelContext
from mindspore._extends.pynative_helper import args_type_check from mindspore._checkparam import args_type_check
class _CostModelContext: class _CostModelContext:
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import threading import threading
from mindspore._c_expression import CostModelContext from mindspore._c_expression import CostModelContext
from mindspore._extends.pynative_helper import args_type_check from mindspore._checkparam import args_type_check
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
......
...@@ -14,16 +14,13 @@ ...@@ -14,16 +14,13 @@
# ============================================================================ # ============================================================================
""" test_backend """ """ test_backend """
import os import os
import numpy as np
import pytest import pytest
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context, ms_function
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore._extends.pynative_helper import args_type_check from mindspore._checkparam import args_type_check
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
def setup_module(module): def setup_module(module):
...@@ -32,6 +29,7 @@ def setup_module(module): ...@@ -32,6 +29,7 @@ def setup_module(module):
class Net(nn.Cell): class Net(nn.Cell):
""" Net definition """ """ Net definition """
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.add = P.TensorAdd() self.add = P.TensorAdd()
...@@ -50,6 +48,7 @@ def test_vm_backend(): ...@@ -50,6 +48,7 @@ def test_vm_backend():
output = add() output = add()
assert output.asnumpy().shape == (1, 3, 3, 4) assert output.asnumpy().shape == (1, 3, 3, 4)
def test_vm_set_context(): def test_vm_set_context():
""" test_vm_set_context """ """ test_vm_set_context """
context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE)
...@@ -59,6 +58,7 @@ def test_vm_set_context(): ...@@ -59,6 +58,7 @@ def test_vm_set_context():
assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0
context.set_context(mode=context.PYNATIVE_MODE) context.set_context(mode=context.PYNATIVE_MODE)
@args_type_check(v_str=str, v_int=int, v_tuple=tuple) @args_type_check(v_str=str, v_int=int, v_tuple=tuple)
def check_input(v_str, v_int, v_tuple): def check_input(v_str, v_int, v_tuple):
""" check_input """ """ check_input """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册