Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
77dcdd89
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
77dcdd89
编写于
8月 21, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support parameter updata with implicit type conversion
上级
1166a091
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
32 addition
and
4 deletion
+32
-4
mindspore/common/dtype.py
mindspore/common/dtype.py
+3
-0
mindspore/common/parameter.py
mindspore/common/parameter.py
+7
-3
mindspore/common/tensor.py
mindspore/common/tensor.py
+1
-1
tests/ut/python/nn/test_parameter.py
tests/ut/python/nn/test_parameter.py
+21
-0
未找到文件。
mindspore/common/dtype.py
浏览文件 @
77dcdd89
...
...
@@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,)
uint_type
=
(
uint8
,
uint16
,
uint32
,
uint64
)
float_type
=
(
float16
,
float32
,
float64
,)
implicit_conversion_seq
=
{
t
:
idx
for
idx
,
t
in
enumerate
((
bool_
,
int8
,
uint8
,
int16
,
int32
,
int64
,
float16
,
float32
,
float64
))}
_simple_types
=
{
list
:
list_
,
tuple
:
tuple_
,
...
...
mindspore/common/parameter.py
浏览文件 @
77dcdd89
...
...
@@ -313,8 +313,9 @@ class Parameter(MetaTensor):
Parameter, the parameter after set data.
"""
def
raise_type_error
(
incoming
):
raise
TypeError
(
f
"Can not change the Parameter dtype. Current dtype is
{
self
.
set_dtype
}
"
f
", and incoming is
{
incoming
}
. Use .set_dtype(xxx) to change the dtype."
)
raise
TypeError
(
f
"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f
"Current dtype is
{
self
.
dtype
}
, and incoming is
{
incoming
}
. "
f
"Use .set_dtype(xxx) to change the dtype."
)
if
not
isinstance
(
data
,
(
MetaTensor
,
Initializer
,
int
,
float
)):
raise
TypeError
(
f
"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
...
...
@@ -338,7 +339,10 @@ class Parameter(MetaTensor):
raise
ValueError
(
f
"Can not change the shape of Parameter which has been initialized."
f
" Current shape is
{
self
.
shape
}
, and incoming is
{
data
.
shape
}
."
)
if
self
.
dtype
!=
data
.
dtype
:
raise_type_error
(
data
.
dtype
)
if
mstype
.
implicit_conversion_seq
[
self
.
dtype
]
<
mstype
.
implicit_conversion_seq
[
data
.
dtype
]:
raise_type_error
(
data
.
dtype
)
else
:
data
=
Tensor
(
data
,
self
.
dtype
)
if
isinstance
(
data
,
Initializer
):
# The parameter has been initializered, directly update by the data
if
is_current_tensor
:
...
...
mindspore/common/tensor.py
浏览文件 @
77dcdd89
...
...
@@ -74,7 +74,7 @@ class Tensor(Tensor_):
self
.
_virtual_flag
=
False
def
__repr__
(
self
):
return
str
(
Tensor_
.
__str__
(
self
)
)
return
Tensor_
.
__repr__
(
self
)
def
__add__
(
self
,
other
):
out
=
tensor_operator_registry
.
get
(
'__add__'
)(
self
,
other
)
...
...
tests/ut/python/nn/test_parameter.py
浏览文件 @
77dcdd89
...
...
@@ -157,6 +157,7 @@ def test_parameter_compute():
def
test_scalar_parameter_update
():
# float
fp
=
Parameter
(
0.5
,
'fp'
)
fp
.
default_input
=
0.8
assert
np
.
array_equal
(
fp
.
default_input
.
asnumpy
(),
np
.
array
(
0.8
,
np
.
float32
))
...
...
@@ -167,6 +168,26 @@ def test_scalar_parameter_update():
assert
np
.
array_equal
(
int_
.
default_input
.
asnumpy
(),
np
.
array
(
2
,
np
.
int32
))
with
pytest
.
raises
(
TypeError
):
int_
.
default_input
=
1.2
# Tensor
fp32
=
Tensor
(
0.5
,
mstype
.
float32
)
int32
=
Tensor
(
2
,
mstype
.
int32
)
fp16
=
Tensor
(
0.6
,
mstype
.
float16
)
int16
=
Tensor
(
3
,
mstype
.
int16
)
bool_
=
Tensor
(
np
.
array
(
True
,
dtype
=
np
.
bool_
))
# updata_by_tensor
fp32_p
=
Parameter
(
fp32
,
'fp32'
)
fp32_p
.
default_input
=
0.8
fp32_p
.
default_input
=
1
fp32_p
.
default_input
=
int32
fp32_p
.
default_input
=
fp32
fp32_p
.
default_input
=
int16
fp32_p
.
default_input
=
fp16
fp32_p
.
default_input
=
bool_
# updata_by_tensor
fp16_p
=
Parameter
(
fp16
,
'fp16'
)
with
pytest
.
raises
(
TypeError
):
fp16_p
.
default_input
=
fp32
def
test_parameter_lazy_init
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录