提交 797c0284 编写于 作者: W wuzewu

fix version issue

上级 986b7509
......@@ -265,7 +265,7 @@ class TransformerModule(NLPBaseModule):
**kwargs)
self.max_seq_len = max_seq_len
if version_compare(paddle.__version__, '1.8.0'):
if version_compare(paddle.__version__, '1.8'):
with tmp_dir() as _dir:
input_dict, output_dict, program = self.context(
max_seq_len=max_seq_len)
......@@ -479,7 +479,7 @@ class TransformerModule(NLPBaseModule):
return self.params_layer
def forward(self, input_ids, position_ids, segment_ids, input_mask):
if version_compare(paddle.__version__, '1.8.0'):
if version_compare(paddle.__version__, '1.8'):
pooled_output, sequence_output = self.model_runner(
input_ids, position_ids, segment_ids, input_mask)
return {
......@@ -488,5 +488,5 @@ class TransformerModule(NLPBaseModule):
}
else:
raise RuntimeError(
'{} only support dynamic graph mode in paddle >= 1.8.0'.format(
'{} only support dynamic graph mode in paddle >= 1.8'.format(
self.name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册