未验证 提交 ac9e026d 编写于 作者: L lijianshe02 提交者: GitHub

fix some errors in test (#179) (#180)

* fix some errors in test
上级 228373ca
......@@ -12,7 +12,7 @@ Runing the following command to complete the face parsing task. The output resul
```
cd applications
python face_parse.py --input_image ../docs/imgs/face.png
python tools/face_parse.py --input_image ../docs/imgs/face.png
```
**params:**
......
......@@ -20,7 +20,7 @@ python tools/psgan_infer.py \
--model_path /your/model/path \
--source_path docs/imgs/ps_source.png \
--reference_dir docs/imgs/ref \
--evaluate-only True
--evaluate-only
```
**params:**
- config-file: PSGAN network configuration file, yaml format
......
......@@ -10,7 +10,7 @@
运行如下命令,可以完成人脸解析任务,程序运行成功后,会在`output`文件夹生成解析后的图片文件。具体命令如下所示:
```
cd applications
python face_parse.py --input_image ../docs/imgs/face.png
python tools/face_parse.py --input_image ../docs/imgs/face.png
```
**参数:**
......
......@@ -20,7 +20,7 @@ python tools/psgan_infer.py \
--model_path /your/model/path \
--source_path docs/imgs/ps_source.png \
--reference_dir docs/imgs/ref \
--evaluate-only True
--evaluate-only
```
**参数说明:**
- config-file: PSGAN网络到参数配置文件,格式为yaml
......
......@@ -22,12 +22,12 @@ import numpy as np
import paddle
import paddle.vision.transforms as T
from paddle.utils.download import get_weights_path_from_url
import ppgan.faceutils as futils
from ppgan.utils.options import parse_args
from ppgan.utils.config import get_config
from ppgan.utils.setup import setup
from ppgan.utils.filesystem import load
from ppgan.engine.trainer import Trainer
from ppgan.models.builder import build_model
from ppgan.utils.preprocess import *
from .base_predictor import BasePredictor
......@@ -120,7 +120,7 @@ class PostProcess:
class Inference:
def __init__(self, config, model_path=''):
self.model = build_model(config)
self.model = build_model(config.model)
self.preprocess = PreProcess(config)
self.model_path = model_path
......@@ -154,6 +154,7 @@ class Inference:
'P_B': reference_input[2],
'consis_mask': consis_mask
}
state_dicts = load(self.model_path)
for net_name, net in self.model.nets.items():
net.set_state_dict(state_dicts[net_name])
......@@ -175,8 +176,7 @@ class PSGANPredictor(BasePredictor):
self.cfg = cfg
self.weight_path = self.args.model_path
if self.weight_path is None:
cur_path = os.path.abspath(os.path.dirname(__file__))
self.weight_path = get_path_from_url(PS_WEIGHT_URL, cur_path)
self.weight_path = get_weights_path_from_url(PS_WEIGHT_URL)
self.output_path = output_path
def run(self):
......
......@@ -141,6 +141,13 @@ class MakeupModel(BaseModel):
self.visual_items['fake_A'] = self.fake_A
self.visual_items['rec_B'] = self.rec_B
def test(self, input):
with paddle.no_grad():
return self.nets['netG'](input['image_A'], input['image_B'],
input['P_A'], input['P_B'],
input['consis_mask'], input['mask_A_aug'],
input['mask_B_aug'])
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册