-[Integrating the Saved Checkpoint Files](#integrating-the-saved-checkpoint-files)
-[Overall Process](#overall-process)
-[Preparations](#preparations)
-[Importing the Checkpoint Files to the Network](#importing-the-checkpoint-files-to-the-network)
-[Obtaining a List of All Parameters on the Network](#obtaining-a-list-of-all-parameters-on-the-network)
-[Importing the Checkpoint Files in rank id order](#importing-the-checkpoint-files-in-rank-id-order)
-[Obtaining the slice strategy of model](#obtaining-the-slice-strategy-of-model)
-[Integrate the Model Parallel Parameters](#integrate-the-model-parallel-parameters)
-[Saving the Data and Generating a New Checkpoint File](#saving-the-data-and-generating-a-new-checkpoint-file)
-[Loading the Integrated and Saved Checkpoint File](#loading-the-integrated-and-saved-checkpoint-file)
...
...
@@ -71,7 +71,7 @@ For example, in the training stage 1, the training environment with 64 devices i
### Overall Process
Import the checkpoint files to be integrated to the network and obtain the list of all parameters through the API provided by MindSpore. See steps 1 and 2 in the following figure.
Import the checkpoint files to be integrated to the network in rank id order and obtain the list of all parameters through the API provided by MindSpore, and then obtain the slice strategy of model. See steps 1 and 2 in the following figure.
Then, update the parameter list and integrate the model parallel parameters. See step 3 in the following figure.
...
...
@@ -81,120 +81,67 @@ Finally, save the updated parameter list to a file through the API provided by M
### Preparations
#### Importing the Checkpoint Files to the Network
#### Importing the Checkpoint Files in rank id order
Define the network, call the `load_checkpoint` and `load_param_into_net` APIs, and import the checkpoint files to the network.
Define the network, call the `load_checkpoint` and `load_param_into_net` APIs to import the checkpoint files to the network in rank id order, and then call `parameters_and_names` API to obtain all parameters in this network.
```
param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name
file_name = os.path.join("./node"+str(i), "CKP_1-4_32.ckpt") # checkpoint file name of current node
param_dict = load_checkpoint(file_name)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
param_dicts.append(param_dict)
```
In the preceding information:
-`rank_size`: number of nodes in previous distributed training.
-`load_checkpoint`: loads the checkpoint model parameter file and returns a parameter dictionary.
-`load_param_into_net`: loads model parameter data to the network.
-`CKP_1-4_32.ckpt`: name of the saved checkpoint model parameter file.
> If a new checkpoint file is directly saved in the training environment based on the current training data and the parameter values already exist on the network, skip this step and you do not need to import the checkpoint files.
#### Obtaining a List of All Parameters on the Network
Call the `parameters_and_names` API to obtain all parameter data on the network.
Call the `build_searched_strategy` API to obtain the slice strategy of model.
The following uses a model parameter as an example to describe a specific integration process.
-`strategy_train.ckpt`: name of model slice strategy, set by users calling `set_auto_parallel_context` API and customizing `strategy_ckpt_save_file` parameter before training network, and the file saved on each node are the same.
The parameter name is model\_parallel\_weight and the data is Tensor \[\[1, 2, 3, 4], \[5, 6, 7, 8]].
### Integrate the Model Parallel Parameters
The dividing strategy is to perform dividing in a 4-device scenario based on \[2, 2]. That is, the data is first divided into two slices in the row dimension, then the two slices are respectively divided into two smaller slices in the column dimension, and finally four slices are obtained. Data distribution after dividing is as follows:
The following uses a model parameter as an example to describe a specific integration process.
> To ensure that the parameter update speed remains unchanged, you need to integrate the parameters saved in the optimizer, for example, moments.model\_parallel\_weight.
2. Define, instantiate, and execute the `AllGather` Cell, and obtain data on all devices.
```
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
class AllGatherCell(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
The value of `param_data` is the integration of data on each device in dimension 0. The data value is \[\[1, 2], \[3, 4], \[5, 6], \[7, 8]], and the shape is \[4, 2]. The raw data value of `param_data` is \[\[1, 2, 3, 4], \[5, 6, 7, 8]], and the shape is \[2, 4]. The data needs to be redivided and integrated.
3. Divide the data obtained from `AllGather`.
```
slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster
slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster
```
The result of `param_data` is as follows:
slice_list[0] --- [1, 2] Slice data on device0
slice_list[1] --- [3, 4] Slice data on device1
slice_list[2] --- [5, 6] Slice data on device2
slice_list[3] --- [7, 8] Slice data on device3
4. Reassemble data based on the site requirements.
In the following code, slice 1 and slice 2, slice 3 and slice 4 are first spliced by column, and then the obtained data is spliced by row.
```
slice_line1 = np.concatenate((slice_list[0], slice_list[1]), axis=1) # result [1,2,3,4]
slice_line2 = np.concatenate((slice_list[2], slice_list[3]), axis=1) # result [5,6,7,8]
> 1. If there are multiple model parallel parameters, repeat steps 1 to 5 to process them one by one.
> 2. If the data obtained in step 2 is the final data, skip the following steps. That is, the dividing strategy is to perform dividing only on shape0 and each device loads different slice data.
> If there are multiple model parallel parameters, repeat steps 1 to 2 to process them one by one.
### Saving the Data and Generating a New Checkpoint File
...
...
@@ -324,106 +271,89 @@ User process:
```
python ./integrate_checkpoint.py "Path and name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration"
python ./integrate_checkpoint.py "Name of the checkpoint file to be integrated" "Path and name of the checkpoint file generated after integration" "Path and name of the strategy file" "Number of nodes"
```
integrate\_checkpoint.py:
```
import numpy as np
import os
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication.management import init
from mindspore.train.serialization import save_checkpoint, load_checkpoint