From 1cd6ae0de5d6c4bc4580a299f7867d468ed273e6 Mon Sep 17 00:00:00 2001 From: wangna11BD <79366697+wangna11BD@users.noreply.github.com> Date: Mon, 17 Jan 2022 10:52:24 +0800 Subject: [PATCH] fix lapstyle input (#560) * fix lapstyle input * fix image path * fix image path * fix image path --- applications/tools/lapstyle.py | 3 ++- ppgan/apps/lapstyle_predictor.py | 16 +++++++--------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/applications/tools/lapstyle.py b/applications/tools/lapstyle.py index 2b1f6c8..df24d2c 100644 --- a/applications/tools/lapstyle.py +++ b/applications/tools/lapstyle.py @@ -24,6 +24,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--content_img_path", type=str, + required=True, help="path to content image") parser.add_argument("--output_path", @@ -46,7 +47,7 @@ if __name__ == "__main__": parser.add_argument("--style_image_path", type=str, - default=None, + required=True, help="path to style image") parser.add_argument("--cpu", diff --git a/ppgan/apps/lapstyle_predictor.py b/ppgan/apps/lapstyle_predictor.py index 36ae6f0..a590ee7 100644 --- a/ppgan/apps/lapstyle_predictor.py +++ b/ppgan/apps/lapstyle_predictor.py @@ -144,10 +144,17 @@ class LapStylePredictor(BasePredictor): self.net_rev_2.eval() def run(self, content_img_path, style_image_path): + if not self.is_image(content_img_path): + raise ValueError( + 'The path of content_img does not exist or is not image') + if not self.is_image(style_image_path): + raise ValueError( + 'The path of style_image does not exist or is not image') content_img, style_img, h, w = img_read(content_img_path, style_image_path) content_img_visual = tensor2img(content_img, min_max=(0., 1.)) content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR) + content_img_visual = cv.resize(content_img_visual, (w, h)) cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual) style_img_visual = tensor2img(style_img, min_max=(0., 1.)) style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR) @@ -159,20 +166,11 @@ class LapStylePredictor(BasePredictor): cF = self.net_enc(pyr_ci[2]) sF = self.net_enc(pyr_si[2]) stylized_small = self.net_dec(cF, sF) - stylized_small_visual = tensor2img(stylized_small, min_max=(0., 1.)) - stylized_small_visual = cv.cvtColor(stylized_small_visual, - cv.COLOR_RGB2BGR) - cv.imwrite(os.path.join(self.output, 'stylized_small.png'), - stylized_small_visual) stylized_up = F.interpolate(stylized_small, scale_factor=2) revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1) stylized_rev_lap = self.net_rev(revnet_input) stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small]) - stylized_rev_visual = tensor2img(stylized_rev, min_max=(0., 1.)) - stylized_rev_visual = cv.cvtColor(stylized_rev_visual, cv.COLOR_RGB2BGR) - cv.imwrite(os.path.join(self.output, 'stylized_rev_first.png'), - stylized_rev_visual) stylized_up = F.interpolate(stylized_rev, scale_factor=2) revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1) -- GitLab