未验证 提交 99e483b7 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update bert.py

上级 520c7625
...@@ -57,32 +57,32 @@ class Model(backbone): ...@@ -57,32 +57,32 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
returns = {"token_ids": [[-1, -1], 'int64'], ret = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
} }
if self._learning_strategy == 'pairwise' and self._phase=='train': if self._learning_strategy == 'pairwise' and self._phase=='train':
returns.update({"token_ids_neg": [[-1, -1], 'int64'], ret.update({"token_ids_neg": [[-1, -1], 'int64'],
"position_ids_neg": [[-1, -1], 'int64'], "position_ids_neg": [[-1, -1], 'int64'],
"segment_ids_neg": [[-1, -1], 'int64'], "segment_ids_neg": [[-1, -1], 'int64'],
"input_mask_neg": [[-1, -1, 1], 'float32'], "input_mask_neg": [[-1, -1, 1], 'float32'],
}) })
return returns return ret
@property @property
def outputs_attr(self): def outputs_attr(self):
returns = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'], "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']} "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
if self._learning_strategy == 'pairwise' and self._phase == 'train': if self._learning_strategy == 'pairwise' and self._phase == 'train':
returns.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'],
"encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding_neg": [[-1, self._emb_size], 'float32'], "sentence_embedding_neg": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']})
return returns return ret
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
src_ids = inputs['token_ids'] src_ids = inputs['token_ids']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册