未验证 提交 1cd6ae0d 编写于 作者: W wangna11BD 提交者: GitHub

fix lapstyle input (#560)

* fix lapstyle input

* fix image path

* fix image path

* fix image path
上级 3e17a8c2
......@@ -24,6 +24,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--content_img_path",
type=str,
required=True,
help="path to content image")
parser.add_argument("--output_path",
......@@ -46,7 +47,7 @@ if __name__ == "__main__":
parser.add_argument("--style_image_path",
type=str,
default=None,
required=True,
help="path to style image")
parser.add_argument("--cpu",
......
......@@ -144,10 +144,17 @@ class LapStylePredictor(BasePredictor):
self.net_rev_2.eval()
def run(self, content_img_path, style_image_path):
if not self.is_image(content_img_path):
raise ValueError(
'The path of content_img does not exist or is not image')
if not self.is_image(style_image_path):
raise ValueError(
'The path of style_image does not exist or is not image')
content_img, style_img, h, w = img_read(content_img_path,
style_image_path)
content_img_visual = tensor2img(content_img, min_max=(0., 1.))
content_img_visual = cv.cvtColor(content_img_visual, cv.COLOR_RGB2BGR)
content_img_visual = cv.resize(content_img_visual, (w, h))
cv.imwrite(os.path.join(self.output, 'content.png'), content_img_visual)
style_img_visual = tensor2img(style_img, min_max=(0., 1.))
style_img_visual = cv.cvtColor(style_img_visual, cv.COLOR_RGB2BGR)
......@@ -159,20 +166,11 @@ class LapStylePredictor(BasePredictor):
cF = self.net_enc(pyr_ci[2])
sF = self.net_enc(pyr_si[2])
stylized_small = self.net_dec(cF, sF)
stylized_small_visual = tensor2img(stylized_small, min_max=(0., 1.))
stylized_small_visual = cv.cvtColor(stylized_small_visual,
cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'stylized_small.png'),
stylized_small_visual)
stylized_up = F.interpolate(stylized_small, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[1], stylized_up], axis=1)
stylized_rev_lap = self.net_rev(revnet_input)
stylized_rev = fold_laplace_pyramid([stylized_rev_lap, stylized_small])
stylized_rev_visual = tensor2img(stylized_rev, min_max=(0., 1.))
stylized_rev_visual = cv.cvtColor(stylized_rev_visual, cv.COLOR_RGB2BGR)
cv.imwrite(os.path.join(self.output, 'stylized_rev_first.png'),
stylized_rev_visual)
stylized_up = F.interpolate(stylized_rev, scale_factor=2)
revnet_input = paddle.concat(x=[pyr_ci[0], stylized_up], axis=1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册