提交 797c0284 编写于 作者: W wuzewu

fix version issue

上级 986b7509
...@@ -265,7 +265,7 @@ class TransformerModule(NLPBaseModule): ...@@ -265,7 +265,7 @@ class TransformerModule(NLPBaseModule):
**kwargs) **kwargs)
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
if version_compare(paddle.__version__, '1.8.0'): if version_compare(paddle.__version__, '1.8'):
with tmp_dir() as _dir: with tmp_dir() as _dir:
input_dict, output_dict, program = self.context( input_dict, output_dict, program = self.context(
max_seq_len=max_seq_len) max_seq_len=max_seq_len)
...@@ -479,7 +479,7 @@ class TransformerModule(NLPBaseModule): ...@@ -479,7 +479,7 @@ class TransformerModule(NLPBaseModule):
return self.params_layer return self.params_layer
def forward(self, input_ids, position_ids, segment_ids, input_mask): def forward(self, input_ids, position_ids, segment_ids, input_mask):
if version_compare(paddle.__version__, '1.8.0'): if version_compare(paddle.__version__, '1.8'):
pooled_output, sequence_output = self.model_runner( pooled_output, sequence_output = self.model_runner(
input_ids, position_ids, segment_ids, input_mask) input_ids, position_ids, segment_ids, input_mask)
return { return {
...@@ -488,5 +488,5 @@ class TransformerModule(NLPBaseModule): ...@@ -488,5 +488,5 @@ class TransformerModule(NLPBaseModule):
} }
else: else:
raise RuntimeError( raise RuntimeError(
'{} only support dynamic graph mode in paddle >= 1.8.0'.format( '{} only support dynamic graph mode in paddle >= 1.8'.format(
self.name)) self.name))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册