未验证 提交 bccc7e24 编写于 作者: C chenjian 提交者: GitHub

optimize ernie_vilg module doc (#1968)

* optimize doc

* update

* fix

* add logo

* add logo

* update new style

* fix
Co-authored-by: Nwuzewu <wuzewu@baidu.com>
上级 31a5c90a
...@@ -66,13 +66,15 @@ class ErnieVilG: ...@@ -66,13 +66,15 @@ class ErnieVilG:
text_prompts, text_prompts,
style: Optional[str] = "油画", style: Optional[str] = "油画",
topk: Optional[int] = 10, topk: Optional[int] = 10,
visualization: Optional[bool] = True,
output_dir: Optional[str] = 'ernievilg_output'): output_dir: Optional[str] = 'ernievilg_output'):
""" """
Create image by text prompts using ErnieVilG model. Create image by text prompts using ErnieVilG model.
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画 :param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画、探索无限。
:param topk: Top k images to save. :param topk: Top k images to save.
:param visualization: Whether to save images or not.
:output_dir: Output directory :output_dir: Output directory
""" """
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
...@@ -186,7 +188,8 @@ class ErnieVilG: ...@@ -186,7 +188,8 @@ class ErnieVilG:
for text, data in results.items(): for text, data in results.items():
for idx, imgdata in enumerate(data['imgUrls']): for idx, imgdata in enumerate(data['imgUrls']):
image = Image.open(BytesIO(requests.get(imgdata['image']).content)) image = Image.open(BytesIO(requests.get(imgdata['image']).content))
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx))) if visualization:
image.save(os.path.join(output_dir, '{}_{}.png'.format(text, idx)))
result_images.append(image) result_images.append(image)
if idx + 1 >= topk: if idx + 1 >= topk:
break break
...@@ -212,6 +215,7 @@ class ErnieVilG: ...@@ -212,6 +215,7 @@ class ErnieVilG:
results = self.generate_image(text_prompts=args.text_prompts, results = self.generate_image(text_prompts=args.text_prompts,
style=args.style, style=args.style,
topk=args.topk, topk=args.topk,
visualization=args.visualization,
output_dir=args.output_dir) output_dir=args.output_dir)
return results return results
...@@ -237,9 +241,10 @@ class ErnieVilG: ...@@ -237,9 +241,10 @@ class ErnieVilG:
self.arg_input_group.add_argument('--style', self.arg_input_group.add_argument('--style',
type=str, type=str,
default='油画', default='油画',
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'], choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画', '探索无限'],
help="绘画风格") help="绘画风格")
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张") self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张")
self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak") self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak")
self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk") self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk")
self.arg_input_group.add_argument('--visualization', type=bool, default=True, help="是否保存生成的图片")
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册