提交 3947150e 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!29 [Auto parallel] Fix some text of mixed_precision.

Merge pull request !29 from Xiaoda/master
...@@ -80,14 +80,13 @@ label = Tensor(np.zeros([1, 10]).astype(np.float32)) ...@@ -80,14 +80,13 @@ label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define Loss and Optimizer # Define Loss and Optimizer
net.set_train()
loss = MSELoss() loss = MSELoss()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = amp.build_train_network(net_with_loss, optimizer, level="O2") train_network = amp.build_train_network(net_with_loss, optimizer, level="O2")
# Run training # Run training
output = train_network(inputs, label, scaling_sens) output = train_network(predict, label, scaling_sens)
``` ```
...@@ -98,7 +97,7 @@ MindSpore also supports manual mixed precision. It is assumed that only one dens ...@@ -98,7 +97,7 @@ MindSpore also supports manual mixed precision. It is assumed that only one dens
The following is the procedure for implementing manual mixed precision: The following is the procedure for implementing manual mixed precision:
1. Define the network. This step is similar to step 2 in the automatic mixed precision. NoteThe fc3 operator in LeNet needs to be manually set to FP32. 1. Define the network. This step is similar to step 2 in the automatic mixed precision. NoteThe fc3 operator in LeNet needs to be manually set to FP32.
2. Configure the mixed precision. Use net.add_flags_recursive(fp16=True) to set all operators of the cell and its sub-cells to FP16. 2. Configure the mixed precision. Use net.to_float(mstype.float16) to set all operators of the cell and its sub-cells to FP16.
3. Use TrainOneStepWithLossScaleCell to encapsulate the network model and optimizer. 3. Use TrainOneStepWithLossScaleCell to encapsulate the network model and optimizer.
...@@ -113,7 +112,7 @@ class LeNet5(nn.Cell): ...@@ -113,7 +112,7 @@ class LeNet5(nn.Cell):
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10).add_flags_recursive(fp32=True) self.fc3 = nn.Dense(84, 10).to_float(mstype.float32)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.flatten = P.Flatten() self.flatten = P.Flatten()
...@@ -129,7 +128,7 @@ class LeNet5(nn.Cell): ...@@ -129,7 +128,7 @@ class LeNet5(nn.Cell):
# Initialize network and set mixing precision # Initialize network and set mixing precision
net = LeNet5() net = LeNet5()
net.add_flags_recursive(fp16=True) net.to_float(mstype.float16)
# Define training data, label and sens # Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
...@@ -144,5 +143,5 @@ net_with_loss = WithLossCell(net, loss) ...@@ -144,5 +143,5 @@ net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
# Run training # Run training
output = train_network(inputs, label, scaling_sens) output = train_network(predict, label, scaling_sens)
``` ```
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
## 概述 ## 概述
混合精度训练方法通过混合使用单精度和半精度数据格式来加速深度神经网络训练过程,同时保持了单精度训练所能达到的网络精度。混合精度训练能够加速计算过程,同时减少内存使用和存取,并在特定的硬件上可以训练更大的模型或batch size。 混合精度训练方法是通过混合使用单精度和半精度数据格式来加速深度神经网络训练的过程,同时保持了单精度训练所能达到的网络精度。混合精度训练能够加速计算过程,同时减少内存使用和存取,并使得在特定的硬件上可以训练更大的模型或batch size。
## 计算流程 ## 计算流程
...@@ -79,14 +79,13 @@ label = Tensor(np.zeros([1, 10]).astype(np.float32)) ...@@ -79,14 +79,13 @@ label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define Loss and Optimizer # Define Loss and Optimizer
net.set_train()
loss = MSELoss() loss = MSELoss()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = amp.build_train_network(net_with_loss, optimizer, level="O2") train_network = amp.build_train_network(net_with_loss, optimizer, level="O2")
# Run training # Run training
output = train_network(inputs, label, scaling_sens) output = train_network(predict, label, scaling_sens)
``` ```
...@@ -97,7 +96,7 @@ MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer ...@@ -97,7 +96,7 @@ MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer
以下是一个手动混合精度的实现步骤: 以下是一个手动混合精度的实现步骤:
1. 定义网络: 该步骤与自动混合精度中的步骤2类似;注意:在LeNet中的fc3算子,需要手动配置成FP32; 1. 定义网络: 该步骤与自动混合精度中的步骤2类似;注意:在LeNet中的fc3算子,需要手动配置成FP32;
2. 配置混合精度: LeNet通过net.add_flags_recursive(fp16=True),把该Cell及其子Cell中所有的算子都配置成FP16; 2. 配置混合精度: LeNet通过net.to_float(mstype.float16),把该Cell及其子Cell中所有的算子都配置成FP16;
3. 使用TrainOneStepWithLossScaleCell封装网络模型和优化器。 3. 使用TrainOneStepWithLossScaleCell封装网络模型和优化器。
...@@ -112,7 +111,7 @@ class LeNet5(nn.Cell): ...@@ -112,7 +111,7 @@ class LeNet5(nn.Cell):
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10).add_flags_recursive(fp32=True) self.fc3 = nn.Dense(84, 10).to_float(mstype.float32)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.flatten = P.Flatten() self.flatten = P.Flatten()
...@@ -128,7 +127,7 @@ class LeNet5(nn.Cell): ...@@ -128,7 +127,7 @@ class LeNet5(nn.Cell):
# Initialize network and set mixing precision # Initialize network and set mixing precision
net = LeNet5() net = LeNet5()
net.add_flags_recursive(fp16=True) net.to_float(mstype.float16)
# Define training data, label and sens # Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
...@@ -143,5 +142,5 @@ net_with_loss = WithLossCell(net, loss) ...@@ -143,5 +142,5 @@ net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
# Run training # Run training
output = train_network(inputs, label, scaling_sens) output = train_network(predict, label, scaling_sens)
``` ```
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册