diff --git a/tutorials/source_en/use/saving_and_loading_model_parameters.md b/tutorials/source_en/use/saving_and_loading_model_parameters.md index ba46b8a89a72720b6771bb073133b8abd525039e..88d80140d166b2ab3309a71335de78899101b7b5 100644 --- a/tutorials/source_en/use/saving_and_loading_model_parameters.md +++ b/tutorials/source_en/use/saving_and_loading_model_parameters.md @@ -143,14 +143,14 @@ the code is as follows: ```python from mindspore.train.serialization import export import numpy as np -net = ResNet50() +resnet = ResNet50() # return a parameter dict for model -param_dict = load_checkpoint("resnet50-2_32.ckpt", net=resnet) +param_dict = load_checkpoint("resnet50-2_32.ckpt") # load the parameter into net -load_param_into_net(net) +load_param_into_net(resnet, param_dict) input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32) -export(net, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') +export(resnet, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') ``` Before using the `export` interface, you need to import` mindspore.train.serialization`. The `input` parameter is used to specify the input shape and data type of the exported model. -If you want to export the ONNX model, you only need to specify the `file_format` parameter in the` export` interface as ONNX: `file_format = 'ONNX'`. \ No newline at end of file +If you want to export the ONNX model, you only need to specify the `file_format` parameter in the` export` interface as ONNX: `file_format = 'ONNX'`. diff --git a/tutorials/source_zh_cn/use/saving_and_loading_model_parameters.md b/tutorials/source_zh_cn/use/saving_and_loading_model_parameters.md index 50f901b30a83ae3f9cecbbb97d502b775d7ea757..0bdf8fa7f711904001848ef415fdb693ef694d58 100644 --- a/tutorials/source_zh_cn/use/saving_and_loading_model_parameters.md +++ b/tutorials/source_zh_cn/use/saving_and_loading_model_parameters.md @@ -143,14 +143,14 @@ model.train(epoch, dataset) ```python from mindspore.train.serialization import export import numpy as np -net = ResNet50() +resnet = ResNet50() # return a parameter dict for model -param_dict = load_checkpoint("resnet50-2_32.ckpt", net=resnet) +param_dict = load_checkpoint("resnet50-2_32.ckpt") # load the parameter into net -load_param_into_net(net) +load_param_into_net(resnet, param_dict) input = np.random.uniform(0.0, 1.0, size = [32, 3, 224, 224]).astype(np.float32) -export(net, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') +export(resnet, input, file_name = 'resnet50-2_32.pb', file_format = 'GEIR') ``` 使用`export`接口之前,需要先导入`mindspore.train.serialization`。 `input`用来指定导出模型的输入shape以及数据类型。 -如果要导出ONNX模型,只需要将`export`接口中的`file_format`参数指定为ONNX即可:`file_format = 'ONNX'`。 \ No newline at end of file +如果要导出ONNX模型,只需要将`export`接口中的`file_format`参数指定为ONNX即可:`file_format = 'ONNX'`。