提交 1d43a2ee 编写于 作者: C caozhou

modified the content of merging sliced parameters

上级 8487f303
......@@ -9,8 +9,8 @@
- [Integrating the Saved Checkpoint Files](#integrating-the-saved-checkpoint-files)
- [Overall Process](#overall-process)
- [Preparations](#preparations)
- [Importing the Checkpoint Files to the Network](#importing-the-checkpoint-files-to-the-network)
- [Obtaining a List of All Parameters on the Network](#obtaining-a-list-of-all-parameters-on-the-network)
- [Importing the Checkpoint Files in rank id order](#importing-the-checkpoint-files-in-rank-id-order)
- [Obtaining the slice strategy of model](#obtaining-the-slice-strategy-of-model)
- [Integrate the Model Parallel Parameters](#integrate-the-model-parallel-parameters)
- [Saving the Data and Generating a New Checkpoint File](#saving-the-data-and-generating-a-new-checkpoint-file)
- [Loading the Integrated and Saved Checkpoint File](#loading-the-integrated-and-saved-checkpoint-file)
......@@ -71,7 +71,7 @@ For example, in the training stage 1, the training environment with 64 devices i
### Overall Process
Import the checkpoint files to be integrated to the network and obtain the list of all parameters through the API provided by MindSpore. See steps 1 and 2 in the following figure.
Import the checkpoint files to be integrated to the network in rank id order and obtain the list of all parameters through the API provided by MindSpore, and then obtain the slice strategy of model. See steps 1 and 2 in the following figure.
Then, update the parameter list and integrate the model parallel parameters. See step 3 in the following figure.
......@@ -81,120 +81,67 @@ Finally, save the updated parameter list to a file through the API provided by M
### Preparations
#### Importing the Checkpoint Files to the Network
#### Importing the Checkpoint Files in rank id order
Define the network, call the `load_checkpoint` and `load_param_into_net` APIs, and import the checkpoint files to the network.
Define the network, call the `load_checkpoint` and `load_param_into_net` APIs to import the checkpoint files to the network in rank id order, and then call `parameters_and_names` API to obtain all parameters in this network.
```
param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name
net = Net()
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
load_param_into_net(net, param_dict)
param_dicts = []
for i in range(rank_size):
file_name = os.path.join("./node"+str(i), "CKP_1-4_32.ckpt") # checkpoint file name of current node
param_dict = load_checkpoint(file_name)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
param_dicts.append(param_dict)
```
In the preceding information:
- `rank_size`: number of nodes in previous distributed training.
- `load_checkpoint`: loads the checkpoint model parameter file and returns a parameter dictionary.
- `load_param_into_net`: loads model parameter data to the network.
- `CKP_1-4_32.ckpt`: name of the saved checkpoint model parameter file.
> If a new checkpoint file is directly saved in the training environment based on the current training data and the parameter values already exist on the network, skip this step and you do not need to import the checkpoint files.
#### Obtaining a List of All Parameters on the Network
Call the `parameters_and_names` API to obtain all parameter data on the network.
Call the `build_searched_strategy` API to obtain the slice strategy of model.
```
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
strategy = build_searched_strategy("./strategy_train.ckpt")
```
### Integrate the Model Parallel Parameters
In the preceding information:
The following uses a model parameter as an example to describe a specific integration process.
- `strategy_train.ckpt`: name of model slice strategy, set by users calling `set_auto_parallel_context` API and customizing `strategy_ckpt_save_file` parameter before training network, and the file saved on each node are the same.
The parameter name is model\_parallel\_weight and the data is Tensor \[\[1, 2, 3, 4], \[5, 6, 7, 8]].
### Integrate the Model Parallel Parameters
The dividing strategy is to perform dividing in a 4-device scenario based on \[2, 2]. That is, the data is first divided into two slices in the row dimension, then the two slices are respectively divided into two smaller slices in the column dimension, and finally four slices are obtained. Data distribution after dividing is as follows:
The following uses a model parameter as an example to describe a specific integration process.
| Device0 | Device1 | Device2 | Device3 |
|--------------|--------------|--------------|--------------|
| Value [1, 2] | Value [3, 4] | Value [5, 6] | Value [7, 8] |
The parameter name is model\_parallel\_weight and the dividing strategy is to perform dividing in a 4-device scenario.
1. Obtain the data value on the current node for model parallel parameters.
1. Obtain the data value on all nodes for model parallel parameters.
```
param_data = param_dict[“model_parallel_weight”]
param_data_moments = param_dict[“moments.model_parallel_weight”]
sliced_parameters = []
for i in range(4):
parameter = param_dicts[i].get("model_parallel_weight")
sliced_parameters.append(parameter)
```
> To ensure that the parameter update speed remains unchanged, you need to integrate the parameters saved in the optimizer, for example, moments.model\_parallel\_weight.
2. Define, instantiate, and execute the `AllGather` Cell, and obtain data on all devices.
```
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
class AllGatherCell(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
allgather_net = AllGatherCell()
param_data = allgather_net(param_data)
param_data_moments = allgather_net(param_data_moments)
```
The value of `param_data` is the integration of data on each device in dimension 0. The data value is \[\[1, 2], \[3, 4], \[5, 6], \[7, 8]], and the shape is \[4, 2]. The raw data value of `param_data` is \[\[1, 2, 3, 4], \[5, 6, 7, 8]], and the shape is \[2, 4]. The data needs to be redivided and integrated.
3. Divide the data obtained from `AllGather`.
```
slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster
slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster
```
The result of `param_data` is as follows:
slice_list[0] --- [1, 2] Slice data on device0
slice_list[1] --- [3, 4] Slice data on device1
slice_list[2] --- [5, 6] Slice data on device2
slice_list[3] --- [7, 8] Slice data on device3
4. Reassemble data based on the site requirements.
In the following code, slice 1 and slice 2, slice 3 and slice 4 are first spliced by column, and then the obtained data is spliced by row.
```
slice_line1 = np.concatenate((slice_list[0], slice_list[1]), axis=1) # result [1,2,3,4]
slice_line2 = np.concatenate((slice_list[2], slice_list[3]), axis=1) # result [5,6,7,8]
whole_data = np.concatenate((slice_line1, slice_line2), axis=0) # result [[1, 2, 3, 4], [5, 6, 7, 8]]
2. Call the `merge_sliced_parameter` API to merge the sliced parameters.
slice_moments_line1 = np.concatenate((slice_lis_moments[0], slice_lis_moments[1]), axis=1)
slice_moments_line2 = np.concatenate((slice_lis_moments[2], slice_lis_moments[3]), axis=1)
whole_moments_data = np.concatenate((slice_moments_line1, slice_moments_line2), axis=0)
```
5. Assign values to model parameters.
merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
```
param_data = Tensor(whole_data)
param_data_moments = Tensor(whole_moments_data)
```
> 1. If there are multiple model parallel parameters, repeat steps 1 to 5 to process them one by one.
> 2. If the data obtained in step 2 is the final data, skip the following steps. That is, the dividing strategy is to perform dividing only on shape0 and each device loads different slice data.
> If there are multiple model parallel parameters, repeat steps 1 to 2 to process them one by one.
### Saving the Data and Generating a New Checkpoint File
......@@ -324,106 +271,89 @@ User process:
```
python ./integrate_checkpoint.py "Path and name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration"
python ./integrate_checkpoint.py "Name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration" "Path and name of the strategy file" "Number of nodes"
```
integrate\_checkpoint.py:
```
import numpy as np
import os
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication.management import init
from mindspore.train.serialization import save_checkpoint, load_checkpoint
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=True, device_id=devid)
init()
class Net(nn.Cell):
def __init__(self,weight_init):
super(Net, self).__init__()
self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True)
self.fc = P.MatMul(transpose_b=True)
def construct(self, x):
x = self.fc(x, self.weight1)
return x
class AllGatherNet(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super().__init__()
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
def integrate_ckpt_file(old_ckpt_file, new_ckpt_file):
weight = np.ones([2, 8]).astype(np.float32)
net = Net(weight)
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
# load CheckPoint into net
param_dict = load_checkpoint(old_ckpt_file)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]:
# get layer wise model parallel parameter
layerwise_param = param_dict[paramname]
if isinstance(layerwise_param.data, Tensor):
param_data = layerwise_param.data
else:
param_data = Tensor(layerwise_param.data)
# merge the parallel parameters of the model
allgather_net = get_allgather_cell()
param_data = allgather_net(param_data)
layerwise_param.set_parameter_data(param_data, True)
# convert param_dict to list type data
param_list = []
for (key, value) in param_dict.items():
each_param = {}
each_param["name"] = key
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
each_param["data"] = param_data
param_list.append(each_param)
# call the API to generate a new CheckPoint file
save_checkpoint(param_list, new_ckpt_file)
return
if __name__ == "__main__":
try:
old_ckpt_file = sys.argv[1]
new_ckpt_file = sys.argv[2]
integrate(old_ckpt_file, new_ckpt_file)
except:
print("Fail to integrate checkpoint file)
sys.exit(-1)
import numpy as np
import os
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.train.serialization import save_checkpoint, load_checkpoint, build_searched_strategy, merge_sliced_parameter
class Net(nn.Cell):
def __init__(self,weight_init):
super(Net, self).__init__()
self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True)
self.fc = P.MatMul(transpose_b=True)
def construct(self, x):
x = self.fc(x, self.weight1)
return x
def integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size):
weight = np.ones([2, 8]).astype(np.float32)
net = Net(weight)
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
# load CheckPoint into net in rank id order
param_dicts = []
for i in range(rank_size):
file_name = os.path.join("./node"+str(i), old_ckpt_file)
param_dict = load_checkpoint(file_name)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
param_dicts.append(param_dict)
strategy = build_searched_strategy(strategy_file)
param_dict = {}
for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]:
# get layer wise model parallel parameter
sliced_parameters = []
for i in range(rank_size):
parameter = param_dicts[i].get(paramname)
sliced_parameters.append(parameter)
# merge the parallel parameters of the model
merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
param_dict[paramname] = merged_parameter
# convert param_dict to list type data
param_list = []
for (key, value) in param_dict.items():
each_param = {}
each_param["name"] = key
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
each_param["data"] = param_data
param_list.append(each_param)
# call the API to generate a new CheckPoint file
save_checkpoint(param_list, new_ckpt_file)
return
if __name__ == "__main__":
try:
old_ckpt_file = sys.argv[1]
new_ckpt_file = sys.argv[2]
strategy_file = sys.argv[3]
rank_size = int(sys.argv[4])
integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size)
except:
print("Fail to integrate checkpoint file)
sys.exit(-1)
```
In the preceding information:
- `mode=context.GRAPH_MODE`: sets the running mode to graph mode for distributed training. (The PyNative mode does not support parallel running.)
- `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on a computer where the device is located.
- `init`: completes the distributed training initialization.
The command output is as follows.
Before the script is executed, the parameter values in the checkpoint files are as follows:
......@@ -523,6 +453,7 @@ User process:
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......@@ -570,6 +501,12 @@ User process:
label = np.random.random((4, 4)).astype(np.float32)
train_mindspore_impl_fc(input, label, weight1)
```
In the preceding information:
- `mode=context.GRAPH_MODE`: sets the running mode to graph mode for distributed training. (The PyNative mode does not support parallel running.)
- `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on a computer where the device is located.
- `init`: completes the distributed training initialization.
Parameter values after loading:
......
......@@ -11,8 +11,8 @@
- [对保存的CheckPoint文件做合并处理](#对保存的checkpoint文件做合并处理)
- [整体流程](#整体流程)
- [准备工作](#准备工作)
- [导入CheckPoint文件到网络](#导入checkpoint文件到网络)
- [获取网络中全量参数列表](#获取网络中全量参数列表)
- [按逻辑顺序导入CheckPoint文件](#按逻辑顺序导入checkpoint文件)
- [获取模型参数切分策略](#获取模型参数切分策略)
- [对模型并行的参数做合并处理](#对模型并行的参数做合并处理)
- [保存数据生成新的CheckPoint文件](#保存数据生成新的checkpoint文件)
- [加载合并保存的CheckPoint文件](#加载合并保存的checkpoint文件)
......@@ -79,7 +79,7 @@ MindSpore模型并行场景下,每个实例进程只保存有本节点对应
### 整体流程
首先,执行准备工作,将待合并处理的CheckPoint文件导入网络,并通过MindSpore提供的API获取全量参数列表。对应下图中的Step1和Step2。
首先,执行准备工作,按逻辑顺序将待合并处理的CheckPoint文件导入网络,获取模型全量参数并添加至列表中,再获取模型参数切分策略。对应下图中的Step1和Step2。
其次,更新参数列表,对涉及模型并行的参数做合并处理。对应下图中的Step3。
......@@ -89,119 +89,64 @@ MindSpore模型并行场景下,每个实例进程只保存有本节点对应
### 准备工作
#### 导入CheckPoint文件到网络
#### 按逻辑顺序导入CheckPoint文件
定义网络,并调用`load_checkpoint``load_param_into_net`接口,将CheckPoint文件导入网络
定义网络,调用`load_checkpoint``load_param_into_net`接口,按逻辑顺序将CheckPoint文件导入网络,之后调用`parameters_and_names`接口获取网络里所有的参数数据
```
param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name
net = Net()
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
load_param_into_net(net, param_dict)
param_dicts = []
for i in range(rank_size):
file_name = os.path.join("./node"+str(i), "CKP_1-4_32.ckpt") # checkpoint file name of current node
param_dict = load_checkpoint(file_name)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
param_dicts.append(param_dict)
```
其中,
- `rank_size`:之前分布式训练的节点数。
- `load_checkpoint`:通过该接口加载CheckPoint模型参数文件,返回一个参数字典。
- `load_param_into_net`:模型参数数据加载到网络中。
- `CKP_1-4_32.ckpt`:之前保存的CheckPoint模型参数文件名称。
> 如果直接在训练环境上,基于当前训练得到的数据直接保存新的CheckPoint文件,参数值已经存在在网络中,则可以省略该步骤,无需导入CheckPoint文件。
#### 获取网络中全量参数列表
#### 获取模型参数切分策略
调用`parameters_and_names`接口,获取网络里所有的参数数据
调用`build_searched_strategy`接口,得到模型各个参数的切分策略
```
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
strategy = build_searched_strategy("./strategy_train.cpkt")
```
### 对模型并行的参数做合并处理
其中,
下面以一个具体的模型参数为例,说明下参数合并处理的具体流程
- `strategy_train.ckpt`:保存的模型参数切分策略文件名称,训练网络之前由用户调用`set_auto_parallel_context`接口自定义`strategy_ckpt_save_file`参数生成,各个节点上保存的策略文件相同
参数名称为"model_parallel_weight",数据为Tensor [[1, 2, 3, 4], [5, 6, 7, 8]]。
### 对模型并行的参数做合并处理
切分逻辑为4卡场景,按[2, 2]切分,即先在行维度切分为2个切片,之后再对得到的2个切片,分别在列维度分再切分为2个更小的切片,最后得到4个切片。
切分后数据分布情况如下:
下面以一个具体的模型参数为例,说明下参数合并处理的具体流程。
| Device0 | Device1 | Device2 | Device3 |
| ------------- | ------------ | ------------- | ------------- |
| Value [1, 2] | Value [3, 4] | Value [5, 6] | Value [7, 8] |
参数名称为"model_parallel_weight",切分逻辑为4卡场景。
1. 针对涉及模型并行的参数,获取本节点上的数据值
1. 针对涉及模型并行的参数,获取所有节点上的参数数据
```
param_data = param_dict[“model_parallel_weight”]
param_data_moments = param_dict[“moments.model_parallel_weight”]
sliced_parameters = []
for i in range(4):
parameter = param_dicts[i].get("model_parallel_weight")
sliced_parameters.append(parameter)
```
> 如果要保证参数更新速度不变,需要对优化器中保存的参数,如“moments.model_parallel_weight”,同样做合并处理。
2. 定义`AllGather`类型子图,并实例化和执行,获取所有卡上的数据。
```
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
class AllGatherCell(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
allgather_net = AllGatherCell()
param_data = allgather_net(param_data)
param_data_moments = allgather_net(param_data_moments)
```
​得到的数据`param_data`为每卡上的数据在维度0上的合并,数据值为 [[1, 2], [3, 4], [5, 6], [7, 8]],shape为[4, 2]。
`param_data`原始数据值为[[1, 2, 3, 4], [5, 6, 7, 8]],shape为[2, 4],需要对数据重新切分合并。
3. 切分通过`AllGather`得到的数据。
```
slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster
slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster
```
得到结果`param_data`为:
slice_list[0] --- [1, 2] device0上的切片数据
slice_list[1] --- [3, 4] device1上的切片数据
slice_list[2] --- [5, 6] device2上的切片数据
slice_list[3] --- [7, 8] device3上的切片数据
4. 按照实际情况,重新组装数据。
如下代码,先分别对切片1和切片2,切片3和切片4按列拼接,之后对前两步得到的数据按行拼接。
```
slice_line1 = np.concatenate((slice_list[0], slice_list[1]), axis=1) # result [1,2,3,4]
slice_line2 = np.concatenate((slice_list[2], slice_list[3]), axis=1) # result [5,6,7,8]
whole_data = np.concatenate((slice_line1, slice_line2), axis=0) # result [[1, 2, 3, 4], [5, 6, 7, 8]]
slice_moments_line1 = np.concatenate((slice_lis_moments[0], slice_lis_moments[1]), axis=1)
slice_moments_line2 = np.concatenate((slice_lis_moments[2], slice_lis_moments[3]), axis=1)
whole_moments_data = np.concatenate((slice_moments_line1, slice_moments_line2), axis=0)
```
5. 对模型参数赋值。
2. 调用`merge_sliced_parameter`接口进行参数合并。
```
param_data = Tensor(whole_data)
param_data_moments = Tensor(whole_moments_data)
merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
```
> 1. 如果存在多个模型并行的参数,则需要重复步骤1到步骤5循环逐个处理。
> 2. 如果步骤2执行`allgather`子图获取的数据,已经是最终的数据,则后面的步骤可省略。
> 即本身切分逻辑是仅在shape0上切分,每个卡加载不同切片数据。
> 如果存在多个模型并行的参数,则需要重复步骤1到步骤2循环逐个处理。
### 保存数据生成新的CheckPoint文件
......@@ -327,7 +272,7 @@ load_param_into_net(opt, param_dict)
脚本执行命令:
```
python ./integrate_checkpoint.py "待合并的CheckPoint文件路径&名称" "合并生成的CheckPoint文件路径&名称"
python ./integrate_checkpoint.py "待合并的CheckPoint文件名称" "合并生成的CheckPoint文件路径&名称" "策略文件路径&名称" "节点数"
```
integrate_checkpoint.py:
......@@ -336,15 +281,9 @@ load_param_into_net(opt, param_dict)
import numpy as np
import os
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication.management import init
from mindspore.train.serialization import save_checkpoint, load_checkpoint
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=True, device_id=devid)
init()
from mindspore.train.serialization import save_checkpoint, load_checkpoint, build_searched_strategy, merge_sliced_parameter
class Net(nn.Cell):
def __init__(self,weight_init):
......@@ -356,43 +295,36 @@ load_param_into_net(opt, param_dict)
x = self.fc(x, self.weight1)
return x
class AllGatherNet(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super().__init__()
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
def integrate_ckpt_file(old_ckpt_file, new_ckpt_file):
def integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size):
weight = np.ones([2, 8]).astype(np.float32)
net = Net(weight)
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
# load CheckPoint into net
param_dict = load_checkpoint(old_ckpt_file)
load_param_into_net(net, param_dict)
# load CheckPoint into net in rank id order
param_dicts = []
for i in range(rank_size):
file_name = os.path.join("./node"+str(i), old_ckpt_file)
param_dict = load_checkpoint(file_name)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
param_dicts.append(param_dict)
strategy = build_searched_strategy(strategy_file)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]:
# get layer wise model parallel parameter
layerwise_param = param_dict[paramname]
if isinstance(layerwise_param.data, Tensor):
param_data = layerwise_param.data
else:
param_data = Tensor(layerwise_param.data)
sliced_parameters = []
for i in range(rank_size):
parameter = param_dicts[i].get(paramname)
sliced_parameters.append(parameter)
# merge the parallel parameters of the model
allgather_net = get_allgather_cell()
param_data = allgather_net(param_data)
layerwise_param.set_parameter_data(param_data, True)
merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
param_dict[paramname] = merged_parameter
# convert param_dict to list type data
param_list = []
......@@ -415,18 +347,14 @@ load_param_into_net(opt, param_dict)
try:
old_ckpt_file = sys.argv[1]
new_ckpt_file = sys.argv[2]
integrate(old_ckpt_file, new_ckpt_file)
strategy_file = sys.argv[3]
rank_size = int(sys.argv[4])
integrate_ckpt_file(old_ckpt_file, new_ckpt_file, strategy_file, rank_size)
except:
print("Fail to integrate checkpoint file)
sys.exit(-1)
```
其中,
- `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。
- `device_id`:卡物理序号,即卡所在机器中的实际序号。
- `init`:完成分布式训练初始化操作。
执行结果:
脚本执行前,CheckPoint文件中参数值:
......@@ -526,6 +454,7 @@ load_param_into_net(opt, param_dict)
import os
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......@@ -573,6 +502,12 @@ load_param_into_net(opt, param_dict)
label = np.random.random((4, 4)).astype(np.float32)
train_mindspore_impl_fc(input, label, weight1)
```
其中,
- `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。
- `device_id`:卡物理序号,即卡所在机器中的实际序号。
- `init`:完成分布式训练初始化操作。
加载后的参数值:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册