未验证 提交 3b7e01b1 编写于 作者: S shiyutang 提交者: GitHub

add_torch2paddlepy (#5460)

上级 511e2e28
import numpy as np
import torch
import paddle
def torch2paddle():
torch_path = "./data/mobilenet_v3_small-047dcff4.pth"
paddle_path = "./data/mv3_small_paddle.pdparams"
torch_state_dict = torch.load(torch_path)
fc_names = ["classifier"]
paddle_state_dict = {}
for k in torch_state_dict:
if "num_batches_tracked" in k:
continue
v = torch_state_dict[k].detach().cpu().numpy()
flag = [i in k for i in fc_names]
if any(flag) and "weight" in k: # ignore bias
new_shape = [1, 0] + list(range(2, v.ndim))
print(
f"name: {k}, ori shape: {v.shape}, new shape: {v.transpose(new_shape).shape}"
)
v = v.transpose(new_shape)
k = k.replace("running_var", "_variance")
k = k.replace("running_mean", "_mean")
# if k not in model_state_dict:
if False:
print(k)
else:
paddle_state_dict[k] = v
paddle.save(paddle_state_dict, paddle_path)
if __name__ == "__main__":
torch2paddle()
......@@ -218,14 +218,13 @@ MobilnetV3网络结构的PyTorch实现: [mobilenetv3_prod/Step1-5/mobilenetv3_re
将mobilenetv3-torch的[模型参数](https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth)保存在本地之后,就可以通过下面的权重转换示例进行转换:
```python
import numpy as np
import torch
import paddle
def torch2paddle():
torch_path = "./mobilenet_v3_small-047dcff4.pth"
paddle_path = "./mv3_small_paddle.pdparams"
torch_path = "./data/mobilenet_v3_small-047dcff4.pth"
paddle_path = "./data/mv3_small_paddle.pdparams"
torch_state_dict = torch.load(torch_path)
fc_names = ["classifier"]
paddle_state_dict = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册