diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 4e2c739119c577fc382f4584bedf1b37b6255090..e8557a2931013470baefdc526248178595e74d47 100755 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -20,8 +20,8 @@ import warnings import logging import paddle.fluid.core as core -from paddle.fluid.io import is_parameter, is_belong_to_optimizer from paddle.framework.io import _to_LodTensor +from paddle.fluid.io import is_parameter, is_belong_to_optimizer def is_valid_list_index(list, index): @@ -362,49 +362,97 @@ def make_data_unshard(dist_main_prog, dist_startup_prog, dist_context=None): dist_context.set_tensor_dist_attr_for_program(var, tensor_dist_attr) -def _check_addition_info(addition_info): - """ - Validity check of additional information - """ +def _update_addition_info(addition_info): + """ Update default addition_info with inputs """ + add_info = {"epoch": 0, "batch": 0, "batch_size": 0} if not addition_info: - return addition_info + return add_info elif not isinstance(addition_info, dict): - raise TypeError( - "The type of addition_info should be 'dict', but got {}".format( - str(type(addition_info)))) + raise TypeError("The type of 'addition_info' should be 'dict', " + "but got '{}'.".format(str(type(addition_info)))) else: - return addition_info + for item, value in addition_info.items(): + if item not in ["epoch", "batch", "batch_size"]: + raise ValueError( + "The key of 'addition_info' should be one of the " + "['epoch', 'batch', 'batch_size'], but got '{}'." + .format(str(item))) + if not isinstance(value, int): + raise ValueError( + "The value of 'addition_info' should be 'int', " + "but got '{}'.".format(str(type(value)))) + add_info[item] = value + return add_info def _check_valid_path(file_path): - """ - Validity check of input file path - """ + """ Validity check of input file path """ if not file_path: return file_path - elif isinstance(file_path, str): - if not os.path.exists(file_path): - raise ValueError("The file_path '{}' does not exist.".format( - file_path)) - else: - return [file_path] elif isinstance(file_path, list): - if not all(isinstance(file, str) for file in file_path): - raise ValueError("The type of each file_path should be str.") - if not all(os.path.exists(file) for file in file_path): - raise ValueError("The file_path's file does not exist.") + for file in file_path: + if not isinstance(file, str): + raise TypeError("The type of file path should be 'str', " + "but got '{}'.".format(str(type(file)))) + if not os.path.exists(file): + raise ValueError("The file path '{}' does not exist." + .format(file)) return file_path else: - raise TypeError( - "The type of file_path should be 'str' or 'list', but got '{}'.". - format(str(type(file_path)))) + raise TypeError("The type of file path should be 'list', " + "but got '{}'.".format(str(type(file_path)))) + + +def _check_param_dict(param_dict): + if not param_dict: + raise ValueError("'param_dict' cannot be None.") + elif not isinstance(param_dict, dict): + raise TypeError("The type of 'param_dict' should be 'dict', " + "but got '{}'.".format(str(type(param_dict)))) + else: + for name, value in param_dict.items(): + if not isinstance(name, str): + raise TypeError( + "The type of key of 'param_dict' should be 'str', " + "but got '{}'.".format(str(type(name)))) + if not isinstance(value, paddle.fluid.LoDTensor): + raise TypeError( + "The type of value of 'param_dict' should be 'LoDTensor', " + "but got '{}'.".format(str(type(value)))) + return param_dict + + +def _check_dist_attr(dist_attr): + if not dist_attr: + return dist_attr + elif not isinstance(dist_attr, dict): + raise TypeError("The type of 'dist_attr' should be 'dict', " + "but got '{}'.".format(str(type(dist_attr)))) + else: + for name, value in dist_attr.items(): + if not isinstance(name, str): + raise TypeError( + "The type of param name of 'dist_attr' should be 'str', " + "but got '{}'.".format(str(type(name)))) + if not isinstance(value, dict): + raise TypeError( + "The type of distributed attribute should be 'dict', " + "but got '{}'".format(str(type(value)))) + attr = ['process_shape', 'process_group', 'dims_mapping'] + if list(value.keys()) != attr: + raise ValueError( + "The key of distributed attribute should be " + "'['process_shape', 'process_group', 'dims_mapping']', " + "but got {}.".format(str(value.keys()))) + return dist_attr def save_distributed_checkpoint(program, checkpoint_path, - is_integrated=False, + dist_attr_path, addition_info=None, - dist_attr_path=None): + is_integrated=False, + dist_context=None): """ Save model parameter state, optimzer state, distributed attribute and additional information of each rank. @@ -412,9 +460,11 @@ def save_distributed_checkpoint(program, Args: program(Program): The program to be saved. checkpoint_path(str): The path of the checkpoint file to be saved. + dist_attr_path(str): The path of distributed attribute file to be saved. + addition_info(dict, optional): Additional information, key should be selected in ['epoch', 'batch', 'batch_size']. + Default values are 0, when 'addition_info' is None. Default: None. is_integrated(bool, optional): Whether to integrate param before save. Default: False. - addition_info(dict, optional): Additional information. Default: None. - dist_attr_path(str, optional): The path of distributed attribute file to be saved. Default: None + dist_context(DistributedContext ,optional): collect related distributed information for program Returns: None @@ -422,78 +472,508 @@ def save_distributed_checkpoint(program, Examples: .. code-block:: python - ckpt_path = os.path.join(args.output_dir, "step_%d" % step) - os.makedirs(ckpt_path, exist_ok=True) - save_distributed_checkpoint(program, ckpt_path) + path = os.path.join("./output", "step_%d" % step) + os.makedirs(path, exist_ok=True) + add_info = {'batch': step, "batch_size": global_batch_size} + save_distributed_checkpoint(program, path, path, add_info) """ + from .dist_context import get_default_distributed_context + + assert isinstance(program, paddle.fluid.framework.Program) + assert isinstance(is_integrated, bool) + if dist_context is None: + dist_context = get_default_distributed_context() + addition_info = _update_addition_info(addition_info) + if not is_integrated: - rank = paddle.distributed.get_rank() - ckpt_file_name = os.path.join(checkpoint_path, - "model_state_rank{}.pdmodel".format(rank)) - - state_dict = { - "model": program.state_dict(), - "ranks": paddle.distributed.get_world_size() - } - if _check_addition_info(addition_info): - state_dict["addition_info"] = addition_info - - paddle.save(state_dict, ckpt_file_name) - logging.info("Already save model to {}".format(checkpoint_path)) - - if dist_attr_path: - raise NotImplementedError( - "Save distributed attribute has not been implemented.") + _save_distributed_state_dict(program, addition_info, checkpoint_path) + _save_distributed_attribute(program, dist_attr_path, dist_context) else: # TODO: integrate param before save raise NotImplementedError( "Integrating parameter has not been implemented.") -def load_distributed_checkpoint(checkpoint_path, - program=None, - dist_attr_path=None): +def load_distributed_checkpoint(checkpoint_path, dist_attr_path): """ - Load parameter, optimizer, distributed attribute and addition_info of model. + Load parameter, optimizer, distributed attribute and addition_info. Args: - checkpoint_path(str|list[str]): checkpoint_path's type can be 'str' or 'list', \ - which must be in order of rank id when type is 'list'. - program(Program, optional): The program to be updated with checkpoint_path. Default: None. - dist_attr_path(str|list[str], optional): dist_attr_path's type can be 'str' or 'list', \ - which must be in order of rank id when type is 'list'. Default: None. + checkpoint_path(list[str]): model parameter file path, must be in order of rank id. + dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. Returns: - None or addition_info which user saved in last train. + param_dict(dict): parameters' value of all ranks. + dist_attr(dict): parameters' distributed attribute. + addition_info(dict): additional information user saved in last training. + + Notes: + The return, 'addition_info', is belonging to the first file of checkpoint_path by default. + + Examples: + .. code-block:: python + + ckpt_path = ['./model_state_rank0.pdmodel', + './model_state_rank1.pdmodel'] + dist_attr_path = ['./dist_attr_rank0.pdattr', + './dist_attr_rank1.pdattr'] + param_dict, dist_attr, add_info = load_distributed_checkpoint(ckpt_path, dist_attr_path) + """ + assert _check_valid_path(checkpoint_path), \ + "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), \ + "'dist_attr_path' cannot be None." + + state_dict_info = _load_distributed_state_dict(checkpoint_path) + dist_attr = _load_distributed_attribute(dist_attr_path) + param_dict = state_dict_info["model"] + addition_info = state_dict_info["addition_info"] + return param_dict, dist_attr, addition_info + + +def load_checkpoint_into_program(checkpoint_path, + dist_attr_path, + program, + dist_context=None): + """ + Load parameter, optimizer, distributed attribute and addition_info into model. + + Args: + checkpoint_path(list[str]): model parameter file path, must be in order of rank id. + dist_attr_path(list[str]): distributed attribute file path, must be in order of rank id. + program(Program): the program to be updated with checkpoint_path. + dist_context(DistributedContext ,optional): collect related distributed information for program + + Returns: + addition_info(dict): user saved in last train. + + Notes: + The return, 'addition_info', is belonging to the first file of checkpoint_path by default. Examples: .. code-block:: python exe.run(startup_program) - ckpt_path = ['./output/step_10/model_state_rank0.pdmodel', - './output/step_10/model_state_rank1.pdmodel'] - load_distributed_checkpoint(ckpt_path, main_program) + ckpt_path = ['./model_state_rank0.pdmodel', + './model_state_rank1.pdmodel'] + dist_attr_path = ['./dist_attr_rank0.pdattr', + './dist_attr_rank1.pdattr'] + load_checkpoint_into_program(ckpt_path, dist_attr_path, main_program) """ - checkpoint_path = _check_valid_path(checkpoint_path) - dist_attr_path = _check_valid_path(dist_attr_path) + from .dist_context import get_default_distributed_context - if checkpoint_path and dist_attr_path: - raise NotImplementedError( - "Merge&Slice parameter with dist_attr has not been implemented.") - - elif checkpoint_path: - assert len(checkpoint_path) == paddle.distributed.get_world_size(), \ - "The number of checkpoint_path must equal to the number of ranks" - rank = paddle.distributed.get_rank() - state_dict_info = paddle.load(checkpoint_path[rank]) - state_dict = state_dict_info["model"] + assert isinstance(program, paddle.fluid.framework.Program) + assert _check_valid_path(checkpoint_path), \ + "'checkpoint_path' cannot be None." + assert _check_valid_path(dist_attr_path), \ + "'dist_attr_path' cannot be None." + if dist_context is None: + dist_context = get_default_distributed_context() + all_state_dict_info = _load_distributed_state_dict(checkpoint_path) + all_pre_dist_attr = _load_distributed_attribute(dist_attr_path) + all_cur_dist_attr = get_dist_attr(program, dist_context) + all_param_dict = all_state_dict_info["model"] + addition_info = all_state_dict_info["addition_info"] + sliced_param_dict = merge_and_slice_parameter( + all_param_dict, all_pre_dist_attr, all_cur_dist_attr) + load_parameter_into_program(sliced_param_dict, program) + + return addition_info + + +def load_parameter_into_program(param_dict, program): + """ + Load parameters into program. + + Args: + param_dict(dict): parameters' name and value. + program(Program): the program to be updated + """ + _check_param_dict(param_dict) + assert program and isinstance(program, paddle.fluid.framework.Program) + program.set_state_dict(param_dict) + + +def _save_distributed_attribute(program, dist_attr_path, dist_context): + """ Save distributed attribute of all parameters """ + # TODO: just save a complete distributed attribute file + rank_id = paddle.distributed.get_rank() + dist_attr_name = os.path.join(dist_attr_path, + "dist_attr_rank{}.pdattr".format(rank_id)) + dist_attr_dict = { + "model": get_dist_attr(program, dist_context), + "world_size": paddle.distributed.get_world_size() + } + paddle.save(dist_attr_dict, dist_attr_name) + logging.info("Already saved distributed attribute to '{}'.".format( + dist_attr_path)) + + +def _load_distributed_attribute(dist_attr_path): + """ Load parameters' distributed attribute from dist_attr_path """ + total_dist_attr = {} + for dist_attr_file in dist_attr_path: + dist_attr = paddle.load(dist_attr_file) + pre_world_size = dist_attr["world_size"] + assert pre_world_size == len(dist_attr_path), \ + "The number of 'dist_attr_path' must be equal to the last training world size." + for name, attr in dist_attr["model"].items(): + if name not in total_dist_attr: + total_dist_attr[name] = attr + + return total_dist_attr + + +def _save_distributed_state_dict(program, addition_info, checkpoint_path): + """ Save parameters' state_dict """ + rank = paddle.distributed.get_rank() + ckpt_file_name = os.path.join(checkpoint_path, + "model_state_rank{}.pdmodel".format(rank)) + state_dict = { + "model": program.state_dict(), + "world_size": paddle.distributed.get_world_size(), + "addition_info": addition_info + } + paddle.save(state_dict, ckpt_file_name) + logging.info("Already saved model to '{}'.".format(checkpoint_path)) + + +def _load_distributed_state_dict(checkpoint_path): + """ Load parameters' state_dict from checkpoint_path """ + all_state_dict = {} + for idx, ckpt_file in enumerate(checkpoint_path): + state_dict_info = paddle.load(ckpt_file) + pre_world_size = state_dict_info["world_size"] + assert pre_world_size == len(checkpoint_path), \ + "The number of 'checkpoint_path' must be equal to the last training world size." + if idx == 0: + addition_info = state_dict_info["addition_info"] + for name, value in state_dict_info["model"].items(): + if name in all_state_dict: + all_state_dict[name].append(np.array(value)) + else: + all_state_dict[name] = [np.array(value)] + + all_state_dict_info = { + "model": all_state_dict, + "addition_info": addition_info + } + return all_state_dict_info + + +def get_dist_attr(program, dist_context=None): + """ + Get distributed attribute of current rank. + + Args: + program(Program): main program for training + """ + from .dist_context import get_default_distributed_context + + assert isinstance(program, paddle.fluid.framework.Program) + if dist_context is None: + dist_context = get_default_distributed_context() + dist_attr = {} + for var in program.list_vars(): + if is_parameter(var) or is_belong_to_optimizer(var): + tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var) + process_mesh = tensor_dist_attr.process_mesh + dims_mapping = tensor_dist_attr.dims_mapping + dist_attr[var.name] = { + "process_shape": process_mesh.topology, + "process_group": process_mesh.processes, + "dims_mapping": dims_mapping + } + return dist_attr + + +def merge_and_slice_parameter(dist_param_dict, pre_dist_attr, cur_dist_attr): + """ + Merge parameters with previous dist_attr and slice parameters with current dist_attr + + Arags: + dist_param_dict(dict): parameters' value of all ranks. + pre_dist_attr(dict): parameters' dist_attr of last training process. + cur_dist_attr(dict): parameters' dist_attr of current training process. + + Returns: + dist_param_dict(dict): parameters' value of current rank. + """ + assert _check_dist_attr(pre_dist_attr), "'pre_dist_attr' cannot be None." + assert _check_dist_attr(cur_dist_attr), "'pre_dist_attr' cannot be None." + assert isinstance(dist_param_dict, dict), \ + "The type of 'dist_param_dict' should be 'dict', but got {}.".format( + str(type(dist_param_dict))) + for name, value in dist_param_dict.items(): + if not isinstance(name, str): + raise TypeError("The key of 'dist_param_dict' is parameter's name, " + "and its type should be 'str', but got {}." + .format(str(type(name)))) + if not isinstance(value, list) or not all( + isinstance(v, np.ndarray) for v in value): + raise TypeError( + "The value of 'dist_param_dict' is parameter's value of all ranks, " + "and its type should be 'list(numpy.ndarray)'.") + + param_not_in_pre = [] + param_not_in_cur = [] + logging.info("Start to merge and slice parameters.") + for var_name in cur_dist_attr.keys(): + if var_name not in pre_dist_attr: + param_not_in_pre.append(var_name) + continue + + pre_attr = pre_dist_attr[var_name] + cur_attr = cur_dist_attr[var_name] + if pre_attr == cur_attr: + # skip merge and slice + rank_id = paddle.distributed.get_rank() + index = cur_attr["process_group"].index(rank_id) + param = dist_param_dict[var_name][index] + dist_param_dict[var_name] = _to_LodTensor(param) + continue + + pre_param = dist_param_dict[var_name] + pre_dims_mapping = pre_attr["dims_mapping"] + cur_dims_mapping = cur_attr["dims_mapping"] + if len(set(pre_dims_mapping)) > 1 or -1 not in pre_dims_mapping: + complete_param = _merge_parameter_with_dist_attr(pre_param, + pre_attr) + dist_param_dict[var_name] = complete_param + else: + complete_param = pre_param[0] + dist_param_dict[var_name] = _to_LodTensor(complete_param) + + if len(set(cur_dims_mapping)) > 1 or -1 not in cur_dims_mapping: + sliced_param = _slice_parameter_with_dist_attr(complete_param, + cur_attr) + dist_param_dict[var_name] = sliced_param + + for var_name in pre_dist_attr: + if var_name not in cur_dist_attr: + param_not_in_cur.append(var_name) + dist_param_dict.pop(var_name) + + if param_not_in_pre: + warnings.warn("Parameters '{}' are not found in last training process." + .format(str(param_not_in_pre))) + if param_not_in_cur: + warnings.warn( + "Parameters '{}' are not found in current training process." + .format(str(param_not_in_cur))) + + return dist_param_dict + + +def _merge_parameter_with_dist_attr(param_list, dist_attr): + """ Merge parameter with distributed attribute """ + from .reshard import _compute_complete_shape, _compute_partition_index + + dims_mapping = dist_attr["dims_mapping"] + process_shape = dist_attr["process_shape"] + process_group = dist_attr["process_group"] + # get the complete shape of the parameter + complete_shape = _compute_complete_shape(param_list[0].shape, process_shape, + dims_mapping) + # merge the parameter with dist_attr + partition_param_list = [] + for process in process_group: + partition_index = _compute_partition_index( + process, complete_shape, dims_mapping, process_shape, process_group) + index = process_group.index(process) + _merge_parameter(partition_param_list, param_list[index], + partition_index) + assert len(partition_param_list) == 1 or not partition_param_list, \ + "Fail to merge parameter" + complete_param = _to_LodTensor(partition_param_list[0][0]) + return complete_param + + +def _slice_parameter_with_dist_attr(param, dist_attr): + """ Slice parameter with distributed attribute """ + param = np.array(param) if isinstance(param, + paddle.fluid.LoDTensor) else param + dims_mapping = dist_attr["dims_mapping"] + process_shape = dist_attr["process_shape"] + process_group = dist_attr["process_group"] + # slice the parameter with dist_attr + partition_index_list = _get_split_indices(param.shape, dims_mapping, + process_shape, process_group) + sliced_param_list = _slice_parameter(param, partition_index_list, + len(partition_index_list)) + # get the current parameter's index in sliced_param_list + rank_id = paddle.distributed.get_rank() + sliced_param_index = _get_sliced_param_index( + rank_id, param.shape, dims_mapping, process_shape, process_group) + sliced_param = _to_LodTensor(sliced_param_list[sliced_param_index]) + return sliced_param + + +def _merge_parameter(partition_param_list, param, partition_index): + """ + Merge partitial parameters to a complete one. + + Returns: + None + + Examples: + .. code-block:: python + + import numpy as np + partition_param_list = [(np.array([[[1.11, 1.12]]]), [[0,1],[0,1],[0,2]])] + param = np.array([[[1.13, 1.14]]]) + partition_index = [[0,1],[0,1],[2,4]] + + _merge_parameter(partition_param_list, param, partition_index) + # partition_param_list: [(np.array([[[1.11, 1.12, 1.13, 1.14]]]), [[0,1],[0,1],[0,4]])] + """ + from .reshard import _compute_concat_info + + if not partition_param_list: + partition_param_list.append((param, partition_index)) else: - raise ValueError("'checkpoint_path' can not be None.") + i = 0 + has_concat = False + while i < len(partition_param_list): + concat_axis, first_order, new_partition = _compute_concat_info( + partition_param_list[i][1], partition_index) + if concat_axis != -1: + has_concat = True + if first_order == 0: + new_param = np.concatenate( + (partition_param_list[i][0], param), axis=concat_axis) + else: + new_param = np.concatenate( + (param, partition_param_list[i][0]), axis=concat_axis) + + partition_param_list.pop(i) + _merge_parameter(partition_param_list, new_param, new_partition) + break + i += 1 + + if not has_concat: + need_append = True + for i in range(len(partition_param_list)): + if partition_index == partition_param_list[i][1]: + need_append = False + break + if need_append: + partition_param_list.append((param, partition_index)) + + +def _slice_parameter(complete_param, partition_index_list, length): + """ + Slice a complete parameter. + + Returns: + sliced_param_list(list): sliced parameters with 'partition_index_list' + + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + rank = 2 + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + sliced_param_list = _slice_parameter(complete_param, [[], [], [2, 4]], 3) + # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] + """ + sliced_param_list = [] + axis = len(complete_param.shape) - length + sliced_param = np.split( + complete_param, partition_index_list[axis], axis=axis) + if length == 1: + return sliced_param + for param in sliced_param: + sliced_param_list.extend( + _slice_parameter(param, partition_index_list, length - 1)) + return sliced_param_list + + +def _get_sliced_param_index(rank, complete_shape, dims_mapping, process_shape, + process_group): + """ + Get sliced_param's index of current rank in all sliced parameters list. + + Returns: + sliced_param_index(int): the index of sliced param in sliced_param_list + + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + rank = 2 + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + slice_param = _slice_parameter(complete_param, [[], [], [2, 4]], 3) + # slice_param: + # [array([[[1.11, 1.12]]]), array([[[1.13, 1.14]]]), array([[[1.15, 1.16]]])] + + index = _get_sliced_param_index(rank, complete_shape, dims_mapping + process_shape, process_group) + # index: 2 + """ + from .reshard import _compute_partition_index + + partition_index = _compute_partition_index( + rank, complete_shape, dims_mapping, process_shape, process_group) + sliced_param_index = 0 + for i, shape in enumerate(complete_shape): + if dims_mapping[i] == -1: + slice_shape = shape + else: + slice_shape = shape // process_shape[dims_mapping[i]] + if shape == 1: + index = 0 + else: + index = (partition_index[i][0] + 1) // slice_shape + sliced_param_index = sliced_param_index * (shape // slice_shape) + index + return sliced_param_index - program.set_state_dict(state_dict) if program else \ - warnings.warn("'Program' is None, parameters will not be loaded.") - if "addition_info" not in state_dict_info: - return +def _get_split_indices(complete_shape, dims_mapping, process_shape, + process_group): + """ + Get split indices of every dimension. + + Returns: + split_indices_list(list): the split indices of every dimension of the parameter - return state_dict_info["addition_info"] + Examples: + .. code-block:: python + + import numpy as np + complete_param = np.array([[[1.11, 1.12, 1.13, 1.14, 1.15, 1.16]]]) + complete_shape = [1, 1, 6] + dims_mapping = [-1, -1, 0] + process_shape = [3] + process_group = [0, 1, 2] + + index = _get_split_indices(complete_shape, dims_mapping, process_shape, process_group) + # index: [[], [], [2, 4]] + """ + from .reshard import _compute_partition_index + + split_indices_list = [] + for process in process_group: + partition_index = _compute_partition_index( + process, complete_shape, dims_mapping, process_shape, process_group) + if split_indices_list: + for dim in range(len(partition_index)): + split_indices_list[dim].extend(partition_index[dim]) + else: + split_indices_list = partition_index + split_indices_list = list( + map(lambda x, y: list(set(x) - set([y]) - set([0])), split_indices_list, + complete_shape)) + split_indices_list = [sorted(x) for x in split_indices_list] + return split_indices_list diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index e821140a0d1ec9f585c80b036d9ea6a465b23987..9e02347a13d9265779e8ab151e53394fb06a80aa 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -39,6 +39,7 @@ list(APPEND DIST_TEST_OPS test_parallel_class_center_sample) list(APPEND DIST_TEST_OPS test_parallel_margin_cross_entropy) list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard) list(APPEND DIST_TEST_OPS test_auto_parallel_save_load) +list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert) set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS}) #remove distribute unittests. list(APPEND MIXED_DIST_TEST_OPS test_dgc_op) @@ -256,6 +257,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM)) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_cost_model) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_data_unshard) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_save_load) + LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_autoconvert) elseif(WITH_GPU) if (${CUDNN_VERSION} VERSION_LESS 7100) LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op) @@ -1035,6 +1037,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_margin_cross_entropy PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_data_unshard PROPERTIES TIMEOUT 120) set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120) + set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120) if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212) set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_transformer PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py new file mode 100644 index 0000000000000000000000000000000000000000..2277c69674b3faf4f2fddc43b8032152a465fd42 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel_autoconvert.py @@ -0,0 +1,376 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import random +import numpy as np +import os +import shutil + +import paddle +import paddle.nn as nn +import paddle.utils as utils +import paddle.static as static +import paddle.nn.functional as F +import paddle.distributed.auto_parallel as auto + +from paddle.distributed import fleet +from paddle.fluid.initializer import NumpyArrayInitializer +from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint, load_checkpoint_into_program +from paddle.distributed.auto_parallel.utils import get_dist_attr, merge_and_slice_parameter, load_parameter_into_program +from paddle.distributed.auto_parallel.reshard import HAS_SENT, HAS_RECV, HAS_ALLGATHER + +paddle.enable_static() +_global_parallel_strategy = None +_global_process_mesh = None +PP_MESH_0 = None +PP_MESH_1 = None + + +class MLPLayer(nn.Layer): + def __init__(self, + hidden_size=64, + intermediate_size=4 * 64, + initializer_range=0.02): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + np.random.seed(2021) + arr0 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward)) + arr1 = np.random.normal(0, 0.02, size=(d_model, dim_feedforward)) + weight_attr0 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr0)) + weight_attr1 = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr1)) + bias_attr = None + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr0, bias_attr=bias_attr) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr1, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, input): + if _global_parallel_strategy == "pp": + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) + elif _global_parallel_strategy == "mp": + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, 0] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) + elif _global_parallel_strategy == "dp": + auto.shard_tensor( + self.linear0.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + self.linear1.weight, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) + + out = self.norm(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = self.linear1(out) + return out + + +def mlp_forward(train_program, start_program): + with static.program_guard(train_program,start_program), \ + utils.unique_name.guard(): + batch_size = 4 + hidden_size = 64 + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32') + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32') + + if _global_parallel_strategy == "pp": + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": PP_MESH_0, + "dims_mapping": [-1, -1] + }) + auto.shard_tensor( + label, + dist_attr={ + "process_mesh": PP_MESH_1, + "dims_mapping": [-1, -1] + }) + elif _global_parallel_strategy == "dp": + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [0, -1] + }) + elif _global_parallel_strategy == "mp": + auto.shard_tensor( + input, + dist_attr={ + "process_mesh": _global_process_mesh, + "dims_mapping": [-1, -1] + }) + + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02) + predict = mlp(input) + error_cost = paddle.nn.functional.square_error_cost(predict, label) + loss = paddle.mean(error_cost) + return loss, train_program, start_program + + +def get_distributed_program(): + train_program = static.Program() + startup_program = static.Program() + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + fleet.init(is_collective=True, strategy=dist_strategy) + loss, train_program, startup_program = mlp_forward(train_program, + startup_program) + optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer) + _, _, dist_startup_prog, dist_main_prog = optimizer.minimize( + loss, startup_program) + + return dist_main_prog, dist_startup_prog, loss + + +class TestMLPAutoConvert(unittest.TestCase): + def setUp(self): + paddle.seed(2021) + random.seed(2021) + np.random.seed(2021) + + def tearDown(self): + os.remove("./model_state_rank{}.pdmodel".format( + str(paddle.distributed.get_rank()))) + os.remove("./dist_attr_rank{}.pdattr".format( + str(paddle.distributed.get_rank()))) + + def test_mlp_mp2pp(self): + global _global_parallel_strategy + _global_parallel_strategy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh([0, 1]) + + input = np.random.random(size=(80, 64)).astype('float32') + label = np.random.random(size=(80, 1)).astype('float32') + + dist_main_prog, dist_start_prog, loss = get_distributed_program() + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(dist_start_prog) + + for step in range(20): + if step == 10: + save_distributed_checkpoint( + dist_main_prog, ".", dist_attr_path=".") + + res = exe.run(dist_main_prog, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }, + fetch_list=[loss]) + last_res = res[0] + + _global_parallel_strategy = "pp" + _global_process_mesh = auto.ProcessMesh([0, 1]) + global PP_MESH_0 + PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + global PP_MESH_1 + PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + + dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( + ) + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(dist_start_prog_load) + + ckpt_path = [ + "./model_state_rank0.pdmodel", "./model_state_rank1.pdmodel" + ] + dist_attr_path = [ + "./dist_attr_rank0.pdattr", "./dist_attr_rank1.pdattr" + ] + load_checkpoint_into_program(ckpt_path, dist_attr_path, + dist_main_prog_load) + for step in range(10, 20): + if paddle.distributed.get_rank() in [0]: + res = exe.run(dist_main_prog_load, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }) + else: + res = exe.run(dist_main_prog_load, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }, + fetch_list=[loss_load]) + if paddle.distributed.get_rank() in [1]: + self.assertEqual(last_res, res[0]) + + +class TestMLPAutoConvert2(unittest.TestCase): + def setUp(self): + paddle.seed(2021) + random.seed(2021) + np.random.seed(2021) + HAS_SENT.clear() + HAS_RECV.clear() + HAS_ALLGATHER.clear() + + def tearDown(self): + os.remove("./model_state_rank{}.pdmodel".format( + str(paddle.distributed.get_rank()))) + os.remove("./dist_attr_rank{}.pdattr".format( + str(paddle.distributed.get_rank()))) + + def test_mlp_pp2mp(self): + global _global_parallel_strategy + _global_parallel_strategy = "pp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh([0, 1]) + global PP_MESH_0 + PP_MESH_0 = auto.ProcessMesh(mesh=[0]) + global PP_MESH_1 + PP_MESH_1 = auto.ProcessMesh(mesh=[1]) + input = np.random.random(size=(80, 64)).astype('float32') + label = np.random.random(size=(80, 1)).astype('float32') + + dist_main_prog, dist_start_prog, loss = get_distributed_program() + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(dist_start_prog) + for step in range(20): + if step == 10: + add_info = {"batch": step, "batch_size": 4} + save_distributed_checkpoint(dist_main_prog, ".", ".", add_info) + + if paddle.distributed.get_rank() in [0]: + res = exe.run(dist_main_prog, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }) + else: + res = exe.run(dist_main_prog, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }, + fetch_list=[loss]) + if paddle.distributed.get_rank() in [1]: + last_res = res[0] + + _global_parallel_strategy = "mp" + _global_process_mesh = auto.ProcessMesh([0, 1]) + + dist_main_prog_load, dist_start_prog_load, loss_load = get_distributed_program( + ) + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + exe.run(dist_start_prog_load) + ckpt_path = [ + "./model_state_rank0.pdmodel", "./model_state_rank1.pdmodel" + ] + dist_attr_path = [ + "./dist_attr_rank0.pdattr", "./dist_attr_rank1.pdattr" + ] + param_dict, pre_dist_attr, add_info = load_distributed_checkpoint( + ckpt_path, dist_attr_path) + batch = add_info["batch"] + batch_size = add_info["batch_size"] + start_index = batch * batch_size + input = input[start_index:, :] + label = label[start_index:, :] + cur_dist_attr = get_dist_attr(dist_main_prog_load) + sliced_param_dict = merge_and_slice_parameter(param_dict, pre_dist_attr, + cur_dist_attr) + load_parameter_into_program(sliced_param_dict, dist_main_prog_load) + for step in range(10): + res = exe.run(dist_main_prog_load, + feed={ + "input": input[step * 4:(step + 1) * 4, :], + "label": label[step * 4:(step + 1) * 4, :] + }, + fetch_list=[loss_load]) + if paddle.distributed.get_rank() in [1]: + self.assertEqual(last_res, res[0]) + + +class TestMLPAutoConvertInvalid(unittest.TestCase): + def setUp(self): + paddle.seed(2021) + random.seed(2021) + np.random.seed(2021) + + def test_input_invalid(self): + global _global_parallel_strategy + _global_parallel_strategy = "mp" + global _global_process_mesh + _global_process_mesh = auto.ProcessMesh([0, 1]) + dist_main_prog, _, _ = get_distributed_program() + with self.assertRaises(TypeError): + save_distributed_checkpoint( + dist_main_prog, [""], [""], addition_info=[0]) + with self.assertRaises(ValueError): + save_distributed_checkpoint( + dist_main_prog, [""], [""], addition_info={"step": 0}) + with self.assertRaises(ValueError): + save_distributed_checkpoint( + dist_main_prog, [""], [""], addition_info={"batch": 0.0}) + with self.assertRaises(ValueError): + load_checkpoint_into_program(["./model_state_rank.pdmodel"], + ["./dist_attr_rank.pdattr"], + dist_main_prog) + with self.assertRaises(ValueError): + load_distributed_checkpoint(["./model_state_rank.pdmodel"], + ["./dist_attr_rank.pdattr"]) + with self.assertRaises(TypeError): + load_distributed_checkpoint({ + "0": "./model_state_rank.pdmodel" + }, {"1": "./dist_attr_rank.pdattr"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py b/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py index 6996fab09112f946a588083746f4309f049e6e89..35ee4f30da00ccaacb607789c6582ad492960516 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel_save_load.py @@ -29,12 +29,7 @@ import paddle.distributed.auto_parallel as auto from paddle.distributed import fleet from paddle.fluid.initializer import NumpyArrayInitializer -from paddle.distributed.auto_parallel.utils import make_data_unshard -from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_distributed_checkpoint -from paddle.distributed.auto_parallel.reshard import reshard -from paddle.distributed.auto_parallel.partitioner import Partitioner -from paddle.distributed.auto_parallel.dist_context import DistributedContext -from paddle.distributed.auto_parallel.process_group import get_all_process_groups +from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_checkpoint_into_program paddle.enable_static() _global_parallel_strategy = None @@ -204,7 +199,7 @@ class TestMLPSaveLoad(unittest.TestCase): if step == 10: path = "./output_dp{}".format(paddle.distributed.get_rank()) os.makedirs(path, exist_ok=True) - save_distributed_checkpoint(dist_main_prog, path) + save_distributed_checkpoint(dist_main_prog, path, path) res = exe.run(dist_main_prog, feed={ @@ -218,7 +213,11 @@ class TestMLPSaveLoad(unittest.TestCase): "./output_dp0/model_state_rank0.pdmodel", "./output_dp1/model_state_rank1.pdmodel" ] - load_distributed_checkpoint(ckpt_path, dist_main_prog) + dist_attr_path = [ + "./output_dp0/dist_attr_rank0.pdattr", + "./output_dp1/dist_attr_rank1.pdattr" + ] + load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog) for step in range(10, 20): res = exe.run(dist_main_prog, feed={ @@ -248,7 +247,7 @@ class TestMLPSaveLoad(unittest.TestCase): if step == 10: path = "./output_mp{}".format(paddle.distributed.get_rank()) os.makedirs(path, exist_ok=True) - save_distributed_checkpoint(dist_main_prog, path) + save_distributed_checkpoint(dist_main_prog, path, path) res = exe.run(dist_main_prog, feed={ @@ -262,7 +261,11 @@ class TestMLPSaveLoad(unittest.TestCase): "./output_mp0/model_state_rank0.pdmodel", "./output_mp1/model_state_rank1.pdmodel" ] - load_distributed_checkpoint(ckpt_path, dist_main_prog) + dist_attr_path = [ + "./output_mp0/dist_attr_rank0.pdattr", + "./output_mp1/dist_attr_rank1.pdattr" + ] + load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog) for step in range(10, 20): res = exe.run(dist_main_prog, feed={ @@ -296,7 +299,7 @@ class TestMLPSaveLoad(unittest.TestCase): if step == 10: path = "./output_pp{}".format(paddle.distributed.get_rank()) os.makedirs(path, exist_ok=True) - save_distributed_checkpoint(dist_main_prog, path) + save_distributed_checkpoint(dist_main_prog, path, path) if paddle.distributed.get_rank() in [0]: res = exe.run(dist_main_prog, @@ -319,7 +322,11 @@ class TestMLPSaveLoad(unittest.TestCase): "./output_pp0/model_state_rank0.pdmodel", "./output_pp1/model_state_rank1.pdmodel" ] - load_distributed_checkpoint(ckpt_path, dist_main_prog) + dist_attr_path = [ + "./output_pp0/dist_attr_rank0.pdattr", + "./output_pp1/dist_attr_rank1.pdattr" + ] + load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog) for step in range(10, 20): if paddle.distributed.get_rank() in [0]: res = exe.run(dist_main_prog, diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_autoconvert.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_autoconvert.py new file mode 100644 index 0000000000000000000000000000000000000000..131f2d299b5d70a2bb5df7cfd0bf81127bb0bf0a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_autoconvert.py @@ -0,0 +1,29 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle.fluid as fluid + +from test_parallel_dygraph_dataparallel import TestMultipleGpus + + +class TestAutoParallelAutoConvert(TestMultipleGpus): + def test_auto_parallel_autoconvert(self): + self.run_mnist_2gpu('auto_parallel_autoconvert.py') + + +if __name__ == "__main__": + unittest.main()