未验证 提交 34a9fbf1 编写于 作者: R Reza Yazdani 提交者: GitHub

Fix gpt-j inference issue (#3639)

* fix gpt-j inference issue for mlp_gemm_func call

* bring back the gpt-j inference-test

* fix formatting

* fix the neox and pythia injection issue
上级 7e59ef12
......@@ -249,17 +249,21 @@ class BaseTransformerContainer(ABC):
allocate_tensor=reversed_dim)
def copy_data_to_new_module(self):
params = {
self.module.mlp.attn_nw: self.attn_nw,
self.module.mlp.attn_nb: self.attn_nb,
self.module.norm_w: self.input_nw,
self.module.norm_b: self.input_nb
}
for dst, src in params.items():
if src is None:
dst = src
params = {'attn_nw': self.attn_nw, 'attn_nb': self.attn_nb}
for key in params:
if params[key] is None:
setattr(self.module.mlp, key, None)
else:
dst.data.copy_(src.to(get_accelerator().current_device_name()))
setattr(self.module.mlp, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))
params = {'norm_w': self.input_nw, 'norm_b': self.input_nb}
for key in params:
if params[key] is None:
setattr(self.module, key, None)
else:
setattr(self.module, key,
torch.nn.parameter.Parameter(params[key].to(get_accelerator().current_device_name())))
def transpose(self):
self.transpose_attention()
......
......@@ -49,7 +49,7 @@ _gpt_models = [
"gpt2",
"distilgpt2",
"Norod78/hebrew-bad_wiki-gpt_neo-tiny",
#"EleutherAI/gpt-j-6B", # Removed as this is causing OOM errors randomly
"EleutherAI/gpt-j-6B", # bring back this model as we did not catch an error before by merging some changes! TODO: we need to fix the OOM issue later!
"bigscience/bloom-560m",
]
_opt_models = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册