未验证 提交 2c6e8194 编写于 作者: R Reza Yazdani 提交者: GitHub

Fix Checkpoint-loading with Meta-tensor (#2781)

* Reset KV-cache at the beginning of text-generation

* Pass the ckpt-loading arguments to work with meta-tensor

* remove unrelated changes
上级 c5b983e9
......@@ -37,7 +37,10 @@ class BLOOMLayerPolicy(TransformerPolicy):
inference=True,
use_load_prefix=True,
split_qkv=False):
super().__init__(inference, linear_layer=True)
super().__init__(inference,
linear_layer=True,
use_load_prefix=use_load_prefix,
split_qkv=split_qkv)
self.client_module = client_module
try:
import transformers
......
......@@ -32,7 +32,7 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
version = 0
def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False):
super().__init__(inference, megatron_v2=megatron_v2)
super().__init__(inference, megatron_v2=megatron_v2, split_qkv=split_qkv)
self.client_module = client_module
if GPTNEOXLayerPolicy._orig_layer_class is None:
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
......
......@@ -27,7 +27,8 @@ class HFOPTLayerPolicy(TransformerPolicy):
super().__init__(inference,
linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU,
pre_attn_norm=True)
pre_attn_norm=True,
use_load_prefix=use_load_prefix)
self.client_module = client_module
try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册