提交 77dcdd89 编写于 作者: W Wei Luning

support parameter updata with implicit type conversion

上级 1166a091
......@@ -119,6 +119,9 @@ int_type = (int8, int16, int32, int64,)
uint_type = (uint8, uint16, uint32, uint64)
float_type = (float16, float32, float64,)
implicit_conversion_seq = {t: idx for idx, t in enumerate((
bool_, int8, uint8, int16, int32, int64, float16, float32, float64))}
_simple_types = {
list: list_,
tuple: tuple_,
......
......@@ -313,8 +313,9 @@ class Parameter(MetaTensor):
Parameter, the parameter after set data.
"""
def raise_type_error(incoming):
raise TypeError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
f", and incoming is {incoming}. Use .set_dtype(xxx) to change the dtype.")
raise TypeError(f"Incoming Parameter dtype can not be converted to current dtype implicitly. "
f"Current dtype is {self.dtype}, and incoming is {incoming}. "
f"Use .set_dtype(xxx) to change the dtype.")
if not isinstance(data, (MetaTensor, Initializer, int, float)):
raise TypeError(f"Parameter data must be [`Initializer`, `int`, `float`] or a kind of `MetaTensor` "
......@@ -338,7 +339,10 @@ class Parameter(MetaTensor):
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {self.shape}, and incoming is {data.shape}.")
if self.dtype != data.dtype:
raise_type_error(data.dtype)
if mstype.implicit_conversion_seq[self.dtype] < mstype.implicit_conversion_seq[data.dtype]:
raise_type_error(data.dtype)
else:
data = Tensor(data, self.dtype)
if isinstance(data, Initializer):
# The parameter has been initializered, directly update by the data
if is_current_tensor:
......
......@@ -74,7 +74,7 @@ class Tensor(Tensor_):
self._virtual_flag = False
def __repr__(self):
return str(Tensor_.__str__(self))
return Tensor_.__repr__(self)
def __add__(self, other):
out = tensor_operator_registry.get('__add__')(self, other)
......
......@@ -157,6 +157,7 @@ def test_parameter_compute():
def test_scalar_parameter_update():
# float
fp = Parameter(0.5, 'fp')
fp.default_input = 0.8
assert np.array_equal(fp.default_input.asnumpy(), np.array(0.8, np.float32))
......@@ -167,6 +168,26 @@ def test_scalar_parameter_update():
assert np.array_equal(int_.default_input.asnumpy(), np.array(2, np.int32))
with pytest.raises(TypeError):
int_.default_input = 1.2
# Tensor
fp32 = Tensor(0.5, mstype.float32)
int32 = Tensor(2, mstype.int32)
fp16 = Tensor(0.6, mstype.float16)
int16 = Tensor(3, mstype.int16)
bool_ = Tensor(np.array(True, dtype=np.bool_))
# updata_by_tensor
fp32_p = Parameter(fp32, 'fp32')
fp32_p.default_input = 0.8
fp32_p.default_input = 1
fp32_p.default_input = int32
fp32_p.default_input = fp32
fp32_p.default_input = int16
fp32_p.default_input = fp16
fp32_p.default_input = bool_
# updata_by_tensor
fp16_p = Parameter(fp16, 'fp16')
with pytest.raises(TypeError):
fp16_p.default_input = fp32
def test_parameter_lazy_init():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册