提交 22c6baee 编写于 作者: W WeibiaoYu

Support to config whether to save integeated checkpoint, in auto model parallel scene

上级 60f7a95b
......@@ -374,9 +374,6 @@ class _Executor:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params)
if _get_parallel_mode() in ["hybrid_parallel"]:
obj.parameter_layout_dict = self._build_parameter_layout(obj)
# the following GE init process is not needed when use vm or ms backend
if enable_ge:
# decide whether to sink based on whether the inputs is virtual or not
......@@ -449,38 +446,6 @@ class _Executor:
return self._exec_pip(obj, *args, phase=phase_real)
raise KeyError('{} graph is not exist.'.format(phase_real))
def _build_parameter_layout(self, obj):
"""
Build parameter layout, for layerwise_parallel parameter.
Args:
obj (Function or Cell): The function or cell instance need to be compiled.
Returns:
Dictionary, parameter layout info.
"""
parameter_layout_dict = {}
layerwise_parallel_parameters = []
for key in obj.parameters_dict():
if obj.parameters_dict()[key].layerwise_parallel is True:
layerwise_parallel_parameters.append(key)
if not layerwise_parallel_parameters:
return parameter_layout_dict
from ..communication.management import get_group_size
group_size = [get_group_size()]
for key in layerwise_parallel_parameters:
tensor_map = [0]
shape = obj.parameters_dict()[key].data.shape()
for x in range(len(shape)): # dim 0 set 0, others set -1
if x:
tensor_map.append(-1)
layout = [group_size, tensor_map]
parameter_layout_dict[key] = layout
return parameter_layout_dict
def del_net_res(self, net_id):
self._executor.del_net_res(net_id)
......
......@@ -24,7 +24,7 @@ import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
from mindspore.train._utils import _make_directory
from mindspore import log as logger
from mindspore._checkparam import check_int_non_negative
from mindspore._checkparam import check_int_non_negative, check_bool
from mindspore.common.tensor import Tensor
from .summary.summary_record import _cache_summary_tensor_data
......@@ -150,6 +150,8 @@ class CheckpointConfig:
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.
Raises:
ValueError: If the input_param is None or 0.
......@@ -163,7 +165,8 @@ class CheckpointConfig:
save_checkpoint_steps=1,
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0):
keep_checkpoint_per_n_minutes=0,
integrated_save=True):
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
......@@ -191,6 +194,8 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1
self._integrated_save = check_bool(integrated_save)
@property
def save_checkpoint_steps(self):
"""Get the value of _save_checkpoint_steps."""
......@@ -211,6 +216,11 @@ class CheckpointConfig:
"""Get the value of _keep_checkpoint_per_n_minutes."""
return self._keep_checkpoint_per_n_minutes
@property
def integrated_save(self):
"""Get the value of _integrated_save."""
return self._integrated_save
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
......@@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
_set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file)
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
......
......@@ -279,13 +279,14 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpoint_file_name):
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
"""
param_dict = {}
......@@ -300,9 +301,9 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name):
else:
param_data = Tensor(value.data)
# in model parallel scenario, some parameters were spliteds to all the devices,
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if key in train_network.parameter_layout_dict:
if integrated_save and key in train_network.parameter_layout_dict:
param_data = _get_merged_param_data(train_network, key, param_data)
each_param["data"] = param_data
......
......@@ -308,10 +308,10 @@ def test_RunContext():
def test_Checkpoint_Config():
"""Test CheckpointConfig all None or 0."""
with pytest.raises(ValueError):
CheckpointConfig(0, 0, 0, 0)
CheckpointConfig(0, 0, 0, 0, True)
with pytest.raises(ValueError):
CheckpointConfig(0, None, 0, 0)
CheckpointConfig(0, None, 0, 0, True)
def test_step_end_save_graph():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册