diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 3cf638f3d4695db731eef55a179b8ab69a6f30c2..355cf6273a772589134e4c1f37bb884497575894 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -85,8 +85,6 @@ endif () if (ENABLE_DUMP_PROTO) include_directories(${CMAKE_BINARY_DIR}) - file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "utils/node_strategy.proto") - ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "utils/anf_ir.proto" @@ -94,6 +92,7 @@ if (ENABLE_DUMP_PROTO) "utils/lineage.proto" "utils/checkpoint.proto" "utils/print.proto" + "utils/node_strategy.proto" ) ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY}) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 19bb90fab4a015e12d73ca37154b1c513048b86a..6b93878d4b59c23fa9ab3020325c7db701434bba 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -23,6 +23,7 @@ import mindspore.nn as nn from mindspore import log as logger from mindspore.train.checkpoint_pb2 import Checkpoint from mindspore.train.print_pb2 import Print +from mindspore.train.node_strategy_pb2 import ParallelStrategyMap from mindspore.common.tensor import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter @@ -569,3 +570,214 @@ def parse_print(print_file_name): raise RuntimeError(e.__str__()) return tensor_list + + +def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): + """ + Merge data slices to one tensor with whole data when strategy is not None. + + Args: + sliced_data (list[numpy.ndarray]): data slices in order of rank_id. + parameter_name (str): name of parameter. + strategy (dict): parameter slice strategy. + is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly. + + Returns: + Tensor, the merged Tensor which has the whole data. + + Raises: + ValueError: failed to merge. + """ + layout = strategy.get(parameter_name) + try: + dev_mat = list(layout.dev_matrix[0].dim) + tensor_map = list(layout.tensor_map[0].dim) + param_split_shape = list(layout.param_split_shape[0].dim) + field_size = int(layout.field) + except BaseException as e: + raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.") + + device_count = 1 + for dim in dev_mat: + device_count *= dim + + if len(sliced_data) != device_count: + raise ValueError(f"The sliced_parameters length should be equal to device_count. " + f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.") + + merged_tensor = None + if not param_split_shape: + if not is_even: + raise ValueError("The shape of every parameter in sliced_parameters should be the same " + "when slice manner is even.") + + all_gather_tensor = Tensor(np.concatenate(sliced_data)) + + if field_size > 0: + from mindspore.parallel._tensor import _reshape_param_data_with_weight + merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, [field_size]) + + else: + from mindspore.parallel._tensor import _reshape_param_data + merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map) + + else: + from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index + tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map) + + slice_count = 1 + for dim in tensor_strategy: + slice_count *= dim + + if len(param_split_shape) != slice_count: + raise ValueError(f"The param_split_shape length in strategy should be {slice_count}, " + f"but got {len(param_split_shape)}.") + + tensor_slices_new = list(range(slice_count)) + tensor_slices = sliced_data + for i in range(device_count): + slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i)) + if tensor_slices[i].shape[0] != param_split_shape[slice_index]: + raise ValueError(f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, " + f"but got {tensor_slices[i].shape[0]}.") + tensor_slices_new[slice_index] = np.array(tensor_slices[i]) + + dim_len = len(tensor_strategy) + for i in range(dim_len): + ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i]) + tensor_slices_new_inner = [] + for j in range(ele_count): + new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]] + for l in range(j * tensor_strategy[dim_len - 1 - i] + 1, + (j + 1) * tensor_strategy[dim_len - 1 - i]): + new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i) + tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor)) + tensor_slices_new = tensor_slices_new_inner + merged_tensor = Tensor(tensor_slices_new[0]) + + return merged_tensor + + +def build_searched_strategy(strategy_filename): + """ + build strategy of every parameter in network. + + Args: + strategy_filename (str): name of strategy file. + + Returns: + Dictionary, whose key is parameter name and value is slice strategy of this parameter. + + Raises: + ValueError: strategy file is incorrect. + TypeError: strategy_filename is not str. + + Examples: + >>> strategy_filename = "./strategy_train.ckpt" + >>> strategy = build_searched_strategy(strategy_filename) + """ + if not isinstance(strategy_filename, str): + raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.") + + if not os.path.isfile(strategy_filename): + raise ValueError(f"No such strategy file: {strategy_filename}.") + + if os.path.getsize(strategy_filename) == 0: + raise ValueError("The strategy file should not be empty.") + + parallel_strategy_map = ParallelStrategyMap() + + with open(strategy_filename, 'rb') as f: + pb_content = f.read() + parallel_strategy_map.ParseFromString(pb_content) + + layout_items = parallel_strategy_map.parallel_layout_item + if not layout_items: + raise ValueError("The strategy file has no sliced parameter.") + + strategy = {} + for layout_item in layout_items: + parameter_name = layout_item.param_name + layout = layout_item.parallel_layouts + strategy[parameter_name] = layout + + return strategy + + +def merge_sliced_parameter(sliced_parameters, strategy=None): + """ + Merge parameter slices to one whole parameter. + + Args: + sliced_parameters (list[Parameter]): parameter slices in order of rank_id. + strategy (dict): parameter slice strategy. Default: None. + + If strategy is None, just merge parameter slices in 0 axis order. + - key (str): parameter name. + - value (): slice strategy of this parameter. + + Returns: + Parameter, the merged parameter which has the whole data. + + Raises: + ValueError: failed to merge. + TypeError: the sliced_parameters is incorrect or strategy is not dict. + KeyError: the parameter name is not in keys of strategy. + + Examples: + >>> strategy = build_searched_strategy("./strategy_train.ckpt") + >>> sliced_parameters = [\ + Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \ + Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \ + Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \ + Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")] + >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) + """ + if not isinstance(sliced_parameters, list): + raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.") + + if not sliced_parameters: + raise ValueError("The sliced_parameters should not be empty.") + + if strategy and not isinstance(strategy, dict): + raise TypeError(f"The strategy should be dict, but got {type(strategy)}.") + + try: + parameter_name = sliced_parameters[0].name + parameter_shape = sliced_parameters[0].data.shape + parameter_shape_length = len(parameter_shape) + except BaseException as e: + raise TypeError(f"{e.__str__()}. the element in sliced_parameters should be Parameter.") + + is_even = True + for index, parameter in enumerate(sliced_parameters): + if not isinstance(parameter, Parameter): + raise TypeError(f"The element in sliced_parameters should be Parameter, " + f"but got {type(parameter)} at index {index}.") + + if parameter.name != parameter_name \ + or len(parameter.data.shape) != parameter_shape_length \ + or parameter.data.shape[1:] != parameter_shape[1:]: + raise ValueError("Please make sure that the elements in slice_parameters have the same name, " + "dimension length and shape except 0 axis") + + if parameter.data.shape != parameter_shape: + is_even = False + + layerwise_parallel = sliced_parameters[0].layerwise_parallel + requires_grad = sliced_parameters[0].requires_grad + sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters] + merged_parameter = None + + if not strategy: + merged_tensor = Tensor(np.concatenate(sliced_data)) + merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) + + else: + if parameter_name not in strategy.keys(): + raise KeyError(f"The parameter name should be one key of strategy. " + f"the parameter name is {parameter_name}.") + merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even) + merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel) + + return merged_parameter