未验证 提交 60fc06c6 编写于 作者: R Reza Yazdani 提交者: GitHub

Synchronize the GPUs for the text-generation inference test (#1805)

上级 c3c8d5dd
......@@ -356,6 +356,7 @@ class InferenceEngine(Module):
input = input.to(torch.cuda.current_device())
if not input.is_contiguous():
input = input.contiguous()
dist.broadcast(input, 0)
for k in kwargs:
if torch.is_tensor(kwargs[k]):
kwargs[k] = kwargs[k].to(torch.cuda.current_device())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册