diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index f3f5662bbb1856eda6d401746d4bc4a00852ddcb..26a8aae3281c85911fe0dbbbe8208887808d7944 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -67,34 +67,34 @@ class Model(backbone): @property def inputs_attr(self): - returns = {"token_ids": [[-1, -1], 'int64'], - "position_ids": [[-1, -1], 'int64'], - "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "task_ids": [[-1, -1], 'int64'] - } + ret = {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "task_ids": [[-1, -1], 'int64'] + } if self._learning_strategy == 'pairwise' and self._phase=='train': - returns.update({"token_ids_neg": [[-1, -1], 'int64'], - "position_ids_neg": [[-1, -1], 'int64'], - "segment_ids_neg": [[-1, -1], 'int64'], - "input_mask_neg": [[-1, -1, 1], 'float32'], - "task_ids_neg": [[-1, -1], 'int64'] - }) - return returns + ret.update({"token_ids_neg": [[-1, -1], 'int64'], + "position_ids_neg": [[-1, -1], 'int64'], + "segment_ids_neg": [[-1, -1], 'int64'], + "input_mask_neg": [[-1, -1, 1], 'float32'], + "task_ids_neg": [[-1, -1], 'int64'] + }) + return ret @property def outputs_attr(self): - returns = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], - "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], - "encoder_outputs": [[-1, -1, self._emb_size], 'float32'], - "sentence_embedding": [[-1, self._emb_size], 'float32'], - "sentence_pair_embedding": [[-1, self._emb_size], 'float32']} + ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], + "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], + "encoder_outputs": [[-1, -1, self._emb_size], 'float32'], + "sentence_embedding": [[-1, self._emb_size], 'float32'], + "sentence_pair_embedding": [[-1, self._emb_size], 'float32']} if self._learning_strategy == 'pairwise' and self._phase == 'train': - returns.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], - "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], - "sentence_embedding_neg": [[-1, self._emb_size], 'float32'], - "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) - return returns + ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], + "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], + "sentence_embedding_neg": [[-1, self._emb_size], 'float32'], + "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) + return ret def build(self, inputs, scope_name=""):