未验证 提交 2076bf23 编写于 作者: J Jeff Rasley 提交者: GitHub

get mask token from tokenizer (#2592)

上级 35eabb0a
......@@ -58,7 +58,9 @@ pipe = pipeline("fill-mask", model=args.model, framework="pt", device=args.local
if dtype == torch.half:
pipe.model.half()
br = pipe("Hello I'm a [MASK] model")
mask = pipe.tokenizer.mask_token
br = pipe(f"Hello I'm a {mask} model")
if args.deepspeed:
pipe.model = deepspeed.init_inference(pipe.model,
dtype=dtype,
......@@ -74,7 +76,7 @@ mtimes = []
for i in range(args.trials):
torch.cuda.synchronize()
start = time.time()
r = pipe("Hello I'm a [MASK] model")
r = pipe(f"Hello I'm a {mask} model")
torch.cuda.synchronize()
end = time.time()
responses.append(r)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册