diff --git a/docs/pytorch_project_convertor/API_docs/vision/torchvision.models.md b/docs/pytorch_project_convertor/API_docs/vision/torchvision.models.md index e97f1adbfb041340b64d2b107767d8e4e415a995..2a3a8358b771bbae5bea3287d91db14b88aece5f 100644 --- a/docs/pytorch_project_convertor/API_docs/vision/torchvision.models.md +++ b/docs/pytorch_project_convertor/API_docs/vision/torchvision.models.md @@ -17,20 +17,20 @@ out = model(x) 目前支持的模型为: | PyTorch模型 | Paddle模型 | | ------------------------------------------------------------ | -------------------------------- | -| [torchvision.models.resnet18](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet18) | x2paddle.models.resnet18_pth | -| [torchvision.models.resnet34](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet34) | x2paddle.models.resnet34_pth | -| [torchvision.models.resnet50](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet50) | x2paddle.models.resnet50_pth | -| [torchvision.models.resnet101](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet101) | x2paddle.models.resnet101_pth | -| [torchvision.models.resnet152](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet152) | x2paddle.models.resnet152_pth | -| [torchvision.models.resnext50_32x4d](https://pytorch.org/vision/stable/models.html#torchvision.models.resnext50_32x4d) | x2paddle.models.resnext50_32x4d_pth | -| [torchvision.models.resnext101_32x8d](https://pytorch.org/vision/stable/models.html#torchvision.models.resnext101_32x8d) | x2paddle.resnext101_32x8d_pth | -| [torchvision.models.wide_resnet50_2](https://pytorch.org/vision/stable/models.html#torchvision.models.wide_resnet50_2) | x2paddle.models.wide_resnet50_2_pth | -| [torchvision.models.wide_resnet101_2](https://pytorch.org/vision/stable/models.html#torchvision.models.wide_resnet101_2) | x2paddle.models.wide_resnet101_2_pth | -| [torchvision.models.vgg11](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg11) | x2paddle.models.vgg11_pth | -| [torchvision.models.vgg11_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg11_bn) | x2paddle.models.vgg11_bn_pth | -| [torchvision.models.vgg13](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg13) | x2paddle.models.vgg13_pth | -| [torchvision.models.vgg13_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg13_bn) | x2paddle.models.vgg13_bn_pth | -| [torchvision.models.vgg16](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg16) | x2paddle.models.vgg16_pth | -| [torchvision.models.vgg16_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg16_bn) | x2paddle.models.vgg16_bn_pth | -| [torchvision.models.vgg19](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg19) | x2paddle.models.vgg19_pth | -| [torchvision.models.vgg19_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg19_bn) | x2paddle.models.vgg19_bn_pth | +| [torchvision.models.resnet18](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet18) | [x2paddle.models.resnet18_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L310) | +| [torchvision.models.resnet34](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet34) | [x2paddle.models.resnet34_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L322) | +| [torchvision.models.resnet50](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet50) | [x2paddle.models.resnet50_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L334) | +| [torchvision.models.resnet101](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet101) | [x2paddle.models.resnet101_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L346) | +| [torchvision.models.resnet152](https://pytorch.org/vision/stable/models.html#torchvision.models.resnet152) | [x2paddle.models.resnet152_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L358) | +| [torchvision.models.resnext50_32x4d](https://pytorch.org/vision/stable/models.html#torchvision.models.resnext50_32x4d) | [x2paddle.models.resnext50_32x4d_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L370) | +| [torchvision.models.resnext101_32x8d](https://pytorch.org/vision/stable/models.html#torchvision.models.resnext101_32x8d) | [x2paddle.resnext101_32x8d_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L385) | +| [torchvision.models.wide_resnet50_2](https://pytorch.org/vision/stable/models.html#torchvision.models.wide_resnet50_2) | [x2paddle.models.wide_resnet50_2_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L400) | +| [torchvision.models.wide_resnet101_2](https://pytorch.org/vision/stable/models.html#torchvision.models.wide_resnet101_2) | [x2paddle.models.wide_resnet101_2_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/resnet.py#L419) | +| [torchvision.models.vgg11](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg11) | [x2paddle.models.vgg11_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L117) | +| [torchvision.models.vgg11_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg11_bn) | [x2paddle.models.vgg11_bn_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L128) | +| [torchvision.models.vgg13](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg13) | [x2paddle.models.vgg13_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L139) | +| [torchvision.models.vgg13_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg13_bn) | [x2paddle.models.vgg13_bn_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L150) | +| [torchvision.models.vgg16](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg16) | [x2paddle.models.vgg16_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L161) | +| [torchvision.models.vgg16_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg16_bn) | [x2paddle.models.vgg16_bn_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L172) | +| [torchvision.models.vgg19](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg19) | [x2paddle.models.vgg19_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L183) | +| [torchvision.models.vgg19_bn](https://pytorch.org/vision/stable/models.html#torchvision.models.vgg19_bn) | [x2paddle.models.vgg19_bn_pth](https://github.com/PaddlePaddle/X2Paddle/blob/develop/x2paddle/project_convertor/pytorch/models/vgg.py#L194) | diff --git a/x2paddle/utils.py b/x2paddle/utils.py index 37409cdb2d187d3b5e077ef92e3edc9f6b161db0..3747916bd8b1de40c0b66b2878b5b6bd2bcb6847 100644 --- a/x2paddle/utils.py +++ b/x2paddle/utils.py @@ -45,16 +45,15 @@ class PaddleDtypes(): self.t_int64 = paddle.int64 self.t_bool = paddle.bool else: - from paddle.fluid.core import VarDesc - self.t_float16 = VarDesc.VarType.FP16 - self.t_float32 = VarDesc.VarType.FP32 - self.t_float64 = VarDesc.VarType.FP64 - self.t_uint8 = VarDesc.VarType.UINT8 - self.t_int8 = VarDesc.VarType.INT8 - self.t_int16 = VarDesc.VarType.INT16 - self.t_int32 = VarDesc.VarType.INT32 - self.t_int64 = VarDesc.VarType.INT64 - self.t_bool = VarDesc.VarType.BOOL + self.t_float16 = "paddle.fluid.core.VarDesc.VarType.FP16" + self.t_float32 = "paddle.fluid.core.VarDesc.VarType.FP32" + self.t_float64 = "paddle.fluid.core.VarDesc.VarType.FP64" + self.t_uint8 = "paddle.fluid.core.VarDesc.VarType.UINT8" + self.t_int8 = "paddle.fluid.core.VarDesc.VarType.INT8" + self.t_int16 = "paddle.fluid.core.VarDesc.VarType.INT16" + self.t_int32 = "paddle.fluid.core.VarDesc.VarType.INT32" + self.t_int64 = "paddle.fluid.core.VarDesc.VarType.INT64" + self.t_bool = "paddle.fluid.core.VarDesc.VarType.BOOL" is_new_version = check_version()