提交 edcbe6c1 编写于 作者: L liyanliu

update loading pre-trained models saved remotely using hub API for inference

上级 24899261
......@@ -53,14 +53,54 @@ MindSpore supports the following inference scenarios based on the hardware platf
### Inference Using a Checkpoint File
1. Input a validation dataset to validate a model using the `model.eval` API. The processing method of the validation dataset is the same as that of the training dataset.
1. Input a validation dataset to validate a model using the `model.eval` API.
1.1 Local Storage
When the pre-trained models are saved locally, the steps of performing inference on validation dataset are as follows: firstly creating a model, then loading model and parameters using `load_checkpoint` and `load_param_into_net` in `mindspore.train.serialization` module, and finally performing inference on validation dataset once created. The processing method of the validation dataset is the same as that of the training dataset.
```python
res = model.eval(dataset)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
dataset = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))
```
In the preceding information:
`model.eval` is an API for model validation. For details about the API, see <https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.html#mindspore.Model.eval>.
> Inference sample code: <https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/eval.py>.
1.2 Remote Storage
When the pre-trained models are saved remotely, the steps of performing inference on validation dataset are as follows: firstly creating a model, then loading model and parameters using `hub.load_weights`, and finally performing inference on validation dataset once created. The processing method of the validation dataset is the same as that of the training dataset.
```python
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
hub.load_weights(network, network_name="lenet", **{"device_target":
"ascend", "dataset":"mnist", "version": "0.5.0"})
dataset = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))
```
In the preceding information:
`hub.load_weights` is an API for loading model parameters. PLease check the details in <https://www.mindspore.cn/api/en/master/api/python/mindspore/mindspore.hub.html#mindspore.hub.load_weights>.
2. Use the `model.predict` API to perform inference.
```python
model.predict(input_data)
......
......@@ -53,14 +53,51 @@ CPU | ONNX格式 | 支持ONNX推理的runtime/SDK,如TensorRT。
### 使用checkpoint格式文件推理
1. 使用`model.eval`接口来进行模型验证,你只需传入验证数据集即可,验证数据集的处理方式与训练数据集相同。
1. 使用`model.eval`接口来进行模型验证。
1.1 模型已保存在本地
首先构建模型,然后使用`mindspore.train.serialization`模块的`load_checkpoint``load_param_into_net`从本地加载模型与参数,传入验证数据集后即可进行模型推理,验证数据集的处理方式与训练数据集相同。
```python
res = model.eval(dataset)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
dataset = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))
```
其中,
`model.eval`为模型验证接口,对应接口说明:<https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.html#mindspore.Model.eval>
> 推理样例代码:<https://gitee.com/mindspore/mindspore/blob/master/model_zoo/official/cv/lenet/eval.py>。
1.2 模型保存在华为云
首先构建模型,然后使用`hub.load_weights`从云端加载模型参数,传入验证数据集后即可进行推理,验证数据集的处理方式与训练数据集相同。
```python
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
hub.load_weights(network, network_name="lenet", **{"device_target":
"ascend", "dataset":"mnist", "version": "0.5.0"})
dataset = create_dataset(os.path.join(args.data_path, "test"),
cfg.batch_size,
1)
acc = model.eval(dataset, dataset_sink_mode=args.dataset_sink_mode)
print("============== {} ==============".format(acc))
```
其中,
`hub.load_weights`为加载模型参数接口,对应接口说明:<https://www.mindspore.cn/api/zh-CN/master/api/python/mindspore/mindspore.hub.html#mindspore.hub.load_weights>
2. 使用`model.predict`接口来进行推理操作。
```python
model.predict(input_data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册