未验证 提交 3e4a59f2 编写于 作者: L LielinJiang 提交者: GitHub

Refine code (#44)

* refine code

* test

* fix apps

* update readme

* rm unused code

* fix apps output when input is image

* clean code

* update requirements.txt
上级 359db9ce
...@@ -61,7 +61,7 @@ pip install -v -e . # or "python setup.py develop" ...@@ -61,7 +61,7 @@ pip install -v -e . # or "python setup.py develop"
Please refer to [data prepare](./docs/data_prepare.md) for dataset preparation. Please refer to [data prepare](./docs/data_prepare.md) for dataset preparation.
## Get Start ## Get Start
Please refer [get stated](./docs/get_started.md) for the basic usage of PaddleGAN. Please refer [get started](./docs/get_started.md) for the basic usage of PaddleGAN.
## Model tutorial ## Model tutorial
* [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md) * [Pixel2Pixel and CycleGAN](./docs/tutorials/pix2pix_cyclegan.md)
......
epochs: 200 epochs: 200
isTrain: True
output_dir: output_dir output_dir: output_dir
lambda_A: 10.0 lambda_A: 10.0
lambda_B: 10.0 lambda_B: 10.0
...@@ -39,12 +38,12 @@ dataset: ...@@ -39,12 +38,12 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop - name: RandomCrop
output_size: [256, 256] size: [256, 256]
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
- name: Permute - name: Transpose
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
...@@ -60,8 +59,8 @@ dataset: ...@@ -60,8 +59,8 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Permute - name: Transpose
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
......
epochs: 200 epochs: 200
isTrain: True
output_dir: output_dir output_dir: output_dir
lambda_A: 10.0 lambda_A: 10.0
lambda_B: 10.0 lambda_B: 10.0
...@@ -38,12 +37,12 @@ dataset: ...@@ -38,12 +37,12 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: RandomCrop - name: RandomCrop
output_size: [256, 256] output_size: [256, 256]
- name: RandomHorizontalFlip - name: RandomHorizontalFlip
prob: 0.5 prob: 0.5
- name: Permute - name: Transpose
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
...@@ -60,8 +59,8 @@ dataset: ...@@ -60,8 +59,8 @@ dataset:
transform: transform:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
- name: Permute - name: Transpose
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5] std: [127.5, 127.5, 127.5]
......
epochs: 100 epochs: 100
isTrain: True
output_dir: tmp output_dir: tmp
checkpoints_dir: checkpoints checkpoints_dir: checkpoints
lambda_A: 10.0 lambda_A: 10.0
...@@ -24,14 +23,14 @@ dataset: ...@@ -24,14 +23,14 @@ dataset:
train: train:
name: MakeupDataset name: MakeupDataset
trans_size: 256 trans_size: 256
dataroot: MT-Dataset dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup] cls_list: [non-makeup, makeup]
phase: train phase: train
pool_size: 16 pool_size: 16
test: test:
name: MakeupDataset name: MakeupDataset
trans_size: 256 trans_size: 256
dataroot: MT-Dataset dataroot: data/MT-Dataset
cls_list: [non-makeup, makeup] cls_list: [non-makeup, makeup]
phase: test phase: test
pool_size: 16 pool_size: 16
......
epochs: 200 epochs: 200
isTrain: True
output_dir: output_dir output_dir: output_dir
lambda_L1: 100 lambda_L1: 100
...@@ -36,15 +35,15 @@ dataset: ...@@ -36,15 +35,15 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: PairedRandomCrop - name: PairedRandomCrop
output_size: [256, 256] size: [256, 256]
keys: [image, image] keys: [image, image]
- name: PairedRandomHorizontalFlip - name: PairedRandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
...@@ -63,9 +62,9 @@ dataset: ...@@ -63,9 +62,9 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
......
epochs: 200 epochs: 200
isTrain: True
output_dir: output_dir output_dir: output_dir
lambda_L1: 100 lambda_L1: 100
...@@ -35,15 +34,15 @@ dataset: ...@@ -35,15 +34,15 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: PairedRandomCrop - name: PairedRandomCrop
output_size: [256, 256] size: [256, 256]
keys: [image, image] keys: [image, image]
- name: PairedRandomHorizontalFlip - name: PairedRandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
...@@ -62,9 +61,9 @@ dataset: ...@@ -62,9 +61,9 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
......
epochs: 200 epochs: 200
isTrain: True
output_dir: output_dir output_dir: output_dir
lambda_L1: 100 lambda_L1: 100
...@@ -35,15 +34,15 @@ dataset: ...@@ -35,15 +34,15 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [286, 286] size: [286, 286]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: PairedRandomCrop - name: PairedRandomCrop
output_size: [256, 256] size: [256, 256]
keys: [image, image] keys: [image, image]
- name: PairedRandomHorizontalFlip - name: PairedRandomHorizontalFlip
prob: 0.5 prob: 0.5
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
...@@ -62,9 +61,9 @@ dataset: ...@@ -62,9 +61,9 @@ dataset:
transforms: transforms:
- name: Resize - name: Resize
size: [256, 256] size: [256, 256]
interpolation: 2 #cv2.INTER_CUBIC interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image] keys: [image, image]
- name: Permute - name: Transpose
keys: [image, image] keys: [image, image]
- name: Normalize - name: Normalize
mean: [127.5, 127.5, 127.5] mean: [127.5, 127.5, 127.5]
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import cv2 import cv2
from PIL import Image
import paddle import paddle
...@@ -61,9 +62,10 @@ class BasePredictor(object): ...@@ -61,9 +62,10 @@ class BasePredictor(object):
return out return out
def is_video(self, input): def is_image(self, input):
try: try:
cv2.VideoCapture(input) img = Image.open(input)
_ = img.size
return True return True
except: except:
return False return False
......
...@@ -128,13 +128,15 @@ class DeOldifyPredictor(BasePredictor): ...@@ -128,13 +128,15 @@ class DeOldifyPredictor(BasePredictor):
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def run(self, input): def run(self, input):
if self.is_video(input): if not self.is_image(input):
return self.run_video(input) return self.run_video(input)
else: else:
pred_img = self.run_image(input) pred_img = self.run_image(input)
out_path = None
if self.output: if self.output:
base_name = os.path.basename(input) base_name = os.path.splitext(os.path.basename(input))[0]
pred_img.save(os.path.join(self.output, base_name + '.png')) out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)
return pred_img return pred_img, out_path
...@@ -98,13 +98,18 @@ class RealSRPredictor(BasePredictor): ...@@ -98,13 +98,18 @@ class RealSRPredictor(BasePredictor):
return frame_pattern_combined, vid_out_path return frame_pattern_combined, vid_out_path
def run(self, input): def run(self, input):
if self.is_video(input): if not os.path.exists(self.output):
os.makedirs(self.output)
if not self.is_image(input):
return self.run_video(input) return self.run_video(input)
else: else:
pred_img = self.run_image(input) pred_img = self.run_image(input)
out_path = None
if self.output: if self.output:
base_name = os.path.basename(input) base_name = os.path.splitext(os.path.basename(input))[0]
pred_img.save(os.path.join(self.output, base_name + '.png')) out_path = os.path.join(self.output, base_name + '.png')
pred_img.save(out_path)
return pred_img return pred_img, out_path
...@@ -16,7 +16,7 @@ class PairedDataset(BaseDataset): ...@@ -16,7 +16,7 @@ class PairedDataset(BaseDataset):
"""Initialize this dataset class. """Initialize this dataset class.
Args: Args:
cfg (dict) -- stores all the experiment flags cfg (dict): configs of datasets.
""" """
BaseDataset.__init__(self, cfg) BaseDataset.__init__(self, cfg)
self.dir_AB = os.path.join(cfg.dataroot, self.dir_AB = os.path.join(cfg.dataroot,
...@@ -42,7 +42,7 @@ class PairedDataset(BaseDataset): ...@@ -42,7 +42,7 @@ class PairedDataset(BaseDataset):
""" """
# read a image given a random integer index # read a image given a random integer index
AB_path = self.AB_paths[index] AB_path = self.AB_paths[index]
AB = cv2.imread(AB_path) AB = cv2.cvtColor(cv2.imread(AB_path), cv2.COLOR_BGR2RGB)
# split AB image into A and B # split AB image into A and B
h, w = AB.shape[:2] h, w = AB.shape[:2]
......
from .transforms import RandomCrop, Resize, RandomHorizontalFlip, PairedRandomCrop, PairedRandomHorizontalFlip, Normalize, Permute from .transforms import PairedRandomCrop, PairedRandomHorizontalFlip
...@@ -27,6 +27,7 @@ class Compose(object): ...@@ -27,6 +27,7 @@ class Compose(object):
try: try:
data = f(data) data = f(data)
except Exception as e: except Exception as e:
print(f)
stack_info = traceback.format_exc() stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: " print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info))) "{} and stack:\n{}".format(f, e, str(stack_info)))
......
...@@ -20,7 +20,7 @@ def get_makeup_transform(cfg, pic="image"): ...@@ -20,7 +20,7 @@ def get_makeup_transform(cfg, pic="image"):
if pic == "image": if pic == "image":
transform = T.Compose([ transform = T.Compose([
T.Resize(size=cfg.trans_size), T.Resize(size=cfg.trans_size),
T.Permute(to_rgb=False), T.Transpose(),
]) ])
else: else:
transform = T.Resize(size=cfg.trans_size, transform = T.Resize(size=cfg.trans_size,
......
...@@ -4,7 +4,7 @@ import numbers ...@@ -4,7 +4,7 @@ import numbers
import collections import collections
import numpy as np import numpy as np
from paddle.utils import try_import import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F import paddle.vision.transforms.functional as F
from .builder import TRANSFORMS from .builder import TRANSFORMS
...@@ -16,261 +16,45 @@ else: ...@@ -16,261 +16,45 @@ else:
Sequence = collections.abc.Sequence Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable Iterable = collections.abc.Iterable
TRANSFORMS.register(T.Resize)
class Transform(): TRANSFORMS.register(T.RandomCrop)
def _set_attributes(self, args): TRANSFORMS.register(T.RandomHorizontalFlip)
""" TRANSFORMS.register(T.Normalize)
Set attributes from the input list of parameters. TRANSFORMS.register(T.Transpose)
Args:
args (list): list of parameters.
"""
if args:
for k, v in args.items():
if k != "self" and not k.startswith("_"):
setattr(self, k, v)
def apply_image(self, input):
raise NotImplementedError
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key])
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i])
else:
inputs = self.apply_image(inputs)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register() @TRANSFORMS.register()
class Resize(Transform): class PairedRandomCrop(T.RandomCrop):
"""Resize the input Image to the given size. def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
Args: if isinstance(size, int):
size (int|list|tuple): Desired output size. If size is a sequence like self.size = (size, size)
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Interpolation mode of resize. Default: 1.
0 : cv2.INTER_NEAREST
1 : cv2.INTER_LINEAR
2 : cv2.INTER_CUBIC
3 : cv2.INTER_AREA
4 : cv2.INTER_LANCZOS4
5 : cv2.INTER_LINEAR_EXACT
7 : cv2.INTER_MAX
8 : cv2.WARP_FILL_OUTLIERS
16: cv2.WARP_INVERSE_MAP
"""
def __init__(self, size, interpolation=1, keys=None):
super().__init__()
assert isinstance(size, int) or (isinstance(size, Iterable)
and len(size) == 2)
self._set_attributes(locals())
if isinstance(self.size, Iterable):
self.size = tuple(size)
def apply_image(self, img):
return F.resize(img, self.size, self.interpolation)
@TRANSFORMS.register()
class RandomCrop(Transform):
def __init__(self, output_size, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else: else:
self.output_size = output_size self.size = size
def _get_params(self, img):
h, w, _ = img.shape
th, tw = self.output_size
if w == tw and h == th:
return 0, 0, h, w
i = random.randint(0, h - th) def _get_params(self, inputs):
j = random.randint(0, w - tw) image = inputs[self.keys.index('image')]
return i, j, th, tw params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def apply_image(self, img): def _apply_image(self, img):
i, j, h, w = self._get_params(img) i, j, h, w = self.params['crop_prams']
cropped_img = img[i:i + h, j:j + w] return F.crop(img, i, j, h, w)
return cropped_img
@TRANSFORMS.register() @TRANSFORMS.register()
class PairedRandomCrop(RandomCrop): class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, output_size, keys=None):
super().__init__(output_size, keys)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
self.output_size = output_size
def apply_image(self, img, crop_prams=None):
if crop_prams is not None:
i, j, h, w = crop_prams
else:
i, j, h, w = self._get_params(img)
cropped_img = img[i:i + h, j:j + w]
return cropped_img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
if self.keys is not None:
if isinstance(inputs, dict):
crop_params = self._get_params(inputs[self.keys[0]])
elif isinstance(inputs, (list, tuple)):
crop_params = self._get_params(inputs[0])
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
crop_params)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i],
crop_params)
else:
crop_params = self._get_params(inputs)
inputs = self.apply_image(inputs, crop_params)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class RandomHorizontalFlip(Transform):
"""Horizontally flip the input data randomly with a given probability.
Args:
prob (float): Probability of the input data being flipped. Default: 0.5
"""
def __init__(self, prob=0.5, keys=None): def __init__(self, prob=0.5, keys=None):
super().__init__() super().__init__(prob, keys=keys)
self._set_attributes(locals())
def apply_image(self, img):
if np.random.random() < self.prob:
return F.flip(img, code=1)
return img
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img, flip):
if flip:
return F.flip(img, code=1)
return img
def __call__(self, inputs):
if isinstance(inputs, tuple):
inputs = list(inputs)
flip = np.random.random() < self.prob
if self.keys is not None:
for i, key in enumerate(self.keys):
if isinstance(inputs, dict):
inputs[key] = getattr(self, 'apply_' + key)(inputs[key],
flip)
elif isinstance(inputs, (list, tuple)):
inputs[i] = getattr(self, 'apply_' + key)(inputs[i], flip)
else:
inputs = self.apply_image(inputs, flip)
if isinstance(inputs, list):
inputs = tuple(inputs)
return inputs
@TRANSFORMS.register()
class Normalize(Transform):
"""Normalize the input data with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels,
this transform will normalize each channel of the input data.
``output[channel] = (input[channel] - mean[channel]) / std[channel]``
Args:
mean (int|float|list): Sequence of means for each channel.
std (int|float|list): Sequence of standard deviations for each channel.
"""
def __init__(self, mean=0.0, std=1.0, keys=None):
super().__init__()
self._set_attributes(locals())
if isinstance(mean, numbers.Number):
mean = [mean, mean, mean]
if isinstance(std, numbers.Number):
std = [std, std, std]
self.mean = np.array(mean, dtype=np.float32).reshape(len(mean), 1, 1)
self.std = np.array(std, dtype=np.float32).reshape(len(std), 1, 1)
def apply_image(self, img):
return (img - self.mean) / self.std
@TRANSFORMS.register()
class Permute(Transform):
"""Change input data to a target mode.
For example, most transforms use HWC mode image,
while the Neural Network might use CHW mode input tensor.
Input image should be HWC mode and an instance of numpy.ndarray.
Args:
mode (str): Output mode of input. Default: "CHW".
to_rgb (bool): Convert 'bgr' image to 'rgb'. Default: True.
"""
def __init__(self, mode="CHW", to_rgb=True, keys=None):
super().__init__()
self._set_attributes(locals())
assert mode in [
"CHW"
], "Only support 'CHW' mode, but received mode: {}".format(mode)
self.mode = mode
self.to_rgb = to_rgb
def apply_image(self, img):
if self.to_rgb:
img = img[..., ::-1]
if self.mode == "CHW":
return img.transpose((2, 0, 1))
return img
class Crop():
def __init__(self, pos, size):
self.pos = pos
self.size = size
def __call__(self, img): def _get_params(self, inputs):
oh, ow, _ = img.shape params = {}
x, y = self.pos params['flip'] = random.random() < self.prob
th = tw = self.size return params
if (ow > tw or oh > th):
return img[y:y + th, x:x + tw]
return img def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
...@@ -64,8 +64,8 @@ class UnpairedDataset(BaseDataset): ...@@ -64,8 +64,8 @@ class UnpairedDataset(BaseDataset):
index_B = random.randint(0, self.B_size - 1) index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B] B_path = self.B_paths[index_B]
A_img = cv2.imread(A_path) A_img = cv2.cvtColor(cv2.imread(A_path), cv2.COLOR_BGR2RGB)
B_img = cv2.imread(B_path) B_img = cv2.cvtColor(cv2.imread(B_path), cv2.COLOR_BGR2RGB)
# apply image transformation # apply image transformation
A = self.transform_A(A_img) A = self.transform_A(A_img)
B = self.transform_B(B_img) B = self.transform_B(B_img)
......
...@@ -10,7 +10,7 @@ from paddle.distributed import ParallelEnv ...@@ -10,7 +10,7 @@ from paddle.distributed import ParallelEnv
from ..datasets.builder import build_dataloader from ..datasets.builder import build_dataloader
from ..models.builder import build_model from ..models.builder import build_model
from ..utils.visual import tensor2img, save_image from ..utils.visual import tensor2img, save_image
from ..utils.filesystem import save, load, makedirs from ..utils.filesystem import makedirs, save, load
from ..utils.timer import TimeAverager from ..utils.timer import TimeAverager
from ..metric.psnr_ssim import calculate_psnr, calculate_ssim from ..metric.psnr_ssim import calculate_psnr, calculate_ssim
...@@ -36,8 +36,8 @@ class Trainer: ...@@ -36,8 +36,8 @@ class Trainer:
# base config # base config
self.output_dir = cfg.output_dir self.output_dir = cfg.output_dir
self.epochs = cfg.epochs self.epochs = cfg.epochs
self.start_epoch = 0 self.start_epoch = 1
self.current_epoch = 0 self.current_epoch = 1
self.batch_id = 0 self.batch_id = 0
self.weight_interval = cfg.snapshot_config.interval self.weight_interval = cfg.snapshot_config.interval
self.log_interval = cfg.log_config.interval self.log_interval = cfg.log_config.interval
...@@ -65,7 +65,7 @@ class Trainer: ...@@ -65,7 +65,7 @@ class Trainer:
reader_cost_averager = TimeAverager() reader_cost_averager = TimeAverager()
batch_cost_averager = TimeAverager() batch_cost_averager = TimeAverager()
for epoch in range(self.start_epoch, self.epochs): for epoch in range(self.start_epoch, self.epochs + 1):
self.current_epoch = epoch self.current_epoch = epoch
start_time = step_start_time = time.time() start_time = step_start_time = time.time()
for i, data in enumerate(self.train_dataloader): for i, data in enumerate(self.train_dataloader):
...@@ -91,8 +91,8 @@ class Trainer: ...@@ -91,8 +91,8 @@ class Trainer:
step_start_time = time.time() step_start_time = time.time()
self.logger.info( self.logger.info('train one epoch time: {}'.format(time.time() -
'train one epoch time: {}'.format(time.time() - start_time)) start_time))
if self.validate_interval > -1 and epoch % self.validate_interval: if self.validate_interval > -1 and epoch % self.validate_interval:
self.validate() self.validate()
self.model.lr_scheduler.step() self.model.lr_scheduler.step()
...@@ -102,8 +102,8 @@ class Trainer: ...@@ -102,8 +102,8 @@ class Trainer:
def validate(self): def validate(self):
if not hasattr(self, 'val_dataloader'): if not hasattr(self, 'val_dataloader'):
self.val_dataloader = build_dataloader( self.val_dataloader = build_dataloader(self.cfg.dataset.val,
self.cfg.dataset.val, is_train=False) is_train=False)
metric_result = {} metric_result = {}
...@@ -149,8 +149,8 @@ class Trainer: ...@@ -149,8 +149,8 @@ class Trainer:
self.visual('visual_val', visual_results=visual_results) self.visual('visual_val', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info( self.logger.info('val iter: [%d/%d]' %
'val iter: [%d/%d]' % (i, len(self.val_dataloader))) (i, len(self.val_dataloader)))
for metric_name in metric_result.keys(): for metric_name in metric_result.keys():
metric_result[metric_name] /= len(self.val_dataloader.dataset) metric_result[metric_name] /= len(self.val_dataloader.dataset)
...@@ -160,8 +160,8 @@ class Trainer: ...@@ -160,8 +160,8 @@ class Trainer:
def test(self): def test(self):
if not hasattr(self, 'test_dataloader'): if not hasattr(self, 'test_dataloader'):
self.test_dataloader = build_dataloader( self.test_dataloader = build_dataloader(self.cfg.dataset.test,
self.cfg.dataset.test, is_train=False) is_train=False)
# data[0]: img, data[1]: img path index # data[0]: img, data[1]: img path index
# test batch size must be 1 # test batch size must be 1
...@@ -185,8 +185,8 @@ class Trainer: ...@@ -185,8 +185,8 @@ class Trainer:
self.visual('visual_test', visual_results=visual_results) self.visual('visual_test', visual_results=visual_results)
if i % self.log_interval == 0: if i % self.log_interval == 0:
self.logger.info( self.logger.info('Test iter: [%d/%d]' %
'Test iter: [%d/%d]' % (i, len(self.test_dataloader))) (i, len(self.test_dataloader)))
def print_log(self): def print_log(self):
losses = self.model.get_current_losses() losses = self.model.get_current_losses()
...@@ -208,7 +208,8 @@ class Trainer: ...@@ -208,7 +208,8 @@ class Trainer:
@property @property
def current_learning_rate(self): def current_learning_rate(self):
return self.model.optimizers[0].get_lr() for optimizer in self.model.optimizers.values():
return optimizer.get_lr()
def visual(self, results_dir, visual_results=None): def visual(self, results_dir, visual_results=None):
self.model.compute_visuals() self.model.compute_visuals()
...@@ -216,7 +217,7 @@ class Trainer: ...@@ -216,7 +217,7 @@ class Trainer:
if visual_results is None: if visual_results is None:
visual_results = self.model.get_current_visuals() visual_results = self.model.get_current_visuals()
if self.cfg.isTrain: if self.cfg.is_train:
msg = 'epoch%.3d_' % self.current_epoch msg = 'epoch%.3d_' % self.current_epoch
else: else:
msg = '' msg = ''
...@@ -240,10 +241,8 @@ class Trainer: ...@@ -240,10 +241,8 @@ class Trainer:
state_dicts = {} state_dicts = {}
save_filename = 'epoch_%s_%s.pkl' % (epoch, name) save_filename = 'epoch_%s_%s.pkl' % (epoch, name)
save_path = os.path.join(self.output_dir, save_filename) save_path = os.path.join(self.output_dir, save_filename)
for net_name in self.model.model_names: for net_name, net in self.model.nets.items():
if isinstance(net_name, str): state_dicts[net_name] = net.state_dict()
net = getattr(self.model, 'net' + net_name)
state_dicts['net' + net_name] = net.state_dict()
if name == 'weight': if name == 'weight':
save(state_dicts, save_path) save(state_dicts, save_path)
...@@ -251,10 +250,8 @@ class Trainer: ...@@ -251,10 +250,8 @@ class Trainer:
state_dicts['epoch'] = epoch state_dicts['epoch'] = epoch
for opt_name in self.model.optimizer_names: for opt_name, opt in self.model.optimizers.items():
if isinstance(opt_name, str): state_dicts[opt_name] = opt.state_dict()
opt = getattr(self.model, opt_name)
state_dicts[opt_name] = opt.state_dict()
save(state_dicts, save_path) save(state_dicts, save_path)
...@@ -273,22 +270,14 @@ class Trainer: ...@@ -273,22 +270,14 @@ class Trainer:
if state_dicts.get('epoch', None) is not None: if state_dicts.get('epoch', None) is not None:
self.start_epoch = state_dicts['epoch'] + 1 self.start_epoch = state_dicts['epoch'] + 1
for name in self.model.model_names: for net_name, net in self.model.nets.items():
if isinstance(name, str): net.set_dict(state_dicts[net_name])
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
for name in self.model.optimizer_names: for opt_name, opt in self.model.optimizers.items():
if isinstance(name, str): opt.set_dict(state_dicts[opt_name])
opt = getattr(self.model, name)
opt.set_dict(state_dicts[name])
def load(self, weight_path): def load(self, weight_path):
state_dicts = load(weight_path) state_dicts = load(weight_path)
for name in self.model.model_names: for net_name, net in self.model.nets.items():
if isinstance(name, str): net.set_dict(state_dicts[net_name])
self.logger.info('laod model {} {} params!'.format(
self.cfg.model.name, 'net' + name))
net = getattr(self.model, 'net' + name)
net.set_dict(state_dicts['net' + name])
...@@ -13,13 +13,10 @@ ...@@ -13,13 +13,10 @@
# limitations under the License. # limitations under the License.
import paddle import paddle
from paddle import nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url from paddle.vision.models import resnet18
import numpy as np
from .resnet import resnet18
class ConvBNReLU(paddle.nn.Layer): class ConvBNReLU(paddle.nn.Layer):
...@@ -32,13 +29,13 @@ class ConvBNReLU(paddle.nn.Layer): ...@@ -32,13 +29,13 @@ class ConvBNReLU(paddle.nn.Layer):
*args, *args,
**kwargs): **kwargs):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan, self.conv = nn.Conv2D(in_chan,
out_chan, out_chan,
kernel_size=ks, kernel_size=ks,
stride=stride, stride=stride,
padding=padding, padding=padding,
bias_attr=False) bias_attr=False)
self.bn = nn.BatchNorm2d(out_chan) self.bn = nn.BatchNorm2D(out_chan)
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, x): def forward(self, x):
...@@ -52,7 +49,7 @@ class BiSeNetOutput(paddle.nn.Layer): ...@@ -52,7 +49,7 @@ class BiSeNetOutput(paddle.nn.Layer):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs): def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__() super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan, self.conv_out = nn.Conv2D(mid_chan,
n_classes, n_classes,
kernel_size=1, kernel_size=1,
bias_attr=False) bias_attr=False)
...@@ -67,7 +64,7 @@ class AttentionRefinementModule(paddle.nn.Layer): ...@@ -67,7 +64,7 @@ class AttentionRefinementModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs): def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__() super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan, self.conv_atten = nn.Conv2D(out_chan,
out_chan, out_chan,
kernel_size=1, kernel_size=1,
bias_attr=False) bias_attr=False)
...@@ -87,16 +84,27 @@ class AttentionRefinementModule(paddle.nn.Layer): ...@@ -87,16 +84,27 @@ class AttentionRefinementModule(paddle.nn.Layer):
class ContextPath(paddle.nn.Layer): class ContextPath(paddle.nn.Layer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__() super(ContextPath, self).__init__()
self.resnet = resnet18() self.backbone = resnet18(pretrained=True)
self.arm16 = AttentionRefinementModule(256, 128) self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128) self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
def backbone_forward(self, x):
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
x = self.backbone.relu(x)
x = self.backbone.maxpool(x)
x = self.backbone.layer1(x)
c2 = self.backbone.layer2(x)
c3 = self.backbone.layer3(c2)
c4 = self.backbone.layer4(c3)
return c2, c3, c4
def forward(self, x): def forward(self, x):
H0, W0 = x.shape[2:] H0, W0 = x.shape[2:]
feat8, feat16, feat32 = self.resnet(x) feat8, feat16, feat32 = self.backbone_forward(x)
H8, W8 = feat8.shape[2:] H8, W8 = feat8.shape[2:]
H16, W16 = feat16.shape[2:] H16, W16 = feat16.shape[2:]
H32, W32 = feat32.shape[2:] H32, W32 = feat32.shape[2:]
...@@ -138,13 +146,13 @@ class FeatureFusionModule(paddle.nn.Layer): ...@@ -138,13 +146,13 @@ class FeatureFusionModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs): def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__() super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan, self.conv1 = nn.Conv2D(out_chan,
out_chan // 4, out_chan // 4,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias_attr=False) bias_attr=False)
self.conv2 = nn.Conv2d(out_chan // 4, self.conv2 = nn.Conv2D(out_chan // 4,
out_chan, out_chan,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
......
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
import numpy as np
import math
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'0ba53eea9bc970962d0ef96f7b94057e'),
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm(out_chan)
self.relu = nn.ReLU()
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan,
out_chan,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(paddle.nn.Layer):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = nn.BatchNorm(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
def forward(self, x):
x = self.conv1(x)
x = self.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def resnet18(pretrained=False, **kwargs):
model = Resnet18()
arch = 'resnet18'
if pretrained:
weight_path = './resnet.pdparams'
param, _ = paddle.load(weight_path)
model.set_dict(param)
return model
...@@ -18,4 +18,3 @@ from .pix2pix_model import Pix2PixModel ...@@ -18,4 +18,3 @@ from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel from .srgan_model import SRGANModel
from .sr_model import SRModel from .sr_model import SRModel
from .makeup_model import MakeupModel from .makeup_model import MakeupModel
from .vgg import vgg16
...@@ -8,7 +8,7 @@ __all__ = [ ...@@ -8,7 +8,7 @@ __all__ = [
def conv3x3(in_planes, out_planes, stride=1): def conv3x3(in_planes, out_planes, stride=1):
"3x3 convolution with padding" "3x3 convolution with padding"
return nn.Conv2d(in_planes, return nn.Conv2D(in_planes,
out_planes, out_planes,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
...@@ -53,16 +53,16 @@ class Bottleneck(nn.Layer): ...@@ -53,16 +53,16 @@ class Bottleneck(nn.Layer):
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias_attr=False) self.conv1 = nn.Conv2D(inplanes, planes, kernel_size=1, bias_attr=False)
self.bn1 = nn.BatchNorm(planes) self.bn1 = nn.BatchNorm(planes)
self.conv2 = nn.Conv2d(planes, self.conv2 = nn.Conv2D(planes,
planes, planes,
kernel_size=3, kernel_size=3,
stride=stride, stride=stride,
padding=1, padding=1,
bias_attr=False) bias_attr=False)
self.bn2 = nn.BatchNorm(planes) self.bn2 = nn.BatchNorm(planes)
self.conv3 = nn.Conv2d(planes, self.conv3 = nn.Conv2D(planes,
planes * 4, planes * 4,
kernel_size=1, kernel_size=1,
bias_attr=False) bias_attr=False)
...@@ -97,7 +97,7 @@ class ResNet(nn.Layer): ...@@ -97,7 +97,7 @@ class ResNet(nn.Layer):
def __init__(self, block, layers, num_classes=1000): def __init__(self, block, layers, num_classes=1000):
self.inplanes = 64 self.inplanes = 64
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, self.conv1 = nn.Conv2D(3,
64, 64,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
...@@ -117,7 +117,7 @@ class ResNet(nn.Layer): ...@@ -117,7 +117,7 @@ class ResNet(nn.Layer):
downsample = None downsample = None
if stride != 1 or self.inplanes != planes * block.expansion: if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
nn.Conv2d(self.inplanes, nn.Conv2D(self.inplanes,
planes * block.expansion, planes * block.expansion,
kernel_size=1, kernel_size=1,
stride=stride, stride=stride,
......
...@@ -17,47 +17,35 @@ class BaseModel(ABC): ...@@ -17,47 +17,35 @@ class BaseModel(ABC):
-- <optimize_parameters>: calculate losses, gradients, and update network weights. -- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options. -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
""" """
def __init__(self, opt): def __init__(self, cfg):
"""Initialize the BaseModel class. """Initialize the BaseModel class.
Parameters: Args:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions cfg (Dict)-- configs of Model.
When creating your custom class, you need to implement your own initialization. When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)> In this function, you should first call <super(YourClass, self).__init__(self, cfg)>
Then, you need to define four lists: Then, you need to define four lists:
-- self.losses (str list): specify the training losses that you want to plot and save. -- self.losses (dict): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training. -- self.nets (dict): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. -- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
""" """
self.opt = opt self.cfg = cfg
self.isTrain = opt.isTrain self.is_train = cfg.is_train
self.save_dir = os.path.join( self.save_dir = os.path.join(
opt.output_dir, cfg.output_dir,
opt.model.name) # save all the checkpoints to save_dir cfg.model.name) # save all the checkpoints to save_dir
self.losses = OrderedDict() self.losses = OrderedDict()
self.model_names = [] self.nets = OrderedDict()
self.visual_names = [] self.visual_items = OrderedDict()
self.optimizers = [] self.optimizers = OrderedDict()
self.optimizer_names = []
self.image_paths = [] self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau' self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod @abstractmethod
def set_input(self, input): def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
...@@ -78,7 +66,7 @@ class BaseModel(ABC): ...@@ -78,7 +66,7 @@ class BaseModel(ABC):
pass pass
def build_lr_scheduler(self): def build_lr_scheduler(self):
self.lr_scheduler = build_lr_scheduler(self.opt.lr_scheduler) self.lr_scheduler = build_lr_scheduler(self.cfg.lr_scheduler)
def eval(self): def eval(self):
"""Make models eval mode during test time""" """Make models eval mode during test time"""
...@@ -106,12 +94,8 @@ class BaseModel(ABC): ...@@ -106,12 +94,8 @@ class BaseModel(ABC):
return self.image_paths return self.image_paths
def get_current_visuals(self): def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" """Return visualization images."""
visual_ret = OrderedDict() return self.visual_items
for name in self.visual_names:
if isinstance(name, str) and hasattr(self, name):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self): def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
...@@ -119,7 +103,7 @@ class BaseModel(ABC): ...@@ -119,7 +103,7 @@ class BaseModel(ABC):
def set_requires_grad(self, nets, requires_grad=False): def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters: Args:
nets (network list) -- a list of networks nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not requires_grad (bool) -- whether the networks require gradients or not
""" """
...@@ -128,6 +112,4 @@ class BaseModel(ABC): ...@@ -128,6 +112,4 @@ class BaseModel(ABC):
for net in nets: for net in nets:
if net is not None: if net is not None:
for param in net.parameters(): for param in net.parameters():
# print('trainable:', param.trainable)
param.trainable = requires_grad param.trainable = requires_grad
# param.stop_gradient = not requires_grad
...@@ -24,84 +24,63 @@ class CycleGANModel(BaseModel): ...@@ -24,84 +24,63 @@ class CycleGANModel(BaseModel):
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
""" """
def __init__(self, opt): def __init__(self, cfg):
"""Initialize the CycleGAN class. """Initialize the CycleGAN class.
Parameters: Parameters:
opt (config)-- stores all the experiment flags; needs to be a subclass of Dict opt (config)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
BaseModel.__init__(self, opt) super(CycleGANModel, self).__init__(cfg)
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
# if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
# combine visualizations for A and B
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk.
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# define networks (both Generators and discriminators) # define networks (both Generators and discriminators)
# The naming is different from those used in the paper. # The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = build_generator(opt.model.generator) self.nets['netG_A'] = build_generator(cfg.model.generator)
self.netG_B = build_generator(opt.model.generator) self.nets['netG_B'] = build_generator(cfg.model.generator)
init_weights(self.netG_A) init_weights(self.nets['netG_A'])
init_weights(self.netG_B) init_weights(self.nets['netG_B'])
if self.isTrain: # define discriminators if self.is_train: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator) self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
self.netD_B = build_discriminator(opt.model.discriminator) self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
init_weights(self.netD_A) init_weights(self.nets['netD_A'])
init_weights(self.netD_B) init_weights(self.nets['netD_B'])
if self.isTrain: if self.is_train:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels if cfg.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert ( assert (
opt.dataset.train.input_nc == opt.dataset.train.output_nc) cfg.dataset.train.input_nc == cfg.dataset.train.output_nc)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_A_pool = ImagePool(opt.dataset.train.pool_size) self.fake_A_pool = ImagePool(cfg.dataset.train.pool_size)
# create image buffer to store previously generated images # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.dataset.train.pool_size) self.fake_B_pool = ImagePool(cfg.dataset.train.pool_size)
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionCycle = paddle.nn.L1Loss() self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss()
self.build_lr_scheduler() self.build_lr_scheduler()
self.optimizer_G = build_optimizer( self.optimizers['optimizer_G'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netG_A.parameters() + parameter_list=self.nets['netG_A'].parameters() +
self.netG_B.parameters()) self.nets['netG_B'].parameters())
self.optimizer_D = build_optimizer( self.optimizers['optimizer_D'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netD_A.parameters() + parameter_list=self.nets['netD_A'].parameters() +
self.netD_B.parameters()) self.nets['netD_B'].parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
def set_input(self, input): def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters: Args:
input (dict): include the data itself and its metadata information. input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B. The option 'direction' can be used to swap domain A and domain B.
""" """
mode = 'train' if self.isTrain else 'test' mode = 'train' if self.is_train else 'test'
AtoB = self.opt.dataset[mode].direction == 'AtoB' AtoB = self.cfg.dataset[mode].direction == 'AtoB'
if AtoB: if AtoB:
if 'A' in input: if 'A' in input:
...@@ -122,12 +101,22 @@ class CycleGANModel(BaseModel): ...@@ -122,12 +101,22 @@ class CycleGANModel(BaseModel):
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
if hasattr(self, 'real_A'): if hasattr(self, 'real_A'):
self.fake_B = self.netG_A(self.real_A) # G_A(A) self.fake_B = self.nets['netG_A'](self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.rec_A = self.nets['netG_B'](self.fake_B) # G_B(G_A(A))
# visual
self.visual_items['real_A'] = self.real_A
self.visual_items['fake_B'] = self.fake_B
self.visual_items['rec_A'] = self.rec_A
if hasattr(self, 'real_B'): if hasattr(self, 'real_B'):
self.fake_A = self.netG_B(self.real_B) # G_B(B) self.fake_A = self.nets['netG_B'](self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) self.rec_B = self.nets['netG_A'](self.fake_A) # G_A(G_B(B))
# visual
self.visual_items['real_B'] = self.real_B
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def backward_D_basic(self, netD, real, fake): def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator """Calculate GAN loss for the discriminator
...@@ -148,40 +137,43 @@ class CycleGANModel(BaseModel): ...@@ -148,40 +137,43 @@ class CycleGANModel(BaseModel):
loss_D_fake = self.criterionGAN(pred_fake, False) loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients # Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D = (loss_D_real + loss_D_fake) * 0.5
# loss_D.backward()
if ParallelEnv().nranks > 1: loss_D.backward()
loss_D = netD.scale_loss(loss_D)
loss_D.backward()
netD.apply_collective_grads()
else:
loss_D.backward()
return loss_D return loss_D
def backward_D_A(self): def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
fake_B)
self.losses['D_A_loss'] = self.loss_D_A self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
fake_A)
self.losses['D_B_loss'] = self.loss_D_B self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity lambda_idt = self.cfg.lambda_identity
lambda_A = self.opt.lambda_A lambda_A = self.cfg.lambda_A
lambda_B = self.opt.lambda_B lambda_B = self.cfg.lambda_B
# Identity loss # Identity loss
if lambda_idt > 0: if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B|| # G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B) self.idt_A = self.nets['netG_A'](self.real_B)
self.loss_idt_A = self.criterionIdt( self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_B) * lambda_B * lambda_idt self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A|| # G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A) self.idt_B = self.nets['netG_B'](self.real_A)
# visual
self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B
self.loss_idt_B = self.criterionIdt( self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_A) * lambda_A * lambda_idt self.idt_B, self.real_A) * lambda_A * lambda_idt
else: else:
...@@ -189,9 +181,11 @@ class CycleGANModel(BaseModel): ...@@ -189,9 +181,11 @@ class CycleGANModel(BaseModel):
self.loss_idt_B = 0 self.loss_idt_B = 0
# GAN loss D_A(G_A(A)) # GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_B),
True)
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_A),
True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A self.real_A) * lambda_A
...@@ -208,13 +202,7 @@ class CycleGANModel(BaseModel): ...@@ -208,13 +202,7 @@ class CycleGANModel(BaseModel):
# combined loss and calculate gradients # combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if ParallelEnv().nranks > 1: self.loss_G.backward()
self.loss_G = self.netG_A.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG_A.apply_collective_grads()
self.netG_B.apply_collective_grads()
else:
self.loss_G.backward()
def optimize_parameters(self): def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration""" """Calculate losses, gradients, and update network weights; called in every training iteration"""
...@@ -223,21 +211,22 @@ class CycleGANModel(BaseModel): ...@@ -223,21 +211,22 @@ class CycleGANModel(BaseModel):
self.forward() self.forward()
# G_A and G_B # G_A and G_B
# Ds require no gradients when optimizing Gs # Ds require no gradients when optimizing Gs
self.set_requires_grad([self.netD_A, self.netD_B], False) self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']],
False)
# set G_A and G_B's gradients to zero # set G_A and G_B's gradients to zero
self.optimizer_G.clear_gradients() self.optimizers['optimizer_G'].clear_grad()
# calculate gradients for G_A and G_B # calculate gradients for G_A and G_B
self.backward_G() self.backward_G()
# update G_A and G_B's weights # update G_A and G_B's weights
self.optimizer_G.minimize(self.loss_G) self.optimizers['optimizer_G'].step()
# D_A and D_B # D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True) self.set_requires_grad([self.nets['netD_A'], self.nets['netD_B']], True)
# set D_A and D_B's gradients to zero # set D_A and D_B's gradients to zero
self.optimizer_D.clear_gradients() self.optimizers['optimizer_D'].clear_grad()
# calculate gradients for D_A # calculate gradients for D_A
self.backward_D_A() self.backward_D_A()
# calculate graidents for D_B # calculate graidents for D_B
self.backward_D_B() self.backward_D_B()
# update D_A and D_B's weights # update D_A and D_B's weights
self.optimizer_D.minimize(self.loss_D_A + self.loss_D_B) self.optimizers['optimizer_D'].step()
...@@ -41,9 +41,9 @@ class NLayerDiscriminator(nn.Layer): ...@@ -41,9 +41,9 @@ class NLayerDiscriminator(nn.Layer):
if type( if type(
norm_layer norm_layer
) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters ) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d use_bias = norm_layer.func == nn.InstanceNorm2D
else: else:
use_bias = norm_layer == nn.InstanceNorm2d use_bias = norm_layer == nn.InstanceNorm2D
kw = 4 kw = 4
padw = 1 padw = 1
...@@ -51,7 +51,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -51,7 +51,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral': if norm_type == 'spectral':
sequence = [ sequence = [
Spectralnorm( Spectralnorm(
nn.Conv2d(input_nc, nn.Conv2D(input_nc,
ndf, ndf,
kernel_size=kw, kernel_size=kw,
stride=2, stride=2,
...@@ -60,7 +60,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -60,7 +60,7 @@ class NLayerDiscriminator(nn.Layer):
] ]
else: else:
sequence = [ sequence = [
nn.Conv2d(input_nc, nn.Conv2D(input_nc,
ndf, ndf,
kernel_size=kw, kernel_size=kw,
stride=2, stride=2,
...@@ -76,7 +76,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -76,7 +76,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral': if norm_type == 'spectral':
sequence += [ sequence += [
Spectralnorm( Spectralnorm(
nn.Conv2d(ndf * nf_mult_prev, nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult, ndf * nf_mult,
kernel_size=kw, kernel_size=kw,
stride=2, stride=2,
...@@ -85,7 +85,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -85,7 +85,7 @@ class NLayerDiscriminator(nn.Layer):
] ]
else: else:
sequence += [ sequence += [
nn.Conv2d(ndf * nf_mult_prev, nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult, ndf * nf_mult,
kernel_size=kw, kernel_size=kw,
stride=2, stride=2,
...@@ -100,7 +100,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -100,7 +100,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral': if norm_type == 'spectral':
sequence += [ sequence += [
Spectralnorm( Spectralnorm(
nn.Conv2d(ndf * nf_mult_prev, nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult, ndf * nf_mult,
kernel_size=kw, kernel_size=kw,
stride=1, stride=1,
...@@ -109,7 +109,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -109,7 +109,7 @@ class NLayerDiscriminator(nn.Layer):
] ]
else: else:
sequence += [ sequence += [
nn.Conv2d(ndf * nf_mult_prev, nn.Conv2D(ndf * nf_mult_prev,
ndf * nf_mult, ndf * nf_mult,
kernel_size=kw, kernel_size=kw,
stride=1, stride=1,
...@@ -122,7 +122,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -122,7 +122,7 @@ class NLayerDiscriminator(nn.Layer):
if norm_type == 'spectral': if norm_type == 'spectral':
sequence += [ sequence += [
Spectralnorm( Spectralnorm(
nn.Conv2d(ndf * nf_mult, nn.Conv2D(ndf * nf_mult,
1, 1,
kernel_size=kw, kernel_size=kw,
stride=1, stride=1,
...@@ -131,7 +131,7 @@ class NLayerDiscriminator(nn.Layer): ...@@ -131,7 +131,7 @@ class NLayerDiscriminator(nn.Layer):
] # output 1 channel prediction map ] # output 1 channel prediction map
else: else:
sequence += [ sequence += [
nn.Conv2d(ndf * nf_mult, nn.Conv2D(ndf * nf_mult,
1, 1,
kernel_size=kw, kernel_size=kw,
stride=1, stride=1,
......
...@@ -2,9 +2,9 @@ import numpy as np ...@@ -2,9 +2,9 @@ import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.vision.models import resnet101
from .hook import hook_outputs, model_sizes, dummy_eval from .hook import hook_outputs, model_sizes, dummy_eval
from ..backbones import resnet34, resnet101
from ...modules.nn import Spectralnorm from ...modules.nn import Spectralnorm
...@@ -137,7 +137,7 @@ def custom_conv_layer(ni: int, ...@@ -137,7 +137,7 @@ def custom_conv_layer(ni: int,
bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True bn = norm_type in ('Batch', 'Batchzero') or extra_bn == True
if bias is None: if bias is None:
bias = not bn bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv = conv_func(ni, conv = conv_func(ni,
nf, nf,
...@@ -272,7 +272,7 @@ class PixelShuffle_ICNR(nn.Layer): ...@@ -272,7 +272,7 @@ class PixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale) self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d((1, 0, 1, 0)) self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg')
self.relu = relu(True, leaky=leaky) self.relu = relu(True, leaky=leaky)
...@@ -298,7 +298,7 @@ def conv_layer(ni: int, ...@@ -298,7 +298,7 @@ def conv_layer(ni: int,
if padding is None: padding = (ks - 1) // 2 if not transpose else 0 if padding is None: padding = (ks - 1) // 2 if not transpose else 0
bn = norm_type in ('Batch', 'BatchZero') bn = norm_type in ('Batch', 'BatchZero')
if bias is None: bias = not bn if bias is None: bias = not bn
conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d conv_func = nn.Conv2DTranspose if transpose else nn.Conv1d if is_1d else nn.Conv2D
conv = conv_func(ni, conv = conv_func(ni,
nf, nf,
...@@ -338,7 +338,7 @@ class CustomPixelShuffle_ICNR(nn.Layer): ...@@ -338,7 +338,7 @@ class CustomPixelShuffle_ICNR(nn.Layer):
self.shuf = PixelShuffle(scale) self.shuf = PixelShuffle(scale)
self.pad = ReplicationPad2d((1, 0, 1, 0)) self.pad = ReplicationPad2d([1, 0, 1, 0])
self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg') self.blur = nn.Pool2D(2, pool_stride=1, pool_type='avg')
self.relu = nn.LeakyReLU( self.relu = nn.LeakyReLU(
leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky) leaky) if leaky is not None else nn.ReLU() #relu(True, leaky=leaky)
...@@ -409,7 +409,7 @@ class ReplicationPad2d(nn.Layer): ...@@ -409,7 +409,7 @@ class ReplicationPad2d(nn.Layer):
self.size = size self.size = size
def forward(self, x): def forward(self, x):
return F.pad2d(x, self.size, mode="edge") return F.pad(x, self.size, mode="replicate")
def conv1d(ni: int, def conv1d(ni: int,
...@@ -419,7 +419,7 @@ def conv1d(ni: int, ...@@ -419,7 +419,7 @@ def conv1d(ni: int,
padding: int = 0, padding: int = 0,
bias: bool = False): bias: bool = False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization." "Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias_attr=bias) conv = nn.Conv1D(ni, no, ks, stride=stride, padding=padding, bias_attr=bias)
return Spectralnorm(conv) return Spectralnorm(conv)
......
...@@ -77,7 +77,7 @@ class Hooks(): ...@@ -77,7 +77,7 @@ class Hooks():
def _hook_inner(m, i, o): def _hook_inner(m, i, o):
return o if isinstance( return o if isinstance(
o, paddle.framework.Variable) else o if is_listy(o) else list(o) o, paddle.fluid.framework.Variable) else o if is_listy(o) else list(o)
def hook_output(module, detach=True, grad=False): def hook_output(module, detach=True, grad=False):
......
...@@ -49,22 +49,22 @@ class ResidualBlock(paddle.nn.Layer): ...@@ -49,22 +49,22 @@ class ResidualBlock(paddle.nn.Layer):
bias_attr = None bias_attr = None
self.main = nn.Sequential( self.main = nn.Sequential(
nn.Conv2d(dim_in, nn.Conv2D(dim_in,
dim_out, dim_out,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias_attr=False), bias_attr=False),
nn.InstanceNorm2d(dim_out, nn.InstanceNorm2D(dim_out,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr), nn.ReLU(), bias_attr=bias_attr), nn.ReLU(),
nn.Conv2d(dim_out, nn.Conv2D(dim_out,
dim_out, dim_out,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias_attr=False), bias_attr=False),
nn.InstanceNorm2d(dim_out, nn.InstanceNorm2D(dim_out,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr)) bias_attr=bias_attr))
...@@ -78,7 +78,7 @@ class StyleResidualBlock(paddle.nn.Layer): ...@@ -78,7 +78,7 @@ class StyleResidualBlock(paddle.nn.Layer):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super(StyleResidualBlock, self).__init__() super(StyleResidualBlock, self).__init__()
self.block1 = nn.Sequential( self.block1 = nn.Sequential(
nn.Conv2d(dim_in, nn.Conv2D(dim_in,
dim_out, dim_out,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
...@@ -86,18 +86,18 @@ class StyleResidualBlock(paddle.nn.Layer): ...@@ -86,18 +86,18 @@ class StyleResidualBlock(paddle.nn.Layer):
bias_attr=False), PONO()) bias_attr=False), PONO())
ks = 3 ks = 3
pw = ks // 2 pw = ks // 2
self.beta1 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw) self.beta1 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma1 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw) self.gamma1 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.block2 = nn.Sequential( self.block2 = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(dim_out, nn.Conv2D(dim_out,
dim_out, dim_out,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
bias_attr=False), PONO()) bias_attr=False), PONO())
self.beta2 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw) self.beta2 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma2 = nn.Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw) self.gamma2 = nn.Conv2D(dim_in, dim_out, kernel_size=ks, padding=pw)
def forward(self, x, y): def forward(self, x, y):
"""forward""" """forward"""
...@@ -119,14 +119,14 @@ class MDNet(paddle.nn.Layer): ...@@ -119,14 +119,14 @@ class MDNet(paddle.nn.Layer):
layers = [] layers = []
layers.append( layers.append(
nn.Conv2d(3, nn.Conv2D(3,
conv_dim, conv_dim,
kernel_size=7, kernel_size=7,
stride=1, stride=1,
padding=3, padding=3,
bias_attr=False)) bias_attr=False))
layers.append( layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=None, bias_attr=None)) nn.InstanceNorm2D(conv_dim, weight_attr=None, bias_attr=None))
layers.append(nn.ReLU()) layers.append(nn.ReLU())
...@@ -134,14 +134,14 @@ class MDNet(paddle.nn.Layer): ...@@ -134,14 +134,14 @@ class MDNet(paddle.nn.Layer):
curr_dim = conv_dim curr_dim = conv_dim
for i in range(2): for i in range(2):
layers.append( layers.append(
nn.Conv2d(curr_dim, nn.Conv2D(curr_dim,
curr_dim * 2, curr_dim * 2,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
padding=1, padding=1,
bias_attr=False)) bias_attr=False))
layers.append( layers.append(
nn.InstanceNorm2d(curr_dim * 2, nn.InstanceNorm2D(curr_dim * 2,
weight_attr=None, weight_attr=None,
bias_attr=None)) bias_attr=None))
layers.append(nn.ReLU()) layers.append(nn.ReLU())
...@@ -166,14 +166,14 @@ class TNetDown(paddle.nn.Layer): ...@@ -166,14 +166,14 @@ class TNetDown(paddle.nn.Layer):
layers = [] layers = []
layers.append( layers.append(
nn.Conv2d(3, nn.Conv2D(3,
conv_dim, conv_dim,
kernel_size=7, kernel_size=7,
stride=1, stride=1,
padding=3, padding=3,
bias_attr=False)) bias_attr=False))
layers.append( layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=False, bias_attr=False)) nn.InstanceNorm2D(conv_dim, weight_attr=False, bias_attr=False))
layers.append(nn.ReLU()) layers.append(nn.ReLU())
...@@ -181,14 +181,14 @@ class TNetDown(paddle.nn.Layer): ...@@ -181,14 +181,14 @@ class TNetDown(paddle.nn.Layer):
curr_dim = conv_dim curr_dim = conv_dim
for i in range(2): for i in range(2):
layers.append( layers.append(
nn.Conv2d(curr_dim, nn.Conv2D(curr_dim,
curr_dim * 2, curr_dim * 2,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
padding=1, padding=1,
bias_attr=False)) bias_attr=False))
layers.append( layers.append(
nn.InstanceNorm2d(curr_dim * 2, nn.InstanceNorm2D(curr_dim * 2,
weight_attr=False, weight_attr=False,
bias_attr=False)) bias_attr=False))
layers.append(nn.ReLU()) layers.append(nn.ReLU())
...@@ -210,13 +210,13 @@ class TNetDown(paddle.nn.Layer): ...@@ -210,13 +210,13 @@ class TNetDown(paddle.nn.Layer):
class GetMatrix(paddle.fluid.dygraph.Layer): class GetMatrix(paddle.fluid.dygraph.Layer):
def __init__(self, dim_in, dim_out): def __init__(self, dim_in, dim_out):
super(GetMatrix, self).__init__() super(GetMatrix, self).__init__()
self.get_gamma = nn.Conv2d(dim_in, self.get_gamma = nn.Conv2D(dim_in,
dim_out, dim_out,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
bias_attr=False) bias_attr=False)
self.get_beta = nn.Conv2d(dim_in, self.get_beta = nn.Conv2D(dim_in,
dim_out, dim_out,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
...@@ -236,8 +236,8 @@ class MANet(paddle.nn.Layer): ...@@ -236,8 +236,8 @@ class MANet(paddle.nn.Layer):
self.encoder = TNetDown(conv_dim=conv_dim, repeat_num=repeat_num) self.encoder = TNetDown(conv_dim=conv_dim, repeat_num=repeat_num)
curr_dim = conv_dim * 4 curr_dim = conv_dim * 4
self.w = w self.w = w
self.beta = nn.Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1) self.beta = nn.Conv2D(curr_dim, curr_dim, kernel_size=3, padding=1)
self.gamma = nn.Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1) self.gamma = nn.Conv2D(curr_dim, curr_dim, kernel_size=3, padding=1)
self.simple_spade = GetMatrix(curr_dim, 1) # get the makeup matrix self.simple_spade = GetMatrix(curr_dim, 1) # get the makeup matrix
self.repeat_num = repeat_num self.repeat_num = repeat_num
for i in range(repeat_num): for i in range(repeat_num):
...@@ -252,28 +252,28 @@ class MANet(paddle.nn.Layer): ...@@ -252,28 +252,28 @@ class MANet(paddle.nn.Layer):
for i in range(2): for i in range(2):
layers = [] layers = []
layers.append( layers.append(
nn.ConvTranspose2d(curr_dim, nn.Conv2DTranspose(curr_dim,
curr_dim // 2, curr_dim // 2,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
padding=1, padding=1,
bias_attr=False)) bias_attr=False))
layers.append( layers.append(
nn.InstanceNorm2d(curr_dim // 2, nn.InstanceNorm2D(curr_dim // 2,
weight_attr=False, weight_attr=False,
bias_attr=False)) bias_attr=False))
setattr(self, "up_acts_" + str(i), nn.ReLU()) setattr(self, "up_acts_" + str(i), nn.ReLU())
setattr( setattr(
self, "up_betas_" + str(i), self, "up_betas_" + str(i),
nn.ConvTranspose2d(y_dim, nn.Conv2DTranspose(y_dim,
curr_dim // 2, curr_dim // 2,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
padding=1)) padding=1))
setattr( setattr(
self, "up_gammas_" + str(i), self, "up_gammas_" + str(i),
nn.ConvTranspose2d(y_dim, nn.Conv2DTranspose(y_dim,
curr_dim // 2, curr_dim // 2,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
...@@ -281,7 +281,7 @@ class MANet(paddle.nn.Layer): ...@@ -281,7 +281,7 @@ class MANet(paddle.nn.Layer):
setattr(self, "up_samplers_" + str(i), nn.Sequential(*layers)) setattr(self, "up_samplers_" + str(i), nn.Sequential(*layers))
curr_dim = curr_dim // 2 curr_dim = curr_dim // 2
self.img_reg = [ self.img_reg = [
nn.Conv2d(curr_dim, nn.Conv2D(curr_dim,
3, 3,
kernel_size=7, kernel_size=7,
stride=1, stride=1,
......
...@@ -17,6 +17,7 @@ import functools ...@@ -17,6 +17,7 @@ import functools
from ...modules.norm import build_norm_layer from ...modules.norm import build_norm_layer
from .builder import GENERATORS from .builder import GENERATORS
@GENERATORS.register() @GENERATORS.register()
class MobileResnetGenerator(nn.Layer): class MobileResnetGenerator(nn.Layer):
def __init__(self, def __init__(self,
...@@ -31,64 +32,64 @@ class MobileResnetGenerator(nn.Layer): ...@@ -31,64 +32,64 @@ class MobileResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type) norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial: if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == InstanceNorm use_bias = norm_layer.func == nn.InstanceNorm2D
else: else:
use_bias = norm_layer == InstanceNorm use_bias = norm_layer == nn.InstanceNorm2D
self.model = nn.LayerList([ self.model = nn.LayerList([
nn.ReflectionPad2d([3, 3, 3, 3]), nn.ReflectionPad2d([3, 3, 3, 3]),
nn.Conv2d( nn.Conv2D(input_channel,
input_channel, int(ngf),
int(ngf), kernel_size=7,
kernel_size=7, padding=0,
padding=0, bias_attr=use_bias),
bias_attr=use_bias), norm_layer(ngf), nn.ReLU() norm_layer(ngf),
nn.ReLU()
]) ])
n_downsampling = 2 n_downsampling = 2
for i in range(n_downsampling): for i in range(n_downsampling):
mult = 2**i mult = 2**i
self.model.extend([ self.model.extend([
nn.Conv2d( nn.Conv2D(ngf * mult,
ngf * mult, ngf * mult * 2,
ngf * mult * 2, kernel_size=3,
kernel_size=3, stride=2,
stride=2, padding=1,
padding=1, bias_attr=use_bias),
bias_attr=use_bias), norm_layer(ngf * mult * 2), nn.ReLU() norm_layer(ngf * mult * 2),
nn.ReLU()
]) ])
mult = 2**n_downsampling mult = 2**n_downsampling
for i in range(n_blocks): for i in range(n_blocks):
self.model.extend([ self.model.extend([
MobileResnetBlock( MobileResnetBlock(ngf * mult,
ngf * mult, ngf * mult,
ngf * mult, padding_type=padding_type,
padding_type=padding_type, norm_layer=norm_layer,
norm_layer=norm_layer, use_dropout=use_dropout,
use_dropout=use_dropout, use_bias=use_bias)
use_bias=use_bias)
]) ])
for i in range(n_downsampling): for i in range(n_downsampling):
mult = 2**(n_downsampling - i) mult = 2**(n_downsampling - i)
output_size = (i + 1) * 128 output_size = (i + 1) * 128
self.model.extend([ self.model.extend([
nn.ConvTranspose2d( nn.Conv2DTranspose(ngf * mult,
ngf * mult, int(ngf * mult / 2),
int(ngf * mult / 2), kernel_size=3,
kernel_size=3, stride=2,
stride=2, padding=1,
padding=1, output_padding=1,
output_padding=1, bias_attr=use_bias),
bias_attr=use_bias), norm_layer(int(ngf * mult / 2)), norm_layer(int(ngf * mult / 2)),
nn.ReLU() nn.ReLU()
]) ])
self.model.extend([nn.ReflectionPad2d([3, 3, 3, 3])]) self.model.extend([nn.ReflectionPad2d([3, 3, 3, 3])])
self.model.extend([nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]) self.model.extend([nn.Conv2D(ngf, output_nc, kernel_size=7, padding=0)])
self.model.extend([nn.Tanh()]) self.model.extend([nn.Tanh()])
def forward(self, inputs): def forward(self, inputs):
...@@ -108,9 +109,9 @@ class MobileResnetBlock(nn.Layer): ...@@ -108,9 +109,9 @@ class MobileResnetBlock(nn.Layer):
p = 0 p = 0
if self.padding_type == 'reflect': if self.padding_type == 'reflect':
self.conv_block.extend([nn.ReflectionPad2d([1, 1, 1, 1])]) self.conv_block.extend([nn.Pad2D([1, 1, 1, 1], mode='reflect')])
elif self.padding_type == 'replicate': elif self.padding_type == 'replicate':
self.conv_block.extend([nn.ReplicationPad2d([1, 1, 1, 1])]) self.conv_block.extend([nn.Pad2D([1, 1, 1, 1], mode='replicate')])
elif self.padding_type == 'zero': elif self.padding_type == 'zero':
p = 1 p = 1
else: else:
...@@ -118,12 +119,13 @@ class MobileResnetBlock(nn.Layer): ...@@ -118,12 +119,13 @@ class MobileResnetBlock(nn.Layer):
self.padding_type) self.padding_type)
self.conv_block.extend([ self.conv_block.extend([
SeparableConv2D( SeparableConv2D(num_channels=in_c,
num_channels=in_c, num_filters=out_c,
num_filters=out_c, filter_size=3,
filter_size=3, padding=p,
padding=p, stride=1),
stride=1), norm_layer(out_c), nn.ReLU() norm_layer(out_c),
nn.ReLU()
]) ])
self.conv_block.extend([nn.Dropout(0.5)]) self.conv_block.extend([nn.Dropout(0.5)])
...@@ -139,12 +141,12 @@ class MobileResnetBlock(nn.Layer): ...@@ -139,12 +141,12 @@ class MobileResnetBlock(nn.Layer):
self.padding_type) self.padding_type)
self.conv_block.extend([ self.conv_block.extend([
SeparableConv2D( SeparableConv2D(num_channels=out_c,
num_channels=out_c, num_filters=in_c,
num_filters=in_c, filter_size=3,
filter_size=3, padding=p,
padding=p, stride=1),
stride=1), norm_layer(in_c) norm_layer(in_c)
]) ])
def forward(self, inputs): def forward(self, inputs):
...@@ -154,6 +156,7 @@ class MobileResnetBlock(nn.Layer): ...@@ -154,6 +156,7 @@ class MobileResnetBlock(nn.Layer):
out = inputs + y out = inputs + y
return out return out
class SeparableConv2D(nn.Layer): class SeparableConv2D(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
...@@ -161,14 +164,14 @@ class SeparableConv2D(nn.Layer): ...@@ -161,14 +164,14 @@ class SeparableConv2D(nn.Layer):
filter_size, filter_size,
stride=1, stride=1,
padding=0, padding=0,
norm_layer=InstanceNorm, norm_layer=nn.InstanceNorm2D,
use_bias=True, use_bias=True,
scale_factor=1, scale_factor=1,
stddev=0.02): stddev=0.02):
super(SeparableConv2D, self).__init__() super(SeparableConv2D, self).__init__()
self.conv = nn.LayerList([ self.conv = nn.LayerList([
nn.Conv2d( nn.Conv2D(
in_channels=num_channels, in_channels=num_channels,
out_channels=num_channels * scale_factor, out_channels=num_channels * scale_factor,
kernel_size=filter_size, kernel_size=filter_size,
...@@ -176,22 +179,20 @@ class SeparableConv2D(nn.Layer): ...@@ -176,22 +179,20 @@ class SeparableConv2D(nn.Layer):
padding=padding, padding=padding,
groups=num_channels, groups=num_channels,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(loc=0.0, scale=stddev)),
loc=0.0, scale=stddev)),
bias_attr=use_bias) bias_attr=use_bias)
]) ])
self.conv.extend([norm_layer(num_channels * scale_factor)]) self.conv.extend([norm_layer(num_channels * scale_factor)])
self.conv.extend([ self.conv.extend([
nn.Conv2d( nn.Conv2D(
in_channels=num_channels * scale_factor, in_channels=num_channels * scale_factor,
out_channels=num_filters, out_channels=num_filters,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal( initializer=nn.initializer.Normal(loc=0.0, scale=stddev)),
loc=0.0, scale=stddev)),
bias_attr=use_bias) bias_attr=use_bias)
]) ])
...@@ -199,4 +200,3 @@ class SeparableConv2D(nn.Layer): ...@@ -199,4 +200,3 @@ class SeparableConv2D(nn.Layer):
for sublayer in self.conv: for sublayer in self.conv:
inputs = sublayer(inputs) inputs = sublayer(inputs)
return inputs return inputs
...@@ -67,7 +67,7 @@ class OcclusionAwareGenerator(nn.Layer): ...@@ -67,7 +67,7 @@ class OcclusionAwareGenerator(nn.Layer):
'r' + str(i), 'r' + str(i),
ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1))) ResBlock2d(in_features, kernel_size=(3, 3), padding=(1, 1)))
self.final = nn.Conv2d(block_expansion, self.final = nn.Conv2D(block_expansion,
num_channels, num_channels,
kernel_size=(7, 7), kernel_size=(7, 7),
padding=(3, 3)) padding=(3, 3))
......
...@@ -11,7 +11,7 @@ class TempConv(nn.Layer): ...@@ -11,7 +11,7 @@ class TempConv(nn.Layer):
stride=(1, 1, 1), stride=(1, 1, 1),
padding=(0, 1, 1)): padding=(0, 1, 1)):
super(TempConv, self).__init__() super(TempConv, self).__init__()
self.conv3d = nn.Conv3d(in_planes, self.conv3d = nn.Conv3D(in_planes,
out_planes, out_planes,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
...@@ -26,7 +26,7 @@ class Upsample(nn.Layer): ...@@ -26,7 +26,7 @@ class Upsample(nn.Layer):
def __init__(self, in_planes, out_planes, scale_factor=(1, 2, 2)): def __init__(self, in_planes, out_planes, scale_factor=(1, 2, 2)):
super(Upsample, self).__init__() super(Upsample, self).__init__()
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.conv3d = nn.Conv3d(in_planes, self.conv3d = nn.Conv3D(in_planes,
out_planes, out_planes,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
...@@ -88,13 +88,13 @@ class SourceReferenceAttention(nn.Layer): ...@@ -88,13 +88,13 @@ class SourceReferenceAttention(nn.Layer):
Number of input reference feature vector channels. Number of input reference feature vector channels.
""" """
super(SourceReferenceAttention, self).__init__() super(SourceReferenceAttention, self).__init__()
self.query_conv = nn.Conv3d(in_channels=in_planes_s, self.query_conv = nn.Conv3D(in_channels=in_planes_s,
out_channels=in_planes_s // 8, out_channels=in_planes_s // 8,
kernel_size=1) kernel_size=1)
self.key_conv = nn.Conv3d(in_channels=in_planes_r, self.key_conv = nn.Conv3D(in_channels=in_planes_r,
out_channels=in_planes_r // 8, out_channels=in_planes_r // 8,
kernel_size=1) kernel_size=1)
self.value_conv = nn.Conv3d(in_channels=in_planes_r, self.value_conv = nn.Conv3D(in_channels=in_planes_r,
out_channels=in_planes_r, out_channels=in_planes_r,
kernel_size=1) kernel_size=1)
self.gamma = self.create_parameter( self.gamma = self.create_parameter(
...@@ -128,7 +128,7 @@ class NetworkR(nn.Layer): ...@@ -128,7 +128,7 @@ class NetworkR(nn.Layer):
super(NetworkR, self).__init__() super(NetworkR, self).__init__()
self.layers = nn.Sequential( self.layers = nn.Sequential(
nn.ReplicationPad3d((1, 1, 1, 1, 1, 1)), nn.Pad3D((1, 1, 1, 1, 1, 1), mode='replicate'),
TempConv(1, TempConv(1,
64, 64,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
...@@ -149,7 +149,7 @@ class NetworkR(nn.Layer): ...@@ -149,7 +149,7 @@ class NetworkR(nn.Layer):
TempConv(128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)), TempConv(128, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
TempConv(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)), TempConv(64, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
Upsample(64, 16), Upsample(64, 16),
nn.Conv3d(16, nn.Conv3D(16,
1, 1,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
...@@ -165,7 +165,7 @@ class NetworkC(nn.Layer): ...@@ -165,7 +165,7 @@ class NetworkC(nn.Layer):
super(NetworkC, self).__init__() super(NetworkC, self).__init__()
self.down1 = nn.Sequential( self.down1 = nn.Sequential(
nn.ReplicationPad3d((1, 1, 1, 1, 0, 0)), nn.Pad3D((1, 1, 1, 1, 0, 0), mode='replicate'),
TempConv(1, 64, stride=(1, 2, 2), padding=(0, 0, 0)), TempConv(1, 64, stride=(1, 2, 2), padding=(0, 0, 0)),
TempConv(64, 128), TempConv(128, 128), TempConv(64, 128), TempConv(128, 128),
TempConv(128, 256, stride=(1, 2, 2)), TempConv(256, 256), TempConv(128, 256, stride=(1, 2, 2)), TempConv(256, 256),
...@@ -205,7 +205,7 @@ class NetworkC(nn.Layer): ...@@ -205,7 +205,7 @@ class NetworkC(nn.Layer):
padding=(1, 1, 1))) padding=(1, 1, 1)))
self.up4 = nn.Sequential( self.up4 = nn.Sequential(
Upsample(16, 8), # 1/1 Upsample(16, 8), # 1/1
nn.Conv3d(8, nn.Conv3D(8,
2, 2,
kernel_size=(3, 3, 3), kernel_size=(3, 3, 3),
stride=(1, 1, 1), stride=(1, 1, 1),
......
...@@ -13,7 +13,6 @@ class ResnetGenerator(nn.Layer): ...@@ -13,7 +13,6 @@ class ResnetGenerator(nn.Layer):
code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
""" """
def __init__(self, def __init__(self,
input_nc, input_nc,
output_nc, output_nc,
...@@ -38,14 +37,17 @@ class ResnetGenerator(nn.Layer): ...@@ -38,14 +37,17 @@ class ResnetGenerator(nn.Layer):
norm_layer = build_norm_layer(norm_type) norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial: if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d use_bias = norm_layer.func == nn.InstanceNorm2D
else: else:
use_bias = norm_layer == nn.InstanceNorm2d use_bias = norm_layer == nn.InstanceNorm2D
model = [ model = [
nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"), nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect"),
nn.Conv2d( nn.Conv2D(input_nc,
input_nc, ngf, kernel_size=7, padding=0, bias_attr=use_bias), ngf,
kernel_size=7,
padding=0,
bias_attr=use_bias),
norm_layer(ngf), norm_layer(ngf),
nn.ReLU() nn.ReLU()
] ]
...@@ -54,13 +56,12 @@ class ResnetGenerator(nn.Layer): ...@@ -54,13 +56,12 @@ class ResnetGenerator(nn.Layer):
for i in range(n_downsampling): # add downsampling layers for i in range(n_downsampling): # add downsampling layers
mult = 2**i mult = 2**i
model += [ model += [
nn.Conv2d( nn.Conv2D(ngf * mult,
ngf * mult, ngf * mult * 2,
ngf * mult * 2, kernel_size=3,
kernel_size=3, stride=2,
stride=2, padding=1,
padding=1, bias_attr=use_bias),
bias_attr=use_bias),
norm_layer(ngf * mult * 2), norm_layer(ngf * mult * 2),
nn.ReLU() nn.ReLU()
] ]
...@@ -69,30 +70,28 @@ class ResnetGenerator(nn.Layer): ...@@ -69,30 +70,28 @@ class ResnetGenerator(nn.Layer):
for i in range(n_blocks): # add ResNet blocks for i in range(n_blocks): # add ResNet blocks
model += [ model += [
ResnetBlock( ResnetBlock(ngf * mult,
ngf * mult, padding_type=padding_type,
padding_type=padding_type, norm_layer=norm_layer,
norm_layer=norm_layer, use_dropout=use_dropout,
use_dropout=use_dropout, use_bias=use_bias)
use_bias=use_bias)
] ]
for i in range(n_downsampling): # add upsampling layers for i in range(n_downsampling): # add upsampling layers
mult = 2**(n_downsampling - i) mult = 2**(n_downsampling - i)
model += [ model += [
nn.ConvTranspose2d( nn.Conv2DTranspose(ngf * mult,
ngf * mult, int(ngf * mult / 2),
int(ngf * mult / 2), kernel_size=3,
kernel_size=3, stride=2,
stride=2, padding=1,
padding=1, output_padding=1,
output_padding=1, bias_attr=use_bias),
bias_attr=use_bias),
norm_layer(int(ngf * mult / 2)), norm_layer(int(ngf * mult / 2)),
nn.ReLU() nn.ReLU()
] ]
model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")] model += [nn.Pad2D(padding=[3, 3, 3, 3], mode="reflect")]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Conv2D(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()] model += [nn.Tanh()]
self.model = nn.Sequential(*model) self.model = nn.Sequential(*model)
...@@ -104,7 +103,6 @@ class ResnetGenerator(nn.Layer): ...@@ -104,7 +103,6 @@ class ResnetGenerator(nn.Layer):
class ResnetBlock(nn.Layer): class ResnetBlock(nn.Layer):
"""Define a Resnet block""" """Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block """Initialize the Resnet block
...@@ -137,11 +135,11 @@ class ResnetBlock(nn.Layer): ...@@ -137,11 +135,11 @@ class ResnetBlock(nn.Layer):
elif padding_type == 'zero': elif padding_type == 'zero':
p = 1 p = 1
else: else:
raise NotImplementedError( raise NotImplementedError('padding [%s] is not implemented' %
'padding [%s] is not implemented' % padding_type) padding_type)
conv_block += [ conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias), nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim), norm_layer(dim),
nn.ReLU() nn.ReLU()
] ]
...@@ -154,10 +152,10 @@ class ResnetBlock(nn.Layer): ...@@ -154,10 +152,10 @@ class ResnetBlock(nn.Layer):
elif padding_type == 'zero': elif padding_type == 'zero':
p = 1 p = 1
else: else:
raise NotImplementedError( raise NotImplementedError('padding [%s] is not implemented' %
'padding [%s] is not implemented' % padding_type) padding_type)
conv_block += [ conv_block += [
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias), nn.Conv2D(dim, dim, kernel_size=3, padding=p, bias_attr=use_bias),
norm_layer(dim) norm_layer(dim)
] ]
......
...@@ -10,14 +10,13 @@ class ResidualDenseBlock_5C(nn.Layer): ...@@ -10,14 +10,13 @@ class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True): def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__() super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels # gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias_attr=bias) self.conv1 = nn.Conv2D(nf, gc, 3, 1, 1, bias_attr=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias_attr=bias) self.conv2 = nn.Conv2D(nf + gc, gc, 3, 1, 1, bias_attr=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias) self.conv3 = nn.Conv2D(nf + 2 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias) self.conv4 = nn.Conv2D(nf + 3 * gc, gc, 3, 1, 1, bias_attr=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias) self.conv5 = nn.Conv2D(nf + 4 * gc, nf, 3, 1, 1, bias_attr=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2) self.lrelu = nn.LeakyReLU(negative_slope=0.2)
def forward(self, x): def forward(self, x):
x1 = self.lrelu(self.conv1(x)) x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1))) x2 = self.lrelu(self.conv2(paddle.concat((x, x1), 1)))
...@@ -29,7 +28,6 @@ class ResidualDenseBlock_5C(nn.Layer): ...@@ -29,7 +28,6 @@ class ResidualDenseBlock_5C(nn.Layer):
class RRDB(nn.Layer): class RRDB(nn.Layer):
'''Residual in Residual Dense Block''' '''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32): def __init__(self, nf, gc=32):
super(RRDB, self).__init__() super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc) self.RDB1 = ResidualDenseBlock_5C(nf, gc)
...@@ -42,6 +40,7 @@ class RRDB(nn.Layer): ...@@ -42,6 +40,7 @@ class RRDB(nn.Layer):
out = self.RDB3(out) out = self.RDB3(out)
return out * 0.2 + x return out * 0.2 + x
def make_layer(block, n_layers): def make_layer(block, n_layers):
layers = [] layers = []
for _ in range(n_layers): for _ in range(n_layers):
...@@ -55,14 +54,14 @@ class RRDBNet(nn.Layer): ...@@ -55,14 +54,14 @@ class RRDBNet(nn.Layer):
super(RRDBNet, self).__init__() super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias_attr=True) self.conv_first = nn.Conv2D(in_nc, nf, 3, 1, 1, bias_attr=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb) self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) self.trunk_conv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
#### upsampling #### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) self.upconv1 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) self.upconv2 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias_attr=True) self.HRconv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias_attr=True) self.conv_last = nn.Conv2D(nf, out_nc, 3, 1, 1, bias_attr=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2) self.lrelu = nn.LeakyReLU(negative_slope=0.2)
...@@ -71,8 +70,10 @@ class RRDBNet(nn.Layer): ...@@ -71,8 +70,10 @@ class RRDBNet(nn.Layer):
trunk = self.trunk_conv(self.RRDB_trunk(fea)) trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) fea = self.lrelu(
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
fea = self.lrelu(
self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea))) out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out return out
...@@ -104,12 +104,12 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -104,12 +104,12 @@ class UnetSkipConnectionBlock(nn.Layer):
super(UnetSkipConnectionBlock, self).__init__() super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost self.outermost = outermost
if type(norm_layer) == functools.partial: if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm use_bias = norm_layer.func == nn.InstanceNorm2D
else: else:
use_bias = norm_layer == nn.InstanceNorm use_bias = norm_layer == nn.InstanceNorm2D
if input_nc is None: if input_nc is None:
input_nc = outer_nc input_nc = outer_nc
downconv = nn.Conv2d(input_nc, downconv = nn.Conv2D(input_nc,
inner_nc, inner_nc,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
...@@ -121,7 +121,7 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -121,7 +121,7 @@ class UnetSkipConnectionBlock(nn.Layer):
upnorm = norm_layer(outer_nc) upnorm = norm_layer(outer_nc)
if outermost: if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, upconv = nn.Conv2DTranspose(inner_nc * 2,
outer_nc, outer_nc,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
...@@ -130,7 +130,7 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -130,7 +130,7 @@ class UnetSkipConnectionBlock(nn.Layer):
up = [uprelu, upconv, nn.Tanh()] up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up model = down + [submodule] + up
elif innermost: elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, upconv = nn.Conv2DTranspose(inner_nc,
outer_nc, outer_nc,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
...@@ -140,7 +140,7 @@ class UnetSkipConnectionBlock(nn.Layer): ...@@ -140,7 +140,7 @@ class UnetSkipConnectionBlock(nn.Layer):
up = [uprelu, upconv, upnorm] up = [uprelu, upconv, upnorm]
model = down + up model = down + up
else: else:
upconv = nn.ConvTranspose2d(inner_nc * 2, upconv = nn.Conv2DTranspose(inner_nc * 2,
outer_nc, outer_nc,
kernel_size=4, kernel_size=4,
stride=2, stride=2,
......
...@@ -11,10 +11,12 @@ ...@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.vision.models import vgg16
from .base_model import BaseModel from .base_model import BaseModel
from .builder import MODELS from .builder import MODELS
...@@ -26,92 +28,62 @@ from ..solver import build_optimizer ...@@ -26,92 +28,62 @@ from ..solver import build_optimizer
from ..utils.image_pool import ImagePool from ..utils.image_pool import ImagePool
from ..utils.preprocess import * from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset from ..datasets.makeup_dataset import MakeupDataset
import numpy as np
from .vgg import vgg16
@MODELS.register() @MODELS.register()
class MakeupModel(BaseModel): class MakeupModel(BaseModel):
""" """
This class implements the CycleGAN model, for learning image-to-image translation without paired data. PSGAN paper: https://arxiv.org/pdf/1909.06956.pdf
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
""" """
def __init__(self, opt): def __init__(self, cfg):
"""Initialize the CycleGAN class. """Initialize the PSGAN class.
Parameters: Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions cfg (dict)-- config of model.
""" """
BaseModel.__init__(self, opt) super(MakeupModel, self).__init__(cfg)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_A', 'rec_A']
visual_names_B = ['real_B', 'fake_B', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
self.vgg = vgg16(pretrained=True)
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G']
# define networks (both Generators and discriminators) # define networks (both Generators and discriminators)
# The naming is different from those used in the paper. # The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG = build_generator(opt.model.generator) self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.netG, init_type='xavier', init_gain=1.0) init_weights(self.nets['netG'], init_type='xavier', init_gain=1.0)
if self.isTrain: # define discriminators if self.is_train: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator) vgg = vgg16(pretrained=True)
self.netD_B = build_discriminator(opt.model.discriminator) self.vgg = vgg.features
init_weights(self.netD_A, init_type='xavier', init_gain=1.0) self.nets['netD_A'] = build_discriminator(cfg.model.discriminator)
init_weights(self.netD_B, init_type='xavier', init_gain=1.0) self.nets['netD_B'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD_A'], init_type='xavier', init_gain=1.0)
init_weights(self.nets['netD_B'], init_type='xavier', init_gain=1.0)
if self.isTrain:
self.fake_A_pool = ImagePool( self.fake_A_pool = ImagePool(
opt.dataset.train.pool_size cfg.dataset.train.pool_size
) # create image buffer to store previously generated images ) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool( self.fake_B_pool = ImagePool(
opt.dataset.train.pool_size cfg.dataset.train.pool_size
) # create image buffer to store previously generated images ) # create image buffer to store previously generated images
# define loss functions # define loss functions
self.criterionGAN = GANLoss( self.criterionGAN = GANLoss(
opt.model.gan_mode) #.to(self.device) # define GAN loss. cfg.model.gan_mode) #.to(self.device) # define GAN loss.
self.criterionCycle = paddle.nn.L1Loss() self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss() self.criterionIdt = paddle.nn.L1Loss()
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
self.criterionL2 = paddle.nn.MSELoss() self.criterionL2 = paddle.nn.MSELoss()
self.build_lr_scheduler() self.build_lr_scheduler()
self.optimizer_G = build_optimizer( self.optimizers['optimizer_G'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netG.parameters()) parameter_list=self.nets['netG'].parameters())
# self.optimizer_D = paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD_A.parameters() + self.netD_B.parameters(), beta1=opt.beta1) self.optimizers['optimizer_DA'] = build_optimizer(
self.optimizer_DA = build_optimizer( cfg.optimizer,
opt.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netD_A.parameters()) parameter_list=self.nets['netD_A'].parameters())
self.optimizer_DB = build_optimizer( self.optimizers['optimizer_DB'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netD_B.parameters()) parameter_list=self.nets['netD_B'].parameters())
self.optimizers.append(self.optimizer_G)
# self.optimizers.append(self.optimizer_D)
self.optimizers.append(self.optimizer_DA)
self.optimizers.append(self.optimizer_DB)
self.optimizer_names.extend(
['optimizer_G', 'optimizer_DA', 'optimizer_DB'])
def set_input(self, input): def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
...@@ -129,37 +101,47 @@ class MakeupModel(BaseModel): ...@@ -129,37 +101,47 @@ class MakeupModel(BaseModel):
self.mask_A_aug = paddle.to_tensor(input['mask_A_aug']) self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
self.mask_B_aug = paddle.to_tensor(input['mask_B_aug']) self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1]) self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
if self.isTrain: if self.is_train:
self.mask_A = paddle.to_tensor(input['mask_A']) self.mask_A = paddle.to_tensor(input['mask_A'])
self.mask_B = paddle.to_tensor(input['mask_B']) self.mask_B = paddle.to_tensor(input['mask_B'])
self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A']) self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B']) self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])
#self.hm_gt_A = self.hm_gt_A_lip + self.hm_gt_A_skin + self.hm_gt_A_eye
#self.hm_gt_B = self.hm_gt_B_lip + self.hm_gt_B_skin + self.hm_gt_B_eye
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_A, amm = self.netG(self.real_A, self.real_B, self.P_A, self.fake_A, amm = self.nets['netG'](self.real_A, self.real_B, self.P_A,
self.P_B, self.c_m, self.mask_A_aug, self.P_B, self.c_m,
self.mask_B_aug) # G_A(A) self.mask_A_aug,
self.fake_B, _ = self.netG(self.real_B, self.real_A, self.P_B, self.P_A, self.mask_B_aug) # G_A(A)
self.c_m_t, self.mask_A_aug, self.fake_B, _ = self.nets['netG'](self.real_B, self.real_A, self.P_B,
self.mask_B_aug) # G_A(A) self.P_A, self.c_m_t,
self.rec_A, _ = self.netG(self.fake_A, self.real_A, self.P_A, self.P_A, self.mask_A_aug,
self.c_m_idt_a, self.mask_A_aug, self.mask_B_aug) # G_A(A)
self.mask_B_aug) # G_A(A) self.rec_A, _ = self.nets['netG'](self.fake_A, self.real_A, self.P_A,
self.rec_B, _ = self.netG(self.fake_B, self.real_B, self.P_B, self.P_B, self.P_A, self.c_m_idt_a,
self.c_m_idt_b, self.mask_A_aug, self.mask_A_aug,
self.mask_B_aug) # G_A(A) self.mask_B_aug) # G_A(A)
self.rec_B, _ = self.nets['netG'](self.fake_B, self.real_B, self.P_B,
self.P_B, self.c_m_idt_b,
self.mask_A_aug,
self.mask_B_aug) # G_A(A)
# visual
self.visual_items['real_A'] = self.real_A
self.visual_items['fake_B'] = self.fake_B
self.visual_items['rec_A'] = self.rec_A
self.visual_items['real_B'] = self.real_B
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def forward_test(self, input): def forward_test(self, input):
''' '''
not implement now not implement now
''' '''
return self.netG(input['image_A'], input['image_B'], input['P_A'], return self.nets['netG'](input['image_A'], input['image_B'],
input['P_B'], input['consis_mask'], input['P_A'], input['P_B'],
input['mask_A_aug'], input['mask_B_aug']) input['consis_mask'], input['mask_A_aug'],
input['mask_B_aug'])
def test(self, input): def test(self, input):
"""Forward function used in test time. """Forward function used in test time.
...@@ -195,51 +177,52 @@ class MakeupModel(BaseModel): ...@@ -195,51 +177,52 @@ class MakeupModel(BaseModel):
def backward_D_A(self): def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A""" """Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B) fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) self.loss_D_A = self.backward_D_basic(self.nets['netD_A'], self.real_B,
fake_B)
self.losses['D_A_loss'] = self.loss_D_A self.losses['D_A_loss'] = self.loss_D_A
def backward_D_B(self): def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B""" """Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A) fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) self.loss_D_B = self.backward_D_basic(self.nets['netD_B'], self.real_A,
fake_A)
self.losses['D_B_loss'] = self.loss_D_B self.losses['D_B_loss'] = self.loss_D_B
def backward_G(self): def backward_G(self):
"""Calculate the loss for generators G_A and G_B""" """Calculate the loss for generators G_A and G_B"""
'''
self.loss_names = [ lambda_idt = self.cfg.lambda_identity
'G_A_vgg', lambda_A = self.cfg.lambda_A
'G_B_vgg', lambda_B = self.cfg.lambda_B
'G_bg_consis'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A', 'amm_a']
visual_names_B = ['real_B', 'fake_A', 'rec_B', 'amm_b']
'''
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
lambda_vgg = 5e-3 lambda_vgg = 5e-3
# Identity loss # Identity loss
if lambda_idt > 0: if lambda_idt > 0:
self.idt_A, _ = self.netG(self.real_A, self.real_A, self.P_A, self.idt_A, _ = self.nets['netG'](self.real_A, self.real_A,
self.P_A, self.c_m_idt_a, self.mask_A_aug, self.P_A, self.P_A,
self.mask_B_aug) # G_A(A) self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_A = self.criterionIdt( self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_A) * lambda_A * lambda_idt self.idt_A, self.real_A) * lambda_A * lambda_idt
self.idt_B, _ = self.netG(self.real_B, self.real_B, self.P_B, self.idt_B, _ = self.nets['netG'](self.real_B, self.real_B,
self.P_B, self.c_m_idt_b, self.mask_A_aug, self.P_B, self.P_B,
self.mask_B_aug) # G_A(A) self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_B = self.criterionIdt( self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_B) * lambda_B * lambda_idt self.idt_B, self.real_B) * lambda_B * lambda_idt
# visual
self.visual_items['idt_A'] = self.idt_A
self.visual_items['idt_B'] = self.idt_B
else: else:
self.loss_idt_A = 0 self.loss_idt_A = 0
self.loss_idt_B = 0 self.loss_idt_B = 0
# GAN loss D_A(G_A(A)) # GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_A), True) self.loss_G_A = self.criterionGAN(self.nets['netD_A'](self.fake_A),
True)
# GAN loss D_B(G_B(B)) # GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_B), True) self.loss_G_B = self.criterionGAN(self.nets['netD_B'](self.fake_B),
True)
# Forward cycle loss || G_B(G_A(A)) - A|| # Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A self.real_A) * lambda_A
...@@ -381,27 +364,24 @@ class MakeupModel(BaseModel): ...@@ -381,27 +364,24 @@ class MakeupModel(BaseModel):
self.forward() # compute fake images and reconstruction images. self.forward() # compute fake images and reconstruction images.
# G_A and G_B # G_A and G_B
self.set_requires_grad( self.set_requires_grad(
[self.netD_A, self.netD_B], [self.nets['netD_A'], self.nets['netD_B']],
False) # Ds require no gradients when optimizing Gs False) # Ds require no gradients when optimizing Gs
# self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero # self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.minimize( self.optimizers['optimizer_G'].minimize(
self.loss_G) #step() # update G_A and G_B's weights self.loss_G) #step() # update G_A and G_B's weights
self.optimizer_G.clear_gradients() self.optimizers['optimizer_G'].clear_gradients()
# self.optimizer_G.clear_gradients()
# D_A and D_B # D_A and D_B
# self.set_requires_grad([self.netD_A, self.netD_B], True) self.set_requires_grad(self.nets['netD_A'], True)
self.set_requires_grad(self.netD_A, True)
# self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero # self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A self.backward_D_A() # calculate gradients for D_A
self.optimizer_DA.minimize( self.optimizers['optimizer_DA'].minimize(
self.loss_D_A) #step() # update D_A and D_B's weights self.loss_D_A) #step() # update D_A and D_B's weights
self.optimizer_DA.clear_gradients() #zero_g self.optimizers['optimizer_DA'].clear_gradients() #zero_g
self.set_requires_grad(self.netD_B, True) self.set_requires_grad(self.nets['netD_B'], True)
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_B() # calculate graidents for D_B self.backward_D_B() # calculate graidents for D_B
self.optimizer_DB.minimize( self.optimizers['optimizer_DB'].minimize(
self.loss_D_B) #step() # update D_A and D_B's weights self.loss_D_B) #step() # update D_A and D_B's weights
self.optimizer_DB.clear_gradients( self.optimizers['optimizer_DB'].clear_gradients(
) #zero_grad() # set D_A and D_B's gradients to zero ) #zero_grad() # set D_A and D_B's gradients to zero
...@@ -23,52 +23,38 @@ class Pix2PixModel(BaseModel): ...@@ -23,52 +23,38 @@ class Pix2PixModel(BaseModel):
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
""" """
def __init__(self, opt): def __init__(self, cfg):
"""Initialize the pix2pix class. """Initialize the pix2pix class.
Parameters: Parameters:
opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict opt (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
""" """
BaseModel.__init__(self, opt) super(Pix2PixModel, self).__init__(cfg)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real_A', 'fake_B', 'real_B']
# specify the models you want to save to the disk.
if self.isTrain:
self.model_names = ['G', 'D']
else:
# during test time, only load G
self.model_names = ['G']
# define networks (both generator and discriminator) # define networks (both generator and discriminator)
self.netG = build_generator(opt.model.generator) self.nets['netG'] = build_generator(cfg.model.generator)
init_weights(self.netG) init_weights(self.nets['netG'])
# define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
if self.isTrain: if self.is_train:
self.netD = build_discriminator(opt.model.discriminator) self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.netD) init_weights(self.nets['netD'])
if self.isTrain: if self.is_train:
self.losses = {} self.losses = {}
# define loss functions # define loss functions
self.criterionGAN = GANLoss(opt.model.gan_mode) self.criterionGAN = GANLoss(cfg.model.gan_mode)
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
# build optimizers # build optimizers
self.build_lr_scheduler() self.build_lr_scheduler()
self.optimizer_G = build_optimizer( self.optimizers['optimizer_G'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netG.parameters()) parameter_list=self.nets['netG'].parameters())
self.optimizer_D = build_optimizer( self.optimizers['optimizer_D'] = build_optimizer(
opt.optimizer, cfg.optimizer,
self.lr_scheduler, self.lr_scheduler,
parameter_list=self.netD.parameters()) parameter_list=self.nets['netD'].parameters())
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.optimizer_names.extend(['optimizer_G', 'optimizer_D'])
def set_input(self, input): def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps. """Unpack input data from the dataloader and perform necessary pre-processing steps.
...@@ -79,39 +65,40 @@ class Pix2PixModel(BaseModel): ...@@ -79,39 +65,40 @@ class Pix2PixModel(BaseModel):
The option 'direction' can be used to swap images in domain A and domain B. The option 'direction' can be used to swap images in domain A and domain B.
""" """
AtoB = self.opt.dataset.train.direction == 'AtoB' AtoB = self.cfg.dataset.train.direction == 'AtoB'
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A']) # TODO: replace to_varialbe with to_tensor
self.real_A = paddle.fluid.dygraph.to_variable(
input['A' if AtoB else 'B'])
self.real_B = paddle.fluid.dygraph.to_variable(
input['B' if AtoB else 'A'])
self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self): def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>.""" """Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A) self.fake_B = self.nets['netG'](self.real_A) # G(A)
def forward_test(self, input): # put items to visual dict
input = paddle.to_tensor(input) self.visual_items['fake_B'] = self.fake_B
return self.netG(input) self.visual_items['real_A'] = self.real_A
self.visual_items['real_B'] = self.real_B
def backward_D(self): def backward_D(self):
"""Calculate GAN loss for the discriminator""" """Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B # Fake; stop backprop to the generator by detaching fake_B
# use conditional GANs; we need to feed both input and output to the discriminator # use conditional GANs; we need to feed both input and output to the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB.detach()) pred_fake = self.nets['netD'](fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False) self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real # Real
real_AB = paddle.concat((self.real_A, self.real_B), 1) real_AB = paddle.concat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB) pred_real = self.nets['netD'](real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True) self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
if ParallelEnv().nranks > 1:
self.loss_D = self.netD.scale_loss(self.loss_D) self.loss_D.backward()
self.loss_D.backward()
self.netD.apply_collective_grads()
else:
self.loss_D.backward()
self.losses['D_fake_loss'] = self.loss_D_fake self.losses['D_fake_loss'] = self.loss_D_fake
self.losses['D_real_loss'] = self.loss_D_real self.losses['D_real_loss'] = self.loss_D_real
...@@ -120,21 +107,16 @@ class Pix2PixModel(BaseModel): ...@@ -120,21 +107,16 @@ class Pix2PixModel(BaseModel):
"""Calculate GAN and L1 loss for the generator""" """Calculate GAN and L1 loss for the generator"""
# First, G(A) should fake the discriminator # First, G(A) should fake the discriminator
fake_AB = paddle.concat((self.real_A, self.fake_B), 1) fake_AB = paddle.concat((self.real_A, self.fake_B), 1)
pred_fake = self.netD(fake_AB) pred_fake = self.nets['netD'](fake_AB)
self.loss_G_GAN = self.criterionGAN(pred_fake, True) self.loss_G_GAN = self.criterionGAN(pred_fake, True)
# Second, G(A) = B # Second, G(A) = B
self.loss_G_L1 = self.criterionL1(self.fake_B, self.loss_G_L1 = self.criterionL1(self.fake_B,
self.real_B) * self.opt.lambda_L1 self.real_B) * self.cfg.lambda_L1
# combine loss and calculate gradients # combine loss and calculate gradients
self.loss_G = self.loss_G_GAN + self.loss_G_L1 self.loss_G = self.loss_G_GAN + self.loss_G_L1
if ParallelEnv().nranks > 1: self.loss_G.backward()
self.loss_G = self.netG.scale_loss(self.loss_G)
self.loss_G.backward()
self.netG.apply_collective_grads()
else:
self.loss_G.backward()
self.losses['G_adv_loss'] = self.loss_G_GAN self.losses['G_adv_loss'] = self.loss_G_GAN
self.losses['G_L1_loss'] = self.loss_G_L1 self.losses['G_L1_loss'] = self.loss_G_L1
...@@ -144,13 +126,13 @@ class Pix2PixModel(BaseModel): ...@@ -144,13 +126,13 @@ class Pix2PixModel(BaseModel):
self.forward() self.forward()
# update D # update D
self.set_requires_grad(self.netD, True) self.set_requires_grad(self.nets['netD'], True)
self.optimizer_D.clear_gradients() self.optimizers['optimizer_D'].clear_grad()
self.backward_D() self.backward_D()
self.optimizer_D.minimize(self.loss_D) self.optimizers['optimizer_D'].step()
# update G # update G
self.set_requires_grad(self.netD, False) self.set_requires_grad(self.nets['netD'], False)
self.optimizer_G.clear_gradients() self.optimizers['optimizer_G'].clear_grad()
self.backward_G() self.backward_G()
self.optimizer_G.minimize(self.loss_G) self.optimizers['optimizer_G'].step()
...@@ -30,7 +30,7 @@ class SRModel(BaseModel): ...@@ -30,7 +30,7 @@ class SRModel(BaseModel):
self.loss_names = ['l_total'] self.loss_names = ['l_total']
self.optimizers = [] self.optimizers = []
if self.isTrain: if self.is_train:
self.criterionL1 = paddle.nn.L1Loss() self.criterionL1 = paddle.nn.L1Loss()
self.build_lr_scheduler() self.build_lr_scheduler()
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddle.utils.download import get_weights_path_from_url
from paddle.vision.models.vgg import make_layers
cfg = [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
512, 512, 'M'
]
model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
'89bbffc0f87d260be9b8cdc169c991c4')
}
class VGG(nn.Layer):
def __init__(self, features):
super(VGG, self).__init__()
self.features = features
def forward(self, x):
x = self.features(x)
return x
def vgg16(pretrained=False):
features = make_layers(cfg)
model = VGG(features)
if pretrained:
weight_path = get_weights_path_from_url(model_urls['vgg16'][0],
model_urls['vgg16'][1])
param = paddle.load(weight_path)
model.load_dict(param)
return model
...@@ -25,13 +25,13 @@ class DenseMotionNetwork(nn.Layer): ...@@ -25,13 +25,13 @@ class DenseMotionNetwork(nn.Layer):
max_features=max_features, max_features=max_features,
num_blocks=num_blocks) num_blocks=num_blocks)
self.mask = nn.Conv2d(self.hourglass.out_filters, self.mask = nn.Conv2D(self.hourglass.out_filters,
num_kp + 1, num_kp + 1,
kernel_size=(7, 7), kernel_size=(7, 7),
padding=(3, 3)) padding=(3, 3))
if estimate_occlusion_map: if estimate_occlusion_map:
self.occlusion = nn.Conv2d(self.hourglass.out_filters, self.occlusion = nn.Conv2D(self.hourglass.out_filters,
1, 1,
kernel_size=(7, 7), kernel_size=(7, 7),
padding=(3, 3)) padding=(3, 3))
......
...@@ -52,16 +52,16 @@ class ResBlock2d(nn.Layer): ...@@ -52,16 +52,16 @@ class ResBlock2d(nn.Layer):
""" """
def __init__(self, in_features, kernel_size, padding): def __init__(self, in_features, kernel_size, padding):
super(ResBlock2d, self).__init__() super(ResBlock2d, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_features, self.conv1 = nn.Conv2D(in_channels=in_features,
out_channels=in_features, out_channels=in_features,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding) padding=padding)
self.conv2 = nn.Conv2d(in_channels=in_features, self.conv2 = nn.Conv2D(in_channels=in_features,
out_channels=in_features, out_channels=in_features,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding) padding=padding)
self.norm1 = nn.BatchNorm2d(in_features) self.norm1 = nn.BatchNorm2D(in_features)
self.norm2 = nn.BatchNorm2d(in_features) self.norm2 = nn.BatchNorm2D(in_features)
def forward(self, x): def forward(self, x):
out = self.norm1(x) out = self.norm1(x)
...@@ -86,12 +86,12 @@ class UpBlock2d(nn.Layer): ...@@ -86,12 +86,12 @@ class UpBlock2d(nn.Layer):
groups=1): groups=1):
super(UpBlock2d, self).__init__() super(UpBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features, out_channels=out_features,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding, padding=padding,
groups=groups) groups=groups)
self.norm = nn.BatchNorm2d(out_features) self.norm = nn.BatchNorm2D(out_features)
def forward(self, x): def forward(self, x):
out = F.interpolate(x, scale_factor=2) out = F.interpolate(x, scale_factor=2)
...@@ -112,13 +112,13 @@ class DownBlock2d(nn.Layer): ...@@ -112,13 +112,13 @@ class DownBlock2d(nn.Layer):
padding=1, padding=1,
groups=1): groups=1):
super(DownBlock2d, self).__init__() super(DownBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features, out_channels=out_features,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding, padding=padding,
groups=groups) groups=groups)
self.norm = nn.BatchNorm2d(out_features) self.norm = nn.BatchNorm2D(out_features)
self.pool = nn.AvgPool2d(kernel_size=(2, 2)) self.pool = nn.AvgPool2D(kernel_size=(2, 2))
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
...@@ -139,12 +139,12 @@ class SameBlock2d(nn.Layer): ...@@ -139,12 +139,12 @@ class SameBlock2d(nn.Layer):
kernel_size=3, kernel_size=3,
padding=1): padding=1):
super(SameBlock2d, self).__init__() super(SameBlock2d, self).__init__()
self.conv = nn.Conv2d(in_channels=in_features, self.conv = nn.Conv2D(in_channels=in_features,
out_channels=out_features, out_channels=out_features,
kernel_size=kernel_size, kernel_size=kernel_size,
padding=padding, padding=padding,
groups=groups) groups=groups)
self.norm = nn.BatchNorm2d(out_features) self.norm = nn.BatchNorm2D(out_features)
def forward(self, x): def forward(self, x):
out = self.conv(x) out = self.conv(x)
......
...@@ -26,14 +26,14 @@ class KPDetector(nn.Layer): ...@@ -26,14 +26,14 @@ class KPDetector(nn.Layer):
max_features=max_features, max_features=max_features,
num_blocks=num_blocks) num_blocks=num_blocks)
self.kp = nn.Conv2d(in_channels=self.predictor.out_filters, self.kp = nn.Conv2D(in_channels=self.predictor.out_filters,
out_channels=num_kp, out_channels=num_kp,
kernel_size=(7, 7), kernel_size=(7, 7),
padding=pad) padding=pad)
if estimate_jacobian: if estimate_jacobian:
self.num_jacobian_maps = 1 if single_jacobian_map else num_kp self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
self.jacobian = nn.Conv2d(in_channels=self.predictor.out_filters, self.jacobian = nn.Conv2D(in_channels=self.predictor.out_filters,
out_channels=4 * self.num_jacobian_maps, out_channels=4 * self.num_jacobian_maps,
kernel_size=(7, 7), kernel_size=(7, 7),
padding=pad) padding=pad)
......
...@@ -21,22 +21,21 @@ def build_norm_layer(norm_type='instance'): ...@@ -21,22 +21,21 @@ def build_norm_layer(norm_type='instance'):
if norm_type == 'batch': if norm_type == 'batch':
norm_layer = functools.partial( norm_layer = functools.partial(
nn.BatchNorm, nn.BatchNorm,
weight_attr=paddle.ParamAttr( param_attr=paddle.ParamAttr(
initializer=nn.initializer.Normal(1.0, 0.02)), initializer=nn.initializer.Normal(1.0, 0.02)),
bias_attr=paddle.ParamAttr( bias_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(0.0)), initializer=nn.initializer.Constant(0.0)),
trainable_statistics=True) trainable_statistics=True)
elif norm_type == 'instance': elif norm_type == 'instance':
norm_layer = functools.partial( norm_layer = functools.partial(
nn.InstanceNorm2d, nn.InstanceNorm2D,
weight_attr=paddle.ParamAttr( weight_attr=paddle.ParamAttr(
initializer=nn.initializer.Constant(1.0), initializer=nn.initializer.Constant(1.0),
learning_rate=0.0, learning_rate=0.0,
trainable=False), trainable=False),
bias_attr=paddle.ParamAttr( bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0),
initializer=nn.initializer.Constant(0.0), learning_rate=0.0,
learning_rate=0.0, trainable=False))
trainable=False))
elif norm_type == 'spectral': elif norm_type == 'spectral':
norm_layer = functools.partial(Spectralnorm) norm_layer = functools.partial(Spectralnorm)
elif norm_type == 'none': elif norm_type == 'none':
...@@ -44,6 +43,6 @@ def build_norm_layer(norm_type='instance'): ...@@ -44,6 +43,6 @@ def build_norm_layer(norm_type='instance'):
def norm_layer(x): def norm_layer(x):
return Identity() return Identity()
else: else:
raise NotImplementedError( raise NotImplementedError('normalization layer [%s] is not found' %
'normalization layer [%s] is not found' % norm_type) norm_type)
return norm_layer return norm_layer
...@@ -15,7 +15,8 @@ def save(state_dicts, file_name): ...@@ -15,7 +15,8 @@ def save(state_dicts, file_name):
for k, v in state_dict.items(): for k, v in state_dict.items():
if isinstance( if isinstance(
v, (paddle.framework.Variable, paddle.fluid.core.VarBase)): v,
(paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
model_dict[k] = v.numpy() model_dict[k] = v.numpy()
else: else:
model_dict[k] = v model_dict[k] = v
...@@ -24,8 +25,9 @@ def save(state_dicts, file_name): ...@@ -24,8 +25,9 @@ def save(state_dicts, file_name):
final_dict = {} final_dict = {}
for k, v in state_dicts.items(): for k, v in state_dicts.items():
if isinstance(v, if isinstance(
(paddle.framework.Variable, paddle.fluid.core.VarBase)): v,
(paddle.fluid.framework.Variable, paddle.fluid.core.VarBase)):
final_dict = convert(state_dicts) final_dict = convert(state_dicts)
break break
elif isinstance(v, dict): elif isinstance(v, dict):
......
...@@ -122,11 +122,7 @@ def cal_hist(image): ...@@ -122,11 +122,7 @@ def cal_hist(image):
hists = [] hists = []
for i in range(0, 3): for i in range(0, 3):
channel = image[i] channel = image[i]
# channel = image[i, :, :]
#channel = torch.from_numpy(channel)
hist, _ = np.histogram(channel, bins=256, range=(0, 255)) hist, _ = np.histogram(channel, bins=256, range=(0, 255))
#hist = torch.histc(channel, bins=256, min=0, max=256)
# refHist=hist.view(256,1)
sum = hist.sum() sum = hist.sum()
pdf = [v / sum for v in hist] pdf = [v / sum for v in hist]
for i in range(1, 256): for i in range(1, 256):
......
...@@ -2,14 +2,14 @@ import os ...@@ -2,14 +2,14 @@ import os
import time import time
import paddle import paddle
from paddle.distributed import ParallelEnv
from .logger import setup_logger from .logger import setup_logger
def setup(args, cfg): def setup(args, cfg):
if args.evaluate_only: if args.evaluate_only:
cfg.isTrain = False cfg.is_train = False
else:
cfg.is_train = True
cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime()) cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
cfg.output_dir = os.path.join(cfg.output_dir, cfg.output_dir = os.path.join(cfg.output_dir,
...@@ -19,6 +19,7 @@ def setup(args, cfg): ...@@ -19,6 +19,7 @@ def setup(args, cfg):
logger.info('Configs: {}'.format(cfg)) logger.info('Configs: {}'.format(cfg))
place = paddle.CUDAPlace(ParallelEnv().dev_id) \ if paddle.is_compiled_with_cuda():
if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0) paddle.set_device('gpu')
paddle.disable_static(place) else:
paddle.set_device('cpu')
...@@ -2,3 +2,5 @@ tqdm ...@@ -2,3 +2,5 @@ tqdm
PyYAML>=5.1 PyYAML>=5.1
scikit-image>=0.14.0 scikit-image>=0.14.0
scipy>=1.1.0 scipy>=1.1.0
opencv-python
imageio-ffmpeg
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册