diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 8e23e9184dd9c08107d379c2d17d33e6e935751a..371ebbb445428d04a44507091942bc914418a4a0 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -374,9 +374,6 @@ class _Executor: obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.load_parameter_slice(params) - if _get_parallel_mode() in ["hybrid_parallel"]: - obj.parameter_layout_dict = self._build_parameter_layout(obj) - # the following GE init process is not needed when use vm or ms backend if enable_ge: # decide whether to sink based on whether the inputs is virtual or not @@ -449,38 +446,6 @@ class _Executor: return self._exec_pip(obj, *args, phase=phase_real) raise KeyError('{} graph is not exist.'.format(phase_real)) - def _build_parameter_layout(self, obj): - """ - Build parameter layout, for layerwise_parallel parameter. - - Args: - obj (Function or Cell): The function or cell instance need to be compiled. - - Returns: - Dictionary, parameter layout info. - """ - parameter_layout_dict = {} - layerwise_parallel_parameters = [] - for key in obj.parameters_dict(): - if obj.parameters_dict()[key].layerwise_parallel is True: - layerwise_parallel_parameters.append(key) - - if not layerwise_parallel_parameters: - return parameter_layout_dict - - from ..communication.management import get_group_size - group_size = [get_group_size()] - for key in layerwise_parallel_parameters: - tensor_map = [0] - shape = obj.parameters_dict()[key].data.shape() - for x in range(len(shape)): # dim 0 set 0, others set -1 - if x: - tensor_map.append(-1) - layout = [group_size, tensor_map] - parameter_layout_dict[key] = layout - - return parameter_layout_dict - def del_net_res(self, net_id): self._executor.del_net_res(net_id) diff --git a/mindspore/train/callback.py b/mindspore/train/callback.py index 62f847089da33ce4ec717b123d6b8e53d78d989e..dcf630342c5d0bfbaea9a82e20b05e86e4163b55 100644 --- a/mindspore/train/callback.py +++ b/mindspore/train/callback.py @@ -24,7 +24,7 @@ import mindspore.context as context from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph from mindspore.train._utils import _make_directory from mindspore import log as logger -from mindspore._checkparam import check_int_non_negative +from mindspore._checkparam import check_int_non_negative, check_bool from mindspore.common.tensor import Tensor from .summary.summary_record import _cache_summary_tensor_data @@ -150,6 +150,8 @@ class CheckpointConfig: keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5. keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0. Can't be used with keep_checkpoint_max at the same time. + integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True. + Integrated save function is only supported in automatic parall scene, not supported in manual parallel. Raises: ValueError: If the input_param is None or 0. @@ -163,7 +165,8 @@ class CheckpointConfig: save_checkpoint_steps=1, save_checkpoint_seconds=0, keep_checkpoint_max=5, - keep_checkpoint_per_n_minutes=0): + keep_checkpoint_per_n_minutes=0, + integrated_save=True): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: @@ -191,6 +194,8 @@ class CheckpointConfig: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 + self._integrated_save = check_bool(integrated_save) + @property def save_checkpoint_steps(self): """Get the value of _save_checkpoint_steps.""" @@ -211,6 +216,11 @@ class CheckpointConfig: """Get the value of _keep_checkpoint_per_n_minutes.""" return self._keep_checkpoint_per_n_minutes + @property + def integrated_save(self): + """Get the value of _integrated_save.""" + return self._integrated_save + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, @@ -619,7 +629,7 @@ class ModelCheckpoint(Callback): _set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file) + _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 0478bbc0718d6a40e776f8d11d512e836648f81d..b334c3e9d8c2ddd7880f29a1ed417e796df5aee3 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -279,13 +279,14 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpoint_file_name): +def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): """ Saves checkpoint for 'ms' backend. Args: train_network (Network): The train network for training. ckpoint_file_name (str): The name of checkpoint file. + integrated_save (bool): Whether to intergrated save in automatic model parallel scene. """ param_dict = {} @@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name): else: param_data = Tensor(value.data) - # in model parallel scenario, some parameters were spliteds to all the devices, + # in automatic model parallel scenario, some parameters were spliteds to all the devices, # which should be combined before saving - if key in train_network.parameter_layout_dict: + if integrated_save and key in train_network.parameter_layout_dict: param_data = _get_merged_param_data(train_network, key, param_data) each_param["data"] = param_data diff --git a/tests/ut/python/utils/test_callback.py b/tests/ut/python/utils/test_callback.py index 60e4c6527a00b9ce74ed6497b7cba3a85ccd9de3..43cf8273302adb156fe57980e401fe52e9fb5391 100644 --- a/tests/ut/python/utils/test_callback.py +++ b/tests/ut/python/utils/test_callback.py @@ -308,10 +308,10 @@ def test_RunContext(): def test_Checkpoint_Config(): """Test CheckpointConfig all None or 0.""" with pytest.raises(ValueError): - CheckpointConfig(0, 0, 0, 0) + CheckpointConfig(0, 0, 0, 0, True) with pytest.raises(ValueError): - CheckpointConfig(0, None, 0, 0) + CheckpointConfig(0, None, 0, 0, True) def test_step_end_save_graph():