提交 d5634eb4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1510 seperate auto_parallel and stand_alone when init initializer data

Merge pull request !1510 from yihuaijie/dev
......@@ -327,16 +327,19 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
def _params_init_data(self, obj, params):
def _params_init_data(self, obj, params, auto_parallel_mode=False):
"""Init parameters' data."""
if params is not None:
for key, param in params.items():
if key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
if not auto_parallel_mode:
param.init_data()
elif key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
param.init_data(set_sliced=True)
else:
layout = obj.parameter_layout_dict[key]
param.init_data(layout)
obj.init_parameters_data()
param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
"""
......@@ -383,11 +386,11 @@ class _Executor:
if not do_convert:
return phase, True
if auto_parallel_mode and "train" in phase:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params)
self._params_init_data(obj, params, auto_parallel_mode)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode and "train" in phase:
if auto_parallel_mode:
obj.load_parameter_slice(params)
# set parallel inputs in sink mode
......
......@@ -99,6 +99,10 @@ class Parameter:
"""Get slice status of the parameter."""
return self._sliced
@sliced.setter
def sliced(self, sliced_):
self._sliced = sliced_
@property
def is_init(self):
"""Get init status of the parameter."""
......@@ -211,15 +215,18 @@ class Parameter:
self.default_input = data
def init_data(self, layout=None):
def init_data(self, layout=None, set_sliced=False):
"""
Init data of the parameter.
Args:
layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape].
dev_mat (list[int]): device matrix.
tensor_map (list[int]): tensor map.
slice_shape (list[int]): shape of slice.
layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape].
- dev_mat (list[int]): Device matrix.
- tensor_map (list[int]): Tensor map.
- slice_shape (list[int]): Shape of slice.
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if not isinstance(self.default_input, MetaTensor):
return
......@@ -235,7 +242,8 @@ class Parameter:
self.default_input = self.init_mode.to_tensor()
self.init_mode = None
self._sliced = True
if set_sliced:
self.sliced = True
class ParameterTuple(tuple):
......
......@@ -264,11 +264,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key)
continue
if self.parameters_dict()[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
logger.info("Param %s is already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
self.parameters_dict()[key].set_parameter_data(new_tensor)
self.parameters_dict()[key].sliced = True
elif isinstance(params, OrderedDict):
for key in params:
tensor = params[key].data
......@@ -276,11 +277,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key)
continue
if params[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
logger.info("Param %s is already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
params[key].set_parameter_data(new_tensor)
params[key].sliced = True
else:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
......@@ -435,14 +437,17 @@ class Cell:
"""
raise NotImplementedError
def init_parameters_data(self, recurse=True):
def init_parameters_data(self, recurse=True, auto_parallel_mode=False):
"""Init parameters' data."""
for param in self.get_parameters(expand=recurse):
if param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
if not auto_parallel_mode:
param.init_data()
elif param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
param.init_data(set_sliced=True)
else:
layout = self.parameter_layout_dict[param.name]
param.init_data(layout)
param.init_data(layout, set_sliced=True)
def parameters_dict(self, recurse=True):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册