diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 5c1fba328ea60dcc9499b0e87f1a30b998dda354..529eb8060cf08221b0f96ed797bfe3c3fd75bf21 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -327,16 +327,19 @@ class _Executor: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params))) - def _params_init_data(self, obj, params): + def _params_init_data(self, obj, params, auto_parallel_mode=False): + """Init parameters' data.""" if params is not None: for key, param in params.items(): - if key not in obj.parameter_layout_dict: - logger.info("Layout dict does not contain the key %s.", key) + if not auto_parallel_mode: param.init_data() + elif key not in obj.parameter_layout_dict: + logger.info("Layout dict does not contain the key %s.", key) + param.init_data(set_sliced=True) else: layout = obj.parameter_layout_dict[key] - param.init_data(layout) - obj.init_parameters_data() + param.init_data(layout, set_sliced=True) + obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode) def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): """ @@ -383,11 +386,11 @@ class _Executor: if not do_convert: return phase, True - if auto_parallel_mode and "train" in phase: + if auto_parallel_mode: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) - self._params_init_data(obj, params) + self._params_init_data(obj, params, auto_parallel_mode) if not enable_debug_runtime or enable_ge: - if auto_parallel_mode and "train" in phase: + if auto_parallel_mode: obj.load_parameter_slice(params) # set parallel inputs in sink mode diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 7b504965071b1097cea98e4fbc84055ea33fc645..788c2d03073a17a89aa2421fcc92a6a7f5ea9c81 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -99,6 +99,10 @@ class Parameter: """Get slice status of the parameter.""" return self._sliced + @sliced.setter + def sliced(self, sliced_): + self._sliced = sliced_ + @property def is_init(self): """Get init status of the parameter.""" @@ -211,15 +215,18 @@ class Parameter: self.default_input = data - def init_data(self, layout=None): + def init_data(self, layout=None, set_sliced=False): """ Init data of the parameter. Args: - layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape]. - dev_mat (list[int]): device matrix. - tensor_map (list[int]): tensor map. - slice_shape (list[int]): shape of slice. + layout (list[list[int]]): Parameter slice layout [dev_mat, tensor_map, slice_shape]. + + - dev_mat (list[int]): Device matrix. + - tensor_map (list[int]): Tensor map. + - slice_shape (list[int]): Shape of slice. + set_sliced (bool): True if should set parameter sliced after init the data of initializer. + Default: False. """ if not isinstance(self.default_input, MetaTensor): return @@ -235,7 +242,8 @@ class Parameter: self.default_input = self.init_mode.to_tensor() self.init_mode = None - self._sliced = True + if set_sliced: + self.sliced = True class ParameterTuple(tuple): diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index e0563a05fab608d5c276840d622f4dcec8bb8d59..e6d2dc738336e50853eec9a866649264585a3414 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -264,11 +264,12 @@ class Cell: logger.info("layout dict does not contain the key %s", key) continue if self.parameters_dict()[key].sliced: - logger.info("Param %s is from initializer, already sliced.", key) + logger.info("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) self.parameters_dict()[key].set_parameter_data(new_tensor) + self.parameters_dict()[key].sliced = True elif isinstance(params, OrderedDict): for key in params: tensor = params[key].data @@ -276,11 +277,12 @@ class Cell: logger.info("layout dict does not contain the key %s", key) continue if params[key].sliced: - logger.info("Param %s is from initializer, already sliced.", key) + logger.info("Param %s is already sliced.", key) continue layout = self.parameter_layout_dict[key] new_tensor = _load_tensor_by_layout(tensor, layout) params[key].set_parameter_data(new_tensor) + params[key].sliced = True else: raise TypeError('Parameters need OrderedDict type, but got {}'. format(type(params))) @@ -435,14 +437,17 @@ class Cell: """ raise NotImplementedError - def init_parameters_data(self, recurse=True): + def init_parameters_data(self, recurse=True, auto_parallel_mode=False): + """Init parameters' data.""" for param in self.get_parameters(expand=recurse): - if param.name not in self.parameter_layout_dict: - logger.info("Layout dict does not contain the key %s.", param.name) + if not auto_parallel_mode: param.init_data() + elif param.name not in self.parameter_layout_dict: + logger.info("Layout dict does not contain the key %s.", param.name) + param.init_data(set_sliced=True) else: layout = self.parameter_layout_dict[param.name] - param.init_data(layout) + param.init_data(layout, set_sliced=True) def parameters_dict(self, recurse=True): """