提交 fada0b07 编写于 作者: X Xiaoda 提交者: Gitee

update tutorials/source_en/advanced_use/mixed_precision.md.

上级 e3de99d1
...@@ -46,9 +46,20 @@ The procedure is as follows: ...@@ -46,9 +46,20 @@ The procedure is as follows:
A code example is as follows: A code example is as follows:
```python ```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
# The interface of Auto_mixed precision # The interface of Auto_mixed precision
from mindspore.train import amp from mindspore.train import amp
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network # Define network
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
def __init__(self): def __init__(self):
...@@ -59,7 +70,7 @@ class LeNet5(nn.Cell): ...@@ -59,7 +70,7 @@ class LeNet5(nn.Cell):
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10) self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten() self.flatten = P.Flatten()
def construct(self, x): def construct(self, x):
...@@ -95,15 +106,28 @@ output = train_network(predict, label, scaling_sens) ...@@ -95,15 +106,28 @@ output = train_network(predict, label, scaling_sens)
MindSpore also supports manual mixed precision. It is assumed that only one dense layer in the network needs to be calculated by using FP32, and other layers are calculated by using FP16. The mixed precision is configured in the granularity of cell. The default format of a cell is FP32. MindSpore also supports manual mixed precision. It is assumed that only one dense layer in the network needs to be calculated by using FP32, and other layers are calculated by using FP16. The mixed precision is configured in the granularity of cell. The default format of a cell is FP32.
The following is the procedure for implementing manual mixed precision: The following is the procedure for implementing manual mixed precision:
1. Define the network. This step is similar to step 2 in the automatic mixed precision. NoteThe fc3 operator in LeNet needs to be manually set to FP32. 1. Define the network. This step is similar to step 2 in the automatic mixed precision.
2. Configure the mixed precision. Use net.to_float(mstype.float16) to set all operators of the cell and its sub-cells to FP16. 2. Configure the mixed precision. Use net.to_float(mstype.float16) to set all operators of the cell and its sub-cells to FP16. Then, configure the fc3 to FP32.
3. Use TrainOneStepWithLossScaleCell to encapsulate the network model and optimizer. 3. Use TrainOneStepWithLossScaleCell to encapsulate the network model and optimizer.
A code example is as follows: A code example is as follows:
```python ```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepWithLossScaleCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network # Define network
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
def __init__(self): def __init__(self):
...@@ -112,9 +136,9 @@ class LeNet5(nn.Cell): ...@@ -112,9 +136,9 @@ class LeNet5(nn.Cell):
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10).to_float(mstype.float32) self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten() self.flatten = P.Flatten()
def construct(self, x): def construct(self, x):
...@@ -129,6 +153,7 @@ class LeNet5(nn.Cell): ...@@ -129,6 +153,7 @@ class LeNet5(nn.Cell):
# Initialize network and set mixing precision # Initialize network and set mixing precision
net = LeNet5() net = LeNet5()
net.to_float(mstype.float16) net.to_float(mstype.float16)
net.fc3.to_float(mstype.float32)
# Define training data, label and sens # Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
...@@ -141,6 +166,7 @@ loss = MSELoss() ...@@ -141,6 +166,7 @@ loss = MSELoss()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
train_network.set_train()
# Run training # Run training
output = train_network(predict, label, scaling_sens) output = train_network(predict, label, scaling_sens)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册