未验证 提交 8ece3c2d 编写于 作者: 农夫三拳_'s avatar 农夫三拳_ 提交者: GitHub

add stargan-v2 test and train code (#165)

* add configuration options
e.g. python tools/main.py --c configs/stylegan_v2_256_ffhq.yaml -o total_iters=1 log_config.visiual_interval=1

* add stargan-v2 test and train code

* code normalization

* modify FAN code

* add munch in requirements.txt
上级 0dab3be6
epochs: 200
output_dir: output_dir
model:
name: StarGANv2Model
latent_dim: &LATENT_DIM 16
lambda_sty: 1
lambda_ds: 2
lambda_cyc: 1
generator:
name: StarGANv2Generator
img_size: &IMAGE_SIZE 256
w_hpf: 0
style_dim: &STYLE_DIM 64
style:
name: StarGANv2Style
img_size: *IMAGE_SIZE
style_dim: *STYLE_DIM
num_domains: &NUM_DOMAINS 3
mapping:
name: StarGANv2Mapping
latent_dim: *LATENT_DIM
style_dim: *STYLE_DIM
num_domains: *NUM_DOMAINS
discriminator:
name: StarGANv2Discriminator
img_size: *IMAGE_SIZE
num_domains: *NUM_DOMAINS
dataset:
train:
name: StarGANv2Dataset
dataroot: data/stargan-v2/afhq/train
is_train: True
num_workers: 8
batch_size: 4
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: LoadImageFromFile
key: ref2
- name: Transforms
input_keys: [src, ref, ref2]
pipeline:
- name: RandomResizedCropProb
prob: 0.9
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image, image]
test:
name: StarGANv2Dataset
dataroot: data/stargan-v2/afhq/val
is_train: False
num_workers: 8
batch_size: 16
test_count: 16
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: Transforms
input_keys: [src, ref]
pipeline:
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 365
optimizer:
generator:
name: Adam
net_names:
- generator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
style_encoder:
name: Adam
net_names:
- style_encoder
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
mapping_network:
name: Adam
net_names:
- mapping_network
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
discriminator:
name: Adam
net_names:
- discriminator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
validate:
interval: 5000
save_img: false
log_config:
interval: 5
visiual_interval: 100
snapshot_config:
interval: 5
epochs: 200
output_dir: output_dir
model:
name: StarGANv2Model
latent_dim: &LATENT_DIM 16
lambda_sty: 1
lambda_ds: 1
lambda_cyc: 1
generator:
name: StarGANv2Generator
img_size: &IMAGE_SIZE 256
w_hpf: 1
style_dim: &STYLE_DIM 64
style:
name: StarGANv2Style
img_size: *IMAGE_SIZE
style_dim: *STYLE_DIM
num_domains: &NUM_DOMAINS 2
mapping:
name: StarGANv2Mapping
latent_dim: *LATENT_DIM
style_dim: *STYLE_DIM
num_domains: *NUM_DOMAINS
fan:
name: FAN
fname_pretrained: models/stargan-v2/wing.pdparams
discriminator:
name: StarGANv2Discriminator
img_size: *IMAGE_SIZE
num_domains: *NUM_DOMAINS
dataset:
train:
name: StarGANv2Dataset
dataroot: data/stargan-v2/celeba_hq/train/
is_train: True
num_workers: 8
batch_size: 4
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: LoadImageFromFile
key: ref2
- name: Transforms
input_keys: [src, ref, ref2]
pipeline:
- name: RandomResizedCropProb
prob: 0.9
size: [*IMAGE_SIZE, *IMAGE_SIZE]
scale: [0.8, 1.0]
ratio: [0.9, 1.1]
interpolation: 'bilinear'
keys: [image, image, image]
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bilinear'
keys: [image, image, image]
- name: RandomHorizontalFlip
prob: 0.5
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image, image]
test:
name: StarGANv2Dataset
dataroot: data/stargan-v2/celeba_hq/val/
is_train: False
num_workers: 8
batch_size: 16
test_count: 16
preprocess:
- name: LoadImageFromFile
key: src
- name: LoadImageFromFile
key: ref
- name: Transforms
input_keys: [src, ref]
pipeline:
- name: Resize
size: [*IMAGE_SIZE, *IMAGE_SIZE]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: LinearDecay
learning_rate: 0.0001
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 365
optimizer:
generator:
name: Adam
net_names:
- generator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
style_encoder:
name: Adam
net_names:
- style_encoder
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
mapping_network:
name: Adam
net_names:
- mapping_network
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
discriminator:
name: Adam
net_names:
- discriminator
beta1: 0.0
beta2: 0.99
weight_decay: 0.0001
validate:
interval: 5000
save_img: false
log_config:
interval: 5
visiual_interval: 100
snapshot_config:
interval: 5
......@@ -20,3 +20,4 @@ from .makeup_dataset import MakeupDataset
from .common_vision_dataset import CommonVisionDataset
from .animeganv2_dataset import AnimeGANV2Dataset
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
......@@ -267,6 +267,25 @@ class SRNoise(T.BaseTransform):
return image
@TRANSFORMS.register()
class RandomResizedCropProb(T.RandomResizedCrop):
"""RandomResizedCropProb.
Args:
prob (float): probabilty of using random-resized cropping.
size (int): cropped size.
"""
def __init__(self, prob, size, scale, ratio, interpolation, keys=None):
super().__init__(size, scale, ratio, interpolation)
self.prob = prob
self.keys = keys
def _apply_image(self, image):
if random.random() < self.prob:
image = super()._apply_image(image)
return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
......
import paddle
from .base_dataset import BaseDataset
from .builder import DATASETS
import os
from itertools import chain
from pathlib import Path
import traceback
import random
import numpy as np
from PIL import Image
from paddle.io import Dataset, WeightedRandomSampler
def listdir(dname):
fnames = list(chain(*[list(Path(dname).rglob('*.' + ext))
for ext in ['png', 'jpg', 'jpeg', 'JPG']]))
return fnames
def _make_balanced_sampler(labels):
class_counts = np.bincount(labels)
class_weights = 1. / class_counts
weights = class_weights[labels]
return WeightedRandomSampler(weights, len(weights))
class ImageFolder(Dataset):
def __init__(self, root, use_sampler=False):
self.samples, self.targets = self._make_dataset(root)
self.use_sampler = use_sampler
if self.use_sampler:
self.sampler = _make_balanced_sampler(self.targets)
self.iter_sampler = iter(self.sampler)
def _make_dataset(self, root):
domains = os.listdir(root)
fnames, labels = [], []
for idx, domain in enumerate(sorted(domains)):
class_dir = os.path.join(root, domain)
cls_fnames = listdir(class_dir)
fnames += cls_fnames
labels += [idx] * len(cls_fnames)
return fnames, labels
def __getitem__(self, i):
if self.use_sampler:
try:
index = next(self.iter_sampler)
except StopIteration:
self.iter_sampler = iter(self.sampler)
index = next(self.iter_sampler)
else:
index = i
fname = self.samples[index]
label = self.targets[index]
return fname, label
def __len__(self):
return len(self.targets)
class ReferenceDataset(Dataset):
def __init__(self, root, use_sampler=None):
self.samples, self.targets = self._make_dataset(root)
self.use_sampler = use_sampler
if self.use_sampler:
self.sampler = _make_balanced_sampler(self.targets)
self.iter_sampler = iter(self.sampler)
def _make_dataset(self, root):
domains = os.listdir(root)
fnames, fnames2, labels = [], [], []
for idx, domain in enumerate(sorted(domains)):
class_dir = os.path.join(root, domain)
cls_fnames = listdir(class_dir)
fnames += cls_fnames
fnames2 += random.sample(cls_fnames, len(cls_fnames))
labels += [idx] * len(cls_fnames)
return list(zip(fnames, fnames2)), labels
def __getitem__(self, i):
if self.use_sampler:
try:
index = next(self.iter_sampler)
except StopIteration:
self.iter_sampler = iter(self.sampler)
index = next(self.iter_sampler)
else:
index = i
fname, fname2 = self.samples[index]
label = self.targets[index]
return fname, fname2, label
def __len__(self):
return len(self.targets)
@DATASETS.register()
class StarGANv2Dataset(BaseDataset):
"""
"""
def __init__(self, dataroot, is_train, preprocess, test_count=0):
"""Initialize single dataset class.
Args:
dataroot (str): Directory of dataset.
preprocess (list[dict]): A sequence of data preprocess config.
"""
super(StarGANv2Dataset, self).__init__(preprocess)
self.dataroot = dataroot
self.is_train = is_train
if self.is_train:
self.src_loader = ImageFolder(self.dataroot, use_sampler=True)
self.ref_loader = ReferenceDataset(self.dataroot, use_sampler=True)
self.counts = len(self.src_loader)
else:
files = os.listdir(self.dataroot)
if 'src' in files and 'ref' in files:
self.src_loader = ImageFolder(os.path.join(self.dataroot, 'src'))
self.ref_loader = ImageFolder(os.path.join(self.dataroot, 'ref'))
else:
self.src_loader = ImageFolder(self.dataroot)
self.ref_loader = ImageFolder(self.dataroot)
self.counts = min(test_count, len(self.src_loader))
self.counts = min(self.counts, len(self.ref_loader))
def _fetch_inputs(self):
try:
x, y = next(self.iter_src)
except (AttributeError, StopIteration):
self.iter_src = iter(self.src_loader)
x, y = next(self.iter_src)
return x, y
def _fetch_refs(self):
try:
x, x2, y = next(self.iter_ref)
except (AttributeError, StopIteration):
self.iter_ref = iter(self.ref_loader)
x, x2, y = next(self.iter_ref)
return x, x2, y
def __getitem__(self, idx):
if self.is_train:
x, y = self._fetch_inputs()
x_ref, x_ref2, y_ref = self._fetch_refs()
datas = {
'src_path': x,
'src_cls': y,
'ref_path': x_ref,
'ref2_path': x_ref2,
'ref_cls': y_ref,
}
else:
x, y = self.src_loader[idx]
x_ref, y_ref = self.ref_loader[idx]
datas = {
'src_path': x,
'src_cls': y,
'ref_path': x_ref,
'ref_cls': y_ref,
}
if hasattr(self, 'preprocess') and self.preprocess:
datas = self.preprocess(datas)
return datas
def __len__(self):
return self.counts
def prepare_data_infos(self, dataroot):
pass
......@@ -142,6 +142,7 @@ class Trainer:
self.time_count = {}
self.best_metric = {}
self.model.set_total_iter(self.total_iters)
def distributed_data_parallel(self):
paddle.distributed.init_parallel_env()
......
......@@ -26,3 +26,4 @@ from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
from .wav2lip_hq_model import Wav2LipModelHq
from .starganv2_model import StarGANv2Model
......@@ -95,6 +95,9 @@ class BaseModel(ABC):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def set_total_iter(self, total_iter):
self.total_iter = total_iter
def test_iter(self, metrics=None):
"""Calculate metrics; called in every test iteration"""
self.eval()
......
......@@ -20,3 +20,4 @@ from .discriminator_animegan import AnimeDiscriminator
from .discriminator_styleganv2 import StyleGANv2Discriminator
from .syncnet import SyncNetColor
from .wav2lip_disc_qual import Wav2LipDiscQual
from .discriminator_starganv2 import StarGANv2Discriminator
import paddle.nn as nn
import paddle
from .builder import DISCRIMINATORS
from ..generators.generator_starganv2 import ResBlk
import numpy as np
@DISCRIMINATORS.register()
class StarGANv2Discriminator(nn.Layer):
def __init__(self, img_size=256, num_domains=2, max_conv_dim=512):
super().__init__()
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2D(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2D(dim_out, dim_out, 4, 1, 0)]
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2D(dim_out, num_domains, 1, 1, 0)]
self.main = nn.Sequential(*blocks)
def forward(self, x, y):
out = self.main(x)
out = paddle.reshape(out, (out.shape[0], -1)) # (batch, num_domains)
idx = paddle.zeros_like(out)
for i in range(idx.shape[0]):
idx[i, y[i]] = 1
s = idx * out
s = paddle.sum(s, axis=1)
return s
......@@ -26,3 +26,5 @@ from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator
from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel
from .drn import DRNGenerator
from .generator_starganv2 import StarGANv2Generator, StarGANv2Style, StarGANv2Mapping, FAN
import paddle
from paddle import nn
import paddle.nn.functional as F
from .builder import GENERATORS
import numpy as np
import math
from ppgan.modules.wing import CoordConvTh, ConvBlock, HourGlass, preprocess
class AvgPool2D(nn.Layer):
"""
AvgPool2D
Peplace avg_pool2d because paddle.grad will cause avg_pool2d to report an error when training.
In the future Paddle framework will supports avg_pool2d and remove this class.
"""
def __init__(self):
super(AvgPool2D, self).__init__()
self.filter = paddle.to_tensor([[1, 1],
[1, 1]], dtype='float32')
def forward(self, x):
filter = self.filter.unsqueeze(0).unsqueeze(1).tile([x.shape[1], 1, 1, 1])
return F.conv2d(x, filter, stride=2, padding=0, groups=x.shape[1]) / 4
class ResBlk(nn.Layer):
def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2),
normalize=False, downsample=False):
super().__init__()
self.actv = actv
self.normalize = normalize
self.downsample = downsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out)
def _build_weights(self, dim_in, dim_out):
self.conv1 = nn.Conv2D(dim_in, dim_in, 3, 1, 1)
self.conv2 = nn.Conv2D(dim_in, dim_out, 3, 1, 1)
if self.normalize:
self.norm1 = nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True)
self.norm2 = nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True)
if self.learned_sc:
self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False)
def _shortcut(self, x):
if self.learned_sc:
x = self.conv1x1(x)
if self.downsample:
x = AvgPool2D()(x)
return x
def _residual(self, x):
if self.normalize:
x = self.norm1(x)
x = self.actv(x)
x = self.conv1(x)
if self.downsample:
x = AvgPool2D()(x)
if self.normalize:
x = self.norm2(x)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x):
x = self._shortcut(x) + self._residual(x)
return x / math.sqrt(2) # unit variance
class AdaIN(nn.Layer):
def __init__(self, style_dim, num_features):
super().__init__()
self.norm = nn.InstanceNorm2D(num_features, weight_attr=False, bias_attr=False)
self.fc = nn.Linear(style_dim, num_features*2)
def forward(self, x, s):
h = self.fc(s)
# h = h.view(h.size(0), h.size(1), 1, 1)
h = paddle.reshape(h, (h.shape[0], h.shape[1], 1, 1))
gamma, beta = paddle.chunk(h, chunks=2, axis=1)
return (1 + gamma) * self.norm(x) + beta
class AdainResBlk(nn.Layer):
def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0,
actv=nn.LeakyReLU(0.2), upsample=False):
super().__init__()
self.w_hpf = w_hpf
self.actv = actv
self.upsample = upsample
self.learned_sc = dim_in != dim_out
self._build_weights(dim_in, dim_out, style_dim)
def _build_weights(self, dim_in, dim_out, style_dim=64):
self.conv1 = nn.Conv2D(dim_in, dim_out, 3, 1, 1)
self.conv2 = nn.Conv2D(dim_out, dim_out, 3, 1, 1)
self.norm1 = AdaIN(style_dim, dim_in)
self.norm2 = AdaIN(style_dim, dim_out)
if self.learned_sc:
self.conv1x1 = nn.Conv2D(dim_in, dim_out, 1, 1, 0, bias_attr=False)
def _shortcut(self, x):
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.learned_sc:
x = self.conv1x1(x)
return x
def _residual(self, x, s):
x = self.norm1(x, s)
x = self.actv(x)
if self.upsample:
x = F.interpolate(x, scale_factor=2, mode='nearest')
x = self.conv1(x)
x = self.norm2(x, s)
x = self.actv(x)
x = self.conv2(x)
return x
def forward(self, x, s):
out = self._residual(x, s)
if self.w_hpf == 0:
out = (out + self._shortcut(x)) / math.sqrt(2)
return out
class HighPass(nn.Layer):
def __init__(self, w_hpf):
super(HighPass, self).__init__()
self.filter = paddle.to_tensor([[-1, -1, -1],
[-1, 8., -1],
[-1, -1, -1]]) / w_hpf
def forward(self, x):
# filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1)
filter = self.filter.unsqueeze(0).unsqueeze(1).tile([x.shape[1], 1, 1, 1])
return F.conv2d(x, filter, padding=1, groups=x.shape[1])
@GENERATORS.register()
class StarGANv2Generator(nn.Layer):
def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1):
super().__init__()
dim_in = 2**14 // img_size
self.img_size = img_size
self.from_rgb = nn.Conv2D(3, dim_in, 3, 1, 1)
self.encode = nn.LayerList()
self.decode = nn.LayerList()
self.to_rgb = nn.Sequential(
nn.InstanceNorm2D(dim_in, weight_attr=True, bias_attr=True),
nn.LeakyReLU(0.2),
nn.Conv2D(dim_in, 3, 1, 1, 0))
# down/up-sampling blocks
repeat_num = int(np.log2(img_size)) - 4
if w_hpf > 0:
repeat_num += 1
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
self.encode.append(
ResBlk(dim_in, dim_out, normalize=True, downsample=True))
if len(self.decode) == 0:
self.decode.append(AdainResBlk(dim_out, dim_in, style_dim,
w_hpf=w_hpf, upsample=True))
else:
self.decode.insert(
0, AdainResBlk(dim_out, dim_in, style_dim,
w_hpf=w_hpf, upsample=True)) # stack-like
dim_in = dim_out
# bottleneck blocks
for _ in range(2):
self.encode.append(
ResBlk(dim_out, dim_out, normalize=True))
self.decode.insert(
0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf))
if w_hpf > 0:
self.hpf = HighPass(w_hpf)
def forward(self, x, s, masks=None):
x = self.from_rgb(x)
cache = {}
for block in self.encode:
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
cache[x.shape[2]] = x
x = block(x)
for block in self.decode:
x = block(x, s)
if (masks is not None) and (x.shape[2] in [32, 64, 128]):
mask = masks[0] if x.shape[2] in [32] else masks[1]
mask = F.interpolate(mask, size=[x.shape[2], x.shape[2]], mode='bilinear')
x = x + self.hpf(mask * cache[x.shape[2]])
return self.to_rgb(x)
@GENERATORS.register()
class StarGANv2Mapping(nn.Layer):
def __init__(self, latent_dim=16, style_dim=64, num_domains=2):
super().__init__()
layers = []
layers += [nn.Linear(latent_dim, 512)]
layers += [nn.ReLU()]
for _ in range(3):
layers += [nn.Linear(512, 512)]
layers += [nn.ReLU()]
self.shared = nn.Sequential(*layers)
self.unshared = nn.LayerList()
for _ in range(num_domains):
self.unshared.append(nn.Sequential(nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, style_dim)))
def forward(self, z, y):
h = self.shared(z)
out = []
for layer in self.unshared:
out += [layer(h)]
out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim)
idx = paddle.to_tensor(np.array(range(y.shape[0]))).astype('int')
s = []
for i in range(idx.shape[0]):
s += [out[idx[i].numpy().astype(np.int).tolist()[0], y[i].numpy().astype(np.int).tolist()[0]]]
s = paddle.stack(s)
s = paddle.reshape(s, (s.shape[0], -1))
return s
@GENERATORS.register()
class StarGANv2Style(nn.Layer):
def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512):
super().__init__()
dim_in = 2**14 // img_size
blocks = []
blocks += [nn.Conv2D(3, dim_in, 3, 1, 1)]
repeat_num = int(np.log2(img_size)) - 2
for _ in range(repeat_num):
dim_out = min(dim_in*2, max_conv_dim)
blocks += [ResBlk(dim_in, dim_out, downsample=True)]
dim_in = dim_out
blocks += [nn.LeakyReLU(0.2)]
blocks += [nn.Conv2D(dim_out, dim_out, 4, 1, 0)]
blocks += [nn.LeakyReLU(0.2)]
self.shared = nn.Sequential(*blocks)
self.unshared = nn.LayerList()
for _ in range(num_domains):
self.unshared.append(nn.Linear(dim_out, style_dim))
def forward(self, x, y):
h = self.shared(x)
h = paddle.reshape(h, (h.shape[0], -1))
out = []
for layer in self.unshared:
out += [layer(h)]
out = paddle.stack(out, axis=1) # (batch, num_domains, style_dim)
idx = paddle.to_tensor(np.array(range(y.shape[0]))).astype('int')
s = []
for i in range(idx.shape[0]):
s += [out[idx[i].numpy().astype(np.int).tolist()[0], y[i].numpy().astype(np.int).tolist()[0]]]
s = paddle.stack(s)
s = paddle.reshape(s, (s.shape[0], -1))
return s
@GENERATORS.register()
class FAN(nn.Layer):
def __init__(self, num_modules=1, end_relu=False, num_landmarks=98, fname_pretrained=None):
super(FAN, self).__init__()
self.num_modules = num_modules
self.end_relu = end_relu
# Base part
self.conv1 = CoordConvTh(256, 256, True, False,
in_channels=3, out_channels=64,
kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2D(64)
self.conv2 = ConvBlock(64, 128)
self.conv3 = ConvBlock(128, 128)
self.conv4 = ConvBlock(128, 256)
# Stacking part
self.add_sublayer('m0', HourGlass(1, 4, 256, first_one=True))
self.add_sublayer('top_m_0', ConvBlock(256, 256))
self.add_sublayer('conv_last0', nn.Conv2D(256, 256, 1, 1, 0))
self.add_sublayer('bn_end0', nn.BatchNorm2D(256))
self.add_sublayer('l0', nn.Conv2D(256, num_landmarks+1, 1, 1, 0))
if fname_pretrained is not None:
self.load_pretrained_weights(fname_pretrained)
def load_pretrained_weights(self, fname):
import pickle
import six
with open(fname, 'rb') as f:
checkpoint = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1')
model_weights = self.state_dict()
model_weights.update({k: v for k, v in checkpoint['state_dict'].items()
if k in model_weights})
self.set_state_dict(model_weights)
def forward(self, x):
x, _ = self.conv1(x)
x = F.relu(self.bn1(x), True)
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
x = self.conv3(x)
x = self.conv4(x)
outputs = []
boundary_channels = []
tmp_out = None
ll, boundary_channel = self._sub_layers['m0'](x, tmp_out)
ll = self._sub_layers['top_m_0'](ll)
ll = F.relu(self._sub_layers['bn_end0']
(self._sub_layers['conv_last0'](ll)), True)
# Predict heatmaps
tmp_out = self._sub_layers['l0'](ll)
if self.end_relu:
tmp_out = F.relu(tmp_out) # HACK: Added relu
outputs.append(tmp_out)
boundary_channels.append(boundary_channel)
return outputs, boundary_channels
@paddle.no_grad()
def get_heatmap(self, x, b_preprocess=True):
''' outputs 0-1 normalized heatmap '''
x = F.interpolate(x, size=[256, 256], mode='bilinear')
x_01 = x*0.5 + 0.5
outputs, _ = self(x_01)
heatmaps = outputs[-1][:, :-1, :, :]
scale_factor = x.shape[2] // heatmaps.shape[2]
if b_preprocess:
heatmaps = F.interpolate(heatmaps, scale_factor=scale_factor,
mode='bilinear', align_corners=True)
heatmaps = preprocess(heatmaps)
return heatmaps
from paddle.fluid.layers.nn import soft_relu
from .base_model import BaseModel
from paddle import nn
import paddle
import paddle.nn.functional as F
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from ..modules.init import kaiming_normal_, constant_
from ppgan.utils.visual import make_grid, tensor2img
import numpy as np
def translate_using_reference(nets, w_hpf, x_src, x_ref, y_ref):
N, C, H, W = x_src.shape
wb = paddle.to_tensor(np.ones((1, C, H, W))).astype('float32')
x_src_with_wb = paddle.concat([wb, x_src], axis=0)
masks = nets['fan'].get_heatmap(x_src) if w_hpf > 0 else None
s_ref = nets['style_encoder'](x_ref, y_ref)
s_ref_list = paddle.unsqueeze(s_ref, axis=[1])
s_ref_lists = []
for _ in range(N):
s_ref_lists.append(s_ref_list)
s_ref_list = paddle.stack(s_ref_lists, axis=1)
s_ref_list = paddle.reshape(s_ref_list, (s_ref_list.shape[0], s_ref_list.shape[1], s_ref_list.shape[3]))
x_concat = [x_src_with_wb]
for i, s_ref in enumerate(s_ref_list):
x_fake = nets['generator'](x_src, s_ref, masks=masks)
x_fake_with_ref = paddle.concat([x_ref[i:i+1], x_fake], axis=0)
x_concat += [x_fake_with_ref]
x_concat = paddle.concat(x_concat, axis=0)
img = tensor2img(make_grid(x_concat, nrow=N+1, range=(0, 1)))
del x_concat
return img
def compute_d_loss(nets, lambda_reg, x_real, y_org, y_trg, z_trg=None, x_ref=None, masks=None):
assert (z_trg is None) != (x_ref is None)
# with real images
x_real.stop_gradient = False
out = nets['discriminator'](x_real, y_org)
loss_real = adv_loss(out, 1)
loss_reg = r1_reg(out, x_real)
# with fake images
with paddle.no_grad():
if z_trg is not None:
s_trg = nets['mapping_network'](z_trg, y_trg)
else: # x_ref is not None
s_trg = nets['style_encoder'](x_ref, y_trg)
x_fake = nets['generator'](x_real, s_trg, masks=masks)
out = nets['discriminator'](x_fake, y_trg)
loss_fake = adv_loss(out, 0)
loss = loss_real + loss_fake + lambda_reg * loss_reg
return loss, {'real': loss_real.numpy(),
'fake': loss_fake.numpy(),
'reg': loss_reg.numpy()}
def adv_loss(logits, target):
assert target in [1, 0]
targets = paddle.full_like(logits, fill_value=target)
loss = F.binary_cross_entropy_with_logits(logits, targets)
return loss
def r1_reg(d_out, x_in):
# zero-centered gradient penalty for real images
batch_size = x_in.shape[0]
grad_dout = paddle.grad(
outputs=d_out.sum(), inputs=x_in,
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_dout2 = grad_dout.pow(2)
assert(grad_dout2.shape == x_in.shape)
reg = 0.5 * paddle.reshape(grad_dout2, (batch_size, -1)).sum(1).mean(0)
return reg
def soft_update(source, target, beta=1.0):
assert 0.0 <= beta <= 1.0
target_model_map = dict(target.named_parameters())
for param_name, source_param in source.named_parameters():
target_param = target_model_map[param_name]
target_param.set_value(beta * source_param + (1.0 - beta) * target_param)
def dump_model(model):
params = {}
for k in model.state_dict().keys():
if k.endswith('.scale'):
params[k] = model.state_dict()[k].shape
return params
def compute_g_loss(nets, w_hpf, lambda_sty, lambda_ds, lambda_cyc, x_real, y_org, y_trg, z_trgs=None, x_refs=None, masks=None):
assert (z_trgs is None) != (x_refs is None)
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# adversarial loss
if z_trgs is not None:
s_trg = nets['mapping_network'](z_trg, y_trg)
else:
s_trg = nets['style_encoder'](x_ref, y_trg)
x_fake = nets['generator'](x_real, s_trg, masks=masks)
out = nets['discriminator'](x_fake, y_trg)
loss_adv = adv_loss(out, 1)
# style reconstruction loss
s_pred = nets['style_encoder'](x_fake, y_trg)
loss_sty = paddle.mean(paddle.abs(s_pred - s_trg))
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = nets['mapping_network'](z_trg2, y_trg)
else:
s_trg2 = nets['style_encoder'](x_ref2, y_trg)
x_fake2 = nets['generator'](x_real, s_trg2, masks=masks)
loss_ds = paddle.mean(paddle.abs(x_fake - x_fake2))
# cycle-consistency loss
masks = nets['fan'].get_heatmap(x_fake) if w_hpf > 0 else None
s_org = nets['style_encoder'](x_real, y_org)
x_rec = nets['generator'](x_fake, s_org, masks=masks)
loss_cyc = paddle.mean(paddle.abs(x_rec - x_real))
loss = loss_adv + lambda_sty * loss_sty \
- lambda_ds * loss_ds + lambda_cyc * loss_cyc
return loss, {'adv': loss_adv.numpy(),
'sty': loss_sty.numpy(),
'ds:': loss_ds.numpy(),
'cyc': loss_cyc.numpy()}
def he_init(module):
if isinstance(module, nn.Conv2D):
kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
constant_(module.bias, 0)
if isinstance(module, nn.Linear):
kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
if module.bias is not None:
constant_(module.bias, 0)
@MODELS.register()
class StarGANv2Model(BaseModel):
def __init__(
self,
generator,
style=None,
mapping=None,
discriminator=None,
fan=None,
latent_dim=16,
lambda_reg=1,
lambda_sty=1,
lambda_ds=1,
lambda_cyc=1,
):
super(StarGANv2Model, self).__init__()
self.w_hpf = generator['w_hpf']
self.nets_ema = {}
self.nets['generator'] = build_generator(generator)
self.nets_ema['generator'] = build_generator(generator)
self.nets['style_encoder'] = build_generator(style)
self.nets_ema['style_encoder'] = build_generator(style)
self.nets['mapping_network'] = build_generator(mapping)
self.nets_ema['mapping_network'] = build_generator(mapping)
if discriminator:
self.nets['discriminator'] = build_discriminator(discriminator)
if self.w_hpf > 0:
fan_model = build_generator(fan)
fan_model.eval()
self.nets['fan'] = fan_model
self.nets_ema['fan'] = fan_model
self.latent_dim = latent_dim
self.lambda_reg = lambda_reg
self.lambda_sty = lambda_sty
self.lambda_ds = lambda_ds
self.lambda_cyc = lambda_cyc
self.nets['generator'].apply(he_init)
self.nets['style_encoder'].apply(he_init)
self.nets['mapping_network'].apply(he_init)
self.nets['discriminator'].apply(he_init)
# remember the initial value of ds weight
self.initial_lambda_ds = self.lambda_ds
def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Args:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap images in domain A and domain B.
"""
pass
self.input = input
self.input['z_trg'] = paddle.randn((input['src'].shape[0], self.latent_dim))
self.input['z_trg2'] = paddle.randn((input['src'].shape[0], self.latent_dim))
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def _reset_grad(self, optims):
for optim in optims.values():
optim.clear_gradients()
def train_iter(self, optimizers=None):
#TODO
x_real, y_org = self.input['src'], self.input['src_cls']
x_ref, x_ref2, y_trg = self.input['ref'], self.input['ref2'], self.input['ref_cls']
z_trg, z_trg2 = self.input['z_trg'], self.input['z_trg2']
masks = self.nets['fan'].get_heatmap(x_real) if self.w_hpf > 0 else None
# train the discriminator
d_loss, d_losses_latent = compute_d_loss(
self.nets, self.lambda_reg, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
self._reset_grad(optimizers)
d_loss.backward()
optimizers['discriminator'].minimize(d_loss)
d_loss, d_losses_ref = compute_d_loss(
self.nets, self.lambda_reg, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
self._reset_grad(optimizers)
d_loss.backward()
optimizers['discriminator'].step()
# train the generator
g_loss, g_losses_latent = compute_g_loss(
self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
self._reset_grad(optimizers)
g_loss.backward()
optimizers['generator'].step()
optimizers['mapping_network'].step()
optimizers['style_encoder'].step()
g_loss, g_losses_ref = compute_g_loss(
self.nets, self.w_hpf, self.lambda_sty, self.lambda_ds, self.lambda_cyc, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
self._reset_grad(optimizers)
g_loss.backward()
optimizers['generator'].step()
# compute moving average of network parameters
soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999)
soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999)
soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999)
# decay weight for diversity sensitive loss
if self.lambda_ds > 0:
self.lambda_ds -= (self.initial_lambda_ds / self.total_iter)
for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
for key, value in loss.items():
self.losses[prefix + key] = value
self.losses['G/lambda_ds'] = self.lambda_ds
self.losses['Total iter'] = int(self.total_iter)
def test_iter(self, metrics=None):
#TODO
self.nets_ema['generator'].eval()
self.nets_ema['style_encoder'].eval()
soft_update(self.nets['generator'], self.nets_ema['generator'], beta=0.999)
soft_update(self.nets['mapping_network'], self.nets_ema['mapping_network'], beta=0.999)
soft_update(self.nets['style_encoder'], self.nets_ema['style_encoder'], beta=0.999)
src_img = self.input['src']
ref_img = self.input['ref']
ref_label = self.input['ref_cls']
with paddle.no_grad():
img = translate_using_reference(self.nets_ema, self.w_hpf,
paddle.to_tensor(src_img).astype('float32'),
paddle.to_tensor(ref_img).astype('float32'),
paddle.to_tensor(ref_label).astype('float32'))
self.visual_items['reference'] = img
self.nets_ema['generator'].train()
self.nets_ema['style_encoder'].train()
"""
StarGAN v2
Copyright (c) 2020-present NAVER Corp.
"""
from collections import namedtuple
from copy import deepcopy
from functools import partial
from munch import Munch
import numpy as np
import cv2
from skimage.filters import gaussian
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppgan.models.generators.builder import GENERATORS
class HourGlass(nn.Layer):
def __init__(self, num_modules, depth, num_features, first_one=False):
super(HourGlass, self).__init__()
self.num_modules = num_modules
self.depth = depth
self.features = num_features
self.coordconv = CoordConvTh(64, 64, True, True, 256, first_one,
out_channels=256,
kernel_size=1, stride=1, padding=0)
self._generate_network(self.depth)
def _generate_network(self, level):
self.add_sublayer('b1_' + str(level), ConvBlock(256, 256))
self.add_sublayer('b2_' + str(level), ConvBlock(256, 256))
if level > 1:
self._generate_network(level - 1)
else:
self.add_sublayer('b2_plus_' + str(level), ConvBlock(256, 256))
self.add_sublayer('b3_' + str(level), ConvBlock(256, 256))
def _forward(self, level, inp):
up1 = inp
up1 = self._sub_layers['b1_' + str(level)](up1)
low1 = F.avg_pool2d(inp, 2, stride=2)
low1 = self._sub_layers['b2_' + str(level)](low1)
if level > 1:
low2 = self._forward(level - 1, low1)
else:
low2 = low1
low2 = self._sub_layers['b2_plus_' + str(level)](low2)
low3 = low2
low3 = self._sub_layers['b3_' + str(level)](low3)
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
return up1 + up2
def forward(self, x, heatmap):
x, last_channel = self.coordconv(x, heatmap)
return self._forward(self.depth, x), last_channel
class AddCoordsTh(nn.Layer):
def __init__(self, height=64, width=64, with_r=False, with_boundary=False):
super(AddCoordsTh, self).__init__()
self.with_r = with_r
self.with_boundary = with_boundary
with paddle.no_grad():
x_coords = paddle.arange(height).unsqueeze(1).expand((height, width)).astype('float32')
y_coords = paddle.arange(width).unsqueeze(0).expand((height, width)).astype('float32')
x_coords = (x_coords / (height - 1)) * 2 - 1
y_coords = (y_coords / (width - 1)) * 2 - 1
coords = paddle.stack([x_coords, y_coords], axis=0) # (2, height, width)
if self.with_r:
rr = paddle.sqrt(paddle.pow(x_coords, 2) + paddle.pow(y_coords, 2)) # (height, width)
rr = (rr / paddle.max(rr)).unsqueeze(0)
coords = paddle.concat([coords, rr], axis=0)
self.coords = coords.unsqueeze(0) # (1, 2 or 3, height, width)
self.x_coords = x_coords
self.y_coords = y_coords
def forward(self, x, heatmap=None):
"""
x: (batch, c, x_dim, y_dim)
"""
coords = self.coords.tile((x.shape[0], 1, 1, 1))
if self.with_boundary and heatmap is not None:
boundary_channel = paddle.clip(heatmap[:, -1:, :, :], 0.0, 1.0)
zero_tensor = paddle.zeros_like(self.x_coords)
xx_boundary_channel = paddle.where(boundary_channel > 0.05, self.x_coords, zero_tensor)
yy_boundary_channel = paddle.where(boundary_channel > 0.05, self.y_coords, zero_tensor)
coords = paddle.concat([coords, xx_boundary_channel, yy_boundary_channel], axis=1)
x_and_coords = paddle.concat([x, coords], axis=1)
return x_and_coords
class CoordConvTh(nn.Layer):
"""CoordConv layer as in the paper."""
def __init__(self, height, width, with_r, with_boundary,
in_channels, first_one=False, *args, **kwargs):
super(CoordConvTh, self).__init__()
self.addcoords = AddCoordsTh(height, width, with_r, with_boundary)
in_channels += 2
if with_r:
in_channels += 1
if with_boundary and not first_one:
in_channels += 2
self.conv = nn.Conv2D(in_channels=in_channels, *args, **kwargs)
def forward(self, input_tensor, heatmap=None):
ret = self.addcoords(input_tensor, heatmap)
last_channel = ret[:, -2:, :, :]
ret = self.conv(ret)
return ret, last_channel
class ConvBlock(nn.Layer):
def __init__(self, in_planes, out_planes):
super(ConvBlock, self).__init__()
self.bn1 = nn.BatchNorm2D(in_planes)
conv3x3 = partial(nn.Conv2D, kernel_size=3, stride=1, padding=1, bias_attr=False, dilation=1)
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
self.bn2 = nn.BatchNorm2D(int(out_planes / 2))
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
self.bn3 = nn.BatchNorm2D(int(out_planes / 4))
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
self.downsample = None
if in_planes != out_planes:
self.downsample = nn.Sequential(nn.BatchNorm2D(in_planes),
nn.ReLU(True),
nn.Conv2D(in_planes, out_planes, 1, 1, bias_attr=False))
def forward(self, x):
residual = x
out1 = self.bn1(x)
out1 = F.relu(out1, True)
out1 = self.conv1(out1)
out2 = self.bn2(out1)
out2 = F.relu(out2, True)
out2 = self.conv2(out2)
out3 = self.bn3(out2)
out3 = F.relu(out3, True)
out3 = self.conv3(out3)
out3 = paddle.concat((out1, out2, out3), 1)
if self.downsample is not None:
residual = self.downsample(residual)
out3 += residual
return out3
# ========================== #
# Mask related functions #
# ========================== #
def normalize(x, eps=1e-6):
"""Apply min-max normalization."""
# x = x.contiguous()
N, C, H, W = x.shape
x_ = paddle.reshape(x, (N*C, -1))
max_val = paddle.max(x_, axis=1, keepdim=True)[0]
min_val = paddle.min(x_, axis=1, keepdim=True)[0]
x_ = (x_ - min_val) / (max_val - min_val + eps)
out = paddle.reshape(x_, (N, C, H, W))
return out
def truncate(x, thres=0.1):
"""Remove small values in heatmaps."""
return paddle.where(x < thres, paddle.zeros_like(x), x)
def resize(x, p=2):
"""Resize heatmaps."""
return x**p
def shift(x, N):
"""Shift N pixels up or down."""
x = x.numpy()
up = N >= 0
N = abs(N)
_, _, H, W = x.shape
head = np.arange(N)
tail = np.arange(H-N)
if up:
head = np.arange(H-N)+N
tail = np.arange(N)
else:
head = np.arange(N) + (H-N)
tail = np.arange(H-N)
# permutation indices
perm = np.concatenate([head, tail])
out = x[:, :, perm, :]
out = paddle.to_tensor(out)
return out
IDXPAIR = namedtuple('IDXPAIR', 'start end')
index_map = Munch(chin=IDXPAIR(0 + 8, 33 - 8),
eyebrows=IDXPAIR(33, 51),
eyebrowsedges=IDXPAIR(33, 46),
nose=IDXPAIR(51, 55),
nostrils=IDXPAIR(55, 60),
eyes=IDXPAIR(60, 76),
lipedges=IDXPAIR(76, 82),
lipupper=IDXPAIR(77, 82),
liplower=IDXPAIR(83, 88),
lipinner=IDXPAIR(88, 96))
OPPAIR = namedtuple('OPPAIR', 'shift resize')
def preprocess(x):
"""Preprocess 98-dimensional heatmaps."""
N, C, H, W = x.shape
x = truncate(x)
x = normalize(x)
sw = H // 256
operations = Munch(chin=OPPAIR(0, 3),
eyebrows=OPPAIR(-7*sw, 2),
nostrils=OPPAIR(8*sw, 4),
lipupper=OPPAIR(-8*sw, 4),
liplower=OPPAIR(8*sw, 4),
lipinner=OPPAIR(-2*sw, 3))
for part, ops in operations.items():
start, end = index_map[part]
x[:, start:end] = resize(shift(x[:, start:end], ops.shift), ops.resize)
zero_out = paddle.concat([paddle.arange(0, index_map.chin.start),
paddle.arange(index_map.chin.end, 33),
paddle.to_tensor([index_map.eyebrowsedges.start,
index_map.eyebrowsedges.end,
index_map.lipedges.start,
index_map.lipedges.end])])
x = x.numpy()
zero_out = zero_out.numpy()
x[:, zero_out] = 0
x = paddle.to_tensor(x)
start, end = index_map.nose
x[:, start+1:end] = shift(x[:, start+1:end], 4*sw)
x[:, start:end] = resize(x[:, start:end], 1)
start, end = index_map.eyes
x[:, start:end] = resize(x[:, start:end], 1)
x[:, start:end] = resize(shift(x[:, start:end], -8), 3) + \
shift(x[:, start:end], -24)
# Second-level mask
x2 = deepcopy(x)
x2[:, index_map.chin.start:index_map.chin.end] = 0 # start:end was 0:33
x2[:, index_map.lipedges.start:index_map.lipinner.end] = 0 # start:end was 76:96
x2[:, index_map.eyebrows.start:index_map.eyebrows.end] = 0 # start:end was 33:51
x = paddle.sum(x, axis=1, keepdim=True) # (N, 1, H, W)
x2 = paddle.sum(x2, axis=1, keepdim=True) # mask without faceline and mouth
x = x.numpy()
x2 = x2.numpy()
x[x != x] = 0 # set nan to zero
x2[x != x] = 0 # set nan to zero
x = paddle.to_tensor(x)
x2 = paddle.to_tensor(x2)
return x.clip(0, 1), x2.clip(0, 1)
......@@ -7,3 +7,4 @@ imageio-ffmpeg
librosa==0.7.0
numba==0.48
easydict
munch
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册