提交 8797a98e 编写于 作者: L lijianshe02

add psgan code

上级 b22c19ed
epochs: 200
isTrain: False
output_dir: tmp
checkpoints_dir: checkpoints
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
model:
name: MakeupModel
generator:
name: GeneratorPSGANAttention
conv_dim: 64
repeat_num: 6
discriminator:
name: NLayerDiscriminator
ndf: 64
n_layers: 3
input_nc: 3
norm_type: batch
gan_mode: lsgan
dataset:
train:
name: MakeupDataset
trans_size: 256
dataroot: MT-Dataset
cls_list: [non-makeup, makeup]
phase: train
pool_size: 16
test:
name: MakeupDataset
trans_size: 256
dataroot: MT-Dataset
cls_list: [non-makeup, makeup]
phase: test
pool_size: 16
optimizer:
name: Adam
beta1: 0.5
lr_scheduler:
name: linear
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
log_config:
interval: 10
visiual_interval: 500
snapshot_config:
interval: 1
from .unpaired_dataset import UnpairedDataset
from .single_dataset import SingleDataset
from .paired_dataset import PairedDataset
from .sr_image_dataset import SRImageDataset
\ No newline at end of file
from .sr_image_dataset import SRImageDataset
from .makeup_dataset import MakeupDataset
import cv2
import os.path
from .base_dataset import BaseDataset, get_transform
from .transforms.makeup_transforms import get_makeup_transform
import paddle.vision.transforms as T
from .image_folder import make_dataset
from PIL import Image
import random
import numpy as np
from ..utils.preprocess import *
from .builder import DATASETS
@DATASETS.register()
class MakeupDataset(BaseDataset):
"""
This dataset class can load unaligned/unpaired datasets.
It requires two directories to host training images from domain A '/path/to/data/trainA'
and from domain B '/path/to/data/trainB' respectively.
You can train the model with the dataset flag '--dataroot /path/to/data'.
Similarly, you need to prepare two directories:
'/path/to/data/testA' and '/path/to/data/testB' during test time.
"""
def __init__(self, cfg):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, cfg)
self.image_path = cfg.dataroot
self.mode = cfg.phase
self.transform = get_makeup_transform(cfg)
self.norm = T.Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5])
self.transform_mask = get_makeup_transform(cfg, pic="mask")
self.trans_size = cfg.trans_size
self.cls_list = cfg.cls_list
self.cls_A = self.cls_list[0]
self.cls_B = self.cls_list[1]
for cls in self.cls_list:
setattr(
self, cls + "_list_path",
os.path.join(self.image_path, self.mode + '_' + cls + ".txt"))
setattr(self, cls + "_lines",
open(getattr(self, cls + "_list_path"), 'r').readlines())
setattr(self, "num_of_" + cls + "_data",
len(getattr(self, cls + "_lines")))
print('Start preprocessing dataset..!')
self.preprocess()
print('Finished preprocessing dataset..!')
def preprocess(self):
"""preprocess image"""
for cls in self.cls_list:
setattr(self, cls + "_filenames", [])
setattr(self, cls + "_mask_filenames", [])
setattr(self, cls + "_lmks_filenames", [])
lines = getattr(self, cls + "_lines")
random.shuffle(lines)
for i, line in enumerate(lines):
splits = line.split()
getattr(self, cls + "_filenames").append(splits[0])
getattr(self, cls + "_mask_filenames").append(splits[1])
getattr(self, cls + "_lmks_filenames").append(splits[2])
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
try:
index_A = random.randint(
0, getattr(self, "num_of_" + self.cls_A + "_data"))
index_B = random.randint(
0, getattr(self, "num_of_" + self.cls_B + "_data"))
if self.mode == 'test':
num_b = getattr(self, 'num_of_' + self.cls_list[1] + '_data')
index_A = int(index / num_b)
index_B = int(index % num_b)
image_A = Image.open(
os.path.join(self.image_path,
getattr(self, self.cls_A +
"_filenames")[index_A])).convert("RGB")
image_B = Image.open(
os.path.join(self.image_path,
getattr(self, self.cls_B +
"_filenames")[index_B])).convert("RGB")
mask_A = np.array(
Image.open(
os.path.join(
self.image_path,
getattr(self,
self.cls_A + "_mask_filenames")[index_A])))
mask_B = np.array(
Image.open(
os.path.join(
self.image_path,
getattr(self, self.cls_B +
"_mask_filenames")[index_B])).convert('L'))
#image_A.paste((200,200,200), (0,0), Image.fromarray(np.uint8(255*(np.array(mask_A)==0))))
#image_B.paste((200,200,200), (0,0), Image.fromarray(np.uint8(255*(np.array(mask_B)==0))))
image_A = np.array(image_A)
image_B = np.array(image_B)
print('image shape: ', image_A.shape)
image_A = self.transform(image_A)
image_B = self.transform(image_B)
print('image shape: ', image_A.shape)
mask_A = cv2.resize(mask_A, (256, 256),
interpolation=cv2.INTER_NEAREST)
mask_B = cv2.resize(mask_B, (256, 256),
interpolation=cv2.INTER_NEAREST)
lmks_A = np.loadtxt(
os.path.join(
self.image_path,
getattr(self, self.cls_A + "_lmks_filenames")[index_A]))
lmks_B = np.loadtxt(
os.path.join(
self.image_path,
getattr(self, self.cls_B + "_lmks_filenames")[index_B]))
lmks_A = lmks_A / image_A.shape[:2] * self.trans_size
lmks_B = lmks_B / image_B.shape[:2] * self.trans_size
P_A = generate_P_from_lmks(lmks_A, self.trans_size,
image_A.shape[0], image_A.shape[1])
P_B = generate_P_from_lmks(lmks_B, self.trans_size,
image_B.shape[0], image_B.shape[1])
mask_A_aug = generate_mask_aug(mask_A, lmks_A)
mask_B_aug = generate_mask_aug(mask_B, lmks_B)
consis_mask = calculate_consis_mask(mask_A_aug, mask_B_aug)
consis_mask_idt_A = calculate_consis_mask(mask_A_aug, mask_A_aug)
consis_mask_idt_B = calculate_consis_mask(mask_A_aug, mask_B_aug)
except Exception as e:
print(e)
return self.__getitem__(index + 1)
return {
'image_A': self.norm(image_A),
'image_B': self.norm(image_B),
'mask_A': np.float32(mask_A),
'mask_B': np.float32(mask_B),
'consis_mask': np.float32(consis_mask),
'P_A': np.float32(P_A),
'P_B': np.float32(P_B),
'consis_mask_idt_A': np.float32(consis_mask_idt_A),
'consis_mask_idt_B': np.float32(consis_mask_idt_B),
'mask_A_aug': np.float32(mask_A_aug),
'mask_B_aug': np.float32(mask_B_aug)
}
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
if self.mode == 'train':
num_A = getattr(self, 'num_of_' + self.cls_list[0] + '_data')
num_B = getattr(self, 'num_of_' + self.cls_list[1] + '_data')
return max(num_A, num_B)
elif self.mode == "test":
num_A = getattr(self, 'num_of_' + self.cls_list[0] + '_data')
num_B = getattr(self, 'num_of_' + self.cls_list[1] + '_data')
return num_A * num_B
return max(self.A_size, self.B_size)
import paddle.vision.transforms as T
import cv2
def get_makeup_transform(cfg, pic="image"):
if pic == "image":
transform = T.Compose([
T.Resize(size=cfg.trans_size),
T.Permute(to_rgb=False),
])
else:
transform = T.Resize(size=cfg.trans_size,
interpolation=cv2.INTER_NEAREST)
return transform
#!/usr/bin/python
# -*- encoding: utf-8 -*-
#from . import faceplusplus as fpp
from . import dlibutils as dlib
from . import mask
from . import image
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import detect, crop, landmarks, crop_from_array
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import numpy as np
from PIL import Image
import dlib
import cv2
from ..image import resize_by_max
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(
osp.split(osp.realpath(__file__))[0] + '/lms.dat')
def detect(image: Image):
image = np.asarray(image)
h, w = image.shape[:2]
image = resize_by_max(image, 361)
actual_h, actual_w = image.shape[:2]
faces_on_small = detector(image, 1)
faces = dlib.rectangles()
for face in faces_on_small:
faces.append(
dlib.rectangle(int(face.left() / actual_w * w + 0.5),
int(face.top() / actual_h * h + 0.5),
int(face.right() / actual_w * w + 0.5),
int(face.bottom() / actual_h * h + 0.5)))
return faces
def crop(image: Image, face, up_ratio, down_ratio, width_ratio):
width, height = image.size
face_height = face.height()
face_width = face.width()
delta_up = up_ratio * face_height
delta_down = down_ratio * face_height
delta_width = width_ratio * width
img_left = int(max(0, face.left() - delta_width))
img_top = int(max(0, face.top() - delta_up))
img_right = int(min(width, face.right() + delta_width))
img_bottom = int(min(height, face.bottom() + delta_down))
image = image.crop((img_left, img_top, img_right, img_bottom))
face = dlib.rectangle(face.left() - img_left,
face.top() - img_top,
face.right() - img_left,
face.bottom() - img_top)
face_expand = dlib.rectangle(img_left, img_top, img_right, img_bottom)
center = face_expand.center()
width, height = image.size
# import ipdb; ipdb.set_trace()
crop_left = img_left
crop_top = img_top
crop_right = img_right
crop_bottom = img_bottom
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image.crop((left, 0, right, height))
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
crop_left += left
crop_right = crop_left + height
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image.crop((0, top, width, bottom))
face = dlib.rectangle(face.left(),
face.top() - top, face.right(),
face.bottom() - top)
crop_top += top
crop_bottom = crop_top + width
crop_face = dlib.rectangle(crop_left, crop_top, crop_right, crop_bottom)
return image, face, crop_face
def crop_by_image_size(image: Image, face):
center = face.center()
width, height = image.size
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image.crop((left, 0, right, height))
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image.crop((0, top, width, bottom))
face = dlib.rectangle(face.left(),
face.top() - top, face.right(),
face.bottom() - top)
return image, face
def landmarks(image: Image, face):
shape = predictor(np.asarray(image), face).parts()
return np.array([[p.y, p.x] for p in shape])
def crop_from_array(image: np.array, face):
ratio = 0.20 / 0.85 # delta_size / face_size
height, width = image.shape[:2]
face_height = face.height()
face_width = face.width()
delta_height = ratio * face_height
delta_width = ratio * width
img_left = int(max(0, face.left() - delta_width))
img_top = int(max(0, face.top() - delta_height))
img_right = int(min(width, face.right() + delta_width))
img_bottom = int(min(height, face.bottom() + delta_height))
image = image[img_top:img_bottom, img_left:img_right]
face = dlib.rectangle(face.left() - img_left,
face.top() - img_top,
face.right() - img_left,
face.bottom() - img_top)
center = face.center()
height, width = image.shape[:2]
if width > height:
left = int(center.x - height / 2)
right = int(center.x + height / 2)
if left < 0:
left, right = 0, height
elif right > width:
left, right = width - height, width
image = image[0:height, left:right]
face = dlib.rectangle(face.left() - left, face.top(),
face.right() - left, face.bottom())
elif width < height:
top = int(center.y - width / 2)
bottom = int(center.y + width / 2)
if top < 0:
top, bottom = 0, width
elif bottom > height:
top, bottom = height - width, height
image = image[top:bottom, 0:width]
face = dlib.rectangle(face.left(),
face.top() - top, face.right(),
face.bottom() - top)
return image, face
import numpy as np
import cv2
from io import BytesIO
def load_image(path):
with path.open("rb") as reader:
data = np.fromstring(reader.read(), dtype=np.uint8)
img = cv2.imdecode(data, cv2.IMREAD_COLOR)
if img is None:
return
img = img[..., ::-1]
return img
def resize_by_max(image, max_side=512, force=False):
h, w = image.shape[:2]
if max(h, w) < max_side and not force:
return image
ratio = max(h, w) / max_side
w = int(w / ratio + 0.5)
h = int(h / ratio + 0.5)
return cv2.resize(image, (w, h))
def image2buffer(image):
is_success, buffer = cv2.imencode(".jpg", image)
if not is_success:
return None
return BytesIO(buffer)
#!/usr/bin/python
# -*- encoding: utf-8 -*-
from .main import FaceParser
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import os.path as osp
import numpy as np
import cv2
from PIL import Image
import paddle
import paddle.vision.transforms as T
import pickle
from .model import BiSeNet
class FaceParser:
def __init__(self, device="cpu"):
self.mapper = {
0: 0,
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 0,
7: 11,
8: 12,
9: 0,
10: 6,
11: 8,
12: 7,
13: 9,
14: 13,
15: 0,
16: 0,
17: 10,
18: 0
}
#self.dict = paddle.to_tensor(mapper)
self.save_pth = osp.split(
osp.realpath(__file__))[0] + '/resnet.pdparams'
self.net = BiSeNet(n_classes=19)
self.transforms = T.Compose([
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def parse(self, image):
assert image.shape[:2] == (512, 512)
image = image / 255.0
image = image.transpose((2, 0, 1))
image = self.transforms(image)
state_dict, _ = paddle.load(self.save_pth)
self.net.set_dict(state_dict)
self.net.eval()
with paddle.no_grad():
image = paddle.to_tensor(image)
image = image.unsqueeze(0)
out = self.net(image)[0]
parsing = out.squeeze(0).argmax(0) #argmax(0).astype('float32')
#parsing = paddle.nn.functional.embedding(x=self.dict, weight=parsing)
parse_np = parsing.numpy()
h, w = parse_np.shape
result = np.zeros((h, w))
for i in range(h):
for j in range(w):
result[i][j] = self.mapper[parse_np[i][j]]
with open('/workspace/PaddleGAN/mapper_out.pkl', 'rb') as f:
torch_out = pickle.load(f)
cm = np.allclose(torch_out, result)
print('cm out: ', cm)
result = paddle.to_tensor(result).astype('float32')
return result
#!/usr/bin/python
# -*- encoding: utf-8 -*-
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
import numpy as np
from .resnet import resnet18
class ConvBNReLU(paddle.nn.Layer):
def __init__(self,
in_chan,
out_chan,
ks=3,
stride=1,
padding=1,
*args,
**kwargs):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(in_chan,
out_chan,
kernel_size=ks,
stride=stride,
padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2d(out_chan)
self.relu = nn.ReLU()
#self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class BiSeNetOutput(paddle.nn.Layer):
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
super(BiSeNetOutput, self).__init__()
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
self.conv_out = nn.Conv2d(mid_chan,
n_classes,
kernel_size=1,
bias_attr=False)
#self.init_weight()
def forward(self, x):
x = self.conv(x)
x = self.conv_out(x)
return x
class AttentionRefinementModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(AttentionRefinementModule, self).__init__()
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
self.conv_atten = nn.Conv2d(out_chan,
out_chan,
kernel_size=1,
bias_attr=False)
self.bn_atten = nn.BatchNorm(out_chan)
self.sigmoid_atten = nn.Sigmoid()
#self.init_weight()
def forward(self, x):
feat = self.conv(x)
#atten = F.avg_pool2d(feat, feat.size()[2:])
atten = F.avg_pool2d(feat, feat.shape[2:])
atten = self.conv_atten(atten)
atten = self.bn_atten(atten)
atten = self.sigmoid_atten(atten)
out = feat * atten
return out
#def init_weight(self):
# for ly in self.children():
# if isinstance(ly, nn.Conv2d):
# nn.init.kaiming_normal_(ly.weight, a=1)
# if not ly.bias is None: nn.init.constant_(ly.bias, 0)
class ContextPath(paddle.nn.Layer):
def __init__(self, *args, **kwargs):
super(ContextPath, self).__init__()
self.resnet = resnet18()
self.arm16 = AttentionRefinementModule(256, 128)
self.arm32 = AttentionRefinementModule(512, 128)
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
#self.init_weight()
def forward(self, x):
H0, W0 = x.shape[2:]
feat8, feat16, feat32 = self.resnet(x)
H8, W8 = feat8.shape[2:]
H16, W16 = feat16.shape[2:]
H32, W32 = feat32.shape[2:]
print('feat32.shape: ', feat32.shape[2:])
avg = F.avg_pool2d(feat32, feat32.shape[2:])
avg = self.conv_avg(avg)
#avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
avg_up = F.resize_nearest(avg, out_shape=(H32, W32))
feat32_arm = self.arm32(feat32)
feat32_sum = feat32_arm + avg_up
#feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
feat32_up = F.resize_nearest(feat32_sum, out_shape=(H16, W16))
feat32_up = self.conv_head32(feat32_up)
feat16_arm = self.arm16(feat16)
feat16_sum = feat16_arm + feat32_up
#feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
feat16_up = F.resize_nearest(feat16_sum, out_shape=(H8, W8))
feat16_up = self.conv_head16(feat16_up)
return feat8, feat16_up, feat32_up # x8, x8, x16
### This is not used, since I replace this with the resnet feature with the same size
class SpatialPath(paddle.nn.Layer):
def __init__(self, *args, **kwargs):
super(SpatialPath, self).__init__()
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
#self.init_weight()
def forward(self, x):
feat = self.conv1(x)
feat = self.conv2(feat)
feat = self.conv3(feat)
feat = self.conv_out(feat)
return feat
class FeatureFusionModule(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, *args, **kwargs):
super(FeatureFusionModule, self).__init__()
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
self.conv1 = nn.Conv2d(out_chan,
out_chan // 4,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.conv2 = nn.Conv2d(out_chan // 4,
out_chan,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, fsp, fcp):
fcat = paddle.concat([fsp, fcp], axis=1)
feat = self.convblk(fcat)
#atten = F.avg_pool2d(feat, feat.size()[2:])
atten = F.avg_pool2d(feat, feat.shape[2:])
atten = self.conv1(atten)
atten = self.relu(atten)
atten = self.conv2(atten)
atten = self.sigmoid(atten)
feat_atten = feat * atten
feat_out = feat_atten + feat
return feat_out
class BiSeNet(paddle.nn.Layer):
def __init__(self, n_classes, *args, **kwargs):
super(BiSeNet, self).__init__()
self.cp = ContextPath()
## here self.sp is deleted
self.ffm = FeatureFusionModule(256, 256)
self.conv_out = BiSeNetOutput(256, 256, n_classes)
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
def forward(self, x):
H, W = x.shape[2:]
feat_res8, feat_cp8, feat_cp16 = self.cp(
x) # here return res3b1 feature
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
feat_fuse = self.ffm(feat_sp, feat_cp8)
feat_out = self.conv_out(feat_fuse)
feat_out16 = self.conv_out16(feat_cp8)
feat_out32 = self.conv_out32(feat_cp16)
feat_out = F.resize_bilinear(feat_out, out_shape=(H, W))
feat_out16 = F.resize_bilinear(feat_out16, out_shape=(H, W))
feat_out32 = F.resize_bilinear(feat_out32, out_shape=(H, W))
return feat_out, feat_out16, feat_out32
if __name__ == "__main__":
import pickle
paddle.disable_static()
net = BiSeNet(19)
param, _ = paddle.load('./resnet.pdparams')
net.set_dict(param)
net.eval()
#print(net.state_dict().keys())
#np.random.seed(2)
#x = np.random.randn(16,3,640,480).astype(np.float32)
with open('./x.pickle', 'rb') as f:
x = pickle.load(f)
in_ten = paddle.to_tensor(x)
out, out16, out32 = net(in_ten)
print(out.numpy().sum())
with open('./out.pickle', 'wb') as f:
pickle.dump(out.numpy(), f)
print(out.shape)
print(out16.shape)
print(out32.shape)
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
import numpy as np
import math
#resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
model_urls = {
'resnet18': ('https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams',
'0ba53eea9bc970962d0ef96f7b94057e'),
}
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
class BasicBlock(paddle.nn.Layer):
def __init__(self, in_chan, out_chan, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_chan, out_chan, stride)
self.bn1 = nn.BatchNorm(out_chan)
self.conv2 = conv3x3(out_chan, out_chan)
self.bn2 = nn.BatchNorm(out_chan)
self.relu = nn.ReLU()
self.downsample = None
if in_chan != out_chan or stride != 1:
self.downsample = nn.Sequential(
nn.Conv2d(in_chan,
out_chan,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm(out_chan),
)
def forward(self, x):
residual = self.conv1(x)
residual = self.relu(self.bn1(residual))
residual = self.conv2(residual)
residual = self.bn2(residual)
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(x)
out = shortcut + residual
out = self.relu(out)
return out
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
for i in range(bnum - 1):
layers.append(BasicBlock(out_chan, out_chan, stride=1))
return nn.Sequential(*layers)
class Resnet18(paddle.nn.Layer):
def __init__(self):
super(Resnet18, self).__init__()
self.conv1 = nn.Conv2d(3,
64,
kernel_size=7,
stride=2,
padding=3,
bias_attr=False)
self.bn1 = nn.BatchNorm(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
# self.init_weight()
def forward(self, x):
x = self.conv1(x)
x = self.relu(self.bn1(x))
x = self.maxpool(x)
x = self.layer1(x)
feat8 = self.layer2(x) # 1/8
feat16 = self.layer3(feat8) # 1/16
feat32 = self.layer4(feat16) # 1/32
return feat8, feat16, feat32
def resnet18(pretrained=False, **kwargs):
model = Resnet18()
arch = 'resnet18'
if pretrained:
#weight_path = get_weights_path_from_url(model_urls[arch][0],
# model_urls[arch][1])
#assert weight_path.endswith(
# '.pdparams'), "suffix of weight must be .pdparams"
weight_path = './resnet.pdparams'
param, _ = paddle.load(weight_path)
model.set_dict(param)
return model
if __name__ == "__main__":
paddle.disable_static()
net = resnet18(pretrained=True)
x = paddle.to_tensor(
np.random.uniform(0, 1, (16, 3, 224, 224)).astype(np.float32))
out = net(x)
print(out[0].shape)
print(out[1].shape)
print(out[2].shape)
......@@ -3,4 +3,5 @@ from .cycle_gan_model import CycleGANModel
from .pix2pix_model import Pix2PixModel
from .srgan_model import SRGANModel
from .sr_model import SRModel
from .makeup_model import MakeupModel
from .vgg import vgg16
import paddle
import functools
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ...modules.nn import Conv2d, Spectralnorm
from ...modules.norm import build_norm_layer
from .builder import DISCRIMINATORS
@DISCRIMINATORS.register()
class NLayerDiscriminator(nn.Layer):
class NLayerDiscriminator(paddle.nn.Layer):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance'):
"""Construct a PatchGAN discriminator
Args:
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_type (str) -- normalization layer type
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
norm_layer = build_norm_layer(norm_type)
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm
if type(
norm_layer
) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
#sequence = [Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.01)]
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2)
Spectralnorm(
Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)),
nn.LeakyReLU(0.01)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2**n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw,
bias_attr=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2)
#Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias_attr=use_bias),
Spectralnorm(
Conv2d(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=2,
padding=padw)),
#norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.01)
]
nf_mult_prev = nf_mult
nf_mult = min(2**n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw,
bias_attr=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2)
#Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias_attr=use_bias),
Spectralnorm(
Conv2d(ndf * nf_mult_prev,
ndf * nf_mult,
kernel_size=kw,
stride=1,
padding=padw)),
#norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.01)
]
#sequence += [Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
sequence += [
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
]
Spectralnorm(
Conv2d(ndf * nf_mult,
1,
kernel_size=kw,
stride=1,
padding=padw,
bias_attr=False))
] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
......
from .resnet import ResnetGenerator
from .unet import UnetGenerator
from .rrdb_net import RRDBNet
\ No newline at end of file
from .rrdb_net import RRDBNet
from .makeup import GeneratorPSGANAttention
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import functools
import numpy as np
from ...modules.norm import build_norm_layer
from ...modules.nn import Conv2d, ConvTranspose2d
from .builder import GENERATORS
class PONO(paddle.nn.Layer):
def __init__(self, eps=1e-5):
super(PONO, self).__init__()
self.eps = eps
def forward(self, x):
mean = paddle.mean(x, axis=1, keepdim=True)
var = paddle.mean(paddle.square(x - mean), axis=1, keepdim=True)
tmp = (x - mean) / paddle.sqrt(var + self.eps)
return tmp
class ResidualBlock(paddle.nn.Layer):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out, mode=None):
super(ResidualBlock, self).__init__()
if mode == 't':
weight_attr = False
bias_attr = False
elif mode == 'p' or (mode is None):
weight_attr = None
bias_attr = None
self.main = nn.Sequential(
Conv2d(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False),
nn.InstanceNorm2d(dim_out,
weight_attr=weight_attr,
bias_attr=bias_attr), nn.ReLU(),
Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False),
nn.InstanceNorm2d(dim_out,
weight_attr=weight_attr,
bias_attr=bias_attr))
def forward(self, x):
"""forward"""
return x + self.main(x)
class StyleResidualBlock(paddle.nn.Layer):
"""Residual Block with instance normalization."""
def __init__(self, dim_in, dim_out):
super(StyleResidualBlock, self).__init__()
self.block1 = nn.Sequential(
Conv2d(dim_in,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False), PONO())
ks = 3
pw = ks // 2
self.beta1 = Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma1 = Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.block2 = nn.Sequential(
nn.ReLU(),
Conv2d(dim_out,
dim_out,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False), PONO())
self.beta2 = Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
self.gamma2 = Conv2d(dim_in, dim_out, kernel_size=ks, padding=pw)
def forward(self, x, y):
"""forward"""
x_ = self.block1(x)
b = self.beta1(y)
g = self.gamma1(y)
x_ = (g + 1) * x_ + b
x_ = self.block2(x_)
b = self.beta2(y)
g = self.gamma2(y)
x_ = (g + 1) * x_ + b
return x + x_
class MDNet(paddle.nn.Layer):
"""MDNet in PSGAN"""
def __init__(self, conv_dim=64, repeat_num=3):
super(MDNet, self).__init__()
layers = []
layers.append(
Conv2d(3,
conv_dim,
kernel_size=7,
stride=1,
padding=3,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=None, bias_attr=None))
layers.append(nn.ReLU())
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
layers.append(
Conv2d(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim * 2,
weight_attr=None,
bias_attr=None))
layers.append(nn.ReLU())
curr_dim = curr_dim * 2
# Bottleneck
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
#layers.append(nn.InstanceNorm2d(curr_dim, weight_attr=None, bias_attr=None))
#layers.append(PONO())
self.main = nn.Sequential(*layers)
def forward(self, x):
"""forward"""
out = self.main(x)
return out
class TNetDown(paddle.nn.Layer):
"""MDNet in PSGAN"""
def __init__(self, conv_dim=64, repeat_num=3):
super(TNetDown, self).__init__()
layers = []
layers.append(
Conv2d(3,
conv_dim,
kernel_size=7,
stride=1,
padding=3,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(conv_dim, weight_attr=False, bias_attr=False))
layers.append(nn.ReLU())
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
layers.append(
Conv2d(curr_dim,
curr_dim * 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim * 2,
weight_attr=False,
bias_attr=False))
layers.append(nn.ReLU())
curr_dim = curr_dim * 2
# Bottleneck
for i in range(repeat_num):
layers.append(
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, mode='t'))
#layers.append(nn.InstanceNorm2d(curr_dim, weight_attr=False, bias_attr=False))
self.main = nn.Sequential(*layers)
def forward(self, x):
"""forward"""
out = self.main(x)
return out
class GetMatrix(paddle.fluid.dygraph.Layer):
def __init__(self, dim_in, dim_out):
super(GetMatrix, self).__init__()
self.get_gamma = Conv2d(dim_in,
dim_out,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
self.get_beta = Conv2d(dim_in,
dim_out,
kernel_size=1,
stride=1,
padding=0,
bias_attr=False)
def forward(self, x):
gamma = self.get_gamma(x)
beta = self.get_beta(x)
return gamma, beta
class MANet(paddle.nn.Layer):
"""MANet in PSGAN"""
def __init__(self, conv_dim=64, repeat_num=3, w=0.01):
super(MANet, self).__init__()
self.encoder = TNetDown(conv_dim=conv_dim, repeat_num=repeat_num)
curr_dim = conv_dim * 4
self.w = w
self.beta = Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1)
self.gamma = Conv2d(curr_dim, curr_dim, kernel_size=3, padding=1)
self.simple_spade = GetMatrix(curr_dim, 1) # get the makeup matrix
self.repeat_num = repeat_num
for i in range(repeat_num):
setattr(self, "bottlenecks_" + str(i),
ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, mode='t'))
# Up-Sampling
self.upsamplers = []
self.up_betas = []
self.up_gammas = []
self.up_acts = []
y_dim = curr_dim
for i in range(2):
layers = []
layers.append(
nn.ConvTranspose2d(curr_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1,
bias_attr=False))
layers.append(
nn.InstanceNorm2d(curr_dim // 2,
weight_attr=False,
bias_attr=False))
setattr(self, "up_acts_" + str(i), nn.ReLU())
#setattr(self, "up_betas_" + str(i), Conv2d(y_dim, curr_dim//2, kernel_size=3, padding=1))
setattr(
self, "up_betas_" + str(i),
nn.ConvTranspose2d(y_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1))
#setattr(self, "up_gammas_" + str(i), Conv2d(y_dim, curr_dim//2, kernel_size=3, padding=1))
setattr(
self, "up_gammas_" + str(i),
nn.ConvTranspose2d(y_dim,
curr_dim // 2,
kernel_size=4,
stride=2,
padding=1))
setattr(self, "up_samplers_" + str(i), nn.Sequential(*layers))
curr_dim = curr_dim // 2
self.img_reg = [
Conv2d(curr_dim,
3,
kernel_size=7,
stride=1,
padding=3,
bias_attr=False)
]
self.img_reg = nn.Sequential(*self.img_reg)
def forward(self, x, y, x_p, y_p, consistency_mask, mask_x, mask_y):
"""forward"""
# y -> ref feature
# x -> src img
x = self.encoder(x)
_, c, h, w = x.shape
x_flat = x.reshape([-1, c, h * w])
x_flat = self.w * x_flat
if x_p is not None:
x_flat = paddle.concat([x_flat, x_p], axis=1)
_, c2, h2, w2 = y.shape
y_flat = y.reshape([-1, c2, h2 * w2])
y_flat = self.w * y_flat
if y_p is not None:
y_flat = paddle.concat([y_flat, y_p], axis=1)
a_ = paddle.matmul(x_flat, y_flat, transpose_x=True) * 200.0
# mask softmax
if consistency_mask is not None:
a_ = a_ - 100.0 * (1 - consistency_mask)
#a_ = a_ * consistency_mask
a = F.softmax(a_, axis=-1)
#a = a * consistency_mask
gamma, beta = self.simple_spade(y)
beta = beta.reshape([-1, h2 * w2, 1])
beta = paddle.matmul(a, beta)
beta = beta.reshape([-1, 1, h2, w2])
gamma = gamma.reshape([-1, h2 * w2, 1])
gamma = paddle.matmul(a, gamma)
gamma = gamma.reshape([-1, 1, h2, w2])
x = x * (1 + gamma) + beta
for i in range(self.repeat_num):
layer = getattr(self, "bottlenecks_" + str(i))
x = layer(x)
for idx in range(2):
layer = getattr(self, "up_samplers_" + str(idx))
x = layer(x)
layer = getattr(self, "up_acts_" + str(idx))
x = layer(x)
x = self.img_reg(x)
x = paddle.tanh(x)
return x, a
@GENERATORS.register()
class GeneratorPSGANAttention(paddle.nn.Layer):
def __init__(self, conv_dim=64, repeat_num=3):
super(GeneratorPSGANAttention, self).__init__()
self.ma_net = MANet(conv_dim=conv_dim, repeat_num=repeat_num)
self.md_net = MDNet(conv_dim=conv_dim, repeat_num=repeat_num)
def forward(self, x, y, x_p, y_p, consistency_mask, mask_x, mask_y):
"""forward"""
y = self.md_net(y)
out, a = self.ma_net(x, y, x_p, y_p, consistency_mask, mask_x, mask_y)
return out, a
......@@ -45,7 +45,6 @@ class GANLoss(nn.Layer):
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
if not hasattr(self, 'target_real_tensor'):
self.target_real_tensor = paddle.fill_constant(
......
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .base_model import BaseModel
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .losses import GANLoss
# from ..modules.nn import L1Loss
from ..solver import build_optimizer
from ..utils.image_pool import ImagePool
from ..utils.preprocess import *
from ..datasets.makeup_dataset import MakeupDataset
import numpy as np
from .vgg import vgg16
@MODELS.register()
class MakeupModel(BaseModel):
"""
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
The model training requires '--dataset_mode unaligned' dataset.
By default, it uses a '--netG resnet_9blocks' ResNet generator,
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
and a least-square GANs objective ('--gan_mode lsgan').
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
"""
def __init__(self, opt):
"""Initialize the CycleGAN class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = [
'D_A',
'G_A',
'rec',
'idt',
'D_B',
'G_B',
'G_A_his',
'G_B_his',
'G_bg_consis',
'A_vgg',
'B_vgg',
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_A', 'rec_A']
visual_names_B = ['real_B', 'fake_B', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
self.vgg = vgg16(pretrained=True)
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G']
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG = build_generator(opt.model.generator)
if self.isTrain: # define discriminators
self.netD_A = build_discriminator(opt.model.discriminator)
self.netD_B = build_discriminator(opt.model.discriminator)
if self.isTrain:
self.fake_A_pool = ImagePool(
opt.dataset.train.pool_size
) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(
opt.dataset.train.pool_size
) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = GANLoss(
opt.model.gan_mode) #.to(self.device) # define GAN loss.
self.criterionCycle = paddle.nn.L1Loss()
self.criterionIdt = paddle.nn.L1Loss()
self.criterionL1 = paddle.nn.L1Loss()
self.criterionL2 = paddle.nn.MSELoss()
self.build_lr_scheduler()
self.optimizer_G = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netG.parameters())
# self.optimizer_D = paddle.optimizer.Adam(learning_rate=lr_scheduler_d, parameter_list=self.netD_A.parameters() + self.netD_B.parameters(), beta1=opt.beta1)
self.optimizer_DA = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD_A.parameters())
self.optimizer_DB = build_optimizer(
opt.optimizer,
self.lr_scheduler,
parameter_list=self.netD_B.parameters())
self.optimizers.append(self.optimizer_G)
# self.optimizers.append(self.optimizer_D)
self.optimizers.append(self.optimizer_DA)
self.optimizers.append(self.optimizer_DB)
self.optimizer_names.extend(
['optimizer_G', 'optimizer_DA', 'optimizer_DB'])
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
self.real_A = paddle.to_tensor(input['image_A'])
print('real_a shape: ', self.real_A.shape)
self.real_B = paddle.to_tensor(input['image_B'])
self.c_m = paddle.to_tensor(input['consis_mask'])
self.P_A = paddle.to_tensor(input['P_A'])
self.P_B = paddle.to_tensor(input['P_B'])
self.mask_A_aug = paddle.to_tensor(input['mask_A_aug'])
self.mask_B_aug = paddle.to_tensor(input['mask_B_aug'])
self.c_m_t = paddle.transpose(self.c_m, perm=[0, 2, 1])
if self.isTrain:
self.mask_A = paddle.to_tensor(input['mask_A'])
self.mask_B = paddle.to_tensor(input['mask_B'])
self.c_m_idt_a = paddle.to_tensor(input['consis_mask_idt_A'])
self.c_m_idt_b = paddle.to_tensor(input['consis_mask_idt_B'])
#self.hm_gt_A = self.hm_gt_A_lip + self.hm_gt_A_skin + self.hm_gt_A_eye
#self.hm_gt_B = self.hm_gt_B_lip + self.hm_gt_B_skin + self.hm_gt_B_eye
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_A, amm = self.netG(self.real_A, self.real_B, self.P_A,
self.P_B, self.c_m, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.fake_B, _ = self.netG(self.real_B, self.real_A, self.P_B, self.P_A,
self.c_m_t, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.rec_A, _ = self.netG(self.fake_A, self.real_A, self.P_A, self.P_A,
self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.rec_B, _ = self.netG(self.fake_B, self.real_B, self.P_B, self.P_B,
self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
def forward_test(self, input):
'''
not implement now
'''
return self.netG(input['image_A'], input['image_B'], input['P_A'],
input['P_B'], input['consis_mask'],
input['mask_A_aug'], input['mask_B_aug'])
def test(self, input):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with paddle.no_grad():
return self.forward_test(input)
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
'''
self.loss_names = [
'G_A_vgg',
'G_B_vgg',
'G_bg_consis'
]
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A', 'amm_a']
visual_names_B = ['real_B', 'fake_A', 'rec_B', 'amm_b']
'''
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
lambda_vgg = 5e-3
# Identity loss
if lambda_idt > 0:
self.idt_A, _ = self.netG(self.real_A, self.real_A, self.P_A,
self.P_A, self.c_m_idt_a, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_A = self.criterionIdt(
self.idt_A, self.real_A) * lambda_A * lambda_idt
self.idt_B, _ = self.netG(self.real_B, self.real_B, self.P_B,
self.P_B, self.c_m_idt_b, self.mask_A_aug,
self.mask_B_aug) # G_A(A)
self.loss_idt_B = self.criterionIdt(
self.idt_B, self.real_B) * lambda_B * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_A), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_B), True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A,
self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B,
self.real_B) * lambda_B
mask_A_lip = self.mask_A_aug[:, 0].unsqueeze(1)
mask_B_lip = self.mask_B_aug[:, 0].unsqueeze(1)
mask_A_lip_np = mask_A_lip.numpy().squeeze()
mask_B_lip_np = mask_B_lip.numpy().squeeze()
mask_A_lip_np, mask_B_lip_np, index_A_lip, index_B_lip = mask_preprocess(
mask_A_lip_np, mask_B_lip_np)
real_A = paddle.nn.clip((self.real_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
real_A_np = real_A.numpy().squeeze()
real_B = paddle.nn.clip((self.real_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
real_B_np = real_B.numpy().squeeze()
fake_A = paddle.nn.clip((self.fake_A + 1.0) / 2.0, 0.0, 1.0) * 255.0
fake_A_np = fake_A.numpy().squeeze()
fake_B = paddle.nn.clip((self.fake_B + 1.0) / 2.0, 0.0, 1.0) * 255.0
fake_B_np = fake_B.numpy().squeeze()
fake_match_lip_A = hisMatch(fake_A_np, real_B_np, mask_A_lip_np,
mask_B_lip_np, index_A_lip)
fake_match_lip_B = hisMatch(fake_B_np, real_A_np, mask_B_lip_np,
mask_A_lip_np, index_B_lip)
fake_match_lip_A = paddle.to_tensor(fake_match_lip_A)
fake_match_lip_A.stop_gradient = True
fake_match_lip_A = fake_match_lip_A.unsqueeze(0)
fake_match_lip_B = paddle.to_tensor(fake_match_lip_B)
fake_match_lip_B.stop_gradient = True
fake_match_lip_B = fake_match_lip_B.unsqueeze(0)
fake_A_lip_masked = fake_A * mask_A_lip
fake_B_lip_masked = fake_B * mask_B_lip
g_A_lip_loss_his = self.criterionL1(fake_A_lip_masked, fake_match_lip_A)
g_B_lip_loss_his = self.criterionL1(fake_B_lip_masked, fake_match_lip_B)
#skin
mask_A_skin = self.mask_A_aug[:, 1].unsqueeze(1)
mask_B_skin = self.mask_B_aug[:, 1].unsqueeze(1)
mask_A_skin_np = mask_A_skin.numpy().squeeze()
mask_B_skin_np = mask_B_skin.numpy().squeeze()
mask_A_skin_np, mask_B_skin_np, index_A_skin, index_B_skin = mask_preprocess(
mask_A_skin_np, mask_B_skin_np)
fake_match_skin_A = hisMatch(fake_A_np, real_B_np, mask_A_skin_np,
mask_B_skin_np, index_A_skin)
fake_match_skin_B = hisMatch(fake_B_np, real_A_np, mask_B_skin_np,
mask_A_skin_np, index_B_skin)
fake_match_skin_A = paddle.to_tensor(fake_match_skin_A)
fake_match_skin_A.stop_gradient = True
fake_match_skin_A = fake_match_skin_A.unsqueeze(0)
fake_match_skin_B = paddle.to_tensor(fake_match_skin_B)
fake_match_skin_B.stop_gradient = True
fake_match_skin_B = fake_match_skin_B.unsqueeze(0)
fake_A_skin_masked = fake_A * mask_A_skin
fake_B_skin_masked = fake_B * mask_B_skin
g_A_skin_loss_his = self.criterionL1(fake_A_skin_masked,
fake_match_skin_A)
g_B_skin_loss_his = self.criterionL1(fake_B_skin_masked,
fake_match_skin_B)
#eye
mask_A_eye = self.mask_A_aug[:, 2].unsqueeze(1)
mask_B_eye = self.mask_B_aug[:, 2].unsqueeze(1)
mask_A_eye_np = mask_A_eye.numpy().squeeze()
mask_B_eye_np = mask_B_eye.numpy().squeeze()
mask_A_eye_np, mask_B_eye_np, index_A_eye, index_B_eye = mask_preprocess(
mask_A_eye_np, mask_B_eye_np)
fake_match_eye_A = hisMatch(fake_A_np, real_B_np, mask_A_eye_np,
mask_B_eye_np, index_A_eye)
fake_match_eye_B = hisMatch(fake_B_np, real_A_np, mask_B_eye_np,
mask_A_eye_np, index_B_eye)
fake_match_eye_A = paddle.to_tensor(fake_match_eye_A)
fake_match_eye_A.stop_gradient = True
fake_match_eye_A = fake_match_eye_A.unsqueeze(0)
fake_match_eye_B = paddle.to_tensor(fake_match_eye_B)
fake_match_eye_B.stop_gradient = True
fake_match_eye_B = fake_match_eye_B.unsqueeze(0)
fake_A_eye_masked = fake_A * mask_A_eye
fake_B_eye_masked = fake_B * mask_B_eye
g_A_eye_loss_his = self.criterionL1(fake_A_eye_masked, fake_match_eye_A)
g_B_eye_loss_his = self.criterionL1(fake_B_eye_masked, fake_match_eye_B)
self.loss_G_A_his = (g_A_eye_loss_his + g_A_lip_loss_his +
g_A_skin_loss_his * 0.1) * 0.1
self.loss_G_B_his = (g_B_eye_loss_his + g_B_lip_loss_his +
g_B_skin_loss_his * 0.1) * 0.1
#self.loss_G_A_his = self.criterionL1(tmp_1, tmp_2) * 2048 * 255
#tmp_3 = self.hm_gt_B*self.hm_mask_weight_B
#tmp_4 = self.fake_B*self.hm_mask_weight_B
#self.loss_G_B_his = self.criterionL1(tmp_3, tmp_4) * 2048 * 255
#vgg loss
vgg_s = self.vgg(self.real_A)
vgg_s.stop_gradient = True
vgg_fake_A = self.vgg(self.fake_A)
self.loss_A_vgg = self.criterionL2(vgg_fake_A,
vgg_s) * lambda_A * lambda_vgg
vgg_r = self.vgg(self.real_B)
vgg_r.stop_gradient = True
vgg_fake_B = self.vgg(self.fake_B)
self.loss_B_vgg = self.criterionL2(vgg_fake_B,
vgg_r) * lambda_B * lambda_vgg
self.loss_rec = (self.loss_cycle_A + self.loss_cycle_B +
self.loss_A_vgg + self.loss_B_vgg) * 0.1
self.loss_idt = (self.loss_idt_A + self.loss_idt_B) * 0.1
# bg consistency loss
mask_A_consis = paddle.cast(
(self.mask_A == 0), dtype='float32') + paddle.cast(
(self.mask_A == 10), dtype='float32') + paddle.cast(
(self.mask_A == 8), dtype='float32')
mask_A_consis = paddle.unsqueeze(paddle.clip(mask_A_consis, 0, 1), 1)
self.loss_G_bg_consis = self.criterionL1(self.real_A * mask_A_consis,
self.fake_A * mask_A_consis)
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_rec + self.loss_idt + self.loss_G_A_his + self.loss_G_B_his + self.loss_G_bg_consis
self.loss_G.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad(
[self.netD_A, self.netD_B],
False) # Ds require no gradients when optimizing Gs
# self.optimizer_G.clear_gradients() #zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.minimize(
self.loss_G) #step() # update G_A and G_B's weights
self.optimizer_G.clear_gradients()
# self.optimizer_G.clear_gradients()
# D_A and D_B
# self.set_requires_grad([self.netD_A, self.netD_B], True)
self.set_requires_grad(self.netD_A, True)
# self.optimizer_D.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.optimizer_DA.minimize(
self.loss_D_A) #step() # update D_A and D_B's weights
self.optimizer_DA.clear_gradients() #zero_g
self.set_requires_grad(self.netD_B, True)
# self.optimizer_DB.clear_gradients() #zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_B() # calculate graidents for D_B
self.optimizer_DB.minimize(
self.loss_D_B) #step() # update D_A and D_B's weights
self.optimizer_DB.clear_gradients(
) #zero_grad() # set D_A and D_B's gradients to zero
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.utils.download import get_weights_path_from_url
__all__ = [
'VGG',
'vgg11',
'vgg13',
'vgg16',
'vgg19',
]
model_urls = {
'vgg16': ('https://paddle-hapi.bj.bcebos.com/models/vgg16.pdparams',
'c788f453a3b999063e8da043456281ee')
}
class Classifier(paddle.nn.Layer):
def __init__(self, num_classes, classifier_activation='softmax'):
super(Classifier, self).__init__()
self.linear1 = nn.Linear(512 * 7 * 7, 4096)
self.linear2 = nn.Linear(4096, 4096)
self.linear3 = nn.Linear(4096, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.softmax = nn.Softmax()
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
x = self.relu(x)
x = self.dropout(x)
out = self.linear3(x)
out = self.softmax(out)
return out
class VGG(paddle.nn.Layer):
"""VGG model from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_
Args:
features (nn.Layer): vgg features create by function make_layers.
num_classes (int): output dim of last fc layer. If num_classes <=0, last fc layer
will not be defined. Default: 1000.
classifier_activation (str): activation for the last fc layer. Default: 'softmax'.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import VGG
from paddle.incubate.hapi.vision.models.vgg import make_layers
vgg11_cfg = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
features = make_layers(vgg11_cfg)
vgg11 = VGG(features)
"""
def __init__(self,
features,
num_classes=1000,
classifier_activation='softmax'):
super(VGG, self).__init__()
self.features = features
self.num_classes = num_classes
if num_classes > 0:
classifier = Classifier(num_classes, classifier_activation)
self.classifier = self.add_sublayer("classifier",
nn.Sequential(classifier))
def forward(self, x):
x = self.features(x)
#if self.num_classes > 0:
# x = fluid.layers.flatten(x, 1)
# x = self.classifier(x)
return x
def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
if batch_norm:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU()]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU()]
in_channels = v
return nn.Sequential(*layers)
cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B':
[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512,
512, 512, 'M'
],
'E': [
64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512,
'M', 512, 512, 512, 512, 'M'
],
}
def _vgg(arch, cfg, batch_norm, pretrained, **kwargs):
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm),
num_classes=1000,
**kwargs)
if pretrained:
#assert arch in model_urls, "{} model do not have a pretrained model now, you should set pretrained=False".format(
# arch)
#weight_path = get_weights_path_from_url(model_urls[arch][0],
# model_urls[arch][1])
#assert weight_path.endswith(
# '.pdparams'), "suffix of weight must be .pdparams"
weight_path = './vgg16.pdparams'
param, _ = paddle.load(weight_path)
model.load_dict(param)
return model
def vgg11(pretrained=False, batch_norm=False, **kwargs):
"""VGG 11-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg11
# build model
model = vgg11()
# build vgg11 model with batch_norm
model = vgg11(batch_norm=True)
"""
model_name = 'vgg11'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'A', batch_norm, pretrained, **kwargs)
def vgg13(pretrained=False, batch_norm=False, **kwargs):
"""VGG 13-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg13
# build model
model = vgg13()
# build vgg13 model with batch_norm
model = vgg13(batch_norm=True)
"""
model_name = 'vgg13'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'B', batch_norm, pretrained, **kwargs)
def vgg16(pretrained=False, batch_norm=False, **kwargs):
"""VGG 16-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg16
# build model
model = vgg16()
# build vgg16 model with batch_norm
model = vgg16(batch_norm=True)
"""
model_name = 'vgg16'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'D', batch_norm, pretrained, **kwargs)
def vgg19(pretrained=False, batch_norm=False, **kwargs):
"""VGG 19-layer model
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet. Default: False.
batch_norm (bool): If True, returns a model with batch_norm layer. Default: False.
Examples:
.. code-block:: python
from paddle.incubate.hapi.vision.models import vgg19
# build model
model = vgg19()
# build vgg19 model with batch_norm
model = vgg19(batch_norm=True)
"""
model_name = 'vgg19'
if batch_norm:
model_name += ('_bn')
return _vgg(model_name, 'E', batch_norm, pretrained, **kwargs)
import paddle
import paddle.nn as nn
import math
class _SpectralNorm(nn.SpectralNorm):
......@@ -50,3 +51,123 @@ class Spectralnorm(paddle.nn.Layer):
self.layer.weight = weight
out = self.layer(x)
return out
def initial_type(input,
op_type,
fan_out,
init="normal",
use_bias=False,
kernel_size=0,
stddev=0.02,
name=None):
if init == "kaiming":
if op_type == 'conv':
fan_in = input.shape[1] * kernel_size * kernel_size
elif op_type == 'deconv':
fan_in = fan_out * kernel_size * kernel_size
else:
if len(input.shape) > 2:
fan_in = input.shape[1] * input.shape[2] * input.shape[3]
else:
fan_in = input.shape[1]
bound = 1 / math.sqrt(fan_in)
param_attr = paddle.ParamAttr(
# name=name + "_w",
initializer=paddle.nn.initializer.Uniform(low=-bound, high=bound))
if use_bias == True:
bias_attr = paddle.ParamAttr(
# name=name + '_b',
initializer=paddle.nn.initializer.Uniform(low=-bound,
high=bound))
else:
bias_attr = False
elif init == 'xavier':
param_attr = paddle.ParamAttr(
# name=name + "_w",
initializer=paddle.nn.initializer.Xavier(uniform=False))
if use_bias == True:
bias_attr = paddle.ParamAttr(
# name=name + "_b",
initializer=paddle.nn.initializer.Constant(0.0))
else:
bias_attr = False
else:
param_attr = paddle.ParamAttr(
# name=name + "_w",
initializer=paddle.nn.initializer.NormalInitializer(loc=0.0,
scale=stddev))
if use_bias == True:
bias_attr = paddle.ParamAttr(
# name=name + "_b",
initializer=paddle.nn.initializer.Constant(0.0))
else:
bias_attr = False
return param_attr, bias_attr
class Conv2d(paddle.nn.Conv2d):
def __init__(self,
num_channels,
num_filters,
kernel_size,
padding=0,
stride=1,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
data_format="NCHW",
init_type='xavier'):
param_attr, bias_attr = initial_type(
input=input,
op_type='conv',
fan_out=num_filters,
init=init_type,
use_bias=True if bias_attr != False else False,
kernel_size=kernel_size)
super(Conv2d, self).__init__(in_channels=num_channels,
out_channels=num_filters,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
weight_attr=param_attr,
bias_attr=bias_attr,
data_format=data_format)
class ConvTranspose2d(paddle.nn.ConvTranspose2d):
def __init__(self,
num_channels,
num_filters,
kernel_size,
padding=0,
stride=1,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
data_format="NCHW",
init_type='normal'):
param_attr, bias_attr = initial_type(
input=input,
op_type='deconv',
fan_out=num_filters,
init=init_type,
use_bias=True if bias_attr != False else False,
kernel_size=kernel_size)
super(ConvTranspose2d, self).__init__(in_channels=num_channels,
out_channels=num_filters,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
import argparse
def parse_args():
parser = argparse.ArgumentParser(description='Segmentron')
parser.add_argument('--config-file', metavar="FILE",
parser.add_argument('--config-file',
metavar="FILE",
help='config file path')
# cuda setting
parser.add_argument('--no-cuda', action='store_true', default=False,
parser.add_argument('--no-cuda',
action='store_true',
default=False,
help='disables CUDA training')
# checkpoint and log
parser.add_argument('--resume', type=str, default=None,
parser.add_argument('--resume',
type=str,
default=None,
help='put the path to resuming file if needed')
parser.add_argument('--load', type=str, default=None,
parser.add_argument('--load',
type=str,
default=None,
help='put the path to resuming file if needed')
# for evaluation
parser.add_argument('--val-interval', type=int, default=1,
parser.add_argument('--val-interval',
type=int,
default=1,
help='run validation every interval')
parser.add_argument('--evaluate-only', action='store_true', default=False,
parser.add_argument('--evaluate-only',
action='store_true',
default=False,
help='skip validation during training')
# config options
parser.add_argument('opts', help='See config for all options',
default=None, nargs=argparse.REMAINDER)
parser.add_argument('opts',
help='See config for all options',
default=None,
nargs=argparse.REMAINDER)
#for inference
parser.add_argument("--source_path",
default="",
metavar="FILE",
help="path to source image")
parser.add_argument("--reference_dir",
default="",
help="path to reference images")
parser.add_argument("--model_path", default="", help="model for loading")
args = parser.parse_args()
return args
\ No newline at end of file
return args
import cv2
import numpy as np
def generate_P_from_lmks(lmks, resize, w, h):
"""generate P from lmks"""
diff_size = (64, 64)
xs, ys = np.meshgrid(np.linspace(0, resize - 1, resize),
np.linspace(0, resize - 1, resize))
xs = xs[None].repeat(68, axis=0)
ys = ys[None].repeat(68, axis=0)
fix = np.concatenate([ys, xs], axis=0)
lmks = lmks.transpose(1, 0).reshape(-1, 1, 1)
diff = fix - lmks
diff = diff.transpose(1, 2, 0)
diff = cv2.resize(diff, diff_size, interpolation=cv2.INTER_NEAREST)
diff = diff.transpose(2, 0, 1).reshape(136, -1)
norm = np.linalg.norm(diff, axis=0)
P_np = diff / norm
return P_np
def copy_area(tar, src, lms):
rect = [
int(min(lms[:, 1])) - 16,
int(min(lms[:, 0])) - 16,
int(max(lms[:, 1])) + 16 + 1,
int(max(lms[:, 0])) + 16 + 1
]
tar[rect[1]:rect[3], rect[0]:rect[2]] = \
src[rect[1]:rect[3], rect[0]:rect[2]]
src[rect[1]:rect[3], rect[0]:rect[2]] = 0
def rebound_box(mask, mask_B, mask_face):
"""solver ps"""
index_tmp = mask.nonzero()
x_index = index_tmp[0]
y_index = index_tmp[1]
index_tmp = mask_B.nonzero()
x_B_index = index_tmp[0]
y_B_index = index_tmp[1]
mask_temp = np.copy(mask)
mask_B_temp = np.copy(mask_B)
mask_temp[min(x_index) - 16:max(x_index) + 17, min(y_index) - 16:max(y_index) + 17] =\
mask_face[min(x_index) -
16:max(x_index) +
17, min(y_index) -
16:max(y_index) +
17]
mask_B_temp[min(x_B_index) - 16:max(x_B_index) + 17, min(y_B_index) - 16:max(y_B_index) + 17] =\
mask_face[min(x_B_index) -
16:max(x_B_index) +
17, min(y_B_index) -
16:max(y_B_index) +
17]
return mask_temp, mask_B_temp
def calculate_consis_mask(mask, mask_B):
h_a, w_a = mask.shape[1:]
h_b, w_b = mask_B.shape[1:]
mask_transpose = np.transpose(mask, (1, 2, 0))
mask_B_transpose = np.transpose(mask_B, (1, 2, 0))
mask = cv2.resize(mask_transpose,
dsize=(w_a // 4, h_a // 4),
interpolation=cv2.INTER_NEAREST)
mask = np.transpose(mask, (2, 0, 1))
mask_B = cv2.resize(mask_B_transpose,
dsize=(w_b // 4, h_b // 4),
interpolation=cv2.INTER_NEAREST)
mask_B = np.transpose(mask_B, (2, 0, 1))
"""calculate consistency mask between images"""
h_a, w_a = mask.shape[1:]
h_b, w_b = mask_B.shape[1:]
mask_lip = mask[0]
mask_skin = mask[1]
mask_eye = mask[2]
mask_B_lip = mask_B[0]
mask_B_skin = mask_B[1]
mask_B_eye = mask_B[2]
maskA_one_hot = np.zeros((h_a * w_a, 3))
maskA_one_hot[:, 0] = mask_skin.flatten()
maskA_one_hot[:, 1] = mask_eye.flatten()
maskA_one_hot[:, 2] = mask_lip.flatten()
maskB_one_hot = np.zeros((h_b * w_b, 3))
maskB_one_hot[:, 0] = mask_B_skin.flatten()
maskB_one_hot[:, 1] = mask_B_eye.flatten()
maskB_one_hot[:, 2] = mask_B_lip.flatten()
con_mask = np.matmul(maskA_one_hot.reshape((h_a * w_a, 3)),
np.transpose(maskB_one_hot.reshape((h_b * w_b, 3))))
con_mask = np.clip(con_mask, 0, 1)
return con_mask
def cal_hist(image):
"""
cal cumulative hist for channel list
"""
hists = []
for i in range(0, 3):
channel = image[i]
# channel = image[i, :, :]
#channel = torch.from_numpy(channel)
hist, _ = np.histogram(channel, bins=256, range=(0, 255))
#hist = torch.histc(channel, bins=256, min=0, max=256)
# refHist=hist.view(256,1)
sum = hist.sum()
pdf = [v / sum for v in hist]
for i in range(1, 256):
pdf[i] = pdf[i - 1] + pdf[i]
hists.append(pdf)
return hists
def cal_trans(ref, adj):
"""
calculate transfer function
algorithm refering to wiki item: Histogram matching
"""
table = list(range(0, 256))
for i in list(range(1, 256)):
for j in list(range(1, 256)):
if ref[i] >= adj[j - 1] and ref[i] <= adj[j]:
table[i] = j
break
table[255] = 255
return table
def histogram_matching(dstImg, refImg, index):
"""
perform histogram matching
dstImg is transformed to have the same the histogram with refImg's
index[0], index[1]: the index of pixels that need to be transformed in dstImg
index[2], index[3]: the index of pixels that to compute histogram in refImg
"""
dst_align = [dstImg[i, index[0], index[1]] for i in range(0, 3)]
ref_align = [refImg[i, index[2], index[3]] for i in range(0, 3)]
hist_ref = cal_hist(ref_align)
hist_dst = cal_hist(dst_align)
tables = [cal_trans(hist_dst[i], hist_ref[i]) for i in range(0, 3)]
mid = dst_align.copy()
for i in range(0, 3):
for k in range(0, len(index[0])):
dst_align[i][k] = tables[i][int(mid[i][k])]
for i in range(0, 3):
dstImg[i, index[0], index[1]] = dst_align[i]
return dstImg
def hisMatch(input_data, target_data, mask_src, mask_tar, index):
"""solver ps"""
mask_src = np.float32(np.clip(mask_src, 0, 1))
mask_tar = np.float32(np.clip(mask_tar, 0, 1))
input_masked = np.float32(input_data) * mask_src
target_masked = np.float32(target_data) * mask_tar
input_match = histogram_matching(input_masked, target_masked, index)
return input_match
def mask_preprocess(mask, mask_B):
"""solver ps"""
index_tmp = mask.nonzero()
x_index = index_tmp[0]
y_index = index_tmp[1]
index_tmp = mask_B.nonzero()
x_B_index = index_tmp[0]
y_B_index = index_tmp[1]
index = [x_index, y_index, x_B_index, y_B_index]
index_2 = [x_B_index, y_B_index, x_index, y_index]
return [mask, mask_B, index, index_2]
def generate_mask_aug(mask, lmks):
lms_eye_left = lmks[42:48]
lms_eye_right = lmks[36:42]
mask_eye_left = np.zeros_like(mask)
mask_eye_right = np.zeros_like(mask)
mask_face = np.float32(mask == 1) + np.float32(mask == 6)
copy_area(mask_eye_left, mask_face, lms_eye_left)
copy_area(mask_eye_right, mask_face, lms_eye_right)
mask_skin = mask_face
mask_lip = np.float32(mask == 7) + np.float32(mask == 9)
mask_eye = mask_eye_left + mask_eye_right
mask_aug = np.concatenate(
(np.expand_dims(mask_lip, 0), np.expand_dims(
mask_skin, 0), np.expand_dims(mask_eye, 0)), 0)
return mask_aug
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册