未验证 提交 130bd7f9 编写于 作者: L LielinJiang 提交者: GitHub

Fix 2.0-beta bugs (#183)

* fix 2.0-beta bugs

* update pretreained path

* add extract_weight.py
上级 60066eb2
......@@ -10,6 +10,7 @@ model:
gan_criterion:
name: GANLoss
gan_mode: lsgan
# use your trained path
pretrain_ckpt: output_dir/AnimeGANV2PreTrainModel-2020-11-29-17-02/epoch_2_checkpoint.pdparams
g_adv_weight: 300.
d_adv_weight: 300.
......@@ -47,13 +48,12 @@ dataset:
test:
name: SingleDataset
dataroot: data/animedataset/test/HR_photo
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
serial_batches: False
pool_size: 50
transforms:
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
- name: ResizeToScale
size: [256, 256]
scale: 32
......@@ -62,6 +62,7 @@ dataset:
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
lr_scheduler:
name: LinearDecay
......
......@@ -21,44 +21,33 @@ model:
dataset:
train:
name: SingleDataset
dataroot: data/mnist/train
name: CommonVisionDataset
dataset_name: MNIST
num_workers: 0
batch_size: 128
preprocess:
- name: LoadImageFromFile
key: A
- name: Transfroms
input_keys: [A]
pipeline:
return_label: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: SingleDataset
dataroot: data/mnist/test
preprocess:
- name: LoadImageFromFile
key: A
- name: Transforms
input_keys: [A]
pipeline:
name: CommonVisionDataset
dataset_name: MNIST
num_workers: 0
batch_size: 128
return_label: False
transforms:
- name: Resize
size: [64, 64]
interpolation: 'bicubic' #cv2.INTER_CUBIC
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [127.5, 127.5, 127.5]
std: [127.5, 127.5, 127.5]
keys: [image, image]
mean: [127.5]
std: [127.5]
keys: [image]
lr_scheduler:
name: LinearDecay
......
......@@ -92,6 +92,21 @@ train model
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
```
### Inference
When you finish training, you need to use ``tools/extract_weight.py`` to extract the corresponding weights.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output YOUR_WEIGHT_PATH.pdparams
```
Then use ``applications/tools/styleganv2.py`` to get results
```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```
Note: ``--size`` should be same with your config file.
## Results
Random Samples:
......
......@@ -54,9 +54,56 @@ python -u tools/styleganv2.py \
- n_col: 采样的图片的列数
- cpu: 是否使用cpu推理,若不使用,请在命令中去除
### 训练(TODO)
### 训练
#### 准备数据集
你可以从[这里](https://drive.google.com/drive/folders/1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP)下载对应的数据集
为了方便,我们提供了[images256x256.tar](https://paddlegan.bj.bcebos.com/datasets/images256x256.tar)
目前的配置文件默认数据集的结构如下:
```
PaddleGAN
├── data
├── ffhq
├──images1024x1024
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──images256x256
├── 00000.png
├── 00001.png
├── 00002.png
├── 00003.png
├── 00004.png
├──custom_data
├── img0.png
├── img1.png
├── img2.png
├── img3.png
├── img4.png
...
```
启动训练
```
python tools/main.py -c configs/stylegan_v2_256_ffhq.yaml
```
### 推理
训练结束后,需要使用 ``tools/extract_weight.py`` 来提取对应的权重给``applications/tools/styleganv2.py``来进行推理.
```
python tools/extract_weight.py output_dir/YOUR_TRAINED_WEIGHT.pdparams --net-name gen_ema --output stylegan_config_f.pdparams
```
```
python tools/styleganv2.py --output_path stylegan01 --weight_path YOUR_WEIGHT_PATH.pdparams --size 256
```
未来还将添加训练脚本方便用户训练出更多类型的 StyleGAN V2 图像生成器。
注意: ``--size`` 这个参数要和配置文件中的参数保持一致.
## 生成结果展示
......
......@@ -20,7 +20,7 @@ from .base_dataset import BaseDataset
from .image_folder import ImageFolder
from .builder import DATASETS
from .transforms.builder import build_transforms
from .preprocess.builder import build_transforms
@DATASETS.register()
......
......@@ -17,7 +17,7 @@ import paddle
from .builder import DATASETS
from .base_dataset import BaseDataset
from .transforms.builder import build_transforms
from .preprocess.builder import build_transforms
@DATASETS.register()
......
......@@ -62,3 +62,15 @@ def build_preprocess(cfg):
preproccess = Compose(preproccess)
return preproccess
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
......@@ -264,3 +264,74 @@ class SRNoise(T.BaseTransform):
image = image + normed_noise
image = np.clip(image, 0., 1.)
return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value
def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params
def _apply_image(self, image):
return np.clip(image + self.params['value'], 0, 255).astype('uint8')
# return custom_F.add(image, self.params['value'])
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params
@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x
if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
from .transforms import ResizeToScale, PairedRandomCrop, PairedRandomHorizontalFlip, Add
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
from . import functional_cv2 as F_cv2
from paddle.vision.transforms.functional import _is_numpy_image, _is_pil_image
__all__ = ['add']
def add(pic, value):
if not (_is_pil_image(pic) or _is_numpy_image(pic)):
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(
type(pic)))
if _is_pil_image(pic):
raise NotImplementedError('add not support pil image')
else:
return F_cv2.add(pic, value)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import division
import numpy as np
def add(image, value):
return np.clip(image + value, 0, 255).astype('uint8')
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numbers
import collections
import numpy as np
import paddle.vision.transforms as T
import paddle.vision.transforms.functional as F
from . import functional as custom_F
from .builder import TRANSFORMS
if sys.version_info < (3, 3):
Sequence = collections.Sequence
Iterable = collections.Iterable
else:
Sequence = collections.abc.Sequence
Iterable = collections.abc.Iterable
TRANSFORMS.register(T.Resize)
TRANSFORMS.register(T.RandomCrop)
TRANSFORMS.register(T.RandomHorizontalFlip)
TRANSFORMS.register(T.Normalize)
TRANSFORMS.register(T.Transpose)
TRANSFORMS.register(T.Grayscale)
@TRANSFORMS.register()
class PairedRandomCrop(T.RandomCrop):
def __init__(self, size, keys=None):
super().__init__(size, keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
params = {}
params['crop_prams'] = self._get_param(image, self.size)
return params
def _apply_image(self, img):
i, j, h, w = self.params['crop_prams']
return F.crop(img, i, j, h, w)
@TRANSFORMS.register()
class PairedRandomHorizontalFlip(T.RandomHorizontalFlip):
def __init__(self, prob=0.5, keys=None):
super().__init__(prob, keys=keys)
def _get_params(self, inputs):
params = {}
params['flip'] = random.random() < self.prob
return params
def _apply_image(self, image):
if self.params['flip']:
return F.hflip(image)
return image
@TRANSFORMS.register()
class Add(T.BaseTransform):
def __init__(self, value, keys=None):
"""Initialize Add Transform
Parameters:
value (List[int]) -- the [r,g,b] value will add to image by pixel wise.
"""
super().__init__(keys=keys)
self.value = value
def _get_params(self, inputs):
params = {}
params['value'] = self.value
return params
def _apply_image(self, image):
return custom_F.add(image, self.params['value'])
@TRANSFORMS.register()
class ResizeToScale(T.BaseTransform):
def __init__(self,
size: int,
scale: int,
interpolation='bilinear',
keys=None):
"""Initialize ResizeToScale Transform
Parameters:
size (List[int]) -- the minimum target size
scale (List[int]) -- the stride scale
interpolation (Optional[str]) -- interpolation method
"""
super().__init__(keys=keys)
if isinstance(size, int):
self.size = (size, size)
else:
self.size = size
self.scale = scale
self.interpolation = interpolation
def _get_params(self, inputs):
image = inputs[self.keys.index('image')]
hw = image.shape[:2]
params = {}
params['taget_size'] = self.reduce_to_scale(hw, self.size[::-1],
self.scale)
return params
@staticmethod
def reduce_to_scale(img_hw, min_hw, scale):
im_h, im_w = img_hw
if im_h <= min_hw[0]:
im_h = min_hw[0]
else:
x = im_h % scale
im_h = im_h - x
if im_w < min_hw[1]:
im_w = min_hw[1]
else:
y = im_w % scale
im_w = im_w - y
return (im_h, im_w)
def _apply_image(self, image):
return F.resize(image, self.params['taget_size'], self.interpolation)
......@@ -165,6 +165,8 @@ class Trainer:
iter_loader = IterLoader(self.train_dataloader)
# set model.is_train = True
self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1):
self.current_epoch = iter_loader.epoch
self.inner_iter = self.current_iter % self.iters_per_epoch
......@@ -219,6 +221,9 @@ class Trainer:
for metric in self.metrics.values():
metric.reset()
# set model.is_train = False
self.model.setup_train_mode(is_train=False)
for i in range(self.max_eval_steps):
data = next(iter_loader)
self.model.setup_input(data)
......@@ -289,7 +294,9 @@ class Trainer:
message += 'ips: %.5f images/s ' % self.ips
if hasattr(self, 'step_time'):
eta = self.step_time * (self.total_iters - self.current_iter - 1)
eta = self.step_time * (self.total_iters - self.current_iter)
eta = eta if eta > 0 else 0
eta_str = str(datetime.timedelta(seconds=int(eta)))
message += f'eta: {eta_str}'
......
......@@ -83,7 +83,7 @@ class AnimeGANV2Model(BaseModel):
self.smooth_gray = paddle.to_tensor(input['smooth_gray'])
else:
self.real = paddle.to_tensor(input['A'])
self.image_paths = input['A_paths']
self.image_paths = input['A_path']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
......
......@@ -56,7 +56,8 @@ class DCGANModel(BaseModel):
input (dict): include the data itself and its metadata information.
"""
# get 1-channel gray image, or 3-channel color image
self.real = paddle.to_tensor(input['A'])
self.real = paddle.to_tensor(input['img'])
if 'img_path' in input:
self.image_paths = input['A_path']
def forward(self):
......
......@@ -74,10 +74,8 @@ class Pix2PixModel(BaseModel):
AtoB = self.direction == 'AtoB'
self.real_A = paddle.to_tensor(
input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(
input['B' if AtoB else 'A'])
self.real_A = paddle.to_tensor(input['A' if AtoB else 'B'])
self.real_B = paddle.to_tensor(input['B' if AtoB else 'A'])
self.image_paths = input['A_path' if AtoB else 'B_path']
......@@ -141,3 +139,7 @@ class Pix2PixModel(BaseModel):
optimizers['optimG'].clear_grad()
self.backward_G()
optimizers['optimG'].step()
def test_iter(self, metrics=None):
with paddle.no_grad():
self.forward()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -12,47 +12,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import traceback
import paddle
from ...utils.registry import Registry
import argparse
TRANSFORMS = Registry("TRANSFORMS")
def parse_args():
parser = argparse.ArgumentParser(
description='This script extracts weights from a checkpoint')
parser.add_argument('checkpoint', help='checkpoint file')
parser.add_argument('--net-name',
type=str,
help='net name in checkpoint dict')
parser.add_argument('--output', type=str, help='destination file name')
args = parser.parse_args()
return args
class Compose(object):
"""
Composes several transforms together use for composing list of transforms
together for a dataset transform.
Args:
transforms (list): List of transforms to compose.
Returns:
A compose object which is callable, __call__ for this Compose
object will call each given :attr:`transforms` sequencely.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, data):
for f in self.transforms:
try:
data = f(data)
except Exception as e:
print(f)
stack_info = traceback.format_exc()
print("fail to perform transform [{}] with error: "
"{} and stack:\n{}".format(f, e, str(stack_info)))
raise e
return data
def main():
args = parse_args()
assert args.output.endswith(".pdparams")
ckpt = paddle.load(args.checkpoint)
state_dict = ckpt[args.net_name]
paddle.save(state_dict, args.output)
def build_transforms(cfg):
transforms = []
for trans_cfg in cfg:
temp_trans_cfg = copy.deepcopy(trans_cfg)
name = temp_trans_cfg.pop('name')
transforms.append(TRANSFORMS.get(name)(**temp_trans_cfg))
transforms = Compose(transforms)
return transforms
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册