提交 d5634eb4 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1510 seperate auto_parallel and stand_alone when init initializer data

Merge pull request !1510 from yihuaijie/dev
...@@ -327,16 +327,19 @@ class _Executor: ...@@ -327,16 +327,19 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'. raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params))) format(type(params)))
def _params_init_data(self, obj, params): def _params_init_data(self, obj, params, auto_parallel_mode=False):
"""Init parameters' data."""
if params is not None: if params is not None:
for key, param in params.items(): for key, param in params.items():
if key not in obj.parameter_layout_dict: if not auto_parallel_mode:
logger.info("Layout dict does not contain the key %s.", key)
param.init_data() param.init_data()
elif key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
param.init_data(set_sliced=True)
else: else:
layout = obj.parameter_layout_dict[key] layout = obj.parameter_layout_dict[key]
param.init_data(layout) param.init_data(layout, set_sliced=True)
obj.init_parameters_data() obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
""" """
...@@ -383,11 +386,11 @@ class _Executor: ...@@ -383,11 +386,11 @@ class _Executor:
if not do_convert: if not do_convert:
return phase, True return phase, True
if auto_parallel_mode and "train" in phase: if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params) self._params_init_data(obj, params, auto_parallel_mode)
if not enable_debug_runtime or enable_ge: if not enable_debug_runtime or enable_ge:
if auto_parallel_mode and "train" in phase: if auto_parallel_mode:
obj.load_parameter_slice(params) obj.load_parameter_slice(params)
# set parallel inputs in sink mode # set parallel inputs in sink mode
......
...@@ -99,6 +99,10 @@ class Parameter: ...@@ -99,6 +99,10 @@ class Parameter:
"""Get slice status of the parameter.""" """Get slice status of the parameter."""
return self._sliced return self._sliced
@sliced.setter
def sliced(self, sliced_):
self._sliced = sliced_
@property @property
def is_init(self): def is_init(self):
"""Get init status of the parameter.""" """Get init status of the parameter."""
...@@ -211,15 +215,18 @@ class Parameter: ...@@ -211,15 +215,18 @@ class Parameter:
self.default_input = data self.default_input = data
def init_data(self, layout=None): def init_data(self, layout=None, set_sliced=False):
""" """
Init data of the parameter. Init data of the parameter.
Args: Args:
layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape]. layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape].
dev_mat (list[int]): device matrix.
tensor_map (list[int]): tensor map. - dev_mat (list[int]): Device matrix.
slice_shape (list[int]): shape of slice. - tensor_map (list[int]): Tensor map.
- slice_shape (list[int]): Shape of slice.
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
""" """
if not isinstance(self.default_input, MetaTensor): if not isinstance(self.default_input, MetaTensor):
return return
...@@ -235,7 +242,8 @@ class Parameter: ...@@ -235,7 +242,8 @@ class Parameter:
self.default_input = self.init_mode.to_tensor() self.default_input = self.init_mode.to_tensor()
self.init_mode = None self.init_mode = None
self._sliced = True if set_sliced:
self.sliced = True
class ParameterTuple(tuple): class ParameterTuple(tuple):
......
...@@ -264,11 +264,12 @@ class Cell: ...@@ -264,11 +264,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key) logger.info("layout dict does not contain the key %s", key)
continue continue
if self.parameters_dict()[key].sliced: if self.parameters_dict()[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key) logger.info("Param %s is already sliced.", key)
continue continue
layout = self.parameter_layout_dict[key] layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout) new_tensor = _load_tensor_by_layout(tensor, layout)
self.parameters_dict()[key].set_parameter_data(new_tensor) self.parameters_dict()[key].set_parameter_data(new_tensor)
self.parameters_dict()[key].sliced = True
elif isinstance(params, OrderedDict): elif isinstance(params, OrderedDict):
for key in params: for key in params:
tensor = params[key].data tensor = params[key].data
...@@ -276,11 +277,12 @@ class Cell: ...@@ -276,11 +277,12 @@ class Cell:
logger.info("layout dict does not contain the key %s", key) logger.info("layout dict does not contain the key %s", key)
continue continue
if params[key].sliced: if params[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key) logger.info("Param %s is already sliced.", key)
continue continue
layout = self.parameter_layout_dict[key] layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout) new_tensor = _load_tensor_by_layout(tensor, layout)
params[key].set_parameter_data(new_tensor) params[key].set_parameter_data(new_tensor)
params[key].sliced = True
else: else:
raise TypeError('Parameters need OrderedDict type, but got {}'. raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params))) format(type(params)))
...@@ -435,14 +437,17 @@ class Cell: ...@@ -435,14 +437,17 @@ class Cell:
""" """
raise NotImplementedError raise NotImplementedError
def init_parameters_data(self, recurse=True): def init_parameters_data(self, recurse=True, auto_parallel_mode=False):
"""Init parameters' data."""
for param in self.get_parameters(expand=recurse): for param in self.get_parameters(expand=recurse):
if param.name not in self.parameter_layout_dict: if not auto_parallel_mode:
logger.info("Layout dict does not contain the key %s.", param.name)
param.init_data() param.init_data()
elif param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
param.init_data(set_sliced=True)
else: else:
layout = self.parameter_layout_dict[param.name] layout = self.parameter_layout_dict[param.name]
param.init_data(layout) param.init_data(layout, set_sliced=True)
def parameters_dict(self, recurse=True): def parameters_dict(self, recurse=True):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册