提交 c40c9018 编写于 作者: Y yuzhenhua666 提交者: yuzhenhua

fix buf of code example

上级 c62d98b4
...@@ -143,14 +143,14 @@ the code is as follows: ...@@ -143,14 +143,14 @@ the code is as follows:
```python ```python
from mindspore.train.serialization import export from mindspore.train.serialization import export
import numpy as np import numpy as np
net = ResNet50() resnet = ResNet50()
# return a parameter dict for model # return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt", net=resnet) param_dict = load_checkpoint("resnet50-2_32.ckpt")
# load the parameter into net # load the parameter into net
load_param_into_net(net) load_param_into_net(resnet, param_dict)
input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32) input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32)
export(net, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') export(resnet, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR')
``` ```
Before using the `export` interface, you need to import` mindspore.train.serialization`. Before using the `export` interface, you need to import` mindspore.train.serialization`.
The `input` parameter is used to specify the input shape and data type of the exported model. The `input` parameter is used to specify the input shape and data type of the exported model.
If you want to export the ONNX model, you only need to specify the `file_format` parameter in the` export` interface as ONNX: `file_format = 'ONNX'`. If you want to export the ONNX model, you only need to specify the `file_format` parameter in the` export` interface as ONNX: `file_format = 'ONNX'`.
\ No newline at end of file
...@@ -143,14 +143,14 @@ model.train(epoch, dataset) ...@@ -143,14 +143,14 @@ model.train(epoch, dataset)
```python ```python
from mindspore.train.serialization import export from mindspore.train.serialization import export
import numpy as np import numpy as np
net = ResNet50() resnet = ResNet50()
# return a parameter dict for model # return a parameter dict for model
param_dict = load_checkpoint("resnet50-2_32.ckpt", net=resnet) param_dict = load_checkpoint("resnet50-2_32.ckpt")
# load the parameter into net # load the parameter into net
load_param_into_net(net) load_param_into_net(resnet, param_dict)
input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32) input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32)
export(net, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') export(resnet, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR')
``` ```
使用`export`接口之前,需要先导入`mindspore.train.serialization` 使用`export`接口之前,需要先导入`mindspore.train.serialization`
`input`用来指定导出模型的输入shape以及数据类型。 `input`用来指定导出模型的输入shape以及数据类型。
如果要导出ONNX模型,只需要将`export`接口中的`file_format`参数指定为ONNX即可:`file_format = 'ONNX'` 如果要导出ONNX模型,只需要将`export`接口中的`file_format`参数指定为ONNX即可:`file_format = 'ONNX'`
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册