未验证 提交 520c7625 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update ernie.py

上级 26fd1533
...@@ -67,34 +67,34 @@ class Model(backbone): ...@@ -67,34 +67,34 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
returns = {"token_ids": [[-1, -1], 'int64'], ret = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64'] "task_ids": [[-1, -1], 'int64']
} }
if self._learning_strategy == 'pairwise' and self._phase=='train': if self._learning_strategy == 'pairwise' and self._phase=='train':
returns.update({"token_ids_neg": [[-1, -1], 'int64'], ret.update({"token_ids_neg": [[-1, -1], 'int64'],
"position_ids_neg": [[-1, -1], 'int64'], "position_ids_neg": [[-1, -1], 'int64'],
"segment_ids_neg": [[-1, -1], 'int64'], "segment_ids_neg": [[-1, -1], 'int64'],
"input_mask_neg": [[-1, -1, 1], 'float32'], "input_mask_neg": [[-1, -1, 1], 'float32'],
"task_ids_neg": [[-1, -1], 'int64'] "task_ids_neg": [[-1, -1], 'int64']
}) })
return returns return ret
@property @property
def outputs_attr(self): def outputs_attr(self):
returns = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'], "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']} "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
if self._learning_strategy == 'pairwise' and self._phase == 'train': if self._learning_strategy == 'pairwise' and self._phase == 'train':
returns.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'],
"encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding_neg": [[-1, self._emb_size], 'float32'], "sentence_embedding_neg": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']})
return returns return ret
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册