From 35310e39513b9e8f63825725ba4a726903f3c967 Mon Sep 17 00:00:00 2001 From: liyanliu Date: Thu, 30 Jul 2020 21:32:56 +0800 Subject: [PATCH] update loading pre-trained models saved remotely using hub API --- .../source_en/use/multi_platform_inference.md | 47 +++++++++++++++++-- .../use/multi_platform_inference.md | 44 +++++++++++++++-- 2 files changed, 85 insertions(+), 6 deletions(-) diff --git a/tutorials/source_en/use/multi_platform_inference.md b/tutorials/source_en/use/multi_platform_inference.md index cc705c0a..edc6b07d 100644 --- a/tutorials/source_en/use/multi_platform_inference.md +++ b/tutorials/source_en/use/multi_platform_inference.md @@ -53,14 +53,54 @@ MindSpore supports the following inference scenarios based on the hardware platf ### Inference Using a Checkpoint File -1. Input a validation dataset to validate a model using the `model.eval` API. The processing method of the validation dataset is the same as that of the training dataset. +1. Input a validation dataset to validate a model using the `model.eval` API. + + 1.1 Local Storage + + When the pre-trained models are saved locally, the steps of performing inference on validation dataset are as follows: firstly creating a model, then loading model and parameters using `load_checkpoint` and `load_param_into_net` in `mindspore.train.serialization` module, and finally performing inference on validation dataset once created. The processing method of the validation dataset is the same as that of the training dataset. + ```python - res = model.eval(dataset) + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Testing ==============") + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) + dataset = create_dataset(os.path.join(args.data_path, "test"), + cfg.batch_size, + 1) + acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode) + print("============== {} ==============".format(acc)) ``` In the preceding information: `model.eval` is an API for model validation. For details about the API, see . > Inference sample code: . + 1.2 Remote Storage + + When the pre-trained models are saved remotely, the steps of performing inference on validation dataset are as follows: firstly creating a model, then loading model and parameters using `hub.load_weights`, and finally performing inference on validation dataset once created. The processing method of the validation dataset is the same as that of the training dataset. + + ```python + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Testing ==============") + hub.load_weights(network, network_name="lenet", **{"device_target": + "ascend", "dataset":"mnist", "version": "0.5.0"}) + dataset = create_dataset(os.path.join(args.data_path, "test"), + cfg.batch_size, + 1) + acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode) + print("============== {} ==============".format(acc)) + ``` + In the preceding information: + + `hub.load_weights` is an API for loading model parameters. PLease check the details in . + 2. Use the `model.predict` API to perform inference. ```python model.predict(input_data) @@ -106,4 +146,5 @@ Similar to the inference on a GPU, the following steps are required: ## On-Device Inference -MindSpore Predict is an inference engine for on-device inference. For details, see [On-Device Inference](https://www.mindspore.cn/tutorial/en/r0.6/advanced_use/on_device_inference.html). \ No newline at end of file +MindSpore Predict is an inference engine for on-device inference. For details, see [On-Device Inference](https://www.mindspore.cn/tutorial/en/r0.6/advanced_use/on_device_inference.html). + diff --git a/tutorials/source_zh_cn/use/multi_platform_inference.md b/tutorials/source_zh_cn/use/multi_platform_inference.md index 2f83db74..0645e351 100644 --- a/tutorials/source_zh_cn/use/multi_platform_inference.md +++ b/tutorials/source_zh_cn/use/multi_platform_inference.md @@ -53,14 +53,52 @@ CPU | ONNX格式 | 支持ONNX推理的runtime/SDK,如TensorRT。 ### 使用checkpoint格式文件推理 -1. 使用`model.eval`接口来进行模型验证,你只需传入验证数据集即可,验证数据集的处理方式与训练数据集相同。 +1. 使用`model.eval`接口来进行模型验证。 + + 1.1 模型已保存在本地 + + 首先构建模型,然后使用mindspore.train.serialization模块的`load_checkpoint`和`load_param_into_net`从本地加载模型与参数,传入验证数据集后即可进行模型推理,验证数据集的处理方式与训练数据集相同。 + ```python - res = model.eval(dataset) + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Testing ==============") + param_dict = load_checkpoint(args.ckpt_path) + load_param_into_net(network, param_dict) + dataset = create_dataset(os.path.join(args.data_path, "test"), + cfg.batch_size, + 1) + acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode) + print("============== {} ==============".format(acc)) ``` 其中, `model.eval`为模型验证接口,对应接口说明:。 > 推理样例代码:。 + 1.2 模型保存在华为云 + + 首先构建模型,然后使用`hub.load_weights`从云端加载模型参数,传入验证数据集后即可进行推理,验证数据集的处理方式与训练数据集相同。 + ```python + network = LeNet5(cfg.num_classes) + net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Testing ==============") + hub.load_weights(network, network_name="lenet", **{"device_target": + "ascend", "dataset":"mnist", "version": "0.5.0"}) + dataset = create_dataset(os.path.join(args.data_path, "test"), + cfg.batch_size, + 1) + acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode) + print("============== {} ==============".format(acc)) + ``` + 其中, + `hub.load_weights`为加载模型参数接口,对应接口说明:。 + 2. 使用`model.predict`接口来进行推理操作。 ```python model.predict(input_data) @@ -106,4 +144,4 @@ Ascend 310 AI处理器上搭载了ACL框架,他支持OM格式,而OM格式需 ## 端侧推理 -端侧推理需使用MindSpore Predict推理引擎,详细操作请参考[端侧推理教程](https://www.mindspore.cn/tutorial/zh-CN/r0.6/advanced_use/on_device_inference.html)。 \ No newline at end of file +端侧推理需使用MindSpore Predict推理引擎,详细操作请参考[端侧推理教程](https://www.mindspore.cn/tutorial/zh-CN/r0.6/advanced_use/on_device_inference.html)。 -- GitLab