From 6d5b4eb4c4e37bfbd12889cc522c82eab432ace0 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Fri, 26 Feb 2021 13:57:04 +0800 Subject: [PATCH] fix some errors in test (#179) * fix some errors in test --- docs/en_US/tutorials/face_parse.md | 2 +- docs/en_US/tutorials/psgan.md | 2 +- docs/zh_CN/tutorials/face_parse.md | 2 +- docs/zh_CN/tutorials/psgan.md | 2 +- ppgan/apps/psgan_predictor.py | 8 ++++---- ppgan/models/makeup_model.py | 7 +++++++ 6 files changed, 15 insertions(+), 8 deletions(-) diff --git a/docs/en_US/tutorials/face_parse.md b/docs/en_US/tutorials/face_parse.md index 3bf4acb..8f21b81 100644 --- a/docs/en_US/tutorials/face_parse.md +++ b/docs/en_US/tutorials/face_parse.md @@ -12,7 +12,7 @@ Runing the following command to complete the face parsing task. The output resul ``` cd applications -python face_parse.py --input_image ../docs/imgs/face.png +python tools/face_parse.py --input_image ../docs/imgs/face.png ``` **params:** diff --git a/docs/en_US/tutorials/psgan.md b/docs/en_US/tutorials/psgan.md index 74af4a6..368a212 100644 --- a/docs/en_US/tutorials/psgan.md +++ b/docs/en_US/tutorials/psgan.md @@ -20,7 +20,7 @@ python tools/psgan_infer.py \ --model_path /your/model/path \ --source_path docs/imgs/ps_source.png \ --reference_dir docs/imgs/ref \ - --evaluate-only True + --evaluate-only ``` **params:** - config-file: PSGAN network configuration file, yaml format diff --git a/docs/zh_CN/tutorials/face_parse.md b/docs/zh_CN/tutorials/face_parse.md index 24c1a62..931a76d 100644 --- a/docs/zh_CN/tutorials/face_parse.md +++ b/docs/zh_CN/tutorials/face_parse.md @@ -10,7 +10,7 @@ 运行如下命令,可以完成人脸解析任务,程序运行成功后,会在`output`文件夹生成解析后的图片文件。具体命令如下所示: ``` cd applications -python face_parse.py --input_image ../docs/imgs/face.png +python tools/face_parse.py --input_image ../docs/imgs/face.png ``` **参数:** diff --git a/docs/zh_CN/tutorials/psgan.md b/docs/zh_CN/tutorials/psgan.md index 5d55f14..fe172e8 100644 --- a/docs/zh_CN/tutorials/psgan.md +++ b/docs/zh_CN/tutorials/psgan.md @@ -20,7 +20,7 @@ python tools/psgan_infer.py \ --model_path /your/model/path \ --source_path docs/imgs/ps_source.png \ --reference_dir docs/imgs/ref \ - --evaluate-only True + --evaluate-only ``` **参数说明:** - config-file: PSGAN网络到参数配置文件,格式为yaml diff --git a/ppgan/apps/psgan_predictor.py b/ppgan/apps/psgan_predictor.py index 488a7a8..b39da8c 100644 --- a/ppgan/apps/psgan_predictor.py +++ b/ppgan/apps/psgan_predictor.py @@ -22,12 +22,12 @@ import numpy as np import paddle import paddle.vision.transforms as T +from paddle.utils.download import get_weights_path_from_url import ppgan.faceutils as futils from ppgan.utils.options import parse_args from ppgan.utils.config import get_config from ppgan.utils.setup import setup from ppgan.utils.filesystem import load -from ppgan.engine.trainer import Trainer from ppgan.models.builder import build_model from ppgan.utils.preprocess import * from .base_predictor import BasePredictor @@ -120,7 +120,7 @@ class PostProcess: class Inference: def __init__(self, config, model_path=''): - self.model = build_model(config) + self.model = build_model(config.model) self.preprocess = PreProcess(config) self.model_path = model_path @@ -154,6 +154,7 @@ class Inference: 'P_B': reference_input[2], 'consis_mask': consis_mask } + state_dicts = load(self.model_path) for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) @@ -175,8 +176,7 @@ class PSGANPredictor(BasePredictor): self.cfg = cfg self.weight_path = self.args.model_path if self.weight_path is None: - cur_path = os.path.abspath(os.path.dirname(__file__)) - self.weight_path = get_path_from_url(PS_WEIGHT_URL, cur_path) + self.weight_path = get_weights_path_from_url(PS_WEIGHT_URL) self.output_path = output_path def run(self): diff --git a/ppgan/models/makeup_model.py b/ppgan/models/makeup_model.py index 947191b..fefdf54 100644 --- a/ppgan/models/makeup_model.py +++ b/ppgan/models/makeup_model.py @@ -141,6 +141,13 @@ class MakeupModel(BaseModel): self.visual_items['fake_A'] = self.fake_A self.visual_items['rec_B'] = self.rec_B + def test(self, input): + with paddle.no_grad(): + return self.nets['netG'](input['image_A'], input['image_B'], + input['P_A'], input['P_B'], + input['consis_mask'], input['mask_A_aug'], + input['mask_B_aug']) + def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator -- GitLab