Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
24b26ee1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
24b26ee1
编写于
4年前
作者:
L
leonwanghui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move args_type_check function to _checkparam.py
上级
5d467874
master
r0.2
r0.3
r0.5
r0.6
r0.7
v0.7.0-beta
v0.6.0-beta
v0.5.0-beta
v0.3.1-alpha
v0.3.0-alpha
v0.2.0-alpha
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
60 addition
and
72 deletion
+60
-72
mindspore/_checkparam.py
mindspore/_checkparam.py
+41
-12
mindspore/_extends/__init__.py
mindspore/_extends/__init__.py
+1
-1
mindspore/_extends/pynative_helper.py
mindspore/_extends/pynative_helper.py
+0
-44
mindspore/context.py
mindspore/context.py
+10
-7
mindspore/parallel/_auto_parallel_context.py
mindspore/parallel/_auto_parallel_context.py
+1
-1
mindspore/parallel/_cost_model_context.py
mindspore/parallel/_cost_model_context.py
+1
-1
mindspore/parallel/algo_parameter_config.py
mindspore/parallel/algo_parameter_config.py
+1
-1
tests/ut/python/pynative_mode/test_backend.py
tests/ut/python/pynative_mode/test_backend.py
+5
-5
未找到文件。
mindspore/_checkparam.py
浏览文件 @
24b26ee1
...
...
@@ -14,8 +14,9 @@
# ============================================================================
"""Check parameters."""
import
re
import
inspect
from
enum
import
Enum
from
functools
import
reduce
from
functools
import
reduce
,
wraps
from
itertools
import
repeat
from
collections.abc
import
Iterable
...
...
@@ -181,7 +182,7 @@ class Validator:
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
prim_name
):
"""Checks whether some type is sub
lc
ass of another type"""
"""Checks whether some type is sub
cl
ass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
...
...
@@ -240,7 +241,6 @@ class Validator:
elem_types
=
map
(
_check_tensor_type
,
args
.
items
())
reduce
(
_check_types_same
,
elem_types
)
@
staticmethod
def
check_scalar_or_tensor_type_same
(
args
,
valid_values
,
prim_name
,
allow_mix
=
False
):
"""
...
...
@@ -261,7 +261,7 @@ class Validator:
def
_check_types_same
(
arg1
,
arg2
):
arg1_name
,
arg1_type
=
arg1
arg2_name
,
arg2_type
=
arg2
exc
p
_flag
=
False
exc
ept
_flag
=
False
if
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
and
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
)):
arg1_type
=
arg1_type
.
element_type
()
arg2_type
=
arg2_type
.
element_type
()
...
...
@@ -271,9 +271,9 @@ class Validator:
arg1_type
=
arg1_type
.
element_type
()
if
isinstance
(
arg1_type
,
type
(
mstype
.
tensor
))
else
arg1_type
arg2_type
=
arg2_type
.
element_type
()
if
isinstance
(
arg2_type
,
type
(
mstype
.
tensor
))
else
arg2_type
else
:
exc
p
_flag
=
True
exc
ept
_flag
=
True
if
exc
p
_flag
or
arg1_type
!=
arg2_type
:
if
exc
ept
_flag
or
arg1_type
!=
arg2_type
:
raise
TypeError
(
f
'For
\'
{
prim_name
}
\'
type of `
{
arg2_name
}
` should be same as `
{
arg1_name
}
`,'
f
' but `
{
arg1_name
}
` is
{
arg1_type
}
and `
{
arg2_name
}
` is
{
arg2_type
}
.'
)
return
arg1
...
...
@@ -283,11 +283,12 @@ class Validator:
def
check_value_type
(
arg_name
,
arg_value
,
valid_types
,
prim_name
):
"""Checks whether a value is instance of some types."""
valid_types
=
valid_types
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
def
raise_error_msg
():
"""func for raising error message when check failed"""
type_names
=
[
t
.
__name__
for
t
in
valid_types
]
num_types
=
len
(
valid_types
)
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
'The'
msg_prefix
=
f
'For
\'
{
prim_name
}
\'
the'
if
prim_name
else
'The'
raise
TypeError
(
f
'
{
msg_prefix
}
type of `
{
arg_name
}
` should be
{
"one of "
if
num_types
>
1
else
""
}
'
f
'
{
type_names
if
num_types
>
1
else
type_names
[
0
]
}
, but got
{
type
(
arg_value
).
__name__
}
.'
)
...
...
@@ -303,6 +304,7 @@ class Validator:
def
check_type_name
(
arg_name
,
arg_type
,
valid_types
,
prim_name
):
"""Checks whether a type in some specified types"""
valid_types
=
valid_types
if
isinstance
(
valid_types
,
Iterable
)
else
(
valid_types
,)
def
get_typename
(
t
):
return
t
.
__name__
if
hasattr
(
t
,
'__name__'
)
else
str
(
t
)
...
...
@@ -368,9 +370,9 @@ class ParamValidator:
@
staticmethod
def
check_isinstance
(
arg_name
,
arg_value
,
classes
):
"""Check arg isintance of classes"""
"""Check arg isin
s
tance of classes"""
if
not
isinstance
(
arg_value
,
classes
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isintance of
{
classes
}
, but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isin
s
tance of
{
classes
}
, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
...
...
@@ -384,7 +386,7 @@ class ParamValidator:
@
staticmethod
def
check_subclass
(
arg_name
,
type_
,
template_type
,
with_type_of
=
True
):
"""Check whether some type is sub
lc
ass of another type"""
"""Check whether some type is sub
cl
ass of another type"""
if
not
isinstance
(
template_type
,
Iterable
):
template_type
=
(
template_type
,)
if
not
any
([
mstype
.
issubclass_
(
type_
,
x
)
for
x
in
template_type
]):
...
...
@@ -402,9 +404,9 @@ class ParamValidator:
@
staticmethod
def
check_bool
(
arg_name
,
arg_value
):
"""Check arg isintance of bool"""
"""Check arg isin
s
tance of bool"""
if
not
isinstance
(
arg_value
,
bool
):
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isintance of bool, but got
{
arg_value
}
.'
)
raise
ValueError
(
f
'The `
{
arg_name
}
` should be isin
s
tance of bool, but got
{
arg_value
}
.'
)
return
arg_value
@
staticmethod
...
...
@@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII):
if
re
.
match
(
reg
,
target
,
flag
)
is
None
:
raise
ValueError
(
"'{}' is illegal, it should be match regular'{}' by flags'{}'"
.
format
(
target
,
reg
,
flag
))
return
True
def
args_type_check
(
*
type_args
,
**
type_kwargs
):
"""Check whether input data type is correct."""
def
type_check
(
func
):
sig
=
inspect
.
signature
(
func
)
bound_types
=
sig
.
bind_partial
(
*
type_args
,
**
type_kwargs
).
arguments
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
bound_types
bound_values
=
sig
.
bind
(
*
args
,
**
kwargs
)
argument_dict
=
bound_values
.
arguments
if
"kwargs"
in
bound_types
:
bound_types
=
bound_types
[
"kwargs"
]
if
"kwargs"
in
argument_dict
:
argument_dict
=
argument_dict
[
"kwargs"
]
for
name
,
value
in
argument_dict
.
items
():
if
name
in
bound_types
:
if
value
is
not
None
and
not
isinstance
(
value
,
bound_types
[
name
]):
raise
TypeError
(
'Argument {} must be {}'
.
format
(
name
,
bound_types
[
name
]))
return
func
(
*
args
,
**
kwargs
)
return
wrapper
return
type_check
This diff is collapsed.
Click to expand it.
mindspore/_extends/__init__.py
浏览文件 @
24b26ee1
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""
Extension
functions.
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
...
...
This diff is collapsed.
Click to expand it.
mindspore/_extends/pynative_helper.py
已删除
100644 → 0
浏览文件 @
5d467874
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Pynative mode help module."""
from
inspect
import
signature
from
functools
import
wraps
def
args_type_check
(
*
type_args
,
**
type_kwargs
):
"""Check whether input data type is correct."""
def
type_check
(
func
):
sig
=
signature
(
func
)
bound_types
=
sig
.
bind_partial
(
*
type_args
,
**
type_kwargs
).
arguments
@
wraps
(
func
)
def
wrapper
(
*
args
,
**
kwargs
):
nonlocal
bound_types
bound_values
=
sig
.
bind
(
*
args
,
**
kwargs
)
argument_dict
=
bound_values
.
arguments
if
"kwargs"
in
bound_types
:
bound_types
=
bound_types
[
"kwargs"
]
if
"kwargs"
in
argument_dict
:
argument_dict
=
argument_dict
[
"kwargs"
]
for
name
,
value
in
argument_dict
.
items
():
if
name
in
bound_types
:
if
value
is
not
None
and
not
isinstance
(
value
,
bound_types
[
name
]):
raise
TypeError
(
'Argument {} must be {}'
.
format
(
name
,
bound_types
[
name
]))
return
func
(
*
args
,
**
kwargs
)
return
wrapper
return
type_check
This diff is collapsed.
Click to expand it.
mindspore/context.py
浏览文件 @
24b26ee1
...
...
@@ -14,7 +14,7 @@
# ============================================================================
"""
The context of mindspore, used to configure the current execution environment,
including execution mode, execution backend and other feature switchs.
including execution mode, execution backend and other feature switch
e
s.
"""
import
os
import
threading
...
...
@@ -22,7 +22,7 @@ from collections import namedtuple
from
types
import
FunctionType
from
mindspore
import
log
as
logger
from
mindspore._c_expression
import
MSContext
from
mindspore._
extends.pynative_helper
import
args_type_check
from
mindspore._
checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
...
...
@@ -38,7 +38,7 @@ def _make_directory(path: str):
"""Make directory."""
real_path
=
None
if
path
is
None
or
not
isinstance
(
path
,
str
)
or
path
.
strip
()
==
""
:
raise
ValueError
(
f
"Input path `
{
path
}
` is inva
il
d type"
)
raise
ValueError
(
f
"Input path `
{
path
}
` is inva
li
d type"
)
# convert the relative paths
path
=
os
.
path
.
realpath
(
path
)
...
...
@@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local):
"""
Thread local Info used for store thread local attributes.
"""
def
__init__
(
self
):
super
(
_ThreadLocalInfo
,
self
).
__init__
()
self
.
_reserve_class_name_in_scope
=
True
...
...
@@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local):
Args:
is_pynative (bool): Whether to adopt the PyNative mode.
"""
def
__init__
(
self
,
is_pynative
):
super
(
_ContextSwitchInfo
,
self
).
__init__
()
self
.
context_stack
=
[]
...
...
@@ -209,7 +211,7 @@ class _Context:
def
device_target
(
self
,
target
):
success
=
self
.
_context_handle
.
set_device_target
(
target
)
if
not
success
:
raise
ValueError
(
"
t
arget device name is invalid!!!"
)
raise
ValueError
(
"
T
arget device name is invalid!!!"
)
@
property
def
device_id
(
self
):
...
...
@@ -335,7 +337,7 @@ class _Context:
@
graph_memory_max_size
.
setter
def
graph_memory_max_size
(
self
,
graph_memory_max_size
):
if
check_input_fo
t
mat
(
graph_memory_max_size
):
if
check_input_fo
r
mat
(
graph_memory_max_size
):
graph_memory_max_size_
=
graph_memory_max_size
[:
-
2
]
+
" * 1024 * 1024 * 1024"
self
.
_context_handle
.
set_graph_memory_max_size
(
graph_memory_max_size_
)
else
:
...
...
@@ -347,7 +349,7 @@ class _Context:
@
variable_memory_max_size
.
setter
def
variable_memory_max_size
(
self
,
variable_memory_max_size
):
if
check_input_fo
t
mat
(
variable_memory_max_size
):
if
check_input_fo
r
mat
(
variable_memory_max_size
):
variable_memory_max_size_
=
variable_memory_max_size
[:
-
2
]
+
" * 1024 * 1024 * 1024"
self
.
_context_handle
.
set_variable_memory_max_size
(
variable_memory_max_size_
)
else
:
...
...
@@ -367,12 +369,13 @@ class _Context:
thread_info
.
debug_runtime
=
enable
def
check_input_fo
t
mat
(
x
):
def
check_input_fo
r
mat
(
x
):
import
re
pattern
=
r
'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB'
result
=
re
.
match
(
pattern
,
x
)
return
result
is
not
None
_k_context
=
None
...
...
This diff is collapsed.
Click to expand it.
mindspore/parallel/_auto_parallel_context.py
浏览文件 @
24b26ee1
...
...
@@ -17,7 +17,7 @@ import threading
import
mindspore.context
as
context
from
mindspore.parallel._dp_allreduce_fusion
import
_set_fusion_strategy_by_idx
,
_set_fusion_strategy_by_size
from
mindspore._c_expression
import
AutoParallelContext
from
mindspore._
extends.pynative_helper
import
args_type_check
from
mindspore._
checkparam
import
args_type_check
class
_AutoParallelContext
:
...
...
This diff is collapsed.
Click to expand it.
mindspore/parallel/_cost_model_context.py
浏览文件 @
24b26ee1
...
...
@@ -15,7 +15,7 @@
"""Context of cost_model in auto_parallel"""
import
threading
from
mindspore._c_expression
import
CostModelContext
from
mindspore._
extends.pynative_helper
import
args_type_check
from
mindspore._
checkparam
import
args_type_check
class
_CostModelContext
:
...
...
This diff is collapsed.
Click to expand it.
mindspore/parallel/algo_parameter_config.py
浏览文件 @
24b26ee1
...
...
@@ -16,7 +16,7 @@
import
threading
from
mindspore._c_expression
import
CostModelContext
from
mindspore._
extends.pynative_helper
import
args_type_check
from
mindspore._
checkparam
import
args_type_check
__all__
=
[
"get_algo_parameters"
,
"reset_algo_parameters"
,
"set_algo_parameters"
]
...
...
This diff is collapsed.
Click to expand it.
tests/ut/python/pynative_mode/test_backend.py
浏览文件 @
24b26ee1
...
...
@@ -14,16 +14,13 @@
# ============================================================================
""" test_backend """
import
os
import
numpy
as
np
import
pytest
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
,
ms_function
from
mindspore.common.initializer
import
initializer
from
mindspore.common.parameter
import
Parameter
from
mindspore._extends.pynative_helper
import
args_type_check
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore._checkparam
import
args_type_check
def
setup_module
(
module
):
...
...
@@ -32,6 +29,7 @@ def setup_module(module):
class
Net
(
nn
.
Cell
):
""" Net definition """
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
add
=
P
.
TensorAdd
()
...
...
@@ -50,6 +48,7 @@ def test_vm_backend():
output
=
add
()
assert
output
.
asnumpy
().
shape
==
(
1
,
3
,
3
,
4
)
def
test_vm_set_context
():
""" test_vm_set_context """
context
.
set_context
(
save_graphs
=
True
,
save_graphs_path
=
"mindspore_ir_path"
,
mode
=
context
.
GRAPH_MODE
)
...
...
@@ -59,6 +58,7 @@ def test_vm_set_context():
assert
context
.
get_context
(
"save_graphs_path"
).
find
(
"mindspore_ir_path"
)
>
0
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
@
args_type_check
(
v_str
=
str
,
v_int
=
int
,
v_tuple
=
tuple
)
def
check_input
(
v_str
,
v_int
,
v_tuple
):
""" check_input """
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部