diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index 5619377f044ed90cfbfba2886326078b38321bb8..c8a29c9015de864ba188cfa9c673cc217eb55b59 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -31,7 +31,7 @@ class ERNIE(Backbone): def __init__(self, hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \ max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \ - hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, phase='train'): + hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise=False, use_task_emb=True, phase='train'): # self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变 @@ -54,6 +54,7 @@ class ERNIE(Backbone): self._task_emb_name = "task_embedding" self._emb_dtype = "float32" self._is_pairwise = is_pairwise + self._use_task_emb = use_task_emb self._phase=phase self._param_initializer = fluid.initializer.TruncatedNormal( scale=initializer_range) @@ -85,6 +86,10 @@ class ERNIE(Backbone): task_type_vocab_size = config['task_type_vocab_size'] else: task_type_vocab_size = config['type_vocab_size'] + if 'use_task_emb' in config: + use_task_emb = config['use_task_emb'] + else: + use_task_emb = True hidden_act = config['hidden_act'] hidden_dropout_prob = config['hidden_dropout_prob'] attention_probs_dropout_prob = config['attention_probs_dropout_prob'] @@ -96,7 +101,7 @@ class ERNIE(Backbone): return cls(hidden_size, num_hidden_layers, num_attention_heads, vocab_size, \ max_position_embeddings, sent_type_vocab_size, task_type_vocab_size, \ - hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, phase=phase) + hidden_act, hidden_dropout_prob, attention_probs_dropout_prob, initializer_range, is_pairwise, use_task_emb=use_task_emb, phase=phase) @property def inputs_attr(self): @@ -180,15 +185,16 @@ class ERNIE(Backbone): emb_out = emb_out + position_emb_out emb_out = emb_out + sent_emb_out - task_emb_out = fluid.embedding( - task_ids, - size=[self._task_types, self._emb_size], - dtype=self._emb_dtype, - param_attr=fluid.ParamAttr( - name=scope_name+self._task_emb_name, - initializer=self._param_initializer)) + if self._use_task_emb: + task_emb_out = fluid.embedding( + task_ids, + size=[self._task_types, self._emb_size], + dtype=self._emb_dtype, + param_attr=fluid.ParamAttr( + name=scope_name+self._task_emb_name, + initializer=self._param_initializer)) - emb_out = emb_out + task_emb_out + emb_out = emb_out + task_emb_out emb_out = pre_process_layer( emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder') diff --git a/paddlepalm/head/base_head.py b/paddlepalm/head/base_head.py index f71f491003b58a4a563babdd33931a1b069ce61a..c9a0ee572c5afb4fa18490126f27687441eb65a1 100644 --- a/paddlepalm/head/base_head.py +++ b/paddlepalm/head/base_head.py @@ -122,11 +122,11 @@ class Head(object): output_dir: 积累结果的保存路径。 """ if output_dir is not None: - for i in self._results_buffer: - print(i) - else: if not os.path.exists(output_dir): os.makedirs(output_dir) with open(os.path.join(output_dir, self._phase), 'w') as writer: for i in self._results_buffer: writer.write(json.dumps(i)+'\n') + else: + return self._results_buffer + diff --git a/paddlepalm/head/match.py b/paddlepalm/head/match.py index 4921f6cca5f7785e17a061646dcfd38431dfb5c7..5fb828bfa7a43aa7686cad275ee93ba778954e85 100644 --- a/paddlepalm/head/match.py +++ b/paddlepalm/head/match.py @@ -159,8 +159,6 @@ class Match(Head): else: return {'probs': pos_score} - - def batch_postprocess(self, rt_outputs): if not self._is_training: probs = [] @@ -170,6 +168,10 @@ class Match(Head): if self._learning_strategy == 'pointwise': logits = rt_outputs['logits'] self._preds_logits.extend(logits.tolist()) + + def reset(self): + self._preds_logits = [] + self._preds = [] def epoch_postprocess(self, post_inputs, output_dir=None): # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index de45bcca9bc81bd652a5ff6c65999d694130c381..600951f29cf58767bf07fd74c554d2d05ea65c8c 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -587,6 +587,9 @@ class Trainer(object): results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir) return results + def reset_buffer(self): + self._pred_head.reset() + def _check_phase(self, phase): assert phase in ['train', 'predict'], "Supported phase: train, predict,"