提交 afeea0fa 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!83 [Mixed-precision] Fix the runing errors of mix_precision

Merge pull request !83 from Xiaoda/master
......@@ -36,7 +36,6 @@ This document describes the computation process by using examples of automatic a
## Automatic Mixed Precision
To use the automatic mixed precision, you need to invoke the corresponding API, which takes the network to be trained and the optimizer as the input. This API converts the operators of the entire network into FP16 operators (except the BatchNorm and Loss operators).
In addition, after the mixed precision is employed, the loss scale must be used to avoid data overflow.
The procedure is as follows:
1. Introduce the MindSpore mixed precision API.
......@@ -49,57 +48,44 @@ A code example is as follows:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
# The interface of Auto_mixed precision
from mindspore.train import amp
from mindspore import amp
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
super(Net, self).__init__()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = P.ReLU()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
x = self.dense(x)
x = self.relu(x)
return x
# Initialize network
net = LeNet5()
net = Net(512, 128)
# Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
# Define Loss and Optimizer
loss = MSELoss()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = amp.build_train_network(net_with_loss, optimizer, level="O2")
train_network = amp.build_train_network(net, optimizer, loss, level="O2", loss_scale_manager=None)
# Run training
output = train_network(predict, label, scaling_sens)
output = train_network(predict, label)
```
......@@ -110,66 +96,53 @@ MindSpore also supports manual mixed precision. It is assumed that only one dens
The following is the procedure for implementing manual mixed precision:
1. Define the network. This step is similar to step 2 in the automatic mixed precision.
2. Configure the mixed precision. Use net.to_float(mstype.float16) to set all operators of the cell and its sub-cells to FP16. Then, configure the fc3 to FP32.
2. Configure the mixed precision. Use net.to_float(mstype.float16) to set all operators of the cell and its sub-cells to FP16. Then, configure the dense to FP32.
3. Use TrainOneStepWithLossScaleCell to encapsulate the network model and optimizer.
3. Use TrainOneStepCell to encapsulate the network model and optimizer.
A code example is as follows:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepWithLossScaleCell
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
super(Net, self).__init__()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = P.ReLU()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
x = self.dense(x)
x = self.relu(x)
return x
# Initialize network and set mixing precision
net = LeNet5()
net = Net(512, 128)
net.to_float(mstype.float16)
net.fc3.to_float(mstype.float32)
net.dense.to_float(mstype.float32)
# Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
# Define Loss and Optimizer
net.set_train()
loss = MSELoss()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
# Run training
output = train_network(predict, label, scaling_sens)
output = train_network(predict, label)
```
......@@ -35,70 +35,56 @@ MindSpore混合精度典型的计算流程如下图所示:
## 自动混合精度
使用自动混合精度,需要调用相应的接口,将待训练网络和优化器作为输入传进去;该接口会将整张网络的算子转换成FP16算子(除BatchNorm算子和Loss涉及到的算子外)。
另外要注意:使用混合精度后,一般要用上Loss Scale,避免数值计算溢出。
具体的实现步骤为:
1. 引入MindSpore的混合精度的接口amp;
2. 定义网络:该步骤和普通的网络定义没有区别(无需手动配置某个算子的精度);
3. 使用amp.build_train_network()接口封装网络模型和优化器,在该步骤中MindSpore会将有需要的算子自动进行类型转换。
3. 使用amp.build_train_network()接口封装网络模型、优化器和损失函数,在该步骤中MindSpore会将有需要的算子自动进行类型转换。
代码样例如下:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
# The interface of Auto_mixed precision
from mindspore.train import amp
from mindspore import amp
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
super(Net, self).__init__()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = P.ReLU()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
x = self.dense(x)
x = self.relu(x)
return x
# Initialize network
net = LeNet5()
net = Net(512, 128)
# Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
# Define Loss and Optimizer
loss = MSELoss()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = amp.build_train_network(net_with_loss, optimizer, level="O2")
train_network = amp.build_train_network(net, optimizer, loss, level="O2", loss_scale_manager=None)
# Run training
output = train_network(predict, label, scaling_sens)
output = train_network(predict, label)
```
......@@ -109,66 +95,53 @@ MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer
以下是一个手动混合精度的实现步骤:
1. 定义网络: 该步骤与自动混合精度中的步骤2类似;
2. 配置混合精度: LeNet通过net.to_float(mstype.float16),把该Cell及其子Cell中所有的算子都配置成FP16;然后,将LeNet中的fc3算子手动配置成FP32;
2. 配置混合精度: 通过net.to_float(mstype.float16),把该Cell及其子Cell中所有的算子都配置成FP16;然后,将模型中的dense算子手动配置成FP32;
3. 使用TrainOneStepWithLossScaleCell封装网络模型和优化器。
3. 使用TrainOneStepCell封装网络模型和优化器。
代码样例如下:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepWithLossScaleCell
from mindspore.nn import WithLossCell, TrainOneStepCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
class Net(nn.Cell):
def __init__(self, input_channel, out_channel):
super(Net, self).__init__()
self.dense = nn.Dense(input_channel, out_channel)
self.relu = P.ReLU()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
x = self.dense(x)
x = self.relu(x)
return x
# Initialize network and set mixing precision
net = LeNet5()
net = Net(512, 128)
net.to_float(mstype.float16)
net.fc3.to_float(mstype.float32)
net.dense.to_float(mstype.float32)
# Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([1, 10]).astype(np.float32))
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
# Define training data, label
predict = Tensor(np.ones([64, 512]).astype(np.float32) * 0.01)
label = Tensor(np.zeros([64, 128]).astype(np.float32))
# Define Loss and Optimizer
net.set_train()
loss = MSELoss()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
train_network = TrainOneStepCell(net_with_loss, optimizer)
train_network.set_train()
# Run training
output = train_network(predict, label, scaling_sens)
output = train_network(predict, label)
```
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册