From b793e5b3ec4a901e89316c957129f0698b470d61 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 21:15:25 +0800 Subject: [PATCH] Add dynamic shape --- .../pytorch2paddle.md | 54 +++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index 74976b5..f4840fb 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -14,7 +14,7 @@ treelib ## 使用方式 -``` python +```python from x2paddle.convert import pytorch2paddle pytorch2paddle(module=torch_module, save_dir="./pd_model", @@ -27,11 +27,14 @@ pytorch2paddle(module=torch_module, ``` **注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - 当jit_type为"script"时",input_examples不为None时,才可以进行动转静。 + + 当jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 ## 使用示例 -``` python +### Trace 模式 + +```python import torch import numpy as np from torchvision.models import AlexNet @@ -51,3 +54,48 @@ pytorch2paddle(torch_module, jit_type="trace", input_examples=[torch.tensor(input_data)]) ``` + +### Script 模式动态 shape 导出 + +```python +import torch +import numpy as np +from torchvision.models import AlexNet +from torchvision.models.utils import load_state_dict_from_url + +# 获取PyTorch Module +torch_module = AlexNet() +torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth') +torch_module.load_state_dict(torch_state_dict) +# 设置为eval模式 +torch_module.eval() +# 进行转换 +from x2paddle.convert import pytorch2paddle +pytorch2paddle(torch_module, + save_dir="pd_model_script", + jit_type="script", + input_examples=None) +``` + +在自动生成的x2paddle_code.py中添加如下代码: + +```python +def main(x0): + # There are 0 inputs. + paddle.disable_static() + params = paddle.load('model.pdparams') + model = AlexNet() + model.set_dict(params) + model.eval() + ## convert to jit + sepc_list = list() + sepc_list.append( + paddle.static.InputSpec( + shape=[-1, 3, -1, -1], name="x0", dtype="float32")) + static_model = paddle.jit.to_static(model, input_spec=sepc_list) + paddle.jit.save(static_model, "pd_model_script/inference_model/model") + out = model(x0) + return out +``` + +运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export) -- GitLab