提交 24b26ee1 编写于 作者: L leonwanghui

Move args_type_check function to _checkparam.py

上级 5d467874
......@@ -14,8 +14,9 @@
# ============================================================================
"""Check parameters."""
import re
import inspect
from enum import Enum
from functools import reduce
from functools import reduce, wraps
from itertools import repeat
from collections.abc import Iterable
......@@ -181,7 +182,7 @@ class Validator:
@staticmethod
def check_subclass(arg_name, type_, template_type, prim_name):
"""Checks whether some type is sublcass of another type"""
"""Checks whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
......@@ -240,7 +241,6 @@ class Validator:
elem_types = map(_check_tensor_type, args.items())
reduce(_check_types_same, elem_types)
@staticmethod
def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False):
"""
......@@ -261,7 +261,7 @@ class Validator:
def _check_types_same(arg1, arg2):
arg1_name, arg1_type = arg1
arg2_name, arg2_type = arg2
excp_flag = False
except_flag = False
if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)):
arg1_type = arg1_type.element_type()
arg2_type = arg2_type.element_type()
......@@ -271,9 +271,9 @@ class Validator:
arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type
arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type
else:
excp_flag = True
except_flag = True
if excp_flag or arg1_type != arg2_type:
if except_flag or arg1_type != arg2_type:
raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,'
f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.')
return arg1
......@@ -283,11 +283,12 @@ class Validator:
def check_value_type(arg_name, arg_value, valid_types, prim_name):
"""Checks whether a value is instance of some types."""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def raise_error_msg():
"""func for raising error message when check failed"""
type_names = [t.__name__ for t in valid_types]
num_types = len(valid_types)
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The'
raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}'
f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.')
......@@ -303,6 +304,7 @@ class Validator:
def check_type_name(arg_name, arg_type, valid_types, prim_name):
"""Checks whether a type in some specified types"""
valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,)
def get_typename(t):
return t.__name__ if hasattr(t, '__name__') else str(t)
......@@ -368,9 +370,9 @@ class ParamValidator:
@staticmethod
def check_isinstance(arg_name, arg_value, classes):
"""Check arg isintance of classes"""
"""Check arg isinstance of classes"""
if not isinstance(arg_value, classes):
raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.')
return arg_value
@staticmethod
......@@ -384,7 +386,7 @@ class ParamValidator:
@staticmethod
def check_subclass(arg_name, type_, template_type, with_type_of=True):
"""Check whether some type is sublcass of another type"""
"""Check whether some type is subclass of another type"""
if not isinstance(template_type, Iterable):
template_type = (template_type,)
if not any([mstype.issubclass_(type_, x) for x in template_type]):
......@@ -402,9 +404,9 @@ class ParamValidator:
@staticmethod
def check_bool(arg_name, arg_value):
"""Check arg isintance of bool"""
"""Check arg isinstance of bool"""
if not isinstance(arg_value, bool):
raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.')
raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.')
return arg_value
@staticmethod
......@@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if re.match(reg, target, flag) is None:
raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag))
return True
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""
def type_check(func):
sig = inspect.signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check
......@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
Extension functions.
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pynative mode help module."""
from inspect import signature
from functools import wraps
def args_type_check(*type_args, **type_kwargs):
"""Check whether input data type is correct."""
def type_check(func):
sig = signature(func)
bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments
@wraps(func)
def wrapper(*args, **kwargs):
nonlocal bound_types
bound_values = sig.bind(*args, **kwargs)
argument_dict = bound_values.arguments
if "kwargs" in bound_types:
bound_types = bound_types["kwargs"]
if "kwargs" in argument_dict:
argument_dict = argument_dict["kwargs"]
for name, value in argument_dict.items():
if name in bound_types:
if value is not None and not isinstance(value, bound_types[name]):
raise TypeError('Argument {} must be {}'.format(name, bound_types[name]))
return func(*args, **kwargs)
return wrapper
return type_check
......@@ -14,7 +14,7 @@
# ============================================================================
"""
The context of mindspore, used to configure the current execution environment,
including execution mode, execution backend and other feature switchs.
including execution mode, execution backend and other feature switches.
"""
import os
import threading
......@@ -22,7 +22,7 @@ from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext
from mindspore._extends.pynative_helper import args_type_check
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
......@@ -38,7 +38,7 @@ def _make_directory(path: str):
"""Make directory."""
real_path = None
if path is None or not isinstance(path, str) or path.strip() == "":
raise ValueError(f"Input path `{path}` is invaild type")
raise ValueError(f"Input path `{path}` is invalid type")
# convert the relative paths
path = os.path.realpath(path)
......@@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local):
"""
Thread local Info used for store thread local attributes.
"""
def __init__(self):
super(_ThreadLocalInfo, self).__init__()
self._reserve_class_name_in_scope = True
......@@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local):
Args:
is_pynative (bool): Whether to adopt the PyNative mode.
"""
def __init__(self, is_pynative):
super(_ContextSwitchInfo, self).__init__()
self.context_stack = []
......@@ -209,7 +211,7 @@ class _Context:
def device_target(self, target):
success = self._context_handle.set_device_target(target)
if not success:
raise ValueError("target device name is invalid!!!")
raise ValueError("Target device name is invalid!!!")
@property
def device_id(self):
......@@ -335,7 +337,7 @@ class _Context:
@graph_memory_max_size.setter
def graph_memory_max_size(self, graph_memory_max_size):
if check_input_fotmat(graph_memory_max_size):
if check_input_format(graph_memory_max_size):
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
else:
......@@ -347,7 +349,7 @@ class _Context:
@variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size):
if check_input_fotmat(variable_memory_max_size):
if check_input_format(variable_memory_max_size):
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
else:
......@@ -367,12 +369,13 @@ class _Context:
thread_info.debug_runtime = enable
def check_input_fotmat(x):
def check_input_format(x):
import re
pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result = re.match(pattern, x)
return result is not None
_k_context = None
......
......@@ -17,7 +17,7 @@ import threading
import mindspore.context as context
from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size
from mindspore._c_expression import AutoParallelContext
from mindspore._extends.pynative_helper import args_type_check
from mindspore._checkparam import args_type_check
class _AutoParallelContext:
......
......@@ -15,7 +15,7 @@
"""Context of cost_model in auto_parallel"""
import threading
from mindspore._c_expression import CostModelContext
from mindspore._extends.pynative_helper import args_type_check
from mindspore._checkparam import args_type_check
class _CostModelContext:
......
......@@ -16,7 +16,7 @@
import threading
from mindspore._c_expression import CostModelContext
from mindspore._extends.pynative_helper import args_type_check
from mindspore._checkparam import args_type_check
__all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"]
......
......@@ -14,16 +14,13 @@
# ============================================================================
""" test_backend """
import os
import numpy as np
import pytest
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore import context
from mindspore import context, ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore._extends.pynative_helper import args_type_check
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
from mindspore._checkparam import args_type_check
def setup_module(module):
......@@ -32,6 +29,7 @@ def setup_module(module):
class Net(nn.Cell):
""" Net definition """
def __init__(self):
super(Net, self).__init__()
self.add = P.TensorAdd()
......@@ -50,6 +48,7 @@ def test_vm_backend():
output = add()
assert output.asnumpy().shape == (1, 3, 3, 4)
def test_vm_set_context():
""" test_vm_set_context """
context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE)
......@@ -59,6 +58,7 @@ def test_vm_set_context():
assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0
context.set_context(mode=context.PYNATIVE_MODE)
@args_type_check(v_str=str, v_int=int, v_tuple=tuple)
def check_input(v_str, v_int, v_tuple):
""" check_input """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册