提交 b77de9f1 编写于 作者: C chenjian

fix

上级 54734432
...@@ -68,6 +68,8 @@ ...@@ -68,6 +68,8 @@
text_prompts:str, text_prompts:str,
style: Optional[str] = "油画", style: Optional[str] = "油画",
topk: Optional[int] = 10, topk: Optional[int] = 10,
ak: Optional[str] = None,
sk: Optional[str] = None,
output_dir: Optional[str] = 'ernievilg_output') output_dir: Optional[str] = 'ernievilg_output')
``` ```
...@@ -77,7 +79,9 @@ ...@@ -77,7 +79,9 @@
- text_prompts(str): 输入的语句,描述想要生成的图像的内容。 - text_prompts(str): 输入的语句,描述想要生成的图像的内容。
- style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。 - style(Optional[str]): 生成图像的风格,当前支持'油画','水彩','粉笔画','卡通','儿童画','蜡笔画'。
- topk(Optional[int]): 保存前多少张图,最多保存10张。 - topk(Optional[int]): 保存前多少张图,最多保存10张。'
- ak:(Optional[str]): 用于申请文心api使用token的ak,可不填。
- sk:(Optional[str]): 用于申请文心api使用token的sk,可不填。
- output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。 - output_dir(Optional[str]): 保存输出图像的目录,默认为"ernievilg_output"。
......
...@@ -31,6 +31,8 @@ class ErnieVilG: ...@@ -31,6 +31,8 @@ class ErnieVilG:
text_prompts, text_prompts,
style: Optional[str] = "油画", style: Optional[str] = "油画",
topk: Optional[int] = 10, topk: Optional[int] = 10,
ak: Optional[str] = None,
sk: Optional[str] = None,
output_dir: Optional[str] = 'ernievilg_output'): output_dir: Optional[str] = 'ernievilg_output'):
""" """
Create image by text prompts using ErnieVilG model. Create image by text prompts using ErnieVilG model.
...@@ -38,12 +40,16 @@ class ErnieVilG: ...@@ -38,12 +40,16 @@ class ErnieVilG:
:param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like. :param text_prompts: Phrase, sentence, or string of words and phrases describing what the image should look like.
:param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画 :param style: Image stype, currently supported 油画、水彩、粉笔画、卡通、儿童画、蜡笔画
:param topk: Top k images to save. :param topk: Top k images to save.
:param ak: ak for applying token to request wenxin api.
:param sk: sk for applying token to request wenxin api.
:output_dir: Output directory :output_dir: Output directory
""" """
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE' if ak == None:
sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs' ak = 'G26BfAOLpGIRBN5XrOV2eyPA25CE01lE'
if sk == None:
sk = 'txLZOWIjEqXYMU3lSm05ViW4p9DWGOWs'
token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token' token_host = 'https://wenxin.baidu.com/younger/portal/api/oauth/token'
response = requests.get(token_host, response = requests.get(token_host,
params={ params={
...@@ -165,6 +171,8 @@ class ErnieVilG: ...@@ -165,6 +171,8 @@ class ErnieVilG:
results = self.generate_image(text_prompts=args.text_prompts, results = self.generate_image(text_prompts=args.text_prompts,
style=args.style, style=args.style,
topk=args.topk, topk=args.topk,
ak=args.ak,
sk=args.sk,
output_dir=args.output_dir) output_dir=args.output_dir)
return results return results
...@@ -179,4 +187,6 @@ class ErnieVilG: ...@@ -179,4 +187,6 @@ class ErnieVilG:
choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'], choices=['油画', '水彩', '粉笔画', '卡通', '儿童画', '蜡笔画'],
help="绘画风格") help="绘画风格")
self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张") self.arg_input_group.add_argument('--topk', type=int, default=10, help="选取保存前多少张图,最多10张")
self.arg_input_group.add_argument('--ak', type=str, default=None, help="申请文心api使用token的ak")
self.arg_input_group.add_argument('--sk', type=str, default=None, help="申请文心api使用token的sk")
self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output') self.arg_input_group.add_argument('--output_dir', type=str, default='ernievilg_output')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册