未验证 提交 3e4a59f2 编写于 作者: L LielinJiang 提交者: GitHub

Refine code (#44)

* refine code

* test

* fix apps

* update readme

* rm unused code

* fix apps output when input is image

* clean code

* update requirements.txt
上级 359db9ce
......@@ -61,7 +61,7 @@ pip install -v -e . # or "python setup.py develop"
Please refer to [data prepare](./docs/data_prepare.md) for dataset preparation.
## Get Start
Please refer [get stated](./docs/get_started.md) for the basic usage of PaddleGAN.
Please refer [get started](./docs/get_started.md) for the basic usage of PaddleGAN.
## Model tutorial
* [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md)
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
......@@ -39,12 +38,12 @@ dataset:
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
output_size: [256, 256]
size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Permute
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
......@@ -60,8 +59,8 @@ dataset:
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
- name: Permute
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_A: 10.0
lambda_B: 10.0
......@@ -38,12 +37,12 @@ dataset:
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop
output_size: [256, 256]
- name: RandomHorizontalFlip
prob: 0.5
- name: Permute
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
......@@ -60,8 +59,8 @@ dataset:
transform:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
- name: Permute
interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Transpose
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
......
epochs: 100
isTrain: True
output_dir: tmp
checkpoints_dir: checkpoints
lambda_A: 10.0
......@@ -24,14 +23,14 @@ dataset:
train:
name: MakeupDataset
trans_size: 256
dataroot: MT-Dataset
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: train
pool_size: 16
test:
name: MakeupDataset
trans_size: 256
dataroot: MT-Dataset
dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup]
phase: test
pool_size: 16
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
......@@ -36,15 +35,15 @@ dataset:
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......@@ -63,9 +62,9 @@ dataset:
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
......@@ -35,15 +34,15 @@ dataset:
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......@@ -62,9 +61,9 @@ dataset:
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......
epochs: 200
isTrain: True
output_dir: output_dir
lambda_L1: 100
......@@ -35,15 +34,15 @@ dataset:
transforms:
- name: Resize
size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: PairedRandomCrop
output_size: [256, 256]
size: [256, 256]
keys: [image, image]
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......@@ -62,9 +61,9 @@ dataset:
transforms:
- name: Resize
size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Permute
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
......
......@@ -14,6 +14,7 @@
import os
import cv2
from PIL import Image
import paddle
......@@ -61,9 +62,10 @@ class BasePredictor(object):
return out
def is_video(self, input):
def is_image(self, input):
try:
cv2.VideoCapture(input)
img = Image.open(input)
_ = img.size
return True
except:
return False
......
......@@ -128,13 +128,15 @@ class DeOldifyPredictor(BasePredictor):
return frame_pattern_combined, vid_out_path
def run(self, input):
if self.is_video(input):
if not self.is_image(input):
return self.run_video(input)
else:
pred_img = self.run_image(input)
out_path = None
if self.output:
base_name = os.path.basename(input)
pred_img.save(os.path.join(self.output, base_name + '.png'))
base_name = os.path.splitext(os.path.basename(input))[0]
out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)
return pred_img
return pred_img, out_path
......@@ -98,13 +98,18 @@ class RealSRPredictor(BasePredictor):
return frame_pattern_combined, vid_out_path
def run(self, input):
if self.is_video(input):
if not os.path.exists(self.output):
os.makedirs(self.output)
if not self.is_image(input):
return self.run_video(input)
else:
pred_img = self.run_image(input)
out_path = None
if self.output:
base_name = os.path.basename(input)
pred_img.save(os.path.join(self.output, base_name + '.png'))
base_name = os.path.splitext(os.path.basename(input))[0]
out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)
return pred_img
return pred_img, out_path
......@@ -16,7 +16,7 @@ class PairedDataset(BaseDataset):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
cfg (dict): configs of datasets.
"""
BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot,
......@@ -42,7 +42,7 @@ class PairedDataset(BaseDataset):
"""
# read a image given a random integer index
AB_path = self.AB_paths[index]
AB = cv2.imread(AB_path)
AB = cv2.cvtColor(cv2.imread(AB_path), cv2.COLOR_BGR2RGB)
# split AB image into A and B
h, w = AB.shape[:2]
......
from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute
from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip
......@@ -27,6 +27,7 @@ class Compose(object):
try:
data = f(data)
except Exception as e:
print(f)
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
......
......@@ -20,7 +20,7 @@ def get_makeup_transform(cfg, pic="image"):
if pic == "image":
transform = T.Compose([
T.Resize(size=cfg.trans_size),
T.Permute(to_rgb=False),
T.Transpose(),
])
else:
transform = T.Resize(size=cfg.trans_size,
......
......@@ -4,7 +4,7 @@ import numbers
import collections
import numpy as np
from paddle.utils import try_import
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
from .builder import TRANSFORMS
......@@ -16,261 +16,45 @@ else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
class Transform():
def _set_attributes(self, args):
"""
Set attributes from the input list of parameters.
Args:
args (list): list of parameters.
"""
if args:
for k, v in args.items():
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def apply_image(self, input):
raise NotImplementedError
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key])
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i])
else:
inputs = self.apply_image(inputs)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Resize(Transform):
"""Resize the input Image to the given size.
Args:
size (int|list|tuple): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def __init__(self, size, interpolation=1, keys=None):
super().__init__()
assert isinstance(size, int) or (isinstance(size, Iterable)
and len(size) == 2)
self._set_attributes(locals())
if isinstance(self.size, Iterable):
self.size = tuple(size)
def apply_image(self, img):
return F.resize(img, self.size, self.interpolation)
@TRANSFORMS.register()
class RandomCrop(Transform):
def __init__(self, output_size, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def _get_params(self, img):
h, w, _ = img.shape
th, tw = self.output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th)
j = random.randint(0, w - tw)
return i, j, th, tw
def apply_image(self, img):
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
TRANSFORMS.register(T.Resize)
TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
@TRANSFORMS.register()
class PairedRandomCrop(RandomCrop):
def __init__(self, output_size, keys=None):
super().__init__(output_size, keys)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def apply_image(self, img, crop_prams=None):
if crop_prams is not None:
i, j, h, w = crop_prams
else:
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
if isinstance(inputs, dict):
crop_params = self._get_params(inputs[self.keys[0]])
elif isinstance(inputs, (list, tuple)):
crop_params = self._get_params(inputs[0])
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
crop_params)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i],
crop_params)
if isinstance(size, int):
self.size = (size, size)
else:
crop_params = self._get_params(inputs)
inputs = self.apply_image(inputs, crop_params)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class RandomHorizontalFlip(Transform):
"""Horizontally flip the input data randomly with a given probability.
self.size = size
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def apply_image(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
def _apply_image(self, img):
i, j, h, w = self.params['crop_prams']
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, flip):
if flip:
return F.flip(img, code=1)
return img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
flip = np.random.random() < self.prob
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
flip)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip)
else:
inputs = self.apply_image(inputs, flip)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Normalize(Transform):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def __init__(self, mean=0.0, std=1.0, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def apply_image(self, img):
return (img - self.mean) / self.std
@TRANSFORMS.register()
class Permute(Transform):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def __init__(self, mode="CHW", to_rgb=True, keys=None):
super().__init__()
self._set_attributes(locals())
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def apply_image(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
class Crop():
def __init__(self, pos, size):
self.pos = pos
self.size = size
super().__init__(prob, keys=keys)
def __call__(self, img):
oh, ow, _ = img.shape
x, y = self.pos
th = tw = self.size
if (ow > tw or oh > th):
return img[y:y + th, x:x + tw]
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
return img
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
......@@ -64,8 +64,8 @@ class UnpairedDataset(BaseDataset):
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = cv2.imread(A_path)
B_img = cv2.imread(B_path)
A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB)
B_img = cv2.cvtColor(cv2.imread(B_path), cv2.COLOR_BGR2RGB)
# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)
......
......@@ -10,7 +10,7 @@ from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader
from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs
from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
......@@ -36,8 +36,8 @@ class Trainer:
# base config
self.output_dir = cfg.output_dir
self.epochs = cfg.epochs
self.start_epoch = 0
self.current_epoch = 0
self.start_epoch = 1
self.current_epoch = 1
self.batch_id = 0
self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval
......@@ -65,7 +65,7 @@ class Trainer:
reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager()
for epoch in range(self.start_epoch, self.epochs):
for epoch in range(self.start_epoch, self.epochs + 1):
self.current_epoch = epoch
start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader):
......@@ -91,8 +91,8 @@ class Trainer:
step_start_time = time.time()
self.logger.info(
'train one epoch time: {}'.format(time.time() - start_time))
self.logger.info('train one epoch time: {}'.format(time.time() -
start_time))
if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate()
self.model.lr_scheduler.step()
......@@ -102,8 +102,8 @@ class Trainer:
def validate(self):
if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader(
self.cfg.dataset.val, is_train=False)
self.val_dataloader = build_dataloader(self.cfg.dataset.val,
is_train=False)
metric_result = {}
......@@ -149,8 +149,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info(
'val iter: [%d/%d]' % (i, len(self.val_dataloader)))
self.logger.info('val iter: [%d/%d]' %
(i, len(self.val_dataloader)))
for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset)
......@@ -160,8 +160,8 @@ class Trainer:
def test(self):
if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader(
self.cfg.dataset.test, is_train=False)
self.test_dataloader = build_dataloader(self.cfg.dataset.test,
is_train=False)
# data[0]: img, data[1]: img path index
# test batch size must be 1
......@@ -185,8 +185,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0:
self.logger.info(
'Test iter: [%d/%d]' % (i, len(self.test_dataloader)))
self.logger.info('Test iter: [%d/%d]' %
(i, len(self.test_dataloader)))
def print_log(self):
losses = self.model.get_current_losses()
......@@ -208,7 +208,8 @@ class Trainer:
@property
def current_learning_rate(self):
return self.model.optimizers[0].get_lr()
for optimizer in self.model.optimizers.values():
return optimizer.get_lr()
def visual(self, results_dir, visual_results=None):
self.model.compute_visuals()
......@@ -216,7 +217,7 @@ class Trainer:
if visual_results is None:
visual_results = self.model.get_current_visuals()
if self.cfg.isTrain:
if self.cfg.is_train:
msg = 'epoch%.3d_' % self.current_epoch
else:
msg = ''
......@@ -240,10 +241,8 @@ class Trainer:
state_dicts = {}
save_filename = 'epoch_%s_%s.pkl' % (epoch, name)
save_path = os.path.join(self.output_dir, save_filename)
for net_name in self.model.model_names:
if isinstance(net_name, str):
net = getattr(self.model, 'net' + net_name)
state_dicts['net' + net_name] = net.state_dict()
for net_name, net in self.model.nets.items():
state_dicts[net_name] = net.state_dict()
if name == 'weight':
save(state_dicts, save_path)
......@@ -251,9 +250,7 @@ class Trainer:
state_dicts['epoch'] = epoch
for opt_name in self.model.optimizer_names:
if isinstance(opt_name, str):
opt = getattr(self.model, opt_name)
for opt_name, opt in self.model.optimizers.items():
state_dicts[opt_name] = opt.state_dict()
save(state_dicts, save_path)
......@@ -273,22 +270,14 @@ class Trainer:
if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1
for name in self.model.model_names:
if isinstance(name, str):
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name])
for name in self.model.optimizer_names:
if isinstance(name, str):
opt = getattr(self.model, name)
opt.set_dict(state_dicts[name])
for opt_name, opt in self.model.optimizers.items():
opt.set_dict(state_dicts[opt_name])
def load(self, weight_path):
state_dicts = load(weight_path)
for name in self.model.model_names:
if isinstance(name, str):
self.logger.info('laod model {} {} params!'.format(
self.cfg.model.name, 'net' + name))
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
for net_name, net in self.model.nets.items():
net.set_dict(state_dicts[net_name])
......@@ -13,13 +13,10 @@
# limitations under the License.
import paddle
from paddle import nn
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
import numpy as np
from .resnet import resnet18
from paddle.vision.models import resnet18
class ConvBNReLU(paddle.nn.Layer):
......@@ -32,13 +29,13 @@ class ConvBNReLU(paddle.nn.Layer):
*args,
**kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
self.conv = nn.Conv2D(in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2d(out_chan)
self.bn = nn.BatchNorm2D(out_chan)
self.relu = nn.ReLU()
def forward(self, x):
......@@ -52,7 +49,7 @@ class BiSeNetOutput(paddle.nn.Layer):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan,
self.conv_out = nn.Conv2D(mid_chan,
n_classes,
kernel_size=1,
bias_attr=False)
......@@ -67,7 +64,7 @@ class AttentionRefinementModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan,
self.conv_atten = nn.Conv2D(out_chan,
out_chan,
kernel_size=1,
bias_attr=False)
......@@ -87,16 +84,27 @@ class AttentionRefinementModule(paddle.nn.Layer):
class ContextPath(paddle.nn.Layer):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = resnet18()
self.backbone = resnet18(pretrained=True)
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def backbone_forward(self, x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
c2 = self.backbone.layer2(x)
c3 = self.backbone.layer3(c2)
c4 = self.backbone.layer4(c3)
return c2, c3, c4
def forward(self, x):
H0, W0 = x.shape[2:]
feat8, feat16, feat32 = self.resnet(x)
feat8, feat16, feat32 = self.backbone_forward(x)
H8, W8 = feat8.shape[2:]
H16, W16 = feat16.shape[2:]
H32, W32 = feat32.shape[2:]
......@@ -138,13 +146,13 @@ class FeatureFusionModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
self.conv1 = nn.Conv2D(out_chan,
out_chan // 4,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.conv2 = nn.Conv2d(out_chan // 4,
self.conv2 = nn.Conv2D(out_chan // 4,
out_chan,
kernel_size=1,
stride=1,
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
import numpy as np
import math
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'0ba53eea9bc970962d0ef96f7b94057e'),
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm(out_chan)
self.relu = nn.ReLU()
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan,
out_chan,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(paddle.nn.Layer):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = nn.BatchNorm(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def resnet18(pretrained=False, **kwargs):
model = Resnet18()
arch = 'resnet18'
if pretrained:
weight_path = './resnet.pdparams'
param, _ = paddle.load(weight_path)
model.set_dict(param)
return model
......@@ -18,4 +18,3 @@ from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .makeup_model import MakeupModel
from .vgg import vgg16
......@@ -8,7 +8,7 @@ __all__ = [
def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding"
return nn.Conv2d(in_planes,
return nn.Conv2D(in_planes,
out_planes,
kernel_size=3,
stride=stride,
......@@ -53,16 +53,16 @@ class Bottleneck(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias_attr=False)
self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
self.bn1 = nn.BatchNorm(planes)
self.conv2 = nn.Conv2d(planes,
self.conv2 = nn.Conv2D(planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn2 = nn.BatchNorm(planes)
self.conv3 = nn.Conv2d(planes,
self.conv3 = nn.Conv2D(planes,
planes * 4,
kernel_size=1,
bias_attr=False)
......@@ -97,7 +97,7 @@ class ResNet(nn.Layer):
def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3,
self.conv1 = nn.Conv2D(3,
64,
kernel_size=7,
stride=2,
......@@ -117,7 +117,7 @@ class ResNet(nn.Layer):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes,
nn.Conv2D(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
......
......@@ -17,47 +17,35 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
def __init__(self, cfg):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
Args:
cfg (Dict)-- configs of Model.
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
Then, you need to define four lists:
-- self.losses (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.losses (dict): specify the training losses that you want to plot and save.
-- self.nets (dict): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
"""
self.opt = opt
self.isTrain = opt.isTrain
self.cfg = cfg
self.is_train = cfg.is_train
self.save_dir = os.path.join(
opt.output_dir,
opt.model.name) # save all the checkpoints to save_dir
cfg.output_dir,
cfg.model.name) # save all the checkpoints to save_dir
self.losses = OrderedDict()
self.model_names = []
self.visual_names = []
self.optimizers = []
self.optimizer_names = []
self.nets = OrderedDict()
self.visual_items = OrderedDict()
self.optimizers = OrderedDict()
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
......@@ -78,7 +66,7 @@ class BaseModel(ABC):
pass
def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler)
self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler)
def eval(self):
"""Make models eval mode during test time"""
......@@ -106,12 +94,8 @@ class BaseModel(ABC):
return self.image_paths
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str) and hasattr(self, name):
visual_ret[name] = getattr(self, name)
return visual_ret
"""Return visualization images."""
return self.visual_items
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
......@@ -119,7 +103,7 @@ class BaseModel(ABC):
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
Args:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
......@@ -128,6 +112,4 @@ class BaseModel(ABC):
for net in nets:
if net is not None:
for param in net.parameters():
# print('trainable:', param.trainable)
param.trainable = requires_grad
# param.stop_gradient = not requires_grad
......@@ -24,84 +24,63 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
def __init__(self, opt):
def __init__(self, cfg):
"""Initialize the CycleGAN class.
Parameters:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
"""
BaseModel.__init__(self, opt)
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
# combine visualizations for A and B
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk.
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
super(CycleGANModel, self).__init__(cfg)
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = build_generator(opt.model.generator)
self.netG_B = build_generator(opt.model.generator)
init_weights(self.netG_A)
init_weights(self.netG_B)
if self.isTrain: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator)
self.netD_B = build_discriminator(opt.model.discriminator)
init_weights(self.netD_A)
init_weights(self.netD_B)
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
self.nets['netG_A'] = build_generator(cfg.model.generator)
self.nets['netG_B'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG_A'])
init_weights(self.nets['netG_B'])
if self.is_train: # define discriminators
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD_A'])
init_weights(self.nets['netD_B'])
if self.is_train:
if cfg.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert (
opt.dataset.train.input_nc == opt.dataset.train.output_nc)
cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
# create image buffer to store previously generated images
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size)
self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size)
# create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size)
self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size)
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netG_A.parameters() +
self.netG_B.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
parameter_list=self.nets['netG_A'].parameters() +
self.nets['netG_B'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters() +
self.netD_B.parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
parameter_list=self.nets['netD_A'].parameters() +
self.nets['netD_B'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Args:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
mode = 'train' if self.isTrain else 'test'
AtoB = self.opt.dataset[mode].direction == 'AtoB'
mode = 'train' if self.is_train else 'test'
AtoB = self.cfg.dataset[mode].direction == 'AtoB'
if AtoB:
if 'A' in input:
......@@ -122,12 +101,22 @@ class CycleGANModel(BaseModel):
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'):
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_B = self.nets['netG_A'](self.real_A) # G_A(A)
self.rec_A = self.nets['netG_B'](self.fake_B) # G_B(G_A(A))
# visual
self.visual_items['real_A'] = self.real_A
self.visual_items['fake_B'] = self.fake_B
self.visual_items['rec_A'] = self.rec_A
if hasattr(self, 'real_B'):
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
self.fake_A = self.nets['netG_B'](self.real_B) # G_B(B)
self.rec_B = self.nets['netG_A'](self.fake_A) # G_A(G_B(B))
# visual
self.visual_items['real_B'] = self.real_B
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
......@@ -148,40 +137,43 @@ class CycleGANModel(BaseModel):
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
# loss_D.backward()
if ParallelEnv().nranks > 1:
loss_D = netD.scale_loss(loss_D)
loss_D.backward()
netD.apply_collective_grads()
else:
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
lambda_idt = self.cfg.lambda_identity
lambda_A = self.cfg.lambda_A
lambda_B = self.cfg.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.idt_A = self.nets['netG_A'](self.real_B)
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.idt_B = self.nets['netG_B'](self.real_A)
# visual
self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
......@@ -189,9 +181,11 @@ class CycleGANModel(BaseModel):
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B),
True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A),
True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
......@@ -208,12 +202,6 @@ class CycleGANModel(BaseModel):
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if ParallelEnv().nranks > 1:
self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG_A.apply_collective_grads()
self.netG_B.apply_collective_grads()
else:
self.loss_G.backward()
def optimize_parameters(self):
......@@ -223,21 +211,22 @@ class CycleGANModel(BaseModel):
self.forward()
# G_A and G_B
# Ds require no gradients when optimizing Gs
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
False)
# set G_A and G_B's gradients to zero
self.optimizer_G.clear_gradients()
self.optimizers['optimizer_G'].clear_grad()
# calculate gradients for G_A and G_B
self.backward_G()
# update G_A and G_B's weights
self.optimizer_G.minimize(self.loss_G)
self.optimizers['optimizer_G'].step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
# set D_A and D_B's gradients to zero
self.optimizer_D.clear_gradients()
self.optimizers['optimizer_D'].clear_grad()
# calculate gradients for D_A
self.backward_D_A()
# calculate graidents for D_B
self.backward_D_B()
# update D_A and D_B's weights
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B)
self.optimizers['optimizer_D'].step()
......@@ -41,9 +41,9 @@ class NLayerDiscriminator(nn.Layer):
if type(
norm_layer
) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
use_bias = norm_layer.func == nn.InstanceNorm2D
else:
use_bias = norm_layer == nn.InstanceNorm2d
use_bias = norm_layer == nn.InstanceNorm2D
kw = 4
padw = 1
......@@ -51,7 +51,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral':
sequence = [
Spectralnorm(
nn.Conv2d(input_nc,
nn.Conv2D(input_nc,
ndf,
kernel_size=kw,
stride=2,
......@@ -60,7 +60,7 @@ class NLayerDiscriminator(nn.Layer):
]
else:
sequence = [
nn.Conv2d(input_nc,
nn.Conv2D(input_nc,
ndf,
kernel_size=kw,
stride=2,
......@@ -76,7 +76,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral':
sequence += [
Spectralnorm(
nn.Conv2d(ndf * nf_mult_prev,
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
......@@ -85,7 +85,7 @@ class NLayerDiscriminator(nn.Layer):
]
else:
sequence += [
nn.Conv2d(ndf * nf_mult_prev,
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
......@@ -100,7 +100,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral':
sequence += [
Spectralnorm(
nn.Conv2d(ndf * nf_mult_prev,
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
......@@ -109,7 +109,7 @@ class NLayerDiscriminator(nn.Layer):
]
else:
sequence += [
nn.Conv2d(ndf * nf_mult_prev,
nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
......@@ -122,7 +122,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral':
sequence += [
Spectralnorm(
nn.Conv2d(ndf * nf_mult,
nn.Conv2D(ndf * nf_mult,
1,
kernel_size=kw,
stride=1,
......@@ -131,7 +131,7 @@ class NLayerDiscriminator(nn.Layer):
] # output 1 channel prediction map
else:
sequence += [
nn.Conv2d(ndf * nf_mult,
nn.Conv2D(ndf * nf_mult,
1,
kernel_size=kw,
stride=1,
......
......@@ -2,9 +2,9 @@ import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import resnet101
from .hook import hook_outputs, model_sizes, dummy_eval
from ..backbones import resnet34, resnet101
from ...modules.nn import Spectralnorm
......@@ -137,7 +137,7 @@ def custom_conv_layer(ni: int,
bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True
if bias is None:
bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv = conv_func(ni,
nf,
......@@ -272,7 +272,7 @@ class PixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d((1, 0, 1, 0))
self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg')
self.relu = relu(True, leaky=leaky)
......@@ -298,7 +298,7 @@ def conv_layer(ni: int,
if padding is None: padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in ('Batch', 'BatchZero')
if bias is None: bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv = conv_func(ni,
nf,
......@@ -338,7 +338,7 @@ class CustomPixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d((1, 0, 1, 0))
self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg')
self.relu = nn.LeakyReLU(
leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky)
......@@ -409,7 +409,7 @@ class ReplicationPad2d(nn.Layer):
self.size = size
def forward(self, x):
return F.pad2d(x, self.size, mode="edge")
return F.pad(x, self.size, mode="replicate")
def conv1d(ni: int,
......@@ -419,7 +419,7 @@ def conv1d(ni: int,
padding: int = 0,
bias: bool = False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias_attr=bias)
conv = nn.Conv1D(ni, no, ks, stride=stride, padding=padding, bias_attr=bias)
return Spectralnorm(conv)
......
......@@ -77,7 +77,7 @@ class Hooks():
def _hook_inner(m, i, o):
return o if isinstance(
o, paddle.framework.Variable) else o if is_listy(o) else list(o)
o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False):
......
......@@ -49,22 +49,22 @@ class ResidualBlock(paddle.nn.Layer):
bias_attr = None
self.main = nn.Sequential(
nn.Conv2d(dim_in,
nn.Conv2D(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False),
nn.InstanceNorm2d(dim_out,
nn.InstanceNorm2D(dim_out,
weight_attr=weight_attr,
bias_attr=bias_attr), nn.ReLU(),
nn.Conv2d(dim_out,
nn.Conv2D(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False),
nn.InstanceNorm2d(dim_out,
nn.InstanceNorm2D(dim_out,
weight_attr=weight_attr,
bias_attr=bias_attr))
......@@ -78,7 +78,7 @@ class StyleResidualBlock(paddle.nn.Layer):
def __init__(self, dim_in, dim_out):
super(StyleResidualBlock, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(dim_in,
nn.Conv2D(dim_in,
dim_out,
kernel_size=3,
stride=1,
......@@ -86,18 +86,18 @@ class StyleResidualBlock(paddle.nn.Layer):
bias_attr=False), PONO())
ks = 3
pw = ks // 2
self.beta1 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma1 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.beta1 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma1 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.block2 = nn.Sequential(
nn.ReLU(),
nn.Conv2d(dim_out,
nn.Conv2D(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False), PONO())
self.beta2 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma2 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.beta2 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma2 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
def forward(self, x, y):
"""forward"""
......@@ -119,14 +119,14 @@ class MDNet(paddle.nn.Layer):
layers = []
layers.append(
nn.Conv2d(3,
nn.Conv2D(3,
conv_dim,
kernel_size=7,
stride=1,
padding=3,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=None, bias_attr=None))
nn.InstanceNorm2D(conv_dim, weight_attr=None, bias_attr=None))
layers.append(nn.ReLU())
......@@ -134,14 +134,14 @@ class MDNet(paddle.nn.Layer):
curr_dim = conv_dim
for i in range(2):
layers.append(
nn.Conv2d(curr_dim,
nn.Conv2D(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim * 2,
nn.InstanceNorm2D(curr_dim * 2,
weight_attr=None,
bias_attr=None))
layers.append(nn.ReLU())
......@@ -166,14 +166,14 @@ class TNetDown(paddle.nn.Layer):
layers = []
layers.append(
nn.Conv2d(3,
nn.Conv2D(3,
conv_dim,
kernel_size=7,
stride=1,
padding=3,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=False, bias_attr=False))
nn.InstanceNorm2D(conv_dim, weight_attr=False, bias_attr=False))
layers.append(nn.ReLU())
......@@ -181,14 +181,14 @@ class TNetDown(paddle.nn.Layer):
curr_dim = conv_dim
for i in range(2):
layers.append(
nn.Conv2d(curr_dim,
nn.Conv2D(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim * 2,
nn.InstanceNorm2D(curr_dim * 2,
weight_attr=False,
bias_attr=False))
layers.append(nn.ReLU())
......@@ -210,13 +210,13 @@ class TNetDown(paddle.nn.Layer):
class GetMatrix(paddle.fluid.dygraph.Layer):
def __init__(self, dim_in, dim_out):
super(GetMatrix, self).__init__()
self.get_gamma = nn.Conv2d(dim_in,
self.get_gamma = nn.Conv2D(dim_in,
dim_out,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.get_beta = nn.Conv2d(dim_in,
self.get_beta = nn.Conv2D(dim_in,
dim_out,
kernel_size=1,
stride=1,
......@@ -236,8 +236,8 @@ class MANet(paddle.nn.Layer):
self.encoder = TNetDown(conv_dim=conv_dim, repeat_num=repeat_num)
curr_dim = conv_dim * 4
self.w = w
self.beta = nn.Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1)
self.gamma = nn.Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1)
self.beta = nn.Conv2D(curr_dim, curr_dim, kernel_size=3, padding=1)
self.gamma = nn.Conv2D(curr_dim, curr_dim, kernel_size=3, padding=1)
self.simple_spade = GetMatrix(curr_dim, 1) # get the makeup matrix
self.repeat_num = repeat_num
for i in range(repeat_num):
......@@ -252,28 +252,28 @@ class MANet(paddle.nn.Layer):
for i in range(2):
layers = []
layers.append(
nn.ConvTranspose2d(curr_dim,
nn.Conv2DTranspose(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim // 2,
nn.InstanceNorm2D(curr_dim // 2,
weight_attr=False,
bias_attr=False))
setattr(self, "up_acts_" + str(i), nn.ReLU())
setattr(
self, "up_betas_" + str(i),
nn.ConvTranspose2d(y_dim,
nn.Conv2DTranspose(y_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1))
setattr(
self, "up_gammas_" + str(i),
nn.ConvTranspose2d(y_dim,
nn.Conv2DTranspose(y_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
......@@ -281,7 +281,7 @@ class MANet(paddle.nn.Layer):
setattr(self, "up_samplers_" + str(i), nn.Sequential(*layers))
curr_dim = curr_dim // 2
self.img_reg = [
nn.Conv2d(curr_dim,
nn.Conv2D(curr_dim,
3,
kernel_size=7,
stride=1,
......
......@@ -17,6 +17,7 @@ import functools
from ...modules.norm import build_norm_layer
from .builder import GENERATORS
@GENERATORS.register()
class MobileResnetGenerator(nn.Layer):
def __init__(self,
......@@ -31,39 +32,40 @@ class MobileResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == InstanceNorm
use_bias = norm_layer.func == nn.InstanceNorm2D
else:
use_bias = norm_layer == InstanceNorm
use_bias = norm_layer == nn.InstanceNorm2D
self.model = nn.LayerList([
nn.ReflectionPad2d([3, 3, 3, 3]),
nn.Conv2d(
input_channel,
nn.Conv2D(input_channel,
int(ngf),
kernel_size=7,
padding=0,
bias_attr=use_bias), norm_layer(ngf), nn.ReLU()
bias_attr=use_bias),
norm_layer(ngf),
nn.ReLU()
])
n_downsampling = 2
for i in range(n_downsampling):
mult = 2**i
self.model.extend([
nn.Conv2d(
ngf * mult,
nn.Conv2D(ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
padding=1,
bias_attr=use_bias), norm_layer(ngf * mult * 2), nn.ReLU()
bias_attr=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU()
])
mult = 2**n_downsampling
for i in range(n_blocks):
self.model.extend([
MobileResnetBlock(
ngf * mult,
MobileResnetBlock(ngf * mult,
ngf * mult,
padding_type=padding_type,
norm_layer=norm_layer,
......@@ -71,24 +73,23 @@ class MobileResnetGenerator(nn.Layer):
use_bias=use_bias)
])
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
output_size = (i + 1) * 128
self.model.extend([
nn.ConvTranspose2d(
ngf * mult,
nn.Conv2DTranspose(ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias_attr=use_bias), norm_layer(int(ngf * mult / 2)),
bias_attr=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU()
])
self.model.extend([nn.ReflectionPad2d([3, 3, 3, 3])])
self.model.extend([nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)])
self.model.extend([nn.Conv2D(ngf, output_nc, kernel_size=7, padding=0)])
self.model.extend([nn.Tanh()])
def forward(self, inputs):
......@@ -108,9 +109,9 @@ class MobileResnetBlock(nn.Layer):
p = 0
if self.padding_type == 'reflect':
self.conv_block.extend([nn.ReflectionPad2d([1, 1, 1, 1])])
self.conv_block.extend([nn.Pad2D([1, 1, 1, 1], mode='reflect')])
elif self.padding_type == 'replicate':
self.conv_block.extend([nn.ReplicationPad2d([1, 1, 1, 1])])
self.conv_block.extend([nn.Pad2D([1, 1, 1, 1], mode='replicate')])
elif self.padding_type == 'zero':
p = 1
else:
......@@ -118,12 +119,13 @@ class MobileResnetBlock(nn.Layer):
self.padding_type)
self.conv_block.extend([
SeparableConv2D(
num_channels=in_c,
SeparableConv2D(num_channels=in_c,
num_filters=out_c,
filter_size=3,
padding=p,
stride=1), norm_layer(out_c), nn.ReLU()
stride=1),
norm_layer(out_c),
nn.ReLU()
])
self.conv_block.extend([nn.Dropout(0.5)])
......@@ -139,12 +141,12 @@ class MobileResnetBlock(nn.Layer):
self.padding_type)
self.conv_block.extend([
SeparableConv2D(
num_channels=out_c,
SeparableConv2D(num_channels=out_c,
num_filters=in_c,
filter_size=3,
padding=p,
stride=1), norm_layer(in_c)
stride=1),
norm_layer(in_c)
])
def forward(self, inputs):
......@@ -154,6 +156,7 @@ class MobileResnetBlock(nn.Layer):
out = inputs + y
return out
class SeparableConv2D(nn.Layer):
def __init__(self,
num_channels,
......@@ -161,14 +164,14 @@ class SeparableConv2D(nn.Layer):
filter_size,
stride=1,
padding=0,
norm_layer=InstanceNorm,
norm_layer=nn.InstanceNorm2D,
use_bias=True,
scale_factor=1,
stddev=0.02):
super(SeparableConv2D, self).__init__()
self.conv = nn.LayerList([
nn.Conv2d(
nn.Conv2D(
in_channels=num_channels,
out_channels=num_channels * scale_factor,
kernel_size=filter_size,
......@@ -176,22 +179,20 @@ class SeparableConv2D(nn.Layer):
padding=padding,
groups=num_channels,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
loc=0.0, scale=stddev)),
initializer=nn.initializer.Normal(loc=0.0, scale=stddev)),
bias_attr=use_bias)
])
self.conv.extend([norm_layer(num_channels * scale_factor)])
self.conv.extend([
nn.Conv2d(
nn.Conv2D(
in_channels=num_channels * scale_factor,
out_channels=num_filters,
kernel_size=1,
stride=1,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(
loc=0.0, scale=stddev)),
initializer=nn.initializer.Normal(loc=0.0, scale=stddev)),
bias_attr=use_bias)
])
......@@ -199,4 +200,3 @@ class SeparableConv2D(nn.Layer):
for sublayer in self.conv:
inputs = sublayer(inputs)
return inputs
......@@ -67,7 +67,7 @@ class OcclusionAwareGenerator(nn.Layer):
'r' + str(i),
ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
self.final = nn.Conv2d(block_expansion,
self.final = nn.Conv2D(block_expansion,
num_channels,
kernel_size=(7, 7),
padding=(3, 3))
......
......@@ -11,7 +11,7 @@ class TempConv(nn.Layer):
stride=(1, 1, 1),
padding=(0, 1, 1)):
super(TempConv, self).__init__()
self.conv3d = nn.Conv3d(in_planes,
self.conv3d = nn.Conv3D(in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
......@@ -26,7 +26,7 @@ class Upsample(nn.Layer):
def __init__(self, in_planes, out_planes, scale_factor=(1, 2, 2)):
super(Upsample, self).__init__()
self.scale_factor = scale_factor
self.conv3d = nn.Conv3d(in_planes,
self.conv3d = nn.Conv3D(in_planes,
out_planes,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
......@@ -88,13 +88,13 @@ class SourceReferenceAttention(nn.Layer):
Number of input reference feature vector channels.
"""
super(SourceReferenceAttention, self).__init__()
self.query_conv = nn.Conv3d(in_channels=in_planes_s,
self.query_conv = nn.Conv3D(in_channels=in_planes_s,
out_channels=in_planes_s // 8,
kernel_size=1)
self.key_conv = nn.Conv3d(in_channels=in_planes_r,
self.key_conv = nn.Conv3D(in_channels=in_planes_r,
out_channels=in_planes_r // 8,
kernel_size=1)
self.value_conv = nn.Conv3d(in_channels=in_planes_r,
self.value_conv = nn.Conv3D(in_channels=in_planes_r,
out_channels=in_planes_r,
kernel_size=1)
self.gamma = self.create_parameter(
......@@ -128,7 +128,7 @@ class NetworkR(nn.Layer):
super(NetworkR, self).__init__()
self.layers = nn.Sequential(
nn.ReplicationPad3d((1, 1, 1, 1, 1, 1)),
nn.Pad3D((1, 1, 1, 1, 1, 1), mode='replicate'),
TempConv(1,
64,
kernel_size=(3, 3, 3),
......@@ -149,7 +149,7 @@ class NetworkR(nn.Layer):
TempConv(128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
TempConv(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
Upsample(64, 16),
nn.Conv3d(16,
nn.Conv3D(16,
1,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
......@@ -165,7 +165,7 @@ class NetworkC(nn.Layer):
super(NetworkC, self).__init__()
self.down1 = nn.Sequential(
nn.ReplicationPad3d((1, 1, 1, 1, 0, 0)),
nn.Pad3D((1, 1, 1, 1, 0, 0), mode='replicate'),
TempConv(1, 64, stride=(1, 2, 2), padding=(0, 0, 0)),
TempConv(64, 128), TempConv(128, 128),
TempConv(128, 256, stride=(1, 2, 2)), TempConv(256, 256),
......@@ -205,7 +205,7 @@ class NetworkC(nn.Layer):
padding=(1, 1, 1)))
self.up4 = nn.Sequential(
Upsample(16, 8), # 1/1
nn.Conv3d(8,
nn.Conv3D(8,
2,
kernel_size=(3, 3, 3),
stride=(1, 1, 1),
......
......@@ -13,7 +13,6 @@ class ResnetGenerator(nn.Layer):
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self,
input_nc,
output_nc,
......@@ -38,14 +37,17 @@ class ResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
use_bias = norm_layer.func == nn.InstanceNorm2D
else:
use_bias = norm_layer == nn.InstanceNorm2d
use_bias = norm_layer == nn.InstanceNorm2D
model = [
nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2d(
input_nc, ngf, kernel_size=7, padding=0, bias_attr=use_bias),
nn.Conv2D(input_nc,
ngf,
kernel_size=7,
padding=0,
bias_attr=use_bias),
norm_layer(ngf),
nn.ReLU()
]
......@@ -54,8 +56,7 @@ class ResnetGenerator(nn.Layer):
for i in range(n_downsampling): # add downsampling layers
mult = 2**i
model += [
nn.Conv2d(
ngf * mult,
nn.Conv2D(ngf * mult,
ngf * mult * 2,
kernel_size=3,
stride=2,
......@@ -69,8 +70,7 @@ class ResnetGenerator(nn.Layer):
for i in range(n_blocks): # add ResNet blocks
model += [
ResnetBlock(
ngf * mult,
ResnetBlock(ngf * mult,
padding_type=padding_type,
norm_layer=norm_layer,
use_dropout=use_dropout,
......@@ -80,8 +80,7 @@ class ResnetGenerator(nn.Layer):
for i in range(n_downsampling): # add upsampling layers
mult = 2**(n_downsampling - i)
model += [
nn.ConvTranspose2d(
ngf * mult,
nn.Conv2DTranspose(ngf * mult,
int(ngf * mult / 2),
kernel_size=3,
stride=2,
......@@ -92,7 +91,7 @@ class ResnetGenerator(nn.Layer):
nn.ReLU()
]
model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Conv2D(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
......@@ -104,7 +103,6 @@ class ResnetGenerator(nn.Layer):
class ResnetBlock(nn.Layer):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
......@@ -137,11 +135,11 @@ class ResnetBlock(nn.Layer):
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim),
nn.ReLU()
]
......@@ -154,10 +152,10 @@ class ResnetBlock(nn.Layer):
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError(
'padding [%s] is not implemented' % padding_type)
raise NotImplementedError('padding [%s] is not implemented' %
padding_type)
conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim)
]
......
......@@ -10,14 +10,13 @@ class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias_attr=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias_attr=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.conv1 = nn.Conv2D(nf, gc, 3, 1, 1, bias_attr=bias)
self.conv2 = nn.Conv2D(nf + gc, gc, 3, 1, 1, bias_attr=bias)
self.conv3 = nn.Conv2D(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv4 = nn.Conv2D(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv5 = nn.Conv2D(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1)))
......@@ -29,7 +28,6 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
......@@ -42,6 +40,7 @@ class RRDB(nn.Layer):
out = self.RDB3(out)
return out * 0.2 + x
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
......@@ -55,14 +54,14 @@ class RRDBNet(nn.Layer):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias_attr=True)
self.conv_first = nn.Conv2D(in_nc, nf, 3, 1, 1, bias_attr=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.trunk_conv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias_attr=True)
self.upconv1 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.HRconv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.conv_last = nn.Conv2D(nf, out_nc, 3, 1, 1, bias_attr=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2)
......@@ -71,8 +70,10 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
......@@ -104,12 +104,12 @@ class UnetSkipConnectionBlock(nn.Layer):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
use_bias = norm_layer.func == nn.InstanceNorm2D
else:
use_bias = norm_layer == nn.InstanceNorm
use_bias = norm_layer == nn.InstanceNorm2D
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc,
downconv = nn.Conv2D(input_nc,
inner_nc,
kernel_size=4,
stride=2,
......@@ -121,7 +121,7 @@ class UnetSkipConnectionBlock(nn.Layer):
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2,
upconv = nn.Conv2DTranspose(inner_nc * 2,
outer_nc,
kernel_size=4,
stride=2,
......@@ -130,7 +130,7 @@ class UnetSkipConnectionBlock(nn.Layer):
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc,
upconv = nn.Conv2DTranspose(inner_nc,
outer_nc,
kernel_size=4,
stride=2,
......@@ -140,7 +140,7 @@ class UnetSkipConnectionBlock(nn.Layer):
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2,
upconv = nn.Conv2DTranspose(inner_nc * 2,
outer_nc,
kernel_size=4,
stride=2,
......
......@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .base_model import BaseModel
from .builder import MODELS
......@@ -26,92 +28,62 @@ from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset
import numpy as np
from .vgg import vgg16
@MODELS.register()
class MakeupModel(BaseModel):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
"""
def __init__(self, opt):
"""Initialize the CycleGAN class.
def __init__(self, cfg):
"""Initialize the PSGAN class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
cfg (dict)-- config of model.
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_A', 'rec_A']
visual_names_B = ['real_B', 'fake_B', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
self.vgg = vgg16(pretrained=True)
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G']
super(MakeupModel, self).__init__(cfg)
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG = build_generator(opt.model.generator)
init_weights(self.netG, init_type='xavier', init_gain=1.0)
self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
if self.isTrain: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator)
self.netD_B = build_discriminator(opt.model.discriminator)
init_weights(self.netD_A, init_type='xavier', init_gain=1.0)
init_weights(self.netD_B, init_type='xavier', init_gain=1.0)
if self.is_train: # define discriminators
vgg = vgg16(pretrained=True)
self.vgg = vgg.features
self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
if self.isTrain:
self.fake_A_pool = ImagePool(
opt.dataset.train.pool_size
cfg.dataset.train.pool_size
) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(
opt.dataset.train.pool_size
cfg.dataset.train.pool_size
) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = GANLoss(
opt.model.gan_mode) #.to(self.device) # define GAN loss.
cfg.model.gan_mode) #.to(self.device) # define GAN loss.
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.criterionL1 = paddle.nn.L1Loss()
self.criterionL2 = paddle.nn.MSELoss()
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
# self.optimizer_D = paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD_A.parameters() + self.netD_B.parameters(), beta1=opt.beta1)
self.optimizer_DA = build_optimizer(
opt.optimizer,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_DA'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters())
self.optimizer_DB = build_optimizer(
opt.optimizer,
parameter_list=self.nets['netD_A'].parameters())
self.optimizers['optimizer_DB'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netD_B.parameters())
self.optimizers.append(self.optimizer_G)
# self.optimizers.append(self.optimizer_D)
self.optimizers.append(self.optimizer_DA)
self.optimizers.append(self.optimizer_DB)
self.optimizer_names.extend(
['optimizer_G', 'optimizer_DA', 'optimizer_DB'])
parameter_list=self.nets['netD_B'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
......@@ -129,37 +101,47 @@ class MakeupModel(BaseModel):
self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
if self.isTrain:
if self.is_train:
self.mask_A = paddle.to_tensor(input['mask_A'])
self.mask_B = paddle.to_tensor(input['mask_B'])
self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])
#self.hm_gt_A = self.hm_gt_A_lip + self.hm_gt_A_skin + self.hm_gt_A_eye
#self.hm_gt_B = self.hm_gt_B_lip + self.hm_gt_B_skin + self.hm_gt_B_eye
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_A, amm = self.netG(self.real_A, self.real_B, self.P_A,
self.P_B, self.c_m, self.mask_A_aug,
self.fake_A, amm = self.nets['netG'](self.real_A, self.real_B, self.P_A,
self.P_B, self.c_m,
self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.fake_B, _ = self.netG(self.real_B, self.real_A, self.P_B, self.P_A,
self.c_m_t, self.mask_A_aug,
self.fake_B, _ = self.nets['netG'](self.real_B, self.real_A, self.P_B,
self.P_A, self.c_m_t,
self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.rec_A, _ = self.netG(self.fake_A, self.real_A, self.P_A, self.P_A,
self.c_m_idt_a, self.mask_A_aug,
self.rec_A, _ = self.nets['netG'](self.fake_A, self.real_A, self.P_A,
self.P_A, self.c_m_idt_a,
self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.rec_B, _ = self.netG(self.fake_B, self.real_B, self.P_B, self.P_B,
self.c_m_idt_b, self.mask_A_aug,
self.rec_B, _ = self.nets['netG'](self.fake_B, self.real_B, self.P_B,
self.P_B, self.c_m_idt_b,
self.mask_A_aug,
self.mask_B_aug) # G_A(A)
# visual
self.visual_items['real_A'] = self.real_A
self.visual_items['fake_B'] = self.fake_B
self.visual_items['rec_A'] = self.rec_A
self.visual_items['real_B'] = self.real_B
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def forward_test(self, input):
'''
not implement now
'''
return self.netG(input['image_A'], input['image_B'], input['P_A'],
input['P_B'], input['consis_mask'],
input['mask_A_aug'], input['mask_B_aug'])
return self.nets['netG'](input['image_A'], input['image_B'],
input['P_A'], input['P_B'],
input['consis_mask'], input['mask_A_aug'],
input['mask_B_aug'])
def test(self, input):
"""Forward function used in test time.
......@@ -195,51 +177,52 @@ class MakeupModel(BaseModel):
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
fake_B)
self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
fake_A)
self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
'''
self.loss_names = [
'G_A_vgg',
'G_B_vgg',
'G_bg_consis'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A', 'amm_a']
visual_names_B = ['real_B', 'fake_A', 'rec_B', 'amm_b']
'''
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
lambda_idt = self.cfg.lambda_identity
lambda_A = self.cfg.lambda_A
lambda_B = self.cfg.lambda_B
lambda_vgg = 5e-3
# Identity loss
if lambda_idt > 0:
self.idt_A, _ = self.netG(self.real_A, self.real_A, self.P_A,
self.P_A, self.c_m_idt_a, self.mask_A_aug,
self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
self.P_A, self.P_A,
self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_A) * lambda_A * lambda_idt
self.idt_B, _ = self.netG(self.real_B, self.real_B, self.P_B,
self.P_B, self.c_m_idt_b, self.mask_A_aug,
self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
self.P_B, self.P_B,
self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_B) * lambda_B * lambda_idt
# visual
self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_A), True)
self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A),
True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_B), True)
self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B),
True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
......@@ -381,27 +364,24 @@ class MakeupModel(BaseModel):
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad(
[self.netD_A, self.netD_B],
[self.nets['netD_A'], self.nets['netD_B']],
False) # Ds require no gradients when optimizing Gs
# self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.minimize(
self.optimizers['optimizer_G'].minimize(
self.loss_G) #step() # update G_A and G_B's weights
self.optimizer_G.clear_gradients()
# self.optimizer_G.clear_gradients()
self.optimizers['optimizer_G'].clear_gradients()
# D_A and D_B
# self.set_requires_grad([self.netD_A, self.netD_B], True)
self.set_requires_grad(self.netD_A, True)
self.set_requires_grad(self.nets['netD_A'], True)
# self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.optimizer_DA.minimize(
self.optimizers['optimizer_DA'].minimize(
self.loss_D_A) #step() # update D_A and D_B's weights
self.optimizer_DA.clear_gradients() #zero_g
self.set_requires_grad(self.netD_B, True)
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.optimizers['optimizer_DA'].clear_gradients() #zero_g
self.set_requires_grad(self.nets['netD_B'], True)
self.backward_D_B() # calculate graidents for D_B
self.optimizer_DB.minimize(
self.optimizers['optimizer_DB'].minimize(
self.loss_D_B) #step() # update D_A and D_B's weights
self.optimizer_DB.clear_gradients(
self.optimizers['optimizer_DB'].clear_gradients(
) #zero_grad() # set D_A and D_B's gradients to zero
......@@ -23,52 +23,38 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
"""
def __init__(self, opt):
def __init__(self, cfg):
"""Initialize the pix2pix class.
Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk.
if self.isTrain:
self.model_names = ['G', 'D']
else:
# during test time, only load G
self.model_names = ['G']
super(Pix2PixModel, self).__init__(cfg)
# define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator)
init_weights(self.netG)
self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.isTrain:
self.netD = build_discriminator(opt.model.discriminator)
init_weights(self.netD)
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
if self.isTrain:
if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode)
self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss()
# build optimizers
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
self.optimizer_D = build_optimizer(
opt.optimizer,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.netD.parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
parameter_list=self.nets['netD'].parameters())
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
......@@ -79,38 +65,39 @@ class Pix2PixModel(BaseModel):
The option 'direction' can be used to swap images in domain A and domain B.
"""
AtoB = self.opt.dataset.train.direction == 'AtoB'
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
AtoB = self.cfg.dataset.train.direction == 'AtoB'
# TODO: replace to_varialbe with to_tensor
self.real_A = paddle.fluid.dygraph.to_variable(
input['A' if AtoB else 'B'])
self.real_B = paddle.fluid.dygraph.to_variable(
input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A)
self.fake_B = self.nets['netG'](self.real_A) # G(A)
def forward_test(self, input):
input = paddle.to_tensor(input)
return self.netG(input)
# put items to visual dict
self.visual_items['fake_B'] = self.fake_B
self.visual_items['real_A'] = self.real_A
self.visual_items['real_B'] = self.real_B
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
# use conditional GANs; we need to feed both input and output to the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB.detach())
pred_fake = self.nets['netD'](fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = paddle.concat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
pred_real = self.nets['netD'](real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
if ParallelEnv().nranks > 1:
self.loss_D = self.netD.scale_loss(self.loss_D)
self.loss_D.backward()
self.netD.apply_collective_grads()
else:
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake
......@@ -120,20 +107,15 @@ class Pix2PixModel(BaseModel):
"""Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB)
pred_fake = self.nets['netD'](fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B,
self.real_B) * self.opt.lambda_L1
self.real_B) * self.cfg.lambda_L1
# combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1
if ParallelEnv().nranks > 1:
self.loss_G = self.netG.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG.apply_collective_grads()
else:
self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN
......@@ -144,13 +126,13 @@ class Pix2PixModel(BaseModel):
self.forward()
# update D
self.set_requires_grad(self.netD, True)
self.optimizer_D.clear_gradients()
self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizer_D.minimize(self.loss_D)
self.optimizers['optimizer_D'].step()
# update G
self.set_requires_grad(self.netD, False)
self.optimizer_G.clear_gradients()
self.set_requires_grad(self.nets['netD'], False)
self.optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizer_G.minimize(self.loss_G)
self.optimizers['optimizer_G'].step()
......@@ -30,7 +30,7 @@ class SRModel(BaseModel):
self.loss_names = ['l_total']
self.optimizers = []
if self.isTrain:
if self.is_train:
self.criterionL1 = paddle.nn.L1Loss()
self.build_lr_scheduler()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
from paddle.vision.models.vgg import make_layers
cfg = [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
512, 512, 'M'
]
model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
'89bbffc0f87d260be9b8cdc169c991c4')
}
class VGG(nn.Layer):
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
def forward(self, x):
x = self.features(x)
return x
def vgg16(pretrained=False):
features = make_layers(cfg)
model = VGG(features)
if pretrained:
weight_path = get_weights_path_from_url(model_urls['vgg16'][0],
model_urls['vgg16'][1])
param = paddle.load(weight_path)
model.load_dict(param)
return model
......@@ -25,13 +25,13 @@ class DenseMotionNetwork(nn.Layer):
max_features=max_features,
num_blocks=num_blocks)
self.mask = nn.Conv2d(self.hourglass.out_filters,
self.mask = nn.Conv2D(self.hourglass.out_filters,
num_kp + 1,
kernel_size=(7, 7),
padding=(3, 3))
if estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters,
self.occlusion = nn.Conv2D(self.hourglass.out_filters,
1,
kernel_size=(7, 7),
padding=(3, 3))
......
......@@ -52,16 +52,16 @@ class ResBlock2d(nn.Layer):
"""
def __init__(self, in_features, kernel_size, padding):
super(ResBlock2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features,
self.conv1 = nn.Conv2D(in_channels=in_features,
out_channels=in_features,
kernel_size=kernel_size,
padding=padding)
self.conv2 = nn.Conv2d(in_channels=in_features,
self.conv2 = nn.Conv2D(in_channels=in_features,
out_channels=in_features,
kernel_size=kernel_size,
padding=padding)
self.norm1 = nn.BatchNorm2d(in_features)
self.norm2 = nn.BatchNorm2d(in_features)
self.norm1 = nn.BatchNorm2D(in_features)
self.norm2 = nn.BatchNorm2D(in_features)
def forward(self, x):
out = self.norm1(x)
......@@ -86,12 +86,12 @@ class UpBlock2d(nn.Layer):
groups=1):
super(UpBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features,
self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_features)
self.norm = nn.BatchNorm2D(out_features)
def forward(self, x):
out = F.interpolate(x, scale_factor=2)
......@@ -112,13 +112,13 @@ class DownBlock2d(nn.Layer):
padding=1,
groups=1):
super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features,
self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_features)
self.pool = nn.AvgPool2d(kernel_size=(2, 2))
self.norm = nn.BatchNorm2D(out_features)
self.pool = nn.AvgPool2D(kernel_size=(2, 2))
def forward(self, x):
out = self.conv(x)
......@@ -139,12 +139,12 @@ class SameBlock2d(nn.Layer):
kernel_size=3,
padding=1):
super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features,
self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
padding=padding,
groups=groups)
self.norm = nn.BatchNorm2d(out_features)
self.norm = nn.BatchNorm2D(out_features)
def forward(self, x):
out = self.conv(x)
......
......@@ -26,14 +26,14 @@ class KPDetector(nn.Layer):
max_features=max_features,
num_blocks=num_blocks)
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters,
self.kp = nn.Conv2D(in_channels=self.predictor.out_filters,
out_channels=num_kp,
kernel_size=(7, 7),
padding=pad)
if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters,
self.jacobian = nn.Conv2D(in_channels=self.predictor.out_filters,
out_channels=4 * self.num_jacobian_maps,
kernel_size=(7, 7),
padding=pad)
......
......@@ -21,20 +21,19 @@ def build_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(
nn.BatchNorm,
weight_attr=paddle.ParamAttr(
param_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(1.0, 0.02)),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)),
trainable_statistics=True)
elif norm_type == 'instance':
norm_layer = functools.partial(
nn.InstanceNorm2d,
nn.InstanceNorm2D,
weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0),
learning_rate=0.0,
trainable=False),
bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0),
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0),
learning_rate=0.0,
trainable=False))
elif norm_type == 'spectral':
......@@ -44,6 +43,6 @@ def build_norm_layer(norm_type='instance'):
def norm_layer(x):
return Identity()
else:
raise NotImplementedError(
'normalization layer [%s] is not found' % norm_type)
raise NotImplementedError('normalization layer [%s] is not found' %
norm_type)
return norm_layer
......@@ -15,7 +15,8 @@ def save(state_dicts, file_name):
for k, v in state_dict.items():
if isinstance(
v, (paddle.framework.Variable, paddle.fluid.core.VarBase)):
v,
(paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
model_dict[k] = v.numpy()
else:
model_dict[k] = v
......@@ -24,8 +25,9 @@ def save(state_dicts, file_name):
final_dict = {}
for k, v in state_dicts.items():
if isinstance(v,
(paddle.framework.Variable, paddle.fluid.core.VarBase)):
if isinstance(
v,
(paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
final_dict = convert(state_dicts)
break
elif isinstance(v, dict):
......
......@@ -122,11 +122,7 @@ def cal_hist(image):
hists = []
for i in range(0, 3):
channel = image[i]
# channel = image[i, :, :]
#channel = torch.from_numpy(channel)
hist, _ = np.histogram(channel, bins=256, range=(0, 255))
#hist = torch.histc(channel, bins=256, min=0, max=256)
# refHist=hist.view(256,1)
sum = hist.sum()
pdf = [v / sum for v in hist]
for i in range(1, 256):
......
......@@ -2,14 +2,14 @@ import os
import time
import paddle
from paddle.distributed import ParallelEnv
from .logger import setup_logger
def setup(args, cfg):
if args.evaluate_only:
cfg.isTrain = False
cfg.is_train = False
else:
cfg.is_train = True
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir,
......@@ -19,6 +19,7 @@ def setup(args, cfg):
logger.info('Configs: {}'.format(cfg))
place = paddle.CUDAPlace(ParallelEnv().dev_id) \
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
paddle.disable_static(place)
if paddle.is_compiled_with_cuda():
paddle.set_device('gpu')
else:
paddle.set_device('cpu')
......@@ -2,3 +2,5 @@ tqdm
PyYAML>=5.1
scikit-image>=0.14.0
scipy>=1.1.0
opencv-python
imageio-ffmpeg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册