From 520c7625fd472be00e653379d8472da362b9eb9b Mon Sep 17 00:00:00 2001 From: Xiaoyao Xi <24541791+xixiaoyao@users.noreply.github.com> Date: Tue, 24 Dec 2019 20:19:35 +0800 Subject: [PATCH] Update ernie.py --- paddlepalm/backbone/ernie.py | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index f3f5662..26a8aae 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -67,34 +67,34 @@ class Model(backbone): @property def inputs_attr(self): - returns = {"token_ids": [[-1, -1], 'int64'], - "position_ids": [[-1, -1], 'int64'], - "segment_ids": [[-1, -1], 'int64'], - "input_mask": [[-1, -1, 1], 'float32'], - "task_ids": [[-1, -1], 'int64'] - } + ret = {"token_ids": [[-1, -1], 'int64'], + "position_ids": [[-1, -1], 'int64'], + "segment_ids": [[-1, -1], 'int64'], + "input_mask": [[-1, -1, 1], 'float32'], + "task_ids": [[-1, -1], 'int64'] + } if self._learning_strategy == 'pairwise' and self._phase=='train': - returns.update({"token_ids_neg": [[-1, -1], 'int64'], - "position_ids_neg": [[-1, -1], 'int64'], - "segment_ids_neg": [[-1, -1], 'int64'], - "input_mask_neg": [[-1, -1, 1], 'float32'], - "task_ids_neg": [[-1, -1], 'int64'] - }) - return returns + ret.update({"token_ids_neg": [[-1, -1], 'int64'], + "position_ids_neg": [[-1, -1], 'int64'], + "segment_ids_neg": [[-1, -1], 'int64'], + "input_mask_neg": [[-1, -1, 1], 'float32'], + "task_ids_neg": [[-1, -1], 'int64'] + }) + return ret @property def outputs_attr(self): - returns = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], - "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], - "encoder_outputs": [[-1, -1, self._emb_size], 'float32'], - "sentence_embedding": [[-1, self._emb_size], 'float32'], - "sentence_pair_embedding": [[-1, self._emb_size], 'float32']} + ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'], + "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], + "encoder_outputs": [[-1, -1, self._emb_size], 'float32'], + "sentence_embedding": [[-1, self._emb_size], 'float32'], + "sentence_pair_embedding": [[-1, self._emb_size], 'float32']} if self._learning_strategy == 'pairwise' and self._phase == 'train': - returns.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], - "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], - "sentence_embedding_neg": [[-1, self._emb_size], 'float32'], - "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) - return returns + ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'], + "encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'], + "sentence_embedding_neg": [[-1, self._emb_size], 'float32'], + "sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']}) + return ret def build(self, inputs, scope_name=""): -- GitLab