未验证 提交 996d1e3d 编写于 作者: X Xintao 提交者: GitHub

Major revision: Support Pypi (#37)

* reorganize

* update inference

* update inference

* format
上级 77dc85b8
name: PyPI Publish
on: push
jobs:
build-n-publish:
runs-on: ubuntu-latest
if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.8
- name: Upgrade pip
run: pip install pip --upgrade
- name: Install PyTorch (cpu)
run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install dependencies
run: pip install -r requirements.txt
- name: Build and install
run: rm -rf .eggs && pip install -e .
- name: Build for distribution
# remove bdist_wheel for pip installation with compiling cuda extensions
run: python setup.py sdist
- name: Publish distribution to PyPI
uses: pypa/gh-action-pypi-publish@master
with:
password: ${{ secrets.PYPI_API_TOKEN }}
......@@ -25,5 +25,5 @@ jobs:
- name: Lint
run: |
flake8 .
isort --check-only --diff data/ archs/ models/ train.py inference_gfpgan_full.py
yapf -r -d data/ archs/ models/ train.py inference_gfpgan_full.py
isort --check-only --diff gfpgan/ scripts/ inference_gfpgan.py setup.py
yapf -r -d gfpgan/ scripts/ inference_gfpgan.py setup.py
.vscode
# ignored folders
datasets/*
experiments/*
results/*
tb_logger/*
wandb/*
tmp/*
# ignored files
version.py
# ignored files with suffix
*.html
*.png
*.jpeg
*.jpg
*.gif
*.pth
*.zip
# template
.vscode
# Byte-compiled / optimized / DLL files
__pycache__/
......@@ -39,6 +31,8 @@ parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
......@@ -57,12 +51,14 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
......@@ -74,6 +70,7 @@ coverage.xml
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
......@@ -91,11 +88,26 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
......@@ -121,3 +133,8 @@ venv.bak/
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
include assets/*
include inputs/*
include scripts/*.py
include inference_gfpgan.py
include VERSION
include LICENSE
include requirements.txt
include gfpgan/weights/README.md
......@@ -27,6 +27,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib
pip install -r requirements.txt
python setup.py develop
# remember to set BASICSR_JIT=True before your running commands
```
......@@ -45,6 +46,7 @@ If you want want to use the original model in our paper, please follow the instr
pip install facexlib
pip install -r requirements.txt
python setup.py develop
```
## :zap: Quick Inference
......@@ -58,17 +60,17 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/GFPGANv1.pth
- Option 1: Load extensions just-in-time(JIT)
```bash
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images
BASICSR_JIT=True python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
BASICSR_JIT=True python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
```
- Option 2: Have successfully compiled extensions during installation
```bash
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/whole_imgs --save_root results --arch original --channel 1
# for aligned images
python inference_gfpgan_full.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
python inference_gfpgan.py --model_path experiments/pretrained_models/GFPGANv1.pth --test_path inputs/cropped_faces --save_root results --arch original --channel 1 --aligned
```
# GFPGAN (CVPR 2021)
[![download](https://img.shields.io/github/downloads/TencentARC/GFPGAN/total.svg)](https://github.com/TencentARC/GFPGAN/releases)
[![PyPI](https://img.shields.io/pypi/v/gfpgan)](https://pypi.org/project/gfpgan/)
[![Open issue](https://isitmaintained.com/badge/open/TencentARC/GFPGAN.svg)](https://github.com/TencentARC/GFPGAN/issues)
[![LICENSE](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/TencentARC/GFPGAN/blob/master/LICENSE)
[![python lint](https://github.com/TencentARC/GFPGAN/actions/workflows/pylint.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/pylint.yml)
[![Publish-pip](https://github.com/TencentARC/GFPGAN/actions/workflows/publish-pip.yml/badge.svg)](https://github.com/TencentARC/GFPGAN/blob/master/.github/workflows/publish-pip.yml)
1. [Colab Demo](https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo) for GFPGAN <a href="https://colab.research.google.com/drive/1sVsoBd9AjckIXThgtZhGrHRfFI6UUYOo"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>; (Another [Colab Demo](https://colab.research.google.com/drive/1Oa1WwKB4M4l1GmR7CtswDVgOCOeSLChA?usp=sharing) for the original paper model)
1. We provide a *clean* version of GFPGAN, which can run without CUDA extensions. So that it can run in **Windows** or on **CPU mode**.
......@@ -59,6 +61,7 @@ If you want want to use the original model in our paper, please see [PaperModel.
pip install facexlib
pip install -r requirements.txt
python setup.py develop
```
## :zap: Quick Inference
......@@ -72,7 +75,7 @@ wget https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1
**Inference!**
```bash
python inference_gfpgan_full.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
python inference_gfpgan.py --upscale_factor 2 --test_path inputs/whole_imgs --save_root results
```
## :european_castle: Model Zoo
......@@ -90,10 +93,9 @@ You could improve it according to your own needs.
1. More high quality faces can improve the restoration quality.
2. You may need to perform some pre-processing, such as beauty makeup.
**Procedures**
(You can try a simple version ( `train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
(You can try a simple version ( `options/train_gfpgan_v1_simple.yml`) that does not require face component landmarks.)
1. Dataset preparation: [FFHQ](https://github.com/NVlabs/ffhq-dataset)
......@@ -102,11 +104,11 @@ You could improve it according to your own needs.
1. [Component locations of FFHQ: FFHQ_eye_mouth_landmarks_512.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/FFHQ_eye_mouth_landmarks_512.pth)
1. [A simple ArcFace model: arcface_resnet18.pth](https://github.com/TencentARC/GFPGAN/releases/download/v0.1.0/arcface_resnet18.pth)
1. Modify the configuration file `train_gfpgan_v1.yml` accordingly.
1. Modify the configuration file `options/train_gfpgan_v1.yml` accordingly.
1. Training
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 train.py -opt train_gfpgan_v1.yml --launcher pytorch
> python -m torch.distributed.launch --nproc_per_node=4 --master_port=22021 gfpgan/train.py -opt options/train_gfpgan_v1.yml --launcher pytorch
## :scroll: License and Acknowledgement
......
# flake8: noqa
from .archs import *
from .data import *
from .models import *
from .utils import *
from .version import __gitsha__, __version__
import importlib
from os import path as osp
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import arch modules for registry
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
# scan all the files that end with '_arch.py' under the archs folder
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
# import all the arch modules
_arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
_arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames]
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
......
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.stylegan2_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU,
StyleGAN2Generator)
from basicsr.ops.fused_act import FusedLeakyReLU
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class StyleGAN2GeneratorSFT(StyleGAN2Generator):
......
import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from basicsr.archs.arch_util import default_init_weights
from basicsr.utils.registry import ARCH_REGISTRY
from torch import nn
from torch.nn import functional as F
class NormStyleCode(nn.Module):
......
import importlib
from os import path as osp
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import dataset modules for registry
# scan all the files under the data folder with '_dataset' in file names
# scan all the files that end with '_dataset.py' under the data folder
data_folder = osp.dirname(osp.abspath(__file__))
dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
# import all the dataset modules
_dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
_dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames]
......@@ -4,14 +4,13 @@ import numpy as np
import os.path as osp
import torch
import torch.utils.data as data
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
from basicsr.data import degradations as degradations
from basicsr.data.data_util import paths_from_folder
from basicsr.data.transforms import augment
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize)
@DATASET_REGISTRY.register()
......
import importlib
from os import path as osp
from basicsr.utils import scandir
from os import path as osp
# automatically scan and import model modules for registry
# scan all the files under the 'models' folder and collect files ending with
# '_model.py'
# scan all the files that end with '_model.py' under the model folder
model_folder = osp.dirname(osp.abspath(__file__))
model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
# import all the model modules
_model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
_model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames]
import math
import os.path as osp
import torch
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.losses import r1_penalty
......@@ -13,6 +8,10 @@ from basicsr.metrics import calculate_metric
from basicsr.models.base_model import BaseModel
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
from torchvision.ops import roi_align
from tqdm import tqdm
@MODEL_REGISTRY.register()
......
# flake8: noqa
import os.path as osp
import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
from basicsr.train import train_pipeline
import gfpgan.archs
import gfpgan.data
import gfpgan.models
if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir))
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
train_pipeline(root_path)
import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torch.hub import download_url_to_file, get_dir
from torchvision.transforms.functional import normalize
from urllib.parse import urlparse
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
class GFPGANer():
def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None):
self.upscale = upscale
self.bg_upsampler = bg_upsampler
# initialize model
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# initialize the GFP-GAN
if arch == 'clean':
self.gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=False,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
self.gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=channel_multiplier,
decoder_load_path=None,
fix_decoder=True,
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
# initialize face helper
self.face_helper = FaceRestoreHelper(
upscale,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=self.device)
if model_path.startswith('https://'):
model_path = load_file_from_url(url=model_path, model_dir='gfpgan/weights', progress=True, file_name=None)
loadnet = torch.load(model_path)
if 'params_ema' in loadnet:
keyname = 'params_ema'
else:
keyname = 'params'
self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
self.gfpgan.eval()
self.gfpgan = self.gfpgan.to(self.device)
@torch.no_grad()
def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True):
self.face_helper.clean_all()
if has_aligned:
img = cv2.resize(img, (512, 512))
self.face_helper.cropped_faces = [img]
else:
self.face_helper.read_image(img)
# get face landmarks for each face
self.face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
self.face_helper.align_warp_face()
# face restoration
for cropped_face in self.face_helper.cropped_faces:
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
try:
output = self.gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
self.face_helper.add_restored_face(restored_face)
if not has_aligned and paste_back:
if self.bg_upsampler is not None:
# Now only support RealESRGAN
bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
else:
bg_img = None
self.face_helper.get_inverse_affine(None)
# paste each restored face to the input image
restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
else:
return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
"""
if model_dir is None:
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
if file_name is not None:
filename = file_name
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
if not os.path.exists(cached_file):
print(f'Downloading: "{url}" to {cached_file}\n')
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
return cached_file
# Weights
Put the downloaded weights to this folder.
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from basicsr.utils import imwrite
from gfpgan import GFPGANer
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--upscale', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--bg_upsampler', type=str, default='realesrgan')
parser.add_argument('--bg_tile', type=int, default=0)
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# background upsampler
if args.bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU
import warnings
warnings.warn('The unoptimized RealESRGAN is very slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.')
bg_upsampler = None
else:
from realesrgan import RealESRGANer
bg_upsampler = RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
tile=args.bg_tile,
tile_pad=10,
pre_pad=0,
half=True) # need to set False in CPU mode
else:
bg_upsampler = None
# set up GFPGAN restorer
restorer = GFPGANer(
model_path=args.model_path,
upscale=args.upscale,
arch=args.arch,
channel_multiplier=args.channel,
bg_upsampler=bg_upsampler)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, ext = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
cropped_faces, restored_faces, restored_img = restorer.enhance(
input_img, has_aligned=args.aligned, only_center_face=args.only_center_face, paste_back=args.paste_back)
# save faces
for idx, (cropped_face, restored_face) in enumerate(zip(cropped_faces, restored_faces)):
# save cropped face
save_crop_path = os.path.join(args.save_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
imwrite(restored_face, save_crop_path)
# save restored face
if args.suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{args.suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(args.save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(args.save_root, 'cmp', f'{basename}_{idx:02d}.png'))
# save restored img
if args.suffix is not None:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', f'{basename}_{args.suffix}{ext}')
else:
save_restore_path = os.path.join(args.save_root, 'restored_imgs', img_name)
imwrite(restored_img, save_restore_path)
print(f'Results are in the [{args.save_root}] folder.')
if __name__ == '__main__':
main()
import argparse
import cv2
import glob
import numpy as np
import os
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize
from archs.gfpganv1_arch import GFPGANv1
from archs.gfpganv1_clean_arch import GFPGANv1Clean
from basicsr.utils import img2tensor, imwrite, tensor2img
def restoration(gfpgan,
face_helper,
img_path,
save_root,
has_aligned=False,
only_center_face=True,
suffix=None,
paste_back=False,
device='cuda'):
# read image
img_name = os.path.basename(img_path)
print(f'Processing {img_name} ...')
basename, _ = os.path.splitext(img_name)
input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)
face_helper.clean_all()
if has_aligned:
input_img = cv2.resize(input_img, (512, 512))
face_helper.cropped_faces = [input_img]
else:
face_helper.read_image(input_img)
# get face landmarks for each face
face_helper.get_face_landmarks_5(only_center_face=only_center_face)
# align and warp each face
save_crop_path = os.path.join(save_root, 'cropped_faces', img_name)
face_helper.align_warp_face(save_crop_path)
# face restoration
for idx, cropped_face in enumerate(face_helper.cropped_faces):
# prepare data
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
try:
with torch.no_grad():
output = gfpgan(cropped_face_t, return_rgb=False)[0]
# convert to image
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
except RuntimeError as error:
print(f'\tFailed inference for GFPGAN: {error}.')
restored_face = cropped_face
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
if suffix is not None:
save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
else:
save_face_name = f'{basename}_{idx:02d}.png'
save_restore_path = os.path.join(save_root, 'restored_faces', save_face_name)
imwrite(restored_face, save_restore_path)
# save cmp image
cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
imwrite(cmp_img, os.path.join(save_root, 'cmp', f'{basename}_{idx:02d}.png'))
if not has_aligned and paste_back:
face_helper.get_inverse_affine(None)
save_restore_path = os.path.join(save_root, 'restored_imgs', img_name)
# paste each restored face to the input image
face_helper.paste_faces_to_input_image(save_restore_path)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser()
parser.add_argument('--upscale_factor', type=int, default=2)
parser.add_argument('--arch', type=str, default='clean')
parser.add_argument('--channel', type=int, default=2)
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/GFPGANCleanv1-NoCE-C2.pth')
parser.add_argument('--test_path', type=str, default='inputs/whole_imgs')
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces')
parser.add_argument('--only_center_face', action='store_true')
parser.add_argument('--aligned', action='store_true')
parser.add_argument('--paste_back', action='store_false')
parser.add_argument('--save_root', type=str, default='results')
args = parser.parse_args()
if args.test_path.endswith('/'):
args.test_path = args.test_path[:-1]
os.makedirs(args.save_root, exist_ok=True)
# initialize the GFP-GAN
if args.arch == 'clean':
gfpgan = GFPGANv1Clean(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=False,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
else:
gfpgan = GFPGANv1(
out_size=512,
num_style_feat=512,
channel_multiplier=args.channel,
decoder_load_path=None,
fix_decoder=True,
# for stylegan decoder
num_mlp=8,
input_is_latent=True,
different_w=True,
narrow=1,
sft_half=True)
gfpgan.load_state_dict(torch.load(args.model_path, map_location=lambda storage, loc: storage)['params_ema'])
gfpgan.to(device).eval()
# initialize face helper
face_helper = FaceRestoreHelper(
args.upscale_factor,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=device)
img_list = sorted(glob.glob(os.path.join(args.test_path, '*')))
for img_path in img_list:
restoration(
gfpgan,
face_helper,
img_path,
args.save_root,
has_aligned=args.aligned,
only_center_face=args.only_center_face,
suffix=args.suffix,
paste_back=args.paste_back,
device=device)
print(f'Results are in the [{args.save_root}] folder.')
......@@ -2,9 +2,8 @@ import cv2
import json
import numpy as np
import torch
from collections import OrderedDict
from basicsr.utils import FileClient, imfrombytes
from collections import OrderedDict
print('Load JSON metadata...')
# use the json file in FFHQ dataset
......
......@@ -16,7 +16,7 @@ split_before_expression_after_opening_paren = true
line_length = 120
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = basicsr
known_third_party = cv2,facexlib,numpy,torch,torchvision,tqdm
known_first_party = gfpgan
known_third_party = basicsr,cv2,facexlib,numpy,torch,torchvision,tqdm
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
#!/usr/bin/env python
from setuptools import find_packages, setup
import os
import subprocess
import time
version_file = 'gfpgan/version.py'
def readme():
with open('README.md', encoding='utf-8') as f:
content = f.read()
return content
def get_git_hash():
def _minimal_ext_cmd(cmd):
# construct minimal environment
env = {}
for k in ['SYSTEMROOT', 'PATH', 'HOME']:
v = os.environ.get(k)
if v is not None:
env[k] = v
# LANGUAGE is used on win32
env['LANGUAGE'] = 'C'
env['LANG'] = 'C'
env['LC_ALL'] = 'C'
out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
return out
try:
out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
sha = out.strip().decode('ascii')
except OSError:
sha = 'unknown'
return sha
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
elif os.path.exists(version_file):
try:
from facexlib.version import __version__
sha = __version__.split('+')[-1]
except ImportError:
raise ImportError('Unable to get git version')
else:
sha = 'unknown'
return sha
def write_version_py():
content = """# GENERATED VERSION FILE
# TIME: {}
__version__ = '{}'
__gitsha__ = '{}'
version_info = ({})
"""
sha = get_hash()
with open('VERSION', 'r') as f:
SHORT_VERSION = f.read().strip()
VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
with open(version_file, 'w') as f:
f.write(version_file_str)
def get_version():
with open(version_file, 'r') as f:
exec(compile(f.read(), version_file, 'exec'))
return locals()['__version__']
def get_requirements(filename='requirements.txt'):
here = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(here, filename), 'r') as f:
requires = [line.replace('\n', '') for line in f.readlines()]
return requires
if __name__ == '__main__':
write_version_py()
setup(
name='gfpgan',
version=get_version(),
description='GFPGAN aims at developing Practical Algorithms for Real-world Face Restoration',
long_description=readme(),
long_description_content_type='text/markdown',
author='Xintao Wang',
author_email='xintao.wang@outlook.com',
keywords='computer vision, pytorch, image restoration, super-resolution, face restoration, gan, gfpgan',
url='https://github.com/TencentARC/GFPGAN',
include_package_data=True,
packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
classifiers=[
'Development Status :: 4 - Beta',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
license='Apache License Version 2.0',
setup_requires=['cython', 'numpy'],
install_requires=get_requirements(),
zip_safe=False)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册