未验证 提交 793c23e5 编写于 作者: C Connor Holmes 提交者: GitHub

Explicitly check for OPT activation function (#3278)

Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 145c3a75
......@@ -72,11 +72,7 @@ class HFOPTLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True, use_load_prefix=True):
super().__init__(inference,
linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU,
pre_attn_norm=True,
use_load_prefix=use_load_prefix)
super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix)
self.client_module = client_module
try:
import transformers
......@@ -84,6 +80,18 @@ class HFOPTLayerPolicy(TransformerPolicy):
except:
HFOPTLayerPolicy._orig_layer_class = None
if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config,
"activation_function"):
if TransformerPolicy.hf_model_config.activation_function == "relu":
self.mlp_act_func_type == ActivationFuncType.ReLU
elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]:
self.mlp_act_func_type == ActivationFuncType.GELU
else:
raise ValueError("Unsupported activation function: {}".format(
TransformerPolicy.hf_model_config.activation_function))
else:
self.mlp_act_func_type == ActivationFuncType.ReLU # default
def get_hidden_heads(self):
return self.client_module.self_attn.embed_dim, \
self.client_module.self_attn.num_heads, \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册