提交 bcc2dfe0 编写于 作者: K kinghuin 提交者: wuzewu

support roberta (#200)

上级 9dc26bac
......@@ -576,7 +576,7 @@ class Module(object):
fetch_dict[key] = program.global_block().var(var.name)
# update BERT/ERNIE's input tensor's sequence length to max_seq_len
if self.name.startswith("bert") or self.name.startswith("ernie"):
if "bert" in self.name or self.name.startswith("ernie"):
MAX_SEQ_LENGTH = 512
if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册