From 925c686062221c12a16a3d08b95a492516ed01ce Mon Sep 17 00:00:00 2001 From: SunAhong1993 <48579383+SunAhong1993@users.noreply.github.com> Date: Wed, 26 May 2021 19:50:19 +0800 Subject: [PATCH] =?UTF-8?q?Add=20PyTorch=20Project=20Convertor=E2=80=94?= =?UTF-8?q?=E2=80=94CRAFT=20(#606)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix the code * fix the visit_tuple * Update stargan.md * Update ultra_light_fast_generic_face_detector.md * fix the docs * remove static * fix * fix * fix * fix the docs * fix the project_convertor * remove * fix nn.py * modify the doc * modify the doc * fix Co-authored-by: channingss --- docs/pytorch_project_convertor/demo/README.md | 1 + .../project_convertor/pytorch/api_mapper/nn.py | 8 ++++---- .../project_convertor/pytorch/api_mapper/ops.py | 16 +++++++++++----- x2paddle/project_convertor/pytorch/ast_update.py | 2 ++ .../project_convertor/pytorch/models/resnet.py | 6 +++--- x2paddle/project_convertor/pytorch/models/vgg.py | 8 ++++---- .../project_convertor/pytorch/torch2paddle/nn.py | 13 +++++++++++++ 7 files changed, 38 insertions(+), 16 deletions(-) diff --git a/docs/pytorch_project_convertor/demo/README.md b/docs/pytorch_project_convertor/demo/README.md index f8092a0..8a70fd2 100644 --- a/docs/pytorch_project_convertor/demo/README.md +++ b/docs/pytorch_project_convertor/demo/README.md @@ -5,5 +5,6 @@ |------|-----|----------|------| | StaGAN | [demo](stargan.md)| [code](https://github.com/yunjey/stargan)|[code](https://github.com/SunAhong1993/stargan/tree/paddle)| | Ultra-Light-Fast-Generic-Face-Detector |[demo](ultra_light_fast_generic_face_detector.md)| [code](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB) |[code](https://github.com/SunAhong1993/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/paddle)| +| CRAFT-pytorch |[demo](craft.md)| [code](https://github.com/clovaai/CRAFT-pytorch) |[code](https://github.com/SunAhong1993/CRAFT-pytorch/tree/paddle)| ***持续更新...*** diff --git a/x2paddle/project_convertor/pytorch/api_mapper/nn.py b/x2paddle/project_convertor/pytorch/api_mapper/nn.py index d80f5d9..7b85f3a 100644 --- a/x2paddle/project_convertor/pytorch/api_mapper/nn.py +++ b/x2paddle/project_convertor/pytorch/api_mapper/nn.py @@ -386,10 +386,10 @@ class ReLUModuleMapper(Mapper): target_name=None): super().__init__(func_name, pytorch_api_name, args, kwargs, target_name) - def delete_attrs(self): - delete_key(self.kwargs, "inplace") - if len(self.args) > 0: - self.args.clear() + def run(self): + if len(self.args) > 0 or len(self.kwargs) > 0: + self.func_name = "x2paddle.torch2paddle.ReLU" + return super().run() class SoftmaxModuleMapper(Mapper): diff --git a/x2paddle/project_convertor/pytorch/api_mapper/ops.py b/x2paddle/project_convertor/pytorch/api_mapper/ops.py index 619c786..bdb48a0 100644 --- a/x2paddle/project_convertor/pytorch/api_mapper/ops.py +++ b/x2paddle/project_convertor/pytorch/api_mapper/ops.py @@ -93,6 +93,8 @@ class SetDeviceMapper(Mapper): def run(self): self.process_attrs() insert_codes = list() + if self.target_name is None: + self.target_name = "tmp" insert_codes.append("{} = {}".format(self.target_name, self.useful_attrs["device"])) insert_codes.append("{} = {}.replace('cuda', 'gpu')".format( @@ -539,19 +541,23 @@ class LinspaceMapper(Mapper): out2, self.useful_attrs["requires_grad"]) return out1, out2, out3 + class ToTensorMapper(Mapper): - def __init__(self, + def __init__(self, func_name, - pytorch_api_name, - args, kwargs, + pytorch_api_name, + args, + kwargs, target_name=None): super().__init__(func_name, pytorch_api_name, args, kwargs, target_name) def process_attrs(self): rename_key(self.kwargs, "device", "place") + def run(self): if self.rename_func_name("paddle.to_tensor"): - return [], generate_api_code(self.func_name, self.args, self.kwargs), [] + return [], generate_api_code(self.func_name, self.args, + self.kwargs), [] else: self.convert_args2kwargs() - return self.convert_to_paddle() \ No newline at end of file + return self.convert_to_paddle() diff --git a/x2paddle/project_convertor/pytorch/ast_update.py b/x2paddle/project_convertor/pytorch/ast_update.py index 8d2caf0..9e56ee1 100644 --- a/x2paddle/project_convertor/pytorch/ast_update.py +++ b/x2paddle/project_convertor/pytorch/ast_update.py @@ -497,6 +497,8 @@ class AstUpdater(ast.NodeVisitor): mapper = API_MAPPER[pytorch_api][1](func_name, pytorch_api, args_list, kw_dict, target_name) prefix_insert_codes, new_code, suffix_insert_codes = mapper.run() + if mapper.func_name.startswith("x2paddle."): + self.is_import_x2paddle = True scope_node = self._get_scope_node() if isinstance(ast.parse(new_code).body[0], ast.Assign): node_index = self._get_current_index(scope_node, node) diff --git a/x2paddle/project_convertor/pytorch/models/resnet.py b/x2paddle/project_convertor/pytorch/models/resnet.py index 83eb628..5e86756 100644 --- a/x2paddle/project_convertor/pytorch/models/resnet.py +++ b/x2paddle/project_convertor/pytorch/models/resnet.py @@ -79,7 +79,7 @@ class BasicBlock(nn.Layer): # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) - self.relu = nn.ReLU() + self.relu = torch2paddle.ReLU(True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample @@ -133,7 +133,7 @@ class Bottleneck(nn.Layer): self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) - self.relu = nn.ReLU() + self.relu = torch2paddle.ReLU(True) self.downsample = downsample self.stride = stride @@ -195,7 +195,7 @@ class ResNet(nn.Layer): padding=3, bias_attr=False) self.bn1 = norm_layer(self.inplanes) - self.relu = nn.ReLU() + self.relu = torch2paddle.ReLU(True) self.maxpool = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer( diff --git a/x2paddle/project_convertor/pytorch/models/vgg.py b/x2paddle/project_convertor/pytorch/models/vgg.py index 5e4316f..5393174 100644 --- a/x2paddle/project_convertor/pytorch/models/vgg.py +++ b/x2paddle/project_convertor/pytorch/models/vgg.py @@ -39,10 +39,10 @@ class VGG(nn.Layer): self.avgpool = nn.AdaptiveAvgPool2D((7, 7)) self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), - nn.ReLU(), + torch2paddle.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), - nn.ReLU(), + torch2paddle.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes), ) if init_weights: @@ -81,9 +81,9 @@ def make_layers(cfg: List[Union[str, int]], v = cast(int, v) conv2d = nn.Conv2D(in_channels, v, kernel_size=3, padding=1) if batch_norm: - layers += [conv2d, nn.BatchNorm2D(v), nn.ReLU()] + layers += [conv2d, nn.BatchNorm2D(v), torch2paddle.ReLU(True)] else: - layers += [conv2d, nn.ReLU()] + layers += [conv2d, torch2paddle.ReLU(True)] in_channels = v return nn.Sequential(*layers) diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/nn.py b/x2paddle/project_convertor/pytorch/torch2paddle/nn.py index 79b0762..9b3fb9b 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/nn.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/nn.py @@ -529,6 +529,19 @@ class MaxUnpool2D(paddle.nn.Layer): return out +class ReLU(paddle.nn.ReLU): + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + def forward(self, x): + if self.inplace: + out = paddle.nn.functional.relu_(x) + else: + out = super().forward(x) + return out + + class ReflectionPad2D(paddle.nn.Pad2D): def __init__(self, padding): super().__init__(padding, mode="reflect") -- GitLab