提交 94e6c978 编写于 作者: W WeibiaoYu

add guide for checkpoint in hybrid parallel scence

上级 b3c5d4db
# 手动设置并行场景模型参数的保存和加载
<!-- TOC -->
- [手动设置并行场景模型参数的保存和加载](#手动设置并行场景模型参数的保存和加载)
- [概述](#概述)
- [背景](#背景)
- [使用场景](#使用场景)
- [对保存的CheckPoint文件做合并处理](#对保存的checkpoint文件做合并处理)
- [整体流程](#整体流程)
- [准备工作](#准备工作)
- [导入CheckPoint文件到网络](#导入checkpoint文件到网络)
- [获取网络中全量参数列表](#获取网络中全量参数列表)
- [对模型并行的参数做合并处理](#对模型并行的参数做合并处理)
- [保存数据生成新的CheckPoint文件](#保存数据生成新的checkpoint文件)
- [加载合并保存的CheckPoint文件](#加载合并保存的checkpoint文件)
- [整体流程](#整体流程-1)
- [步骤1:加载CheckPoint文件](#步骤1加载checkpoint文件)
- [步骤2:对模型并行参数做切分处理](#步骤2对模型并行参数做切分处理)
- [步骤3:将修改后的参数数据加载到网络中](#步骤3将修改后的参数数据加载到网络中)
- [示例](#示例)
- [示例场景说明](#示例场景说明)
- [示例代码](#示例代码)
<!-- /TOC -->
## 概述
### 背景
MindSpore模型并行场景下,每个实例进程只保存有本节点对应的参数数据。对于模型并行的Cell,其在每个节点上的参数数据,都是完整参数数据的一个切片。比如完整参数数据shape为[8, 8],每个节点上的参数数据为其中的一部分,如shape[2, 8]。
对于自动切分的模型并行场景(Auto Parallel),切分逻辑由MindSpore自动生成,MindSpore的CheckPoint模块可以支持自动合并保存和基于合并保存的加载能力。
对于用户手动设置的并行场景(HyBrid Parallel),切分逻辑由用户自己实现,MindSpore在每个节点只保存本节点上的数据,用户需要自己实现CheckPoint文件的合并保存与加载功能。本教程用于指导用户在手动切分场景下,实现CheckPoint的合并保存与加载能力。
### 使用场景
如果你遇到如下两个场景,需要参考本教程操作,完成CheckPoint的合并保存与加载:
场景1:多卡训练,单卡推理。
以在64卡上训练,并在单卡上推理为例,整体操作流程如下:
1. 执行训练,自动生成CheckPoint文件。
2. 用户对保存的CheckPoint文件做合并处理。
根据具体的切分逻辑,对于存在切分的具体模型参数做合并处理,生成新CheckPoint文件。
3. 在单卡环境加载新的CheckPoint文件,之后再根据需要调用export接口导出用于推理的模型。
若CheckPoint的保存环境和加载环境集群的卡数相同,比如在同一训练环境保存加载CheckPoint,或者单卡训练单卡推理,则可以不需要做合并保存和加载。
场景2:训练分为多阶段,每个阶段的集群大小不一样。
​ 以训练阶段一是64卡训练环境,阶段二是56卡训练环境为例,整体操作流程如下:
1. 执行阶段一训练,自动生成CheckPoint文件。
2. 用户对保存的CheckPoint文件做合并处理。
根据具体的切分逻辑,对于存在切分的具体模型参数做合并处理,生成新CheckPoint文件。
3. 在阶段二集群上加载合并保存的CheckPoint文件。
在加载过程中,用户需要根据新的训练环境配置,重新切分CheckPoint文件中的参数数据。
4. 执行阶段二训练。
## 对保存的CheckPoint文件做合并处理
### 整体流程
首先,执行准备工作,将待合并处理的CheckPoint文件导入网络,并通过MindSpore提供的API获取全量参数列表。对应下图中的Step1和Step2。
其次,更新参数列表,对涉及模型并行的参数做合并处理。对应下图中的Step3。
最后,将更新之后的参数列表,通过MindSpore提供的API保存到文件,生成新的CheckPoint文件。对应下图中的Step4。
![img](./images/checkpoint_integration_process.jpg)
### 准备工作
#### 导入CheckPoint文件到网络
定义网络,并调用`load_checkpoint``load_param_into_net`接口,将CheckPoint文件导入网络。
```
param_dict = load_checkpoint(./CKP_1-4_32.ckpt) # checkpoint file name
net = Net()
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
load_param_into_net(net, param_dict)
```
其中,
- `load_checkpoint()`:通过该接口加载CheckPoint模型参数文件,返回一个参数字典。
- `load_param_into_net()`:模型参数数据加载到网络中。
- `CKP_1-4_32.ckpt`:之前保存的CheckPoint模型参数文件名称。
> 如果直接在训练环境上,基于当前训练得到的数据直接保存新的CheckPoint文件,参数值已经存在在网络中,则可以省略该步骤,无需导入CheckPoint文件。
#### 获取网络中全量参数列表
调用`parameters_and_names`接口,获取网络里所有的参数数据。
```
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
```
### 对模型并行的参数做合并处理
下面以一个具体的模型参数为例,说明下参数合并处理的具体流程。
参数名称为"model_parallel_weight",数据为Tensor [[1, 2, 3, 4], [5, 6, 7, 8]]。
切分逻辑为4卡场景,按[2, 2]切分,即先在行维度切分为2个切片,之后再对得到的2个切片,分别在列维度分再切分为2个更小的切片,最后得到4个切片。
切分后数据分布情况如下:
| Device0 | Device1 | Device2 | Device3 |
| ------------- | ------------ | ------------- | ------------- |
| Value [1, 2] | Value [3, 4] | Value [5, 6] | Value [7, 8] |
1. 针对涉及模型并行的参数,获取本节点上的数据值。
```
param_data = param_dict[“model_parallel_weight”]
param_data_moments = param_dict[“moments.model_parallel_weight”]
```
> 如果要保证参数更新速度不变,需要对优化器中保存的参数,如“moments.model_parallel_weight”,同样做合并处理。
2. 定义AllGather类型子图,并实例化和执行,获取所有卡上的数据。
```
from mindspore.nn.cell import Cell
from mindspore.ops.operations.comm_ops import AllGather
class AllGatherCell(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super(AllGatherCell, self).__init__(auto_prefix=False)
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
allgather_net = AllGatherCell()
param_data = allgather_net(param_data)
param_data_moments = allgather_net(param_data_moments)
```
​得到的数据param_data为每卡上的数据在维度0上的合并,数据值为 [[1, 2], [3, 4], [5, 6], [7, 8]],shape为[4, 2]。
​param_data原始数据值为[[1, 2, 3, 4], [5, 6, 7, 8]],shape为[2, 4],需要对数据重新切分合并。
3. 切分通过AllGather得到的数据。
```
slice_list = np.split(param_data.asnumpy(), 4, axis=0) # 4:group_size, number of nodes in cluster
slice_lis_moments = np.split(param_data_moments.asnumpy(), 4, axis=0) # 4: group_size, number of nodes in cluster
```
得到结果param_data为:
slice_list[0] --- [1, 2] device0上的切片数据
slice_list[1] --- [3, 4] device1上的切片数据
slice_list[2] --- [5, 6] device2上的切片数据
slice_list[3] --- [7, 8] device3上的切片数据
4. 按照实际情况,重新组装数据。
如下代码,先分别对切片1和切片2,切片3和切片4按列拼接,之后对前两步得到的数据按行拼接。
```
slice_line1 = np.concatenate((slice_list[0], slice_list[1]), aix=1) # result [1,2,3,4]
slice_line2 = np.concatenate((slice_list[2], slice_list[3]), aix=1) # result [5,6,7,8]
whole_data = np.concatenate((slice_line1, slice_line2), aix=0) # result [[1, 2, 3, 4], [5, 6, 7, 8]]
slice_moments_line1 = np.concatenate((slice_lis_moments[0], slice_lis_moments[1]), aix=1)
slice_moments_line2 = np.concatenate((slice_lis_moments[2], slice_lis_moments[3]), aix=1)
whole_moments_data = np.concatenate((slice_moments_line1, slice_moments_line2), aix=0)
```
5. 对模型参数赋值。
```
param_data = Tensor(whole_data)
param_data_moments = Tensor(whole_moments_data)
```
> 1. 如果存在多个模型并行的参数,则需要重复步骤1到步骤5循环逐个处理。
> 2. 如果步骤2执行allgather子图获取的数据,已经是最终的数据,则后面的步骤可省略。
> 即本身切分逻辑是仅在shape0上切分,每个卡加载不同切片数据。
### 保存数据生成新的CheckPoint文件
1. 将param_dict转换为list类型数据。
```
param_list = []
for (key, value) in param_dict.items():
each_param = {}
each_param["name"] = key
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
each_param["data"] = param_data
param_list.append(each_param)
```
2. 调用`save_checkpoint`接口,将参数数据写入文件,生成新的CheckPoint文件。
```
save_checkpoint(param_list, “./CKP-Integrated_1-4_32.ckpt”)
```
其中,
- `save_checkpoint`: 通过该接口将网络模型参数信息存入文件。
- `CKP-Integrated_1-4_32.ckpt`: 新生成的CheckPoint模型参数文件名称。
## 加载合并保存的CheckPoint文件
### 整体流程
如果需要将合并保存的CheckPoint加载到多卡训练或推理中,在模型参数真正加载到网络前,需要对于涉及并行的参数数据按照新的逻辑切分。
如下步骤在训练前脚本里实现,步骤1和3同单机CheckPoint加载的逻辑,步骤2为新增,用于涉及并行的模型参数的切分。
对于加载到单卡训练/推理的场景,不涉及数据切分,则步骤2可省略。
### 步骤1:加载CheckPoint文件
调用`load_checkpoint`接口,从CheckPoint文件中加载模型参数数据。
```
param_dict = load_checkpoint("./CKP-Integrated_1-4_32.ckpt")
```
- `load_checkpoint()`:通过该接口加载CheckPoint模型参数文件,返回一个参数字典。
- `CKP-Integrated_1-4_32.ckpt`:需要加载的CheckPoint模型参数文件名称。
### 步骤2:对模型并行参数做切分处理
下面以一个具体的模型参数为例,参数名称为“model_parallel_weight", 数据值为Tensor [[1, 2, 3, 4], [5, 6, 7, 8]],切分逻辑为2卡场景,按[2, 1]切分。
​切分后数据分布情况如下:
| Device0 | Device1 |
| ------------------- | -------------------- |
| Value [1, 2, 3, 4] | Value [5, 6, 7, 8] |
1. 对模型参数数据做切分。
如下代码示例,在维度0上,将数据切分为两个切片。
```
new_param = parameter_dict[“model_parallel_weight”]
slice_list = np.split(new_param.data.asnumpy(), 2, axis=0)
new_param_moments = parameter_dict[“moments.model_parallel_weight”]
slice_moments_list = np.split(new_param_moments.data.asnumpy(), 2, axis=0)
```
切分后的数据情况:
slice_list[0] --- [1, 2, 3, 4] 对应device0
slice_list[1] --- [5, 6, 7, 8] 对应device1
与slice_list类似,slice_moments_list 也被切分为两个shape为[1, 4]的Tensor。
2. 在每个节点分别加载对应的数据切片。
获取本节点的rank_id,根据rank_id加载数据。
```
rank = get_rank()
tensor_slice = Tensor(slice_list[rank])
tensor_slice_moments = Tensor(slice_moments_list[rank])
```
- `get_rank`:获取当前设备在集群中的ID。
3. 修改模型参数数据值。
```
new_param.set_parameter_data(tensor_slice)
new_param_moments.set_parameter_data(tensor_slice_moments)
```
- `set_parameter_data`:设置模型参数的值,接口参数类型为Tensor 或number。
### 步骤3:将修改后的参数数据加载到网络中
调用`load_param_into_net`接口,将模型参数数据加载到网络中。
```
net = Net()
opt = Momentum(learning_rate=0.01, momentum=0.9, params=parallel_net.get_parameters())
load_param_into_net(net, param_dict)
load_param_into_net(opt, param_dict)
```
## 示例
### 示例场景说明
整体场景:训练分为两个阶段,两阶段的集群规模不一致, 模拟FC层MatMul算子并行。
用户流程:
1. 执行阶段1训练:阶段1为4卡训练环境,每卡上MatMul算子weight的shape为[2, 8],训练过程中自动导出CheckPoint。
2. 执行脚本对CheckPoint文件做合并处理,根据具体的切分逻辑,对于存在切分的具体模型参数做合并处理,生成合并的CheckPoint文件。
3. 执行阶段2训练:阶段2为2卡训练环境,每卡上MatMul算子weight的shape为[4, 8],从合并后的CheckPoint文件加载初始化模型参数数据,之后执行训练。
> 具体分布式环境配置和训练部分代码,此处不做详细说明,可以参考[分布式并行训练](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/distributed_training.html)
章节。
>
> 本文档附上对CheckPoint文件做合并处理以及分布式训练前加载CheckPoint文件的示例代码,仅作为参考,实际请参考具体情况实现。
### 示例代码
1. 执行脚本对CheckPoint文件做合并处理。
脚本执行命令:
```
python ./integrate_checkpoint.py "待合并的CheckPoint文件路径&名称" "合并生成的CheckPoint文件路径&名称"
```
integrate_checkpoint.py:
```
import numpy as np
import os
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import AllGather
from mindspore.communication.management import init
from mindspore.train.serialization import save_checkpoint, load_checkpoint
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', save_graphs=True, device_id=devid)
init()
class Net(nn.Cell):
def __init__(self,weight_init):
super(Net, self).__init__()
self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True)
self.fc = P.MatMul(transpose_b=True)
def construct(self, x):
x = self.fc(x, self.weight1)
return x
class AllGatherNet(Cell):
"""
Allgather cell, used in model parallel scenario.
To allgather the selected parameter slice from each device.
"""
def __init__(self):
super().__init__()
self.allgather = AllGather()
def construct(self, x):
x = self.allgather(x)
return x
def integrate_ckpt_file(old_ckpt_file, new_ckpt_file):
weight = np.ones([2, 8]).astype(np.float32)
net = Net(weight)
opt = Momentum(learning_rate=0.01, momentum=0.9, params=net.get_parameters())
net = TrainOneStepCell(net, opt)
# load CheckPoint into net
param_dict = load_checkpoint(old_ckpt_file)
load_param_into_net(net, param_dict)
param_dict = {}
for _, param in net.parameters_and_names():
param_dict[param.name] = param
for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]:
# get layer wise model parallel parameter
layerwise_param = param_dict[paramname]
if isinstance(layerwise_param.data, Tensor):
param_data = layerwise_param.data
else:
param_data = Tensor(layerwise_param.data)
# merge the parallel parameters of the model
allgather_net = get_allgather_cell()
param_data = allgather_net(param_data)
layerwise_param.set_parameter_data(param_data)
# convert param_dict to list type data
param_list = []
for (key, value) in param_dict.items():
each_param = {}
each_param["name"] = key
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
each_param["data"] = param_data
param_list.append(each_param)
# call the API to generate a new CheckPoint file
save_checkpoint(param_list, new_ckpt_file)
return
if __name__ == "__main__":
try:
old_ckpt_file = sys.argv[1]
new_ckpt_file = sys.argv[2]
integrate(old_ckpt_file, new_ckpt_file)
except:
print("Fail to integrate checkpoint file)
sys.exit(-1)
```
其中,
- `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。
- `device_id`:卡物理序号,即卡所在机器中的实际序号。
- `init()`:完成分布式训练初始化操作。
执行结果:
脚本执行前,CheckPoint文件中参数值:
```
device0:
name is model_parallel_weight
value is
[[0.87537426 1.0448935 0.86736983 0.8836905 0.77354026 0.69588304 0.9183654 0.7792076]
[0.87224025 0.8726848 0.771446 0.81967723 0.88974726 0.7988162 0.72919345 0.7677011]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_weight
value is
[[0.2567724 -0.07485991 0.282002 0.2456022 0.454939 0.619168 0.18964815 0.45714882]
[0.25946522 0.24344791 0.45677605 0.3611395 0.23378398 0.41439137 0.5312468 0.4696194]]
device1:
name is model_parallel_weight
value is
[[0.9210751 0.9050457 0.9827775 0.920396 0.9240526 0.9750359 1.0275179 1.0819869]
[0.73605865 0.84631145 0.9746683 0.9386582 0.82902765 0.83565056 0.9702136 1.0514659]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_weight
value is
[[0.2417504 0.28193963 0.06713893 0.21510397 0.23380603 0.11424308 0.0218009 -0.11969765]
[0.45955992 0.22664294 0.01990281 0.0731914 0.27125207 0.27298513 -0.01716102 -0.15327111]]
device2:
name is model_parallel_weight
value is
[[1.0108461 0.8689414 0.91719437 0.8805056 0.7994629 0.8999671 0.7585804 1.0287056 ]
[0.90653455 0.60146594 0.7206475 0.8306303 0.8364681 0.89625114 0.7354735 0.8447268]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_weight
value is
[[0.03440702 0.41419312 0.24817684 0.30765256 0.48516113 0.24904746 0.57791173 0.00955463]
[0.13458519 0.6690533 0.49259356 0.28319967 0.25951773 0.16777472 0.45696738 0.24933104]]
device3:
name is model_parallel_weight
value is
[[0.7147005 0.9168278 0.80178416 0.6258351 0.8413766 0.5909515 0.696347 0.71359116]
[0.20506378 0.03691584 0.2454556 0.12978578 0.19065076 0.23904312 0.27509746 0.34614682]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_parallel_weight
value is
[[0.14152306 0.5040985 0.24455397 0.10907605 0.11319532 0.19538902 0.01208619 0.40430856]
[-0.7773164 -0.47611716 -0.6041424 -0.6144473 -0.2651842 -0.31909415 -0.4510405 -0.12860501]]
```
脚本执行后,CheckPoint文件中参数值:
```
name is model_parallel_weight
value is
[[1.1138763 1.0962057 1.3516843 1.0812817 1.1579804 1.1078343 1.0906502 1.3207073]
[0.916671 1.0781671 1.0368758 0.9680898 1.1735439 1.0628364 0.9960786 1.0135143]
[0.8828271 0.7963984 0.90675324 0.9830291 0.89010954 0.897052 0.7890109 0.89784735]
[1.0011744 1.0840297 1.0201758 1.0882459 0.94232416 1.0775206 1.0195118 1.0528734]
[1.0053468 0.98402303 0.99762845 0.97587246 1.0259694 1.0055295 0.99420834 0.9496847]
[1.0851002 1.0295962 1.0999886 1.0958165 0.9765328 1.146529 1.0970603 1.1388365]
[0.7147005 0.9168278 0.80178416 0.6258351 0.8413766 0.5909515 0.696347 0.71359116]
[0.20506378 0.03691584 0.2454556 0.12978578 0.19065076 0.23904312 0.27509746 0.34614682]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_parallel_weight
value is
[[0.2567724 -0.07485991 0.282002 0.2456022 0.454939 0.619168 0.18964815 0.45714882]
[0.25946522 0.24344791 0.45677605 0.3611395 0.23378398 0.41439137 0.5312468 0.4696194 ]
[0.2417504 0.28193963 0.06713893 0.21510397 0.23380603 0.11424308 0.0218009 -0.11969765]
[0.45955992 0.22664294 0.01990281 0.0731914 0.27125207 0.27298513 -0.01716102 -0.15327111]
[0.03440702 0.41419312 0.24817684 0.30765256 0.48516113 0.24904746 0.57791173 0.00955463]
[0.13458519 0.6690533 0.49259356 0.28319967 0.25951773 0.16777472 0.45696738 0.24933104]
[0.14152306 0.5040985 0.24455397 0.10907605 0.11319532 0.19538902 0.01208619 0.40430856]
[-0.7773164 -0.47611716 -0.6041424 -0.6144473 -0.2651842 -0.31909415 -0.4510405
-0.12860501]]
```
2. 执行阶段2训练,训练前加载CheckPoint文件。其中训练代码部分,需要根据实际情况补充。
```
import numpy as np
import os
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor, Parameter
from mindspore.ops import operations as P
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,device_target='Ascend',save_graphs=True, device_id=devid)
init()
class Net(nn.Cell):
def __init__(self,weight_init):
super(Net, self).__init__()
self.weight = Parameter(Tensor(weight_init), "model_parallel_weight", layerwise_parallel=True)
self.fc = P.MatMul(transpose_b=True)
def construct(self, x):
x = self.fc(x, self.weight1)
return x
def train_mindspore_impl_fc(input, label, ckpt_file):
param_dict = load_checkpoint(ckpt_file)
for paramname in ["model_parallel_weight", "moments.model_parallel_weight"]:
# get layer wise model parallel parameter
new_param = parameter_dict[paramname]
# split the model parameter data
slice_list = np.split(new_param.data.asnumpy(), 2, axis=0)
# Load the corresponding data slice
rank = get_rank()
tensor_slice = Tensor(slice_list[rank])
# modify model parameter data values
new_param.set_parameter_data(tensor_slice)
# load the modified parameter data into the network
weight = np.ones([4, 8]).astype(np.float32)
net = Net(weight)
load_param_into_net(net, param_dict)
opt = Momentum(learning_rate=0.01, momentum=0.9, params=parallel_net.get_parameters())
load_param_into_net(opt, param_dict)
# train code
...
if __name__ == "__main__":
input = np.random.random((4, 8)).astype(np.float32)
print("mean = ", np.mean(input,axis=1, keepdims=True))
label = np.random.random((4, 4)).astype(np.float32)
train_mindspore_impl_fc(input, label, weight1)
```
加载后的参数值:
```
device0:
name is model_parallel_weight
value is
[[0.87537426 1.0448935 0.86736983 0.8836905 0.77354026 0.69588304 0.9183654 0.7792076]
[0.87224025 0.8726848 0.771446 0.81967723 0.88974726 0.7988162 0.72919345 0.7677011]
[0.8828271 0.7963984 0.90675324 0.9830291 0.89010954 0.897052 0.7890109 0.89784735]
[1.0011744 1.0840297 1.0201758 1.0882459 0.94232416 1.0775206 1.0195118 1.0528734]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_weight
value is
[[0.2567724 -0.07485991 0.282002 0.2456022 0.454939 0.619168 0.18964815 0.45714882]
[0.25946522 0.24344791 0.45677605 0.3611395 0.23378398 0.41439137 0.5312468 0.4696194]
[0.2417504 0.28193963 0.06713893 0.21510397 0.23380603 0.11424308 0.0218009 -0.11969765]
[0.45955992 0.22664294 0.01990281 0.0731914 0.27125207 0.27298513 -0.01716102 -0.15327111]]
device1:
name is model_parallel_weight
value is
[[1.0053468 0.98402303 0.99762845 0.97587246 1.0259694 1.0055295 0.99420834 0.9496847]
[1.0851002 1.0295962 1.0999886 1.0958165 0.9765328 1.146529 1.0970603 1.1388365]
[0.7147005 0.9168278 0.80178416 0.6258351 0.8413766 0.5909515 0.696347 0.71359116]
[0.20506378 0.03691584 0.2454556 0.12978578 0.19065076 0.23904312 0.27509746 0.34614682]]
name is learning_rate
value is [0.01]
name is momentum
value is [0.9]
name is moments.model_weight
value is
[[0.03440702 0.41419312 0.24817684 0.30765256 0.48516113 0.24904746 0.57791173 0.00955463]
[0.13458519 0.6690533 0.49259356 0.28319967 0.25951773 0.16777472 0.45696738 0.24933104]
[0.14152306 0.5040985 0.24455397 0.10907605 0.11319532 0.19538902 0.01208619 0.40430856]
[-0.7773164 -0.47611716 -0.6041424 -0.6144473 -0.2651842 -0.31909415 -0.4510405 -0.12860501]]
```
......@@ -30,6 +30,7 @@ MindSpore教程
advanced_use/visualization_tutorials
advanced_use/mixed_precision
advanced_use/distributed_training
advanced_use/checkpoint_for_hybrid_parallel
advanced_use/computer_vision_application
advanced_use/nlp_application
advanced_use/customized_debugging_information
......@@ -38,4 +39,4 @@ MindSpore教程
advanced_use/model_security
advanced_use/network_migration
advanced_use/mindspore_cpu_win_install
advanced_use/community
\ No newline at end of file
advanced_use/community
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册