提交 7a5004cc 编写于 作者: Y Yi Huaijie

init the slices of a Initialzer on different devices

上级 1b7cdc4c
......@@ -328,8 +328,13 @@ class _Executor:
def _params_init_data(self, obj, params):
if params is not None:
for _, param in params.items():
param.init_data()
for key, param in params.items():
if key not in obj.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", key)
param.init_data()
else:
layout = obj.parameter_layout_dict[key]
param.init_data(layout)
obj.init_parameters_data()
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
......@@ -377,10 +382,11 @@ class _Executor:
if not do_convert:
return phase, True
if auto_parallel_mode and "train" in phase:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params)
if not enable_debug_runtime or enable_ge:
if auto_parallel_mode and "train" in phase:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params)
# the following GE init process is not needed when use vm or ms backend
......
......@@ -41,6 +41,7 @@ class Initializer:
self._kwargs = kwargs
self.shape = None
self.dtype = None
self._seed = None
def _initialize(self, *kwargs):
raise NotImplementedError('Must be overridden!')
......@@ -48,6 +49,15 @@ class Initializer:
def __call__(self, arr):
return self._initialize(arr)
@property
def seed(self):
return self._seed
@seed.setter
def seed(self, seed_):
"""set the random seed."""
self._seed = seed_
@property
def shape(self):
return self._shape
......@@ -65,6 +75,7 @@ class Initializer:
self._dtype = dtype
def to_tensor(self):
"""Get the tensor format data of this Initializer."""
arr = None
try:
arr = np.ndarray(self.shape)
......@@ -72,7 +83,10 @@ class Initializer:
msg = "Error shape={}".format(self.shape)
logger.error(msg)
raise ValueError(msg)
if self._seed is not None:
np.random.seed(self.seed)
self.__call__(arr)
self._seed = None
return Tensor(arr, dtype=self.dtype)
def _register(*aliases):
......
......@@ -20,6 +20,7 @@ from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from ..parallel._utils import _set_clone_info, _CloneInfo
from ..parallel._tensor import _get_seed
__all__ = ['Parameter', 'ParameterTuple']
......@@ -55,6 +56,7 @@ class Parameter:
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
def __repr__(self):
......@@ -91,6 +93,11 @@ class Parameter:
raise ValueError("The type of the name should be `str` or `None`.")
self._name = name_
@property
def sliced(self):
"""Get slice status of the parameter."""
return self._sliced
@property
def is_init(self):
"""Get init status of the parameter."""
......@@ -196,11 +203,31 @@ class Parameter:
self.default_input = data
def init_data(self):
def init_data(self, layout=None):
"""
Init data of the parameter.
Args:
layout (list[list[int]]): parameter slice layout [dev_mat, tensor_map, slice_shape].
dev_mat (list[int]): device matrix.
tensor_map (list[int]): tensor map.
slice_shape (list[int]): shape of slice.
"""
if not isinstance(self.default_input, MetaTensor):
return
if layout is not None:
if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}."
.format(layout))
if len(layout) != 3:
raise ValueError("The length of layout must be 3! layout is {}."
.format(layout))
self.init_mode.shape = layout[2]
self.init_mode.seed = int(_get_seed(layout[0], layout[1]))
self.default_input = self.init_mode.to_tensor()
self.init_mode = None
self._sliced = True
class ParameterTuple(tuple):
......
......@@ -249,6 +249,9 @@ class Cell:
if key not in self.parameter_layout_dict:
logger.info("layout dict does not contain the key %s", key)
continue
if self.parameters_dict()[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
self.parameters_dict()[key].set_parameter_data(new_tensor)
......@@ -258,6 +261,9 @@ class Cell:
if key not in self.parameter_layout_dict:
logger.info("layout dict does not contain the key %s", key)
continue
if params[key].sliced:
logger.info("Param %s is from initializer, already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
params[key].set_parameter_data(new_tensor)
......@@ -398,7 +404,12 @@ class Cell:
def init_parameters_data(self, recurse=True):
for param in self.get_parameters(expand=recurse):
param.init_data()
if param.name not in self.parameter_layout_dict:
logger.info("Layout dict does not contain the key %s.", param.name)
param.init_data()
else:
layout = self.parameter_layout_dict[param.name]
param.init_data(layout)
def parameters_dict(self, recurse=True):
"""
......
......@@ -168,6 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
raise ValueError("The length of np_tensor does not match the length of strategy!")
return _chunk_tensor(np_tensor, strategy, len(strategy))
def _get_seed(dev_mat, tensor_map):
"""
Get the random seed for current slice.
Args:
dev_mat (list): The device matrix of devices.
tensor_map (list): The split strategy of tensor.
Returns:
Integer, the local random seed for this device.
"""
rank = get_rank()
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
return tensor_slice_seed
def _load_tensor(tensor, dev_mat, tensor_map):
"""
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from mindspore import context
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore import Tensor, Parameter
import mindspore as ms
import mindspore.common.api as me
from mindspore.common.initializer import initializer
from hccl_test.manage.api import Hccl
def test_initializer_weight_slice():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2, weight):
super().__init__()
self.weight = Parameter(weight, "w1")
self.matmul = P.MatMul(transpose_a=False, transpose_b=True).set_strategy(strategy1)
self.relu = P.ReLU().set_strategy(strategy2)
def construct(self, x):
out = self.matmul(x, self.weight)
out = self.relu(out)
return out
def get_slice(rank):
hccl = Hccl()
rank_save = hccl.rank_id
hccl.rank_id = rank
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=8, global_rank=0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
strategy1 = ((2, 1), (4, 1))
strategy2 = ((2, 4),)
context.set_context(mode=context.GRAPH_MODE)
exe = me._executor
x = Tensor(np.ones([32, 32]), dtype=ms.float32)
weight = initializer("Uniform", [64, 32], ms.float32)
net = Net(strategy1, strategy2, weight)
net.set_auto_parallel()
exe.compile(net, x, auto_parallel_mode=True, phase='train')
hccl.rank_id = rank_save
return net.parameters_dict()['w1'].data.asnumpy()
slice0 = get_slice(0)
slice1 = get_slice(1)
slice4 = get_slice(4)
slice_shape = slice0.shape
slice0 = slice0.flatten()
slice1 = slice1.flatten()
slice4 = slice4.flatten()
expect_slice_shape = (16, 32)
assert expect_slice_shape == slice_shape
assert all(slice0 == slice4)
assert any(slice0 != slice1)
if __name__ == '__main__':
test_initializer_weight_slice()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册