diff --git a/docs/en_US/tutorials/face_parse.md b/docs/en_US/tutorials/face_parse.md index 3bf4acbd5d489e5a2cdce1426a4377cb44a1939d..8f21b8138cd974b32efdba8dad849f4247990675 100644 --- a/docs/en_US/tutorials/face_parse.md +++ b/docs/en_US/tutorials/face_parse.md @@ -12,7 +12,7 @@ Runing the following command to complete the face parsing task. The output resul ``` cd applications -python face_parse.py --input_image ../docs/imgs/face.png +python tools/face_parse.py --input_image ../docs/imgs/face.png ``` **params:** diff --git a/docs/en_US/tutorials/psgan.md b/docs/en_US/tutorials/psgan.md index 74af4a6bb413b4aaf03c63c0ddaf4577142d8cff..368a212b4dac1a61e2b8a6dad1c7166c431640f6 100644 --- a/docs/en_US/tutorials/psgan.md +++ b/docs/en_US/tutorials/psgan.md @@ -20,7 +20,7 @@ python tools/psgan_infer.py \ --model_path /your/model/path \ --source_path docs/imgs/ps_source.png \ --reference_dir docs/imgs/ref \ - --evaluate-only True + --evaluate-only ``` **params:** - config-file: PSGAN network configuration file, yaml format diff --git a/docs/zh_CN/tutorials/face_parse.md b/docs/zh_CN/tutorials/face_parse.md index 24c1a622f644f065cdaf12cd1ab6e6544064d44b..931a76d6548a929b381cf74daf93917137fb83ce 100644 --- a/docs/zh_CN/tutorials/face_parse.md +++ b/docs/zh_CN/tutorials/face_parse.md @@ -10,7 +10,7 @@ 运行如下命令,可以完成人脸解析任务,程序运行成功后,会在`output`文件夹生成解析后的图片文件。具体命令如下所示: ``` cd applications -python face_parse.py --input_image ../docs/imgs/face.png +python tools/face_parse.py --input_image ../docs/imgs/face.png ``` **参数:** diff --git a/docs/zh_CN/tutorials/psgan.md b/docs/zh_CN/tutorials/psgan.md index 5d55f146dc46ab9f6df9bb4cda3a4d722efe3aed..fe172e897feffa3a478c5bd3495555e9ff2fe8d3 100644 --- a/docs/zh_CN/tutorials/psgan.md +++ b/docs/zh_CN/tutorials/psgan.md @@ -20,7 +20,7 @@ python tools/psgan_infer.py \ --model_path /your/model/path \ --source_path docs/imgs/ps_source.png \ --reference_dir docs/imgs/ref \ - --evaluate-only True + --evaluate-only ``` **参数说明:** - config-file: PSGAN网络到参数配置文件,格式为yaml diff --git a/ppgan/apps/psgan_predictor.py b/ppgan/apps/psgan_predictor.py index 488a7a8b61b206dc8c9a6df9ee1bf74ddc3a7729..b39da8cb42f1a13eb8f37a03c235a626588f16a6 100644 --- a/ppgan/apps/psgan_predictor.py +++ b/ppgan/apps/psgan_predictor.py @@ -22,12 +22,12 @@ import numpy as np import paddle import paddle.vision.transforms as T +from paddle.utils.download import get_weights_path_from_url import ppgan.faceutils as futils from ppgan.utils.options import parse_args from ppgan.utils.config import get_config from ppgan.utils.setup import setup from ppgan.utils.filesystem import load -from ppgan.engine.trainer import Trainer from ppgan.models.builder import build_model from ppgan.utils.preprocess import * from .base_predictor import BasePredictor @@ -120,7 +120,7 @@ class PostProcess: class Inference: def __init__(self, config, model_path=''): - self.model = build_model(config) + self.model = build_model(config.model) self.preprocess = PreProcess(config) self.model_path = model_path @@ -154,6 +154,7 @@ class Inference: 'P_B': reference_input[2], 'consis_mask': consis_mask } + state_dicts = load(self.model_path) for net_name, net in self.model.nets.items(): net.set_state_dict(state_dicts[net_name]) @@ -175,8 +176,7 @@ class PSGANPredictor(BasePredictor): self.cfg = cfg self.weight_path = self.args.model_path if self.weight_path is None: - cur_path = os.path.abspath(os.path.dirname(__file__)) - self.weight_path = get_path_from_url(PS_WEIGHT_URL, cur_path) + self.weight_path = get_weights_path_from_url(PS_WEIGHT_URL) self.output_path = output_path def run(self): diff --git a/ppgan/models/makeup_model.py b/ppgan/models/makeup_model.py index 947191b053690115968aa768b8a5815076088f81..fefdf54df1db97eb765887850e8fa68188518d32 100644 --- a/ppgan/models/makeup_model.py +++ b/ppgan/models/makeup_model.py @@ -141,6 +141,13 @@ class MakeupModel(BaseModel): self.visual_items['fake_A'] = self.fake_A self.visual_items['rec_B'] = self.rec_B + def test(self, input): + with paddle.no_grad(): + return self.nets['netG'](input['image_A'], input['image_B'], + input['P_A'], input['P_B'], + input['consis_mask'], input['mask_A_aug'], + input['mask_B_aug']) + def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator