提交 c26b7c73 编写于 作者: L Logan Adams

Whitespace and PR feedback

上级 61e6d069
......@@ -384,7 +384,7 @@ class TestMPSize(DistributedTest):
@pytest.mark.seq_inference
@pytest.mark.parametrize("model_w_task", [("EleutherAI/gpt-j-6B", "text-generation")], ids=["gpt-j"])
@pytest.mark.parametrize("model_w_task", [("gpt2", "text-generation")], ids=["gpt2"])
class TestLowCpuMemUsage(DistributedTest):
world_size = 1
......@@ -399,25 +399,16 @@ class TestLowCpuMemUsage(DistributedTest):
model, task = model_w_task
local_rank = int(os.getenv("LOCAL_RANK", "0"))
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model, low_cpu_mem_usage=True)
# We have to load these large models on CPU with pipeline because not
# enough GPU memory
pipe = pipeline(task, model=model, tokenizer=tokenizer, device=-1, framework="pt")
pipe = pipeline(task, model=model, model_kwargs={"low_cpu_mem_usage":True}, device=local_rank, framework="pt")
bs_output = pipe(query, **inf_kwargs)
pipe.model = deepspeed.init_inference(pipe.model,
mp_size=self.world_size,
dtype=dtype,
replace_method="auto",
replace_with_kernel_inject=True)
# Switch device to GPU so that input tensors are not on CPU
pipe.device = torch.device(f"cuda:{local_rank}")
ds_output = pipe(query, **inf_kwargs)
print(local_rank, "baseline", bs_output)
print(local_rank, "deepspeed", ds_output)
assert assert_fn(bs_output, ds_output)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册