提交 30fd8808 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4413 merge sliced parameter

Merge pull request !4413 from caozhou/merge_parameter
......@@ -85,8 +85,6 @@ endif ()
if (ENABLE_DUMP_PROTO)
include_directories(${CMAKE_BINARY_DIR})
file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "utils/node_strategy.proto")
ms_protobuf_generate(PROTO_SRCS PROTO_HDRS ${PROTO_LIST})
file(GLOB_RECURSE PROTO_PY RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"utils/anf_ir.proto"
......@@ -94,6 +92,7 @@ if (ENABLE_DUMP_PROTO)
"utils/lineage.proto"
"utils/checkpoint.proto"
"utils/print.proto"
"utils/node_strategy.proto"
)
ms_protobuf_generate_py(PY_SRCS PY_HDRS PY_PYS ${PROTO_PY})
......
......@@ -23,6 +23,7 @@ import mindspore.nn as nn
from mindspore import log as logger
from mindspore.train.checkpoint_pb2 import Checkpoint
from mindspore.train.print_pb2 import Print
from mindspore.train.node_strategy_pb2 import ParallelStrategyMap
from mindspore.common.tensor import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
......@@ -569,3 +570,214 @@ def parse_print(print_file_name):
raise RuntimeError(e.__str__())
return tensor_list
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
"""
Merge data slices to one tensor with whole data when strategy is not None.
Args:
sliced_data (list[numpy.ndarray]): data slices in order of rank_id.
parameter_name (str): name of parameter.
strategy (dict): parameter slice strategy.
is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly.
Returns:
Tensor, the merged Tensor which has the whole data.
Raises:
ValueError: failed to merge.
"""
layout = strategy.get(parameter_name)
try:
dev_mat = list(layout.dev_matrix[0].dim)
tensor_map = list(layout.tensor_map[0].dim)
param_split_shape = list(layout.param_split_shape[0].dim)
field_size = int(layout.field)
except BaseException as e:
raise ValueError(f"{e.__str__()}. please make sure that strategy matches the node_strategy.proto.")
device_count = 1
for dim in dev_mat:
device_count *= dim
if len(sliced_data) != device_count:
raise ValueError(f"The sliced_parameters length should be equal to device_count. "
f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.")
merged_tensor = None
if not param_split_shape:
if not is_even:
raise ValueError("The shape of every parameter in sliced_parameters should be the same "
"when slice manner is even.")
all_gather_tensor = Tensor(np.concatenate(sliced_data))
if field_size > 0:
from mindspore.parallel._tensor import _reshape_param_data_with_weight
merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, [field_size])
else:
from mindspore.parallel._tensor import _reshape_param_data
merged_tensor = _reshape_param_data(all_gather_tensor, dev_mat, tensor_map)
else:
from mindspore.parallel._tensor import _get_tensor_strategy, _get_tensor_slice_index
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
slice_count = 1
for dim in tensor_strategy:
slice_count *= dim
if len(param_split_shape) != slice_count:
raise ValueError(f"The param_split_shape length in strategy should be {slice_count}, "
f"but got {len(param_split_shape)}.")
tensor_slices_new = list(range(slice_count))
tensor_slices = sliced_data
for i in range(device_count):
slice_index = int(_get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, i))
if tensor_slices[i].shape[0] != param_split_shape[slice_index]:
raise ValueError(f"The slice {slice_index} is {param_split_shape[slice_index]} in 0 axis, "
f"but got {tensor_slices[i].shape[0]}.")
tensor_slices_new[slice_index] = np.array(tensor_slices[i])
dim_len = len(tensor_strategy)
for i in range(dim_len):
ele_count = int(len(tensor_slices_new) / tensor_strategy[dim_len - 1 - i])
tensor_slices_new_inner = []
for j in range(ele_count):
new_tensor = tensor_slices_new[j * tensor_strategy[dim_len - 1 - i]]
for l in range(j * tensor_strategy[dim_len - 1 - i] + 1,
(j + 1) * tensor_strategy[dim_len - 1 - i]):
new_tensor = np.concatenate((new_tensor, tensor_slices_new[l]), axis=dim_len - 1 - i)
tensor_slices_new_inner.insert(len(tensor_slices_new_inner), np.array(new_tensor))
tensor_slices_new = tensor_slices_new_inner
merged_tensor = Tensor(tensor_slices_new[0])
return merged_tensor
def build_searched_strategy(strategy_filename):
"""
build strategy of every parameter in network.
Args:
strategy_filename (str): name of strategy file.
Returns:
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Raises:
ValueError: strategy file is incorrect.
TypeError: strategy_filename is not str.
Examples:
>>> strategy_filename = "./strategy_train.ckpt"
>>> strategy = build_searched_strategy(strategy_filename)
"""
if not isinstance(strategy_filename, str):
raise TypeError(f"The strategy_filename should be str, but got {type(strategy_filename)}.")
if not os.path.isfile(strategy_filename):
raise ValueError(f"No such strategy file: {strategy_filename}.")
if os.path.getsize(strategy_filename) == 0:
raise ValueError("The strategy file should not be empty.")
parallel_strategy_map = ParallelStrategyMap()
with open(strategy_filename, 'rb') as f:
pb_content = f.read()
parallel_strategy_map.ParseFromString(pb_content)
layout_items = parallel_strategy_map.parallel_layout_item
if not layout_items:
raise ValueError("The strategy file has no sliced parameter.")
strategy = {}
for layout_item in layout_items:
parameter_name = layout_item.param_name
layout = layout_item.parallel_layouts
strategy[parameter_name] = layout
return strategy
def merge_sliced_parameter(sliced_parameters, strategy=None):
"""
Merge parameter slices to one whole parameter.
Args:
sliced_parameters (list[Parameter]): parameter slices in order of rank_id.
strategy (dict): parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
- key (str): parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): slice strategy of this parameter.
Returns:
Parameter, the merged parameter which has the whole data.
Raises:
ValueError: failed to merge.
TypeError: the sliced_parameters is incorrect or strategy is not dict.
KeyError: the parameter name is not in keys of strategy.
Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
"""
if not isinstance(sliced_parameters, list):
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
if not sliced_parameters:
raise ValueError("The sliced_parameters should not be empty.")
if strategy and not isinstance(strategy, dict):
raise TypeError(f"The strategy should be dict, but got {type(strategy)}.")
try:
parameter_name = sliced_parameters[0].name
parameter_shape = sliced_parameters[0].data.shape
parameter_shape_length = len(parameter_shape)
except BaseException as e:
raise TypeError(f"{e.__str__()}. the element in sliced_parameters should be Parameter.")
is_even = True
for index, parameter in enumerate(sliced_parameters):
if not isinstance(parameter, Parameter):
raise TypeError(f"The element in sliced_parameters should be Parameter, "
f"but got {type(parameter)} at index {index}.")
if parameter.name != parameter_name \
or len(parameter.data.shape) != parameter_shape_length \
or parameter.data.shape[1:] != parameter_shape[1:]:
raise ValueError("Please make sure that the elements in slice_parameters have the same name, "
"dimension length and shape except 0 axis")
if parameter.data.shape != parameter_shape:
is_even = False
layerwise_parallel = sliced_parameters[0].layerwise_parallel
requires_grad = sliced_parameters[0].requires_grad
sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters]
merged_parameter = None
if not strategy:
merged_tensor = Tensor(np.concatenate(sliced_data))
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
else:
if parameter_name not in strategy.keys():
raise KeyError(f"The parameter name should be one key of strategy. "
f"the parameter name is {parameter_name}.")
merged_tensor = _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even)
merged_parameter = Parameter(merged_tensor, parameter_name, requires_grad, layerwise_parallel)
return merged_parameter
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册