From d77cf3bc9302bb0cce10ae16e5014e6e14983c20 Mon Sep 17 00:00:00 2001 From: lijianshe02 Date: Wed, 14 Oct 2020 11:12:03 +0000 Subject: [PATCH] refine psgan code according tho the reviewer's comment --- applications/run.sh | 9 - configs/makeup.yaml | 22 +- ppgan/datasets/makeup_dataset.py | 22 +- ppgan/faceutils/dlibutils/__init__.py | 4 +- .../dlibutils/{main.py => dlib_utils.py} | 1 - ppgan/faceutils/image.py | 17 -- ppgan/faceutils/mask/__init__.py | 2 +- .../mask/{main.py => face_parser.py} | 4 - ppgan/faceutils/mask/model.py | 32 +-- ppgan/models/discriminators/nlayers.py | 95 ++++++--- ppgan/models/makeup_model.py | 13 +- ppgan/models/vgg.py | 197 ++---------------- ppgan/modules/norm.py | 3 + 13 files changed, 105 insertions(+), 316 deletions(-) delete mode 100644 applications/run.sh rename ppgan/faceutils/dlibutils/{main.py => dlib_utils.py} (99%) rename ppgan/faceutils/mask/{main.py => face_parser.py} (90%) diff --git a/applications/run.sh b/applications/run.sh deleted file mode 100644 index 8dcc819..0000000 --- a/applications/run.sh +++ /dev/null @@ -1,9 +0,0 @@ -# 模型说明 -# 目前包含DAIN(插帧模型),DeOldify(上色模型),DeepRemaster(去噪与上色模型),EDVR(基于连续帧(视频)超分辨率模型),RealSR(基于图片的超分辨率模型) -# 参数说明 -# input 输入视频的路径 -# output 输出视频保存的路径 -# proccess_order 要使用的模型及顺序 - -python tools/video-enhance.py \ ---input input.mp4 --output output --proccess_order DeOldify RealSR diff --git a/configs/makeup.yaml b/configs/makeup.yaml index b56678d..34a65ec 100644 --- a/configs/makeup.yaml +++ b/configs/makeup.yaml @@ -1,20 +1,6 @@ -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -epochs: 200 -isTrain: False -output_dir: tmp +epochs: 100 +isTrain: True +output_dir: degrade_match checkpoints_dir: checkpoints lambda_A: 10.0 lambda_B: 10.0 @@ -31,7 +17,7 @@ model: ndf: 64 n_layers: 3 input_nc: 3 - norm_type: batch + norm_type: spectral gan_mode: lsgan dataset: diff --git a/ppgan/datasets/makeup_dataset.py b/ppgan/datasets/makeup_dataset.py index c5394ab..cf8ac8d 100644 --- a/ppgan/datasets/makeup_dataset.py +++ b/ppgan/datasets/makeup_dataset.py @@ -17,7 +17,6 @@ import os.path from .base_dataset import BaseDataset, get_transform from .transforms.makeup_transforms import get_makeup_transform import paddle.vision.transforms as T -from .image_folder import make_dataset from PIL import Image import random import numpy as np @@ -28,15 +27,6 @@ from .builder import DATASETS @DATASETS.register() class MakeupDataset(BaseDataset): - """ - This dataset class can load unaligned/unpaired datasets. - - It requires two directories to host training images from domain A '/path/to/data/trainA' - and from domain B '/path/to/data/trainB' respectively. - You can train the model with the dataset flag '--dataroot /path/to/data'. - Similarly, you need to prepare two directories: - '/path/to/data/testA' and '/path/to/data/testB' during test time. - """ def __init__(self, cfg): """Initialize this dataset class. @@ -83,16 +73,12 @@ class MakeupDataset(BaseDataset): getattr(self, cls + "_lmks_filenames").append(splits[2]) def __getitem__(self, index): - """Return a data point and its metadata information. + """Return MANet and MDNet needed params. Parameters: index (int) -- a random integer for data indexing - Returns a dictionary that contains A, B, A_paths and B_paths - A (tensor) -- an image in the input domain - B (tensor) -- its corresponding image in the target domain - A_paths (str) -- image paths - B_paths (str) -- image paths + Returns a dictionary that contains needed params. """ try: index_A = random.randint( @@ -125,15 +111,11 @@ class MakeupDataset(BaseDataset): self.image_path, getattr(self, self.cls_B + "_mask_filenames")[index_B])).convert('L')) - #image_A.paste((200,200,200), (0,0), Image.fromarray(np.uint8(255*(np.array(mask_A)==0)))) - #image_B.paste((200,200,200), (0,0), Image.fromarray(np.uint8(255*(np.array(mask_B)==0)))) image_A = np.array(image_A) image_B = np.array(image_B) - print('image shape: ', image_A.shape) image_A = self.transform(image_A) image_B = self.transform(image_B) - print('image shape: ', image_A.shape) mask_A = cv2.resize(mask_A, (256, 256), interpolation=cv2.INTER_NEAREST) diff --git a/ppgan/faceutils/dlibutils/__init__.py b/ppgan/faceutils/dlibutils/__init__.py index cdae1ef..b56699f 100644 --- a/ppgan/faceutils/dlibutils/__init__.py +++ b/ppgan/faceutils/dlibutils/__init__.py @@ -1,3 +1 @@ -#!/usr/bin/python -# -*- encoding: utf-8 -*- -from .main import detect, crop, landmarks, crop_from_array +from .dlib_utils import detect, crop, landmarks, crop_from_array diff --git a/ppgan/faceutils/dlibutils/main.py b/ppgan/faceutils/dlibutils/dlib_utils.py similarity index 99% rename from ppgan/faceutils/dlibutils/main.py rename to ppgan/faceutils/dlibutils/dlib_utils.py index 1df3c64..8f1fb87 100644 --- a/ppgan/faceutils/dlibutils/main.py +++ b/ppgan/faceutils/dlibutils/dlib_utils.py @@ -61,7 +61,6 @@ def crop(image: Image, face, up_ratio, down_ratio, width_ratio): face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom) center = face_expand.center() width, height = image.size - # import ipdb; ipdb.set_trace() crop_left = img_left crop_top = img_top crop_right = img_right diff --git a/ppgan/faceutils/image.py b/ppgan/faceutils/image.py index 4583af9..aed144f 100644 --- a/ppgan/faceutils/image.py +++ b/ppgan/faceutils/image.py @@ -3,16 +3,6 @@ import cv2 from io import BytesIO -def load_image(path): - with path.open("rb") as reader: - data = np.fromstring(reader.read(), dtype=np.uint8) - img = cv2.imdecode(data, cv2.IMREAD_COLOR) - if img is None: - return - img = img[..., ::-1] - return img - - def resize_by_max(image, max_side=512, force=False): h, w = image.shape[:2] if max(h, w) < max_side and not force: @@ -22,10 +12,3 @@ def resize_by_max(image, max_side=512, force=False): w = int(w / ratio + 0.5) h = int(h / ratio + 0.5) return cv2.resize(image, (w, h)) - - -def image2buffer(image): - is_success, buffer = cv2.imencode(".jpg", image) - if not is_success: - return None - return BytesIO(buffer) diff --git a/ppgan/faceutils/mask/__init__.py b/ppgan/faceutils/mask/__init__.py index f9ae2b1..15c2e99 100644 --- a/ppgan/faceutils/mask/__init__.py +++ b/ppgan/faceutils/mask/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .main import FaceParser +from .face_parser import FaceParser diff --git a/ppgan/faceutils/mask/main.py b/ppgan/faceutils/mask/face_parser.py similarity index 90% rename from ppgan/faceutils/mask/main.py rename to ppgan/faceutils/mask/face_parser.py index dd00549..94e58e9 100644 --- a/ppgan/faceutils/mask/main.py +++ b/ppgan/faceutils/mask/face_parser.py @@ -67,9 +67,5 @@ class FaceParser: for j in range(w): result[i][j] = self.mapper[parse_np[i][j]] - with open('/workspace/PaddleGAN/mapper_out.pkl', 'rb') as f: - torch_out = pickle.load(f) - cm = np.allclose(torch_out, result) - print('cm out: ', cm) result = paddle.to_tensor(result).astype('float32') return result diff --git a/ppgan/faceutils/mask/model.py b/ppgan/faceutils/mask/model.py index c3516af..4cd665c 100644 --- a/ppgan/faceutils/mask/model.py +++ b/ppgan/faceutils/mask/model.py @@ -40,7 +40,6 @@ class ConvBNReLU(paddle.nn.Layer): bias_attr=False) self.bn = nn.BatchNorm2d(out_chan) self.relu = nn.ReLU() - #self.init_weight() def forward(self, x): x = self.conv(x) @@ -57,7 +56,6 @@ class BiSeNetOutput(paddle.nn.Layer): n_classes, kernel_size=1, bias_attr=False) - #self.init_weight() def forward(self, x): x = self.conv(x) @@ -75,11 +73,9 @@ class AttentionRefinementModule(paddle.nn.Layer): bias_attr=False) self.bn_atten = nn.BatchNorm(out_chan) self.sigmoid_atten = nn.Sigmoid() - #self.init_weight() def forward(self, x): feat = self.conv(x) - #atten = F.avg_pool2d(feat, feat.size()[2:]) atten = F.avg_pool2d(feat, feat.shape[2:]) atten = self.conv_atten(atten) atten = self.bn_atten(atten) @@ -87,12 +83,6 @@ class AttentionRefinementModule(paddle.nn.Layer): out = feat * atten return out - #def init_weight(self): - # for ly in self.children(): - # if isinstance(ly, nn.Conv2d): - # nn.init.kaiming_normal_(ly.weight, a=1) - # if not ly.bias is None: nn.init.constant_(ly.bias, 0) - class ContextPath(paddle.nn.Layer): def __init__(self, *args, **kwargs): @@ -104,37 +94,30 @@ class ContextPath(paddle.nn.Layer): self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) - #self.init_weight() - def forward(self, x): H0, W0 = x.shape[2:] feat8, feat16, feat32 = self.resnet(x) H8, W8 = feat8.shape[2:] H16, W16 = feat16.shape[2:] H32, W32 = feat32.shape[2:] - print('feat32.shape: ', feat32.shape[2:]) avg = F.avg_pool2d(feat32, feat32.shape[2:]) avg = self.conv_avg(avg) - #avg_up = F.interpolate(avg, (H32, W32), mode='nearest') - avg_up = F.resize_nearest(avg, out_shape=(H32, W32)) + avg_up = F.interpolate(avg, size=(H32, W32), mode='nearest') feat32_arm = self.arm32(feat32) feat32_sum = feat32_arm + avg_up - #feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest') - feat32_up = F.resize_nearest(feat32_sum, out_shape=(H16, W16)) + feat32_up = F.interpolate(feat32_sum, size=(H16, W16), mode='nearest') feat32_up = self.conv_head32(feat32_up) feat16_arm = self.arm16(feat16) feat16_sum = feat16_arm + feat32_up - #feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest') - feat16_up = F.resize_nearest(feat16_sum, out_shape=(H8, W8)) + feat16_up = F.interpolate(feat16_sum, size=(H8, W8), mode='nearest') feat16_up = self.conv_head16(feat16_up) return feat8, feat16_up, feat32_up # x8, x8, x16 -### This is not used, since I replace this with the resnet feature with the same size class SpatialPath(paddle.nn.Layer): def __init__(self, *args, **kwargs): super(SpatialPath, self).__init__() @@ -142,7 +125,6 @@ class SpatialPath(paddle.nn.Layer): self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1) self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0) - #self.init_weight() def forward(self, x): feat = self.conv1(x) @@ -174,7 +156,6 @@ class FeatureFusionModule(paddle.nn.Layer): def forward(self, fsp, fcp): fcat = paddle.concat([fsp, fcp], axis=1) feat = self.convblk(fcat) - #atten = F.avg_pool2d(feat, feat.size()[2:]) atten = F.avg_pool2d(feat, feat.shape[2:]) atten = self.conv1(atten) atten = self.relu(atten) @@ -189,7 +170,6 @@ class BiSeNet(paddle.nn.Layer): def __init__(self, n_classes, *args, **kwargs): super(BiSeNet, self).__init__() self.cp = ContextPath() - ## here self.sp is deleted self.ffm = FeatureFusionModule(256, 256) self.conv_out = BiSeNetOutput(256, 256, n_classes) self.conv_out16 = BiSeNetOutput(128, 64, n_classes) @@ -206,9 +186,9 @@ class BiSeNet(paddle.nn.Layer): feat_out16 = self.conv_out16(feat_cp8) feat_out32 = self.conv_out32(feat_cp16) - feat_out = F.resize_bilinear(feat_out, out_shape=(H, W)) - feat_out16 = F.resize_bilinear(feat_out16, out_shape=(H, W)) - feat_out32 = F.resize_bilinear(feat_out32, out_shape=(H, W)) + feat_out = F.interpolate(feat_out, size=(H, W)) + feat_out16 = F.interpolate(feat_out16, size=(H, W)) + feat_out32 = F.interpolate(feat_out32, size=(H, W)) return feat_out, feat_out16, feat_out32 diff --git a/ppgan/models/discriminators/nlayers.py b/ppgan/models/discriminators/nlayers.py index 7cc0ccc..741233e 100644 --- a/ppgan/models/discriminators/nlayers.py +++ b/ppgan/models/discriminators/nlayers.py @@ -25,7 +25,7 @@ from .builder import DISCRIMINATORS @DISCRIMINATORS.register() -class NLayerDiscriminator(paddle.nn.Layer): +class NLayerDiscriminator(nn.Layer): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'): """Construct a PatchGAN discriminator @@ -47,53 +47,98 @@ class NLayerDiscriminator(paddle.nn.Layer): kw = 4 padw = 1 - #sequence = [Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.01)] - sequence = [ - Spectralnorm( - Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)), - nn.LeakyReLU(0.01) - ] + + if norm_type == 'spectral': + sequence = [ + Spectralnorm( + Conv2d(input_nc, + ndf, + kernel_size=kw, + stride=2, + padding=padw)), + nn.LeakyReLU(0.01) + ] + else: + sequence = [ + Conv2d(input_nc, + ndf, + kernel_size=kw, + stride=2, + padding=padw, + bias_attr=use_bias), + nn.LeakyReLU(0.2) + ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2**n, 8) + if norm_type == 'spectral': + sequence += [ + Spectralnorm( + Conv2d(ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw)), + nn.LeakyReLU(0.01) + ] + else: + sequence += [ + Conv2d(ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias_attr=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + if norm_type == 'spectral': sequence += [ - #Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias_attr=use_bias), Spectralnorm( Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, - stride=2, + stride=1, padding=padw)), - #norm_layer(ndf * nf_mult), nn.LeakyReLU(0.01) ] - - nf_mult_prev = nf_mult - nf_mult = min(2**n_layers, 8) - sequence += [ - #Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias_attr=use_bias), - Spectralnorm( + else: + sequence += [ Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, - padding=padw)), - #norm_layer(ndf * nf_mult), - nn.LeakyReLU(0.01) - ] + padding=padw, + bias_attr=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2) + ] - #sequence += [Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map - sequence += [ - Spectralnorm( + if norm_type == 'spectral': + sequence += [ + Spectralnorm( + Conv2d(ndf * nf_mult, + 1, + kernel_size=kw, + stride=1, + padding=padw, + bias_attr=False)) + ] # output 1 channel prediction map + else: + sequence += [ Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw, - bias_attr=False)) - ] # output 1 channel prediction map + bias_attr=False) + ] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) def forward(self, input): diff --git a/ppgan/models/makeup_model.py b/ppgan/models/makeup_model.py index 1e11e1a..97bda2a 100644 --- a/ppgan/models/makeup_model.py +++ b/ppgan/models/makeup_model.py @@ -132,7 +132,6 @@ class MakeupModel(BaseModel): The option 'direction' can be used to swap domain A and domain B. """ self.real_A = paddle.to_tensor(input['image_A']) - print('real_a shape: ', self.real_A.shape) self.real_B = paddle.to_tensor(input['image_B']) self.c_m = paddle.to_tensor(input['consis_mask']) self.P_A = paddle.to_tensor(input['P_A']) @@ -338,9 +337,9 @@ class MakeupModel(BaseModel): g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B) self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his + - g_A_skin_loss_his * 0.1) * 0.1 + g_A_skin_loss_his * 0.1) * 0.01 self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his + - g_B_skin_loss_his * 0.1) * 0.1 + g_B_skin_loss_his * 0.1) * 0.01 #self.loss_G_A_his = self.criterionL1(tmp_1, tmp_2) * 2048 * 255 #tmp_3 = self.hm_gt_B*self.hm_mask_weight_B @@ -361,8 +360,8 @@ class MakeupModel(BaseModel): vgg_r) * lambda_B * lambda_vgg self.loss_rec = (self.loss_cycle_A + self.loss_cycle_B + - self.loss_A_vgg + self.loss_B_vgg) * 0.1 - self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.1 + self.loss_A_vgg + self.loss_B_vgg) * 0.2 + self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.2 # bg consistency loss mask_A_consis = paddle.cast( @@ -370,8 +369,8 @@ class MakeupModel(BaseModel): (self.mask_A == 10), dtype='float32') + paddle.cast( (self.mask_A == 8), dtype='float32') mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1) - self.loss_G_bg_consis = self.criterionL1(self.real_A * mask_A_consis, - self.fake_A * mask_A_consis) + self.loss_G_bg_consis = self.criterionL1( + self.real_A * mask_A_consis, self.fake_A * mask_A_consis) * 0.1 # combined loss and calculate gradients diff --git a/ppgan/models/vgg.py b/ppgan/models/vgg.py index 16dab88..fc108ee 100644 --- a/ppgan/models/vgg.py +++ b/ppgan/models/vgg.py @@ -14,211 +14,38 @@ import paddle import paddle.nn as nn -import paddle.nn.functional as F - from paddle.utils.download import get_weights_path_from_url +from paddle.vision.models.vgg import make_layers -__all__ = [ - 'VGG', - 'vgg11', - 'vgg13', - 'vgg16', - 'vgg19', +cfg = [ + 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, + 512, 512, 'M' ] model_urls = { 'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams', - 'c788f453a3b999063e8da043456281ee') + '89bbffc0f87d260be9b8cdc169c991c4') } -class Classifier(paddle.nn.Layer): - def __init__(self, num_classes, classifier_activation='softmax'): - super(Classifier, self).__init__() - self.linear1 = nn.Linear(512 * 7 * 7, 4096) - self.linear2 = nn.Linear(4096, 4096) - self.linear3 = nn.Linear(4096, num_classes) - self.relu = nn.ReLU() - self.dropout = nn.Dropout(0.5) - self.softmax = nn.Softmax() - - def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.dropout(x) - x = self.linear2(x) - x = self.relu(x) - x = self.dropout(x) - out = self.linear3(x) - out = self.softmax(out) - return out - - -class VGG(paddle.nn.Layer): - """VGG model from - `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `_ - Args: - features (nn.Layer): vgg features create by function make_layers. - num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer - will not be defined. Default: 1000. - classifier_activation (str): activation for the last fc layer. Default: 'softmax'. - Examples: - .. code-block:: python - from paddle.incubate.hapi.vision.models import VGG - from paddle.incubate.hapi.vision.models.vgg import make_layers - vgg11_cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'] - features = make_layers(vgg11_cfg) - vgg11 = VGG(features) - """ - def __init__(self, - features, - num_classes=1000, - classifier_activation='softmax'): +class VGG(nn.Layer): + def __init__(self, features): super(VGG, self).__init__() self.features = features - self.num_classes = num_classes - - if num_classes > 0: - classifier = Classifier(num_classes, classifier_activation) - self.classifier = self.add_sublayer("classifier", - nn.Sequential(classifier)) def forward(self, x): x = self.features(x) - - #if self.num_classes > 0: - # x = fluid.layers.flatten(x, 1) - # x = self.classifier(x) return x -def make_layers(cfg, batch_norm=False): - layers = [] - in_channels = 3 - - for v in cfg: - if v == 'M': - layers += [nn.MaxPool2d(kernel_size=2, stride=2)] - else: - if batch_norm: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) - layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()] - else: - conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) - layers += [conv2d, nn.ReLU()] - in_channels = v - return nn.Sequential(*layers) - - -cfgs = { - 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'B': - [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], - 'D': [ - 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, - 512, 512, 'M' - ], - 'E': [ - 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, - 'M', 512, 512, 512, 512, 'M' - ], -} - - -def _vgg(arch, cfg, batch_norm, pretrained, **kwargs): - model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), - num_classes=1000, - **kwargs) +def vgg16(pretrained=False): + features = make_layers(cfg) + model = VGG(features) if pretrained: - #assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format( - # arch) - #weight_path = get_weights_path_from_url(model_urls[arch][0], - # model_urls[arch][1]) - #assert weight_path.endswith( - # '.pdparams'), "suffix of weight must be .pdparams" - weight_path = './vgg16.pdparams' + weight_path = get_weights_path_from_url(model_urls['vgg16'][0], + model_urls['vgg16'][1]) param, _ = paddle.load(weight_path) model.load_dict(param) return model - - -def vgg11(pretrained=False, batch_norm=False, **kwargs): - """VGG 11-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. - batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. - Examples: - .. code-block:: python - from paddle.incubate.hapi.vision.models import vgg11 - # build model - model = vgg11() - # build vgg11 model with batch_norm - model = vgg11(batch_norm=True) - """ - model_name = 'vgg11' - if batch_norm: - model_name += ('_bn') - return _vgg(model_name, 'A', batch_norm, pretrained, **kwargs) - - -def vgg13(pretrained=False, batch_norm=False, **kwargs): - """VGG 13-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. - batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. - Examples: - .. code-block:: python - from paddle.incubate.hapi.vision.models import vgg13 - # build model - model = vgg13() - # build vgg13 model with batch_norm - model = vgg13(batch_norm=True) - """ - model_name = 'vgg13' - if batch_norm: - model_name += ('_bn') - return _vgg(model_name, 'B', batch_norm, pretrained, **kwargs) - - -def vgg16(pretrained=False, batch_norm=False, **kwargs): - """VGG 16-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. - batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. - Examples: - .. code-block:: python - from paddle.incubate.hapi.vision.models import vgg16 - # build model - model = vgg16() - # build vgg16 model with batch_norm - model = vgg16(batch_norm=True) - """ - model_name = 'vgg16' - if batch_norm: - model_name += ('_bn') - return _vgg(model_name, 'D', batch_norm, pretrained, **kwargs) - - -def vgg19(pretrained=False, batch_norm=False, **kwargs): - """VGG 19-layer model - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False. - batch_norm (bool): If True, returns a model with batch_norm layer. Default: False. - Examples: - .. code-block:: python - from paddle.incubate.hapi.vision.models import vgg19 - # build model - model = vgg19() - # build vgg19 model with batch_norm - model = vgg19(batch_norm=True) - """ - model_name = 'vgg19' - if batch_norm: - model_name += ('_bn') - return _vgg(model_name, 'E', batch_norm, pretrained, **kwargs) diff --git a/ppgan/modules/norm.py b/ppgan/modules/norm.py index 66833fc..67b7b43 100644 --- a/ppgan/modules/norm.py +++ b/ppgan/modules/norm.py @@ -1,6 +1,7 @@ import paddle import functools import paddle.nn as nn +from .nn import Spectralnorm class Identity(nn.Layer): @@ -35,6 +36,8 @@ def build_norm_layer(norm_type='instance'): bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0), learning_rate=0.0, trainable=False)) + elif norm_type == 'spectral': + norm_layer = functools.partial(Spectralnorm) elif norm_type == 'none': def norm_layer(x): -- GitLab