提交 d4d6457e 编写于 作者: W Wei Luning

init parameter data by defaultOnly keep no data as MetaTensor in auto parallel mode

上级 fe7141e9
......@@ -414,6 +414,10 @@ class _Executor:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode:
obj.load_parameter_slice(None)
self._updata_param_node_default_input(phase, replace)
# set parallel inputs in sink mode
......
......@@ -15,25 +15,22 @@
"""Parameter for cell."""
from copy import copy
from mindspore import context
from .._c_expression import ParamValue
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from ..parallel._tensor import _get_slice_index
from ..parallel._auto_parallel_context import auto_parallel_context
__all__ = ['Parameter', 'ParameterTuple']
PARAMETER_NAME_DEFAULT = "Parameter"
PARAMETER_NAME_PREFIX_MAX_LEN = 1024
def _check_type(x):
"""Check input data type"""
if not isinstance(x, Parameter):
raise ValueError("Should be `Parameter` collection.")
return True
def _is_in_parallel_mode():
"""Get parallel mode."""
return auto_parallel_context().get_parallel_mode() in ["semi_auto_parallel", "auto_parallel"]
class Parameter(MetaTensor):
......@@ -42,10 +39,10 @@ class Parameter(MetaTensor):
After initialized `Parameter` is a subtype of `Tensor`.
In graph mode, if init `Parameter` by a `Initializer`, the type of Parameter will be a `MetaTensor`
not a `Tensor`. `MetaTensor` only save the shape type info of a tensor with no memory usage. The shape
can be change while compile for auto-parallel. Call `init_data` will return a Tensor Parameter with
initialized data.
In auto_parallel mode of "semi_auto_parallel" and "auto_parallel", if init `Parameter` by
a `Initializer`, the type of Parameter will be a `MetaTensor` not a `Tensor`. `MetaTensor`
only save the shape type info of a tensor with no memory usage. The shape can be change while
compile for auto-parallel. Call `init_data` will return a Tensor Parameter with initialized data.
Note:
Each parameter of Cell is represented by Parameter class.
......@@ -67,7 +64,7 @@ class Parameter(MetaTensor):
input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of metatensor.
obj.init_mode = None
if isinstance(default_input, Initializer):
if not isinstance(obj, Tensor):
obj.init_mode = default_input
return obj
......@@ -112,11 +109,10 @@ class Parameter(MetaTensor):
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Initializer):
if context.get_context("mode") == context.PYNATIVE_MODE:
# always init data while in pynative mode.
data = data.to_tensor()
else:
if _is_in_parallel_mode():
# do not init data while in auto parallel.
return (MetaTensor, data.dtype, data.shape)
data = data.to_tensor()
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
return (Tensor, data.asnumpy(),)
......@@ -127,9 +123,9 @@ class Parameter(MetaTensor):
return (Tensor, data)
def __str__(self):
value_str = MetaTensor.__repr__(self)
value_str = MetaTensor.__str__(self)
if isinstance(self, Tensor):
value_str = Tensor.__repr__(self)
value_str = Tensor.__str__(self)
return f'Parameter (name={self._value.name}, value={value_str})'
def __repr__(self):
......@@ -235,8 +231,6 @@ class Parameter(MetaTensor):
shape = self.shape
dtype = self.dtype
x.default_input = initializer(init, shape=shape, dtype=dtype)
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
return x
@property
......@@ -381,8 +375,12 @@ class ParameterTuple(tuple):
"""
def __new__(cls, iterable):
"""Create instance object of ParameterTuple."""
g = (x for x in iterable if _check_type(x))
return tuple.__new__(ParameterTuple, g)
data = tuple(iterable)
for x in data:
if not isinstance(x, Parameter):
raise TypeError(f"ParameterTuple input should be `Parameter` collection."
f"But got a {type(iterable)}, {iterable}")
return tuple.__new__(ParameterTuple, tuple(data))
def clone(self, prefix, init='same'):
"""
......
......@@ -120,7 +120,6 @@ def test_yolov3():
net = yolov3_resnet18(ConfigYOLOV3ResNet18())
net = YoloWithLossCell(net, ConfigYOLOV3ResNet18())
init_net_param(net)
total_epoch_size = 60
lr = Tensor(get_lr(learning_rate=lr_init, start_step=0,
......
......@@ -207,7 +207,6 @@ def test_bert_percision():
netwithgrads.set_train(True)
model = Model(netwithgrads)
callback = ModelCallback()
netwithloss.init_parameters_data()
params = netwithloss.trainable_params()
for param in params:
value = param.default_input
......@@ -279,7 +278,6 @@ def test_bert_performance():
netwithgrads.set_train(True)
model = Model(netwithgrads)
callback = ModelCallback()
netwithloss.init_parameters_data()
params = netwithloss.trainable_params()
for param in params:
value = param.default_input
......
......@@ -17,7 +17,7 @@
import numpy as np
import pytest
from mindspore import Tensor, Parameter, ParameterTuple
from mindspore import context, Tensor, Parameter, ParameterTuple
from mindspore._checkparam import _check_str_by_regular
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
......@@ -43,11 +43,11 @@ def test_parameter_tuple_illegal():
ParameterTuple(ptuple)
with pytest.raises(TypeError):
ParameterTuple(p1)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
ParameterTuple(plist2)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
ParameterTuple(ptuple_str)
with pytest.raises(ValueError):
with pytest.raises(TypeError):
ParameterTuple(pstr)
with pytest.raises(TypeError):
ParameterTuple(pnum)
......@@ -136,6 +136,9 @@ def test_check_str_by_regular():
_check_str_by_regular(str6)
def test_parameter_lazy_init():
# support lazy init in SEMI_AUTO_PARALLEL mode
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8)
# Call init_data() without set default_input.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
assert not isinstance(para.default_input, Tensor)
......@@ -167,3 +170,4 @@ def test_parameter_lazy_init():
# expect no effect.
para.init_data()
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
context.reset_auto_parallel_context()
......@@ -71,7 +71,7 @@ def test_group_lr():
assert opt.dynamic_lr is False
assert opt.is_group_params_ordered is True
for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()):
if param in conv_params:
if 'conv' in param.name:
assert np.all(lr.data.asnumpy() == Tensor(conv_lr, mstype.float32).asnumpy())
else:
assert np.all(lr.data.asnumpy() == Tensor(default_lr, mstype.float32).asnumpy())
......@@ -103,7 +103,7 @@ def test_group_dynamic_1():
assert opt.dynamic_lr is True
assert opt.is_group_params_ordered is True
for lr, param, order_param in zip(opt.learning_rate, opt.parameters, net.trainable_params()):
if param in conv_params:
if 'conv' in param.name:
assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array([conv_lr] * 3).astype(np.float32)).asnumpy())
else:
......@@ -135,7 +135,7 @@ def test_group_dynamic_2():
assert opt.is_group is True
assert opt.dynamic_lr is True
for lr, param in zip(opt.learning_rate, opt.parameters):
if param in conv_params:
if 'conv' in param.name:
assert np.all(lr.learning_rate.data.asnumpy() == \
Tensor(np.array(list(conv_lr)).astype(np.float32)).asnumpy())
else:
......@@ -203,7 +203,7 @@ def test_weight_decay():
assert opt.is_group_params_ordered is True
for weight_decay, decay_flags, param, order_param in zip(
opt.weight_decay, opt.decay_flags, opt.parameters, net.trainable_params()):
if param in conv_params:
if 'conv' in param.name:
assert weight_decay == conv_weight_decay
assert decay_flags is True
else:
......@@ -303,11 +303,11 @@ def test_order_params_1():
assert opt.is_group_params_ordered is True
for weight_decay, decay_flags, lr, param, order_param in zip(
opt.weight_decay, opt.decay_flags, opt.learning_rate, opt.parameters, bias_params+conv_params):
if param in conv_params:
if 'conv' in param.name:
assert np.all(lr.data.asnumpy() == Tensor(0.1, mstype.float32).asnumpy())
assert weight_decay == 0.01
assert decay_flags is True
elif param in bias_params:
elif 'bias' in param.name:
assert np.all(lr.data.asnumpy() == Tensor(0.01, mstype.float32).asnumpy())
assert weight_decay == 0.0
assert decay_flags is False
......@@ -342,11 +342,11 @@ def test_order_params_2():
all_lr = opt.get_lr_parameter(fc1_params+conv_params)
for weight_decay, decay_flags, lr, param, order_param in zip(
opt.weight_decay, opt.decay_flags, all_lr, opt.parameters, fc1_params+conv_params):
if param in conv_params:
if 'conv' in param.name:
assert np.all(lr.data.asnumpy() == Tensor(np.array([default_lr] * 3), mstype.float32).asnumpy())
assert weight_decay == conv_weight_decay
assert decay_flags is True
elif param in fc1_params:
elif 'fc1' in param.name:
assert np.all(lr.data.asnumpy() == Tensor(fc1_lr, mstype.float32).asnumpy())
assert weight_decay == default_wd
assert decay_flags is False
......
......@@ -17,7 +17,7 @@ import numpy as np
import mindspore as ms
import mindspore.ops.operations as P
from mindspore import Tensor
from mindspore import Tensor, context
from mindspore.common.api import ms_function
from mindspore.common.dtype import get_py_obj_dtype
from mindspore.ops import composite as C
......@@ -25,6 +25,9 @@ from mindspore.ops import functional as F
from mindspore.ops.composite import grad_all_with_sens
from ...ut_filter import non_graph_engine
# pylint: disable=unused-argument
def setup_module(module):
context.set_context(mode=context.PYNATIVE_MODE)
def mul(x, y):
return x * y
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册