提交 e15ca80a 编写于 作者: W wjj19950828

Update readme

上级 0e7312fe
...@@ -26,7 +26,7 @@ pytorch2paddle(module=torch_module, ...@@ -26,7 +26,7 @@ pytorch2paddle(module=torch_module,
# input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。 # input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。
``` ```
**注意** **注意:**
- jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静;
- jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 - jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。
...@@ -55,11 +55,12 @@ pytorch2paddle(torch_module, ...@@ -55,11 +55,12 @@ pytorch2paddle(torch_module,
input_examples=[torch.tensor(input_data)]) input_examples=[torch.tensor(input_data)])
``` ```
### Script 模式动态 shape 导出 ### 动态 shape 导出
#### 方式一:PyTorch->ONNX->Paddle
```python ```python
import torch import torch
import numpy as np
from torchvision.models import AlexNet from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url from torchvision.models.utils import load_state_dict_from_url
...@@ -69,15 +70,27 @@ torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models ...@@ -69,15 +70,27 @@ torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models
torch_module.load_state_dict(torch_state_dict) torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式 # 设置为eval模式
torch_module.eval() torch_module.eval()
# 进行转换 input_names = ["input_0"]
from x2paddle.convert import pytorch2paddle output_names = ["output_0"]
pytorch2paddle(torch_module,
save_dir="pd_model_script", x = torch.randn((1, 3, 224, 224))
jit_type="script", y = torch.randn((1, 1000))
input_examples=None)
torch.onnx.export(torch_module, x, 'model.onnx', opset_version=11, input_names=input_names,
output_names=output_names, dynamic_axes={'input_0': [0], 'output_0': [0]})
```
导出 ONNX 动态 shape 模型,更多细节参考[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export)
然后通过 X2Paddle 命令导出 Paddle 模型
```shell
x2paddle --framework=onnx --model=model.onnx --save_dir=pd_model_dynamic
``` ```
在自动生成的x2paddle_code.py中添加如下代码: #### 方式二:手动动转静
在自动生成的 x2paddle_code.py 中添加如下代码:
```python ```python
def main(x0): def main(x0):
...@@ -91,11 +104,11 @@ def main(x0): ...@@ -91,11 +104,11 @@ def main(x0):
sepc_list = list() sepc_list = list()
sepc_list.append( sepc_list.append(
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[-1, 3, -1, -1], name="x0", dtype="float32")) shape=[-1, 3, 224, 224], name="x0", dtype="float32"))
static_model = paddle.jit.to_static(model, input_spec=sepc_list) static_model = paddle.jit.to_static(model, input_spec=sepc_list)
paddle.jit.save(static_model, "pd_model_script/inference_model/model") paddle.jit.save(static_model, "pd_model_trace/inference_model/model")
out = model(x0) out = model(x0)
return out return out
``` ```
运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export) 然后运行 main 函数导出动态 shape 模型
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册