提交 71496334 编写于 作者: X Xiaoda 提交者: Gitee

update tutorials/source_zh_cn/advanced_use/mixed_precision.md.

上级 29e26f19
......@@ -45,9 +45,20 @@ MindSpore混合精度典型的计算流程如下图所示:
代码样例如下:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
# The interface of Auto_mixed precision
from mindspore.train import amp
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
......@@ -58,7 +69,7 @@ class LeNet5(nn.Cell):
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
def construct(self, x):
......@@ -103,6 +114,19 @@ MindSpore还支持手动混合精度。假定在网络中只有一个Dense Layer
代码样例如下:
```python
import numpy as np
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, context
from mindspore.ops import operations as P
from mindspore.nn import WithLossCell, TrainOneStepWithLossScaleCell
from mindspore.nn import Momentum
from mindspore.nn.loss import MSELoss
context.set_context(mode=context.GRAPH_MODE)
context.set_context(device_target="Ascend")
# Define network
class LeNet5(nn.Cell):
def __init__(self):
......@@ -111,9 +135,9 @@ class LeNet5(nn.Cell):
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, 10).to_float(mstype.float32)
self.fc3 = nn.Dense(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = P.Flatten()
def construct(self, x):
......@@ -128,6 +152,7 @@ class LeNet5(nn.Cell):
# Initialize network and set mixing precision
net = LeNet5()
net.to_float(mstype.float16)
net.fc3.to_float(mstype.float32)
# Define training data, label and sens
predict = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
......@@ -140,6 +165,7 @@ loss = MSELoss()
optimizer = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
train_network.set_train()
# Run training
output = train_network(predict, label, scaling_sens)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册