pytorch2paddle.md 3.6 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7
# PyTorch2Paddle

PyTorch2Paddle支持trace和script两种方式的转换,均是PyTorch动态图到Paddle动态图的转换,转换后的Paddle动态图运用动转静可转换为静态图模型。trace方式生成的代码可读性较强,较为接近原版PyTorch代码的组织结构;script方式不需要知道输入数据的类型和大小即可转换,使用上较为方便,但目前PyTorch支持的script代码方式有所限制,所以支持转换的代码也有所限制。用户可根据自身需求,选择转换方式。

## 环境依赖

python == 2.7 | python >= 3.5  
8
paddlepaddle >= 2.0.0  
S
SunAhong1993 已提交
9 10
pytorch:torch >=1.5.0 (script方式暂不支持1.7.0)

11
**使用trace方式需安装以下依赖**
S
SunAhong1993 已提交
12
pandas
13
treelib
S
SunAhong1993 已提交
14 15 16

## 使用方式

W
wjj19950828 已提交
17
```python
S
SunAhong1993 已提交
18
from x2paddle.convert import pytorch2paddle
19 20 21
pytorch2paddle(module=torch_module,
               save_dir="./pd_model",
               jit_type="trace",
S
SunAhong1993 已提交
22 23 24 25 26 27 28
               input_examples=[torch_input])
# module (torch.nn.Module): PyTorch的Module。
# save_dir (str): 转换后模型的保存路径。
# jit_type (str): 转换方式。默认为"trace"。
# input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。
```

W
wjj19950828 已提交
29 30 31
**注意:**
- jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静;
- jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。
S
SunAhong1993 已提交
32 33 34

## 使用示例

W
wjj19950828 已提交
35 36 37
### Trace 模式

```python
S
SunAhong1993 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51
import torch
import numpy as np
from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url
# 构建输入
input_data = np.random.rand(1, 3, 224, 224).astype("float32")
# 获取PyTorch Module
torch_module = AlexNet()
torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
52 53 54
pytorch2paddle(torch_module,
               save_dir="pd_model_trace",
               jit_type="trace",
S
SunAhong1993 已提交
55
               input_examples=[torch.tensor(input_data)])
S
SunAhong1993 已提交
56
```
W
wjj19950828 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101

### Script 模式动态 shape 导出

```python
import torch
import numpy as np
from torchvision.models import AlexNet
from torchvision.models.utils import load_state_dict_from_url

# 获取PyTorch Module
torch_module = AlexNet()
torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth')
torch_module.load_state_dict(torch_state_dict)
# 设置为eval模式
torch_module.eval()
# 进行转换
from x2paddle.convert import pytorch2paddle
pytorch2paddle(torch_module,
               save_dir="pd_model_script",
               jit_type="script",
               input_examples=None)
```

在自动生成的x2paddle_code.py中添加如下代码:

```python
def main(x0):
    # There are 0 inputs.
    paddle.disable_static()
    params = paddle.load('model.pdparams')
    model = AlexNet()
    model.set_dict(params)
    model.eval()
    ## convert to jit
    sepc_list = list()
    sepc_list.append(
            paddle.static.InputSpec(
                shape=[-1, 3, -1, -1], name="x0", dtype="float32"))
    static_model = paddle.jit.to_static(model, input_spec=sepc_list)
    paddle.jit.save(static_model, "pd_model_script/inference_model/model")
    out = model(x0)
    return out
```

运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export)