From 3b7e01b1640c2be1c16bc7c35dda03a6e86bce72 Mon Sep 17 00:00:00 2001 From: shiyutang <34859558+shiyutang@users.noreply.github.com> Date: Mon, 10 Jan 2022 17:25:48 +0800 Subject: [PATCH] add_torch2paddlepy (#5460) --- .../mobilenetv3_prod/Step1-5/torch2paddle.py | 34 +++++++++++++++++++ .../ArticleReproduction_CV.md | 5 ++- 2 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 tutorials/mobilenetv3_prod/Step1-5/torch2paddle.py diff --git a/tutorials/mobilenetv3_prod/Step1-5/torch2paddle.py b/tutorials/mobilenetv3_prod/Step1-5/torch2paddle.py new file mode 100644 index 00000000..2a1afe3f --- /dev/null +++ b/tutorials/mobilenetv3_prod/Step1-5/torch2paddle.py @@ -0,0 +1,34 @@ +import numpy as np +import torch +import paddle + + +def torch2paddle(): + torch_path = "./data/mobilenet_v3_small-047dcff4.pth" + paddle_path = "./data/mv3_small_paddle.pdparams" + torch_state_dict = torch.load(torch_path) + fc_names = ["classifier"] + paddle_state_dict = {} + for k in torch_state_dict: + if "num_batches_tracked" in k: + continue + v = torch_state_dict[k].detach().cpu().numpy() + flag = [i in k for i in fc_names] + if any(flag) and "weight" in k: # ignore bias + new_shape = [1, 0] + list(range(2, v.ndim)) + print( + f"name: {k}, ori shape: {v.shape}, new shape: {v.transpose(new_shape).shape}" + ) + v = v.transpose(new_shape) + k = k.replace("running_var", "_variance") + k = k.replace("running_mean", "_mean") + # if k not in model_state_dict: + if False: + print(k) + else: + paddle_state_dict[k] = v + paddle.save(paddle_state_dict, paddle_path) + + +if __name__ == "__main__": + torch2paddle() diff --git a/tutorials/tipc/train_infer_python/ArticleReproduction_CV.md b/tutorials/tipc/train_infer_python/ArticleReproduction_CV.md index b79c2849..81c68e12 100644 --- a/tutorials/tipc/train_infer_python/ArticleReproduction_CV.md +++ b/tutorials/tipc/train_infer_python/ArticleReproduction_CV.md @@ -218,14 +218,13 @@ MobilnetV3网络结构的PyTorch实现: [mobilenetv3_prod/Step1-5/mobilenetv3_re 将mobilenetv3-torch的[模型参数](https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth)保存在本地之后,就可以通过下面的权重转换示例进行转换: ```python - import numpy as np import torch import paddle def torch2paddle(): - torch_path = "./mobilenet_v3_small-047dcff4.pth" - paddle_path = "./mv3_small_paddle.pdparams" + torch_path = "./data/mobilenet_v3_small-047dcff4.pth" + paddle_path = "./data/mv3_small_paddle.pdparams" torch_state_dict = torch.load(torch_path) fc_names = ["classifier"] paddle_state_dict = {} -- GitLab