未验证 提交 de37fd75 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #42 from wangxiao1021/AddPairwise

add pairwise L2R
...@@ -44,6 +44,11 @@ class Model(backbone): ...@@ -44,6 +44,11 @@ class Model(backbone):
self._word_emb_name = "word_embedding" self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding" self._pos_emb_name = "pos_embedding"
self._sent_emb_name = "sent_embedding" self._sent_emb_name = "sent_embedding"
self._phase = phase
if 'learning_strategy' in config:
self._learning_strategy = config['learning_strategy']
else:
self._learning_strategy = 'pointwise'
# Initialize all weigths by truncated normal initializer, and all biases # Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default. # will be initialized by constant zero by default.
...@@ -52,18 +57,32 @@ class Model(backbone): ...@@ -52,18 +57,32 @@ class Model(backbone):
@property @property
def inputs_attr(self): def inputs_attr(self):
return {"token_ids": [[-1, -1], 'int64'], ret = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']} "input_mask": [[-1, -1, 1], 'float32'],
}
if self._learning_strategy == 'pairwise' and self._phase=='train':
ret.update({"token_ids_neg": [[-1, -1], 'int64'],
"position_ids_neg": [[-1, -1], 'int64'],
"segment_ids_neg": [[-1, -1], 'int64'],
"input_mask_neg": [[-1, -1, 1], 'float32'],
})
return ret
@property @property
def outputs_attr(self): def outputs_attr(self):
return {"word_embedding": [[-1, -1, self._emb_size], 'float32'], ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'], "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']} "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
if self._learning_strategy == 'pairwise' and self._phase == 'train':
ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'],
"encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding_neg": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']})
return ret
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
src_ids = inputs['token_ids'] src_ids = inputs['token_ids']
...@@ -72,83 +91,111 @@ class Model(backbone): ...@@ -72,83 +91,111 @@ class Model(backbone):
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
self._emb_dtype = 'float32' self._emb_dtype = 'float32'
# padding id in vocabulary must be set to 0
emb_out = fluid.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
position_emb_out = fluid.embedding(
input=pos_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._sent_emb_name, initializer=self._param_initializer))
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder')
self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
enc_out = encoder(
enc_input=emb_out,
attn_bias=n_head_self_attn_mask,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._emb_size * 4,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
hidden_act=self._hidden_act,
preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name=scope_name+'encoder')
input_buffer = {}
output_buffer = {}
input_buffer['base'] = [src_ids, pos_ids, sent_ids, input_mask]
output_buffer['base'] = {}
if self._learning_strategy == 'pairwise' and self._phase =='train':
src_ids = inputs['token_ids_neg']
pos_ids = inputs['position_ids_neg']
sent_ids = inputs['segment_ids_neg']
input_mask = inputs['input_mask_neg']
input_buffer['neg'] = [src_ids, pos_ids, sent_ids, input_mask]
output_buffer['neg'] = {}
next_sent_feat = fluid.layers.slice( for key, (src_ids, pos_ids, sent_ids, input_mask) in input_buffer.items():
input=enc_out, axes=[1], starts=[0], ends=[1]) # padding id in vocabulary must be set to 0
next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]]) emb_out = fluid.embedding(
next_sent_feat = fluid.layers.fc( input=src_ids,
input=next_sent_feat, size=[self._voc_size, self._emb_size],
size=self._emb_size, dtype=self._emb_dtype,
act="tanh", param_attr=fluid.ParamAttr(
param_attr=fluid.ParamAttr( name=scope_name+self._word_emb_name, initializer=self._param_initializer),
name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer), is_sparse=False)
bias_attr=scope_name+"pooled_fc.b_0")
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
return {'embedding_table': embedding_table, embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
'word_embedding': emb_out,
'encoder_outputs': enc_out, position_emb_out = fluid.embedding(
'sentence_embedding': next_sent_feat, input=pos_ids,
'sentence_pair_embedding': next_sent_feat} size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
sent_emb_out = fluid.embedding(
sent_ids,
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._sent_emb_name, initializer=self._param_initializer))
emb_out = emb_out + position_emb_out
emb_out = emb_out + sent_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder')
self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
enc_out = encoder(
enc_input=emb_out,
attn_bias=n_head_self_attn_mask,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._emb_size * 4,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
hidden_act=self._hidden_act,
preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name=scope_name+'encoder')
next_sent_feat = fluid.layers.slice(
input=enc_out, axes=[1], starts=[0], ends=[1])
next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]])
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer),
bias_attr=scope_name+"pooled_fc.b_0")
output_buffer[key]['word_embedding'] = emb_out
output_buffer[key]['encoder_outputs'] = enc_out
output_buffer[key]['sentence_embedding'] = next_sent_feat
output_buffer[key]['sentence_pair_embedding'] = next_sent_feat
ret = {}
ret['embedding_table'] = embedding_table
ret['word_embedding'] = output_buffer['base']['word_embedding']
ret['encoder_outputs'] = output_buffer['base']['encoder_outputs']
ret['sentence_embedding'] = output_buffer['base']['sentence_embedding']
ret['sentence_pair_embedding'] = output_buffer['base']['sentence_pair_embedding']
if self._learning_strategy == 'pairwise' and self._phase == 'train':
ret['word_embedding_neg'] = output_buffer['neg']['word_embedding']
ret['encoder_outputs_neg'] = output_buffer['neg']['encoder_outputs']
ret['sentence_embedding_neg'] = output_buffer['neg']['sentence_embedding']
ret['sentence_pair_embedding_neg'] = output_buffer['neg']['sentence_pair_embedding']
return ret
def postprocess(self, rt_outputs): def postprocess(self, rt_outputs):
pass pass
......
...@@ -40,6 +40,10 @@ class Model(backbone): ...@@ -40,6 +40,10 @@ class Model(backbone):
self._n_head = config['num_attention_heads'] self._n_head = config['num_attention_heads']
self._voc_size = config['vocab_size'] self._voc_size = config['vocab_size']
self._max_position_seq_len = config['max_position_embeddings'] self._max_position_seq_len = config['max_position_embeddings']
if 'learning_strategy' in config:
self._learning_strategy = config['learning_strategy']
else:
self._learning_strategy = 'pointwise'
if config['sent_type_vocab_size']: if config['sent_type_vocab_size']:
self._sent_types = config['sent_type_vocab_size'] self._sent_types = config['sent_type_vocab_size']
else: else:
...@@ -56,25 +60,41 @@ class Model(backbone): ...@@ -56,25 +60,41 @@ class Model(backbone):
self._sent_emb_name = "sent_embedding" self._sent_emb_name = "sent_embedding"
self._task_emb_name = "task_embedding" self._task_emb_name = "task_embedding"
self._emb_dtype = "float32" self._emb_dtype = "float32"
self._phase = phase
self._param_initializer = fluid.initializer.TruncatedNormal( self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range']) scale=config['initializer_range'])
@property @property
def inputs_attr(self): def inputs_attr(self):
return {"token_ids": [[-1, -1], 'int64'], ret = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'], "position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'], "segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'], "input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1,-1], 'int64']} "task_ids": [[-1, -1], 'int64']
}
if self._learning_strategy == 'pairwise' and self._phase=='train':
ret.update({"token_ids_neg": [[-1, -1], 'int64'],
"position_ids_neg": [[-1, -1], 'int64'],
"segment_ids_neg": [[-1, -1], 'int64'],
"input_mask_neg": [[-1, -1, 1], 'float32'],
"task_ids_neg": [[-1, -1], 'int64']
})
return ret
@property @property
def outputs_attr(self): def outputs_attr(self):
return {"word_embedding": [[-1, -1, self._emb_size], 'float32'], ret = {"word_embedding": [[-1, -1, self._emb_size], 'float32'],
"embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'], "embedding_table": [[-1, self._voc_size, self._emb_size], 'float32'],
"encoder_outputs": [[-1, -1, self._emb_size], 'float32'], "encoder_outputs": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding": [[-1, self._emb_size], 'float32'], "sentence_embedding": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding": [[-1, self._emb_size], 'float32']} "sentence_pair_embedding": [[-1, self._emb_size], 'float32']}
if self._learning_strategy == 'pairwise' and self._phase == 'train':
ret.update({"word_embedding_neg": [[-1, -1, self._emb_size], 'float32'],
"encoder_outputs_neg": [[-1, -1, self._emb_size], 'float32'],
"sentence_embedding_neg": [[-1, self._emb_size], 'float32'],
"sentence_pair_embedding_neg": [[-1, self._emb_size], 'float32']})
return ret
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
...@@ -84,92 +104,120 @@ class Model(backbone): ...@@ -84,92 +104,120 @@ class Model(backbone):
input_mask = inputs['input_mask'] input_mask = inputs['input_mask']
task_ids = inputs['task_ids'] task_ids = inputs['task_ids']
# padding id in vocabulary must be set to 0 input_buffer = {}
emb_out = fluid.embedding( output_buffer = {}
input=src_ids, input_buffer['base'] = [src_ids, pos_ids, sent_ids, input_mask, task_ids]
size=[self._voc_size, self._emb_size], output_buffer['base'] = {}
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr( if self._learning_strategy == 'pairwise' and self._phase =='train':
name=scope_name+self._word_emb_name, initializer=self._param_initializer), src_ids = inputs['token_ids_neg']
is_sparse=False) pos_ids = inputs['position_ids_neg']
sent_ids = inputs['segment_ids_neg']
# fluid.global_scope().find_var('backbone-word_embedding').get_tensor() input_mask = inputs['input_mask_neg']
embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name) task_ids = inputs['task_ids_neg']
input_buffer['neg'] = [src_ids, pos_ids, sent_ids, input_mask, task_ids]
output_buffer['neg'] = {}
for key, (src_ids, pos_ids, sent_ids, input_mask, task_ids) in input_buffer.items():
# padding id in vocabulary must be set to 0
emb_out = fluid.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
name=scope_name+self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
position_emb_out = fluid.embedding( # fluid.global_scope().find_var('backbone-word_embedding').get_tensor()
input=pos_ids, embedding_table = fluid.default_main_program().global_block().var(scope_name+self._word_emb_name)
size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype, position_emb_out = fluid.embedding(
param_attr=fluid.ParamAttr( input=pos_ids,
name=scope_name+self._pos_emb_name, initializer=self._param_initializer)) size=[self._max_position_seq_len, self._emb_size],
dtype=self._emb_dtype,
sent_emb_out = fluid.embedding( param_attr=fluid.ParamAttr(
sent_ids, name=scope_name+self._pos_emb_name, initializer=self._param_initializer))
size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype, sent_emb_out = fluid.embedding(
param_attr=fluid.ParamAttr( sent_ids,
name=scope_name+self._sent_emb_name, initializer=self._param_initializer)) size=[self._sent_types, self._emb_size],
dtype=self._emb_dtype,
emb_out = emb_out + position_emb_out param_attr=fluid.ParamAttr(
emb_out = emb_out + sent_emb_out name=scope_name+self._sent_emb_name, initializer=self._param_initializer))
task_emb_out = fluid.embedding( emb_out = emb_out + position_emb_out
task_ids, emb_out = emb_out + sent_emb_out
size=[self._task_types, self._emb_size],
dtype=self._emb_dtype, task_emb_out = fluid.embedding(
param_attr=fluid.ParamAttr( task_ids,
name=scope_name+self._task_emb_name, size=[self._task_types, self._emb_size],
initializer=self._param_initializer)) dtype=self._emb_dtype,
param_attr=fluid.ParamAttr(
emb_out = emb_out + task_emb_out name=scope_name+self._task_emb_name,
initializer=self._param_initializer))
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder') emb_out = emb_out + task_emb_out
self_attn_mask = fluid.layers.matmul( emb_out = pre_process_layer(
x=input_mask, y=input_mask, transpose_y=True) emb_out, 'nd', self._prepostprocess_dropout, name=scope_name+'pre_encoder')
self_attn_mask = fluid.layers.scale( self_attn_mask = fluid.layers.matmul(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False) x=input_mask, y=input_mask, transpose_y=True)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1) self_attn_mask = fluid.layers.scale(
n_head_self_attn_mask.stop_gradient = True x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
enc_out = encoder( x=[self_attn_mask] * self._n_head, axis=1)
enc_input=emb_out, n_head_self_attn_mask.stop_gradient = True
attn_bias=n_head_self_attn_mask,
n_layer=self._n_layer, enc_out = encoder(
n_head=self._n_head, enc_input=emb_out,
d_key=self._emb_size // self._n_head, attn_bias=n_head_self_attn_mask,
d_value=self._emb_size // self._n_head, n_layer=self._n_layer,
d_model=self._emb_size, n_head=self._n_head,
d_inner_hid=self._emb_size * 4, d_key=self._emb_size // self._n_head,
prepostprocess_dropout=self._prepostprocess_dropout, d_value=self._emb_size // self._n_head,
attention_dropout=self._attention_dropout, d_model=self._emb_size,
relu_dropout=0, d_inner_hid=self._emb_size * 4,
hidden_act=self._hidden_act, prepostprocess_dropout=self._prepostprocess_dropout,
preprocess_cmd="", attention_dropout=self._attention_dropout,
postprocess_cmd="dan", relu_dropout=0,
param_initializer=self._param_initializer, hidden_act=self._hidden_act,
name=scope_name+'encoder') preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name=scope_name+'encoder')
next_sent_feat = fluid.layers.slice(
input=enc_out, axes=[1], starts=[0], ends=[1])
next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]])
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer),
bias_attr=scope_name+"pooled_fc.b_0")
output_buffer[key]['word_embedding'] = emb_out
output_buffer[key]['encoder_outputs'] = enc_out
output_buffer[key]['sentence_embedding'] = next_sent_feat
output_buffer[key]['sentence_pair_embedding'] = next_sent_feat
ret = {}
ret['embedding_table'] = embedding_table
ret['word_embedding'] = output_buffer['base']['word_embedding']
ret['encoder_outputs'] = output_buffer['base']['encoder_outputs']
ret['sentence_embedding'] = output_buffer['base']['sentence_embedding']
ret['sentence_pair_embedding'] = output_buffer['base']['sentence_pair_embedding']
if self._learning_strategy == 'pairwise' and self._phase == 'train':
ret['word_embedding_neg'] = output_buffer['neg']['word_embedding']
ret['encoder_outputs_neg'] = output_buffer['neg']['encoder_outputs']
ret['sentence_embedding_neg'] = output_buffer['neg']['sentence_embedding']
ret['sentence_pair_embedding_neg'] = output_buffer['neg']['sentence_pair_embedding']
next_sent_feat = fluid.layers.slice( return ret
input=enc_out, axes=[1], starts=[0], ends=[1])
next_sent_feat = fluid.layers.reshape(next_sent_feat, [-1, next_sent_feat.shape[-1]])
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name=scope_name+"pooled_fc.w_0", initializer=self._param_initializer),
bias_attr=scope_name+"pooled_fc.b_0")
return {'embedding_table': embedding_table,
'word_embedding': emb_out,
'encoder_outputs': enc_out,
'sentence_embedding': next_sent_feat,
'sentence_pair_embedding': next_sent_feat}
def postprocess(self, rt_outputs): def postprocess(self, rt_outputs):
pass pass
...@@ -397,6 +397,7 @@ class Controller(object): ...@@ -397,6 +397,7 @@ class Controller(object):
iterators = [] iterators = []
prefixes = [] prefixes = []
mrs = [] mrs = []
for inst in instances: for inst in instances:
iterators.append(inst.reader['train'].iterator()) iterators.append(inst.reader['train'].iterator())
prefixes.append(inst.name) prefixes.append(inst.name)
...@@ -415,8 +416,9 @@ class Controller(object): ...@@ -415,8 +416,9 @@ class Controller(object):
train_prog = fluid.default_main_program() train_prog = fluid.default_main_program()
train_init_prog = fluid.default_startup_program() train_init_prog = fluid.default_startup_program()
bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_')
assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys())
pred_prog = fluid.Program() pred_prog = fluid.Program()
pred_init_prog = fluid.Program() pred_init_prog = fluid.Program()
...@@ -432,18 +434,17 @@ class Controller(object): ...@@ -432,18 +434,17 @@ class Controller(object):
task_inputs = {'backbone': bb_output_vars} task_inputs = {'backbone': bb_output_vars}
task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) task_inputs_from_reader = _decode_inputs(net_inputs, inst.name)
task_inputs['reader'] = task_inputs_from_reader task_inputs['reader'] = task_inputs_from_reader
scope = inst.task_reuse_scope + '/' scope = inst.task_reuse_scope + '/'
with fluid.unique_name.guard(scope): with fluid.unique_name.guard(scope):
output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope)
output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()}
old = len(task_output_vars) # for debug old = len(task_output_vars) # for debug
task_output_vars.update(output_vars) task_output_vars.update(output_vars)
assert len(task_output_vars) - old == len(output_vars) # for debug assert len(task_output_vars) - old == len(output_vars) # for debug
# prepare predict vars for saving inference model # prepare predict vars for saving inference model
if inst.is_target: if inst.is_target:
with fluid.program_guard(pred_prog, pred_init_prog): with fluid.program_guard(pred_prog, pred_init_prog):
cur_inputs = _decode_inputs(pred_net_inputs, inst.name) cur_inputs = _decode_inputs(pred_net_inputs, inst.name)
inst.pred_input = cur_inputs inst.pred_input = cur_inputs
...@@ -720,7 +721,6 @@ class Controller(object): ...@@ -720,7 +721,6 @@ class Controller(object):
for feed in inst.reader['pred'].iterator(): for feed in inst.reader['pred'].iterator():
feed = _encode_inputs(feed, inst.name, cand_set=mapper) feed = _encode_inputs(feed, inst.name, cand_set=mapper)
feed = {mapper[k]: v for k,v in feed.items()} feed = {mapper[k]: v for k,v in feed.items()}
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred') inst.postprocess(rt_outputs, phase='pred')
......
...@@ -85,9 +85,8 @@ class Reader(reader): ...@@ -85,9 +85,8 @@ class Reader(reader):
def list_to_dict(x): def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask',
'label_ids', 'unique_ids'] 'label_ids']
outputs = {n: i for n,i in zip(names, x)} outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training: if not self._is_training:
del outputs['label_ids'] del outputs['label_ids']
return outputs return outputs
......
...@@ -22,15 +22,23 @@ class Reader(reader): ...@@ -22,15 +22,23 @@ class Reader(reader):
""" """
Args: Args:
phase: train, eval, pred phase: train, eval, pred
""" """
self._is_training = phase == 'train' self._is_training = phase == 'train'
if 'learning_strategy' in config:
self._learning_strategy = config['learning_strategy']
else:
self._learning_strategy = 'pointwise'
reader = ClassifyReader(config['vocab_path'], reader = ClassifyReader(config['vocab_path'],
max_seq_len=config['max_seq_len'], max_seq_len=config['max_seq_len'],
do_lower_case=config.get('do_lower_case', True), do_lower_case=config.get('do_lower_case', True),
for_cn=config.get('for_cn', False), for_cn=config.get('for_cn', False),
random_seed=config.get('seed', None)) random_seed=config.get('seed', None),
learning_strategy=self._learning_strategy,
phase=phase
)
self._reader = reader self._reader = reader
self._dev_count = dev_count self._dev_count = dev_count
...@@ -59,22 +67,24 @@ class Reader(reader): ...@@ -59,22 +67,24 @@ class Reader(reader):
@property @property
def outputs_attr(self): def outputs_attr(self):
returns = {"token_ids": [[-1, -1], 'int64'],
"position_ids": [[-1, -1], 'int64'],
"segment_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32'],
"task_ids": [[-1, -1], 'int64']
}
if self._is_training: if self._is_training:
return {"token_ids": [[-1, -1], 'int64'], if self._learning_strategy == 'pointwise':
"position_ids": [[-1, -1], 'int64'], returns.update({"label_ids": [[-1], 'int64']})
"segment_ids": [[-1, -1], 'int64'], elif self._learning_strategy == 'pairwise':
"input_mask": [[-1, -1, 1], 'float32'], returns.update({"token_ids_neg": [[-1, -1], 'int64'],
"label_ids": [[-1], 'int64'], "position_ids_neg": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'] "segment_ids_neg": [[-1, -1], 'int64'],
} "input_mask_neg": [[-1, -1, 1], 'float32'],
else: "task_ids_neg": [[-1, -1], 'int64']
return {"token_ids": [[-1, -1], 'int64'], })
"position_ids": [[-1, -1], 'int64'], return returns
"segment_ids": [[-1, -1], 'int64'],
"task_ids": [[-1, -1], 'int64'],
"input_mask": [[-1, -1, 1], 'float32']
}
def load_data(self): def load_data(self):
self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase) self._data_generator = self._reader.data_generator(self._input_file, self._batch_size, self._num_epochs, dev_count=self._dev_count, shuffle=self._shuffle, phase=self._phase)
...@@ -82,17 +92,20 @@ class Reader(reader): ...@@ -82,17 +92,20 @@ class Reader(reader):
def iterator(self): def iterator(self):
def list_to_dict(x): def list_to_dict(x):
names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask', names = ['token_ids', 'segment_ids', 'position_ids', 'task_ids', 'input_mask']
'label_ids', 'unique_ids'] if self._is_training:
if self._learning_strategy == 'pairwise':
names += ['token_ids_neg', 'segment_ids_neg', 'position_ids_neg', 'task_ids_neg', 'input_mask_neg']
elif self._learning_strategy == 'pointwise':
names += ['label_ids']
outputs = {n: i for n,i in zip(names, x)} outputs = {n: i for n,i in zip(names, x)}
del outputs['unique_ids']
if not self._is_training:
del outputs['label_ids']
return outputs return outputs
for batch in self._data_generator(): for batch in self._data_generator():
yield list_to_dict(batch) yield list_to_dict(batch)
@property @property
def num_examples(self): def num_examples(self):
return self._reader.get_num_examples(phase=self._phase) return self._reader.get_num_examples(phase=self._phase)
......
...@@ -57,8 +57,10 @@ class BaseReader(object): ...@@ -57,8 +57,10 @@ class BaseReader(object):
do_lower_case=True, do_lower_case=True,
in_tokens=False, in_tokens=False,
is_inference=False, is_inference=False,
learning_strategy='pointwise',
random_seed=None, random_seed=None,
tokenizer="FullTokenizer", tokenizer="FullTokenizer",
phase='train',
is_classify=True, is_classify=True,
is_regression=False, is_regression=False,
for_cn=True, for_cn=True,
...@@ -72,7 +74,9 @@ class BaseReader(object): ...@@ -72,7 +74,9 @@ class BaseReader(object):
self.sep_id = self.vocab["[SEP]"] self.sep_id = self.vocab["[SEP]"]
self.mask_id = self.vocab["[MASK]"] self.mask_id = self.vocab["[MASK]"]
self.in_tokens = in_tokens self.in_tokens = in_tokens
self.phase = phase
self.is_inference = is_inference self.is_inference = is_inference
self.learning_strategy = learning_strategy
self.for_cn = for_cn self.for_cn = for_cn
self.task_id = task_id self.task_id = task_id
...@@ -124,6 +128,7 @@ class BaseReader(object): ...@@ -124,6 +128,7 @@ class BaseReader(object):
tokens_a.pop() tokens_a.pop()
else: else:
tokens_b.pop() tokens_b.pop()
def _convert_example_to_record(self, example, max_seq_length, tokenizer): def _convert_example_to_record(self, example, max_seq_length, tokenizer):
"""Converts a single `Example` into a single `Record`.""" """Converts a single `Example` into a single `Record`."""
...@@ -131,26 +136,33 @@ class BaseReader(object): ...@@ -131,26 +136,33 @@ class BaseReader(object):
text_a = tokenization.convert_to_unicode(example.text_a) text_a = tokenization.convert_to_unicode(example.text_a)
tokens_a = tokenizer.tokenize(text_a) tokens_a = tokenizer.tokenize(text_a)
tokens_b = None tokens_b = None
has_text_b = False has_text_b = False
has_text_b_neg = False
if isinstance(example, dict): if isinstance(example, dict):
has_text_b = "text_b" in example.keys() has_text_b = "text_b" in example.keys()
has_text_b_neg = "text_b_neg" in example.keys()
else: else:
has_text_b = "text_b" in example._fields has_text_b = "text_b" in example._fields
has_text_b_neg = "text_b_neg" in example._fields
if has_text_b: if has_text_b:
text_b = tokenization.convert_to_unicode(example.text_b) text_b = tokenization.convert_to_unicode(example.text_b)
tokens_b = tokenizer.tokenize(text_b) tokens_b = tokenizer.tokenize(text_b)
if tokens_b:
# Modifies `tokens_a` and `tokens_b` in place so that the total # Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length. # length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3" # Account for [CLS], [SEP], [SEP] with "- 3"
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
if has_text_b_neg and self.phase == 'train':
tokens_a_neg = tokenizer.tokenize(text_a)
text_b_neg = tokenization.convert_to_unicode(example.text_b_neg)
tokens_b_neg = tokenizer.tokenize(text_b_neg)
self._truncate_seq_pair(tokens_a_neg, tokens_b_neg, max_seq_length - 3)
else: else:
# Account for [CLS] and [SEP] with "- 2" # Account for [CLS] and [SEP] with "- 2"
if len(tokens_a) > max_seq_length - 2: if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)] tokens_a = tokens_a[0:(max_seq_length - 2)]
# The convention in BERT/ERNIE is: # The convention in BERT/ERNIE is:
# (a) For sequence pairs: # (a) For sequence pairs:
...@@ -173,6 +185,7 @@ class BaseReader(object): ...@@ -173,6 +185,7 @@ class BaseReader(object):
tokens = [] tokens = []
text_type_ids = [] text_type_ids = []
tokens.append("[CLS]") tokens.append("[CLS]")
text_type_ids.append(0) text_type_ids.append(0)
for token in tokens_a: for token in tokens_a:
tokens.append(token) tokens.append(token)
...@@ -190,6 +203,29 @@ class BaseReader(object): ...@@ -190,6 +203,29 @@ class BaseReader(object):
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids))) position_ids = list(range(len(token_ids)))
if has_text_b_neg and self.phase == 'train':
tokens_neg = []
text_type_ids_neg = []
tokens_neg.append("[CLS]")
text_type_ids_neg.append(0)
for token in tokens_a_neg:
tokens_neg.append(token)
text_type_ids_neg.append(0)
tokens_neg.append("[SEP]")
text_type_ids_neg.append(0)
if tokens_b_neg:
for token in tokens_b_neg:
tokens_neg.append(token)
text_type_ids_neg.append(1)
tokens_neg.append("[SEP]")
text_type_ids_neg.append(1)
token_ids_neg = tokenizer.convert_tokens_to_ids(tokens_neg)
position_ids_neg = list(range(len(token_ids_neg)))
if self.is_inference: if self.is_inference:
Record = namedtuple('Record', Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids']) ['token_ids', 'text_type_ids', 'position_ids'])
...@@ -198,25 +234,38 @@ class BaseReader(object): ...@@ -198,25 +234,38 @@ class BaseReader(object):
text_type_ids=text_type_ids, text_type_ids=text_type_ids,
position_ids=position_ids) position_ids=position_ids)
else: else:
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
Record = namedtuple('Record', [
'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'
])
qid = None qid = None
if "qid" in example._fields: if "qid" in example._fields:
qid = example.qid qid = example.qid
if self.learning_strategy == 'pairwise' and self.phase == 'train':
Record = namedtuple('Record',
['token_ids', 'text_type_ids', 'position_ids', 'token_ids_neg', 'text_type_ids_neg', 'position_ids_neg', 'qid'])
record = Record(
token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
token_ids_neg=token_ids_neg,
text_type_ids_neg=text_type_ids_neg,
position_ids_neg=position_ids_neg,
qid=qid)
else:
if self.label_map:
label_id = self.label_map[example.label]
else:
label_id = example.label
record = Record( Record = namedtuple('Record', [
token_ids=token_ids, 'token_ids', 'text_type_ids', 'position_ids', 'label_id', 'qid'
text_type_ids=text_type_ids, ])
position_ids=position_ids,
label_id=label_id, record = Record(
qid=qid) token_ids=token_ids,
text_type_ids=text_type_ids,
position_ids=position_ids,
label_id=label_id,
qid=qid)
return record return record
def _prepare_batch_data(self, examples, batch_size, phase=None): def _prepare_batch_data(self, examples, batch_size, phase=None):
...@@ -228,7 +277,7 @@ class BaseReader(object): ...@@ -228,7 +277,7 @@ class BaseReader(object):
if phase == "train": if phase == "train":
self.current_example = index self.current_example = index
record = self._convert_example_to_record(example, self.max_seq_len, record = self._convert_example_to_record(example, self.max_seq_len,
self.tokenizer) self.tokenizer)
max_len = max(max_len, len(record.token_ids)) max_len = max(max_len, len(record.token_ids))
if self.in_tokens: if self.in_tokens:
to_append = (len(batch_records) + 1) * max_len <= batch_size to_append = (len(batch_records) + 1) * max_len <= batch_size
...@@ -285,6 +334,7 @@ class BaseReader(object): ...@@ -285,6 +334,7 @@ class BaseReader(object):
if len(all_dev_batches) == dev_count: if len(all_dev_batches) == dev_count:
for batch in all_dev_batches: for batch in all_dev_batches:
yield batch yield batch
all_dev_batches = [] all_dev_batches = []
def f(): def f():
for i in wrapper(): for i in wrapper():
...@@ -368,13 +418,6 @@ class MaskLMReader(BaseReader): ...@@ -368,13 +418,6 @@ class MaskLMReader(BaseReader):
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_ids = tokenizer.convert_tokens_to_ids(tokens)
position_ids = list(range(len(token_ids))) position_ids = list(range(len(token_ids)))
# Record = namedtuple('Record',
# ['token_ids', 'text_type_ids', 'position_ids'])
# record = Record(
# token_ids=token_ids,
# text_type_ids=text_type_ids,
# position_ids=position_ids)
return [token_ids, text_type_ids, position_ids] return [token_ids, text_type_ids, position_ids]
def batch_reader(self, examples, batch_size, in_tokens, phase): def batch_reader(self, examples, batch_size, in_tokens, phase):
...@@ -457,7 +500,6 @@ class ClassifyReader(BaseReader): ...@@ -457,7 +500,6 @@ class ClassifyReader(BaseReader):
index for index, h in enumerate(headers) if h != "label" index for index, h in enumerate(headers) if h != "label"
] ]
Example = namedtuple('Example', headers) Example = namedtuple('Example', headers)
examples = [] examples = []
for line in reader: for line in reader:
for index, text in enumerate(line): for index, text in enumerate(line):
...@@ -474,15 +516,20 @@ class ClassifyReader(BaseReader): ...@@ -474,15 +516,20 @@ class ClassifyReader(BaseReader):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
batch_position_ids = [record.position_ids for record in batch_records] batch_position_ids = [record.position_ids for record in batch_records]
if self.phase=='train' and self.learning_strategy == 'pairwise':
batch_token_ids_neg = [record.token_ids_neg for record in batch_records]
batch_text_type_ids_neg = [record.text_type_ids_neg for record in batch_records]
batch_position_ids_neg = [record.position_ids_neg for record in batch_records]
if not self.is_inference: if not self.is_inference:
batch_labels = [record.label_id for record in batch_records] if not self.learning_strategy == 'pairwise':
if self.is_classify: batch_labels = [record.label_id for record in batch_records]
batch_labels = np.array(batch_labels).astype("int64").reshape( if self.is_classify:
[-1]) batch_labels = np.array(batch_labels).astype("int64").reshape(
elif self.is_regression: [-1])
batch_labels = np.array(batch_labels).astype("float32").reshape( elif self.is_regression:
[-1]) batch_labels = np.array(batch_labels).astype("float32").reshape(
[-1])
if batch_records[0].qid: if batch_records[0].qid:
batch_qids = [record.qid for record in batch_records] batch_qids = [record.qid for record in batch_records]
...@@ -505,8 +552,23 @@ class ClassifyReader(BaseReader): ...@@ -505,8 +552,23 @@ class ClassifyReader(BaseReader):
padded_token_ids, padded_text_type_ids, padded_position_ids, padded_token_ids, padded_text_type_ids, padded_position_ids,
padded_task_ids, input_mask padded_task_ids, input_mask
] ]
if not self.is_inference:
return_list += [batch_labels, batch_qids] if self.phase=='train':
if self.learning_strategy == 'pairwise':
padded_token_ids_neg, input_mask_neg = pad_batch_data(
batch_token_ids_neg, pad_idx=self.pad_id, return_input_mask=True)
padded_text_type_ids_neg = pad_batch_data(
batch_text_type_ids_neg, pad_idx=self.pad_id)
padded_position_ids_neg = pad_batch_data(
batch_position_ids_neg, pad_idx=self.pad_id)
padded_task_ids_neg = np.ones_like(
padded_token_ids_neg, dtype="int64") * self.task_id
return_list += [padded_token_ids_neg, padded_text_type_ids_neg, \
padded_position_ids_neg, padded_task_ids_neg, input_mask_neg]
elif self.learning_strategy == 'pointwise':
return_list += [batch_labels]
return return_list return return_list
......
...@@ -18,6 +18,19 @@ from paddle.fluid import layers ...@@ -18,6 +18,19 @@ from paddle.fluid import layers
from paddlepalm.interface import task_paradigm from paddlepalm.interface import task_paradigm
import numpy as np import numpy as np
import os import os
import json
def computeHingeLoss(pos, neg, margin):
loss_part1 = fluid.layers.elementwise_sub(
fluid.layers.fill_constant_batch_size_like(
input=pos, shape=[-1, 1], value=margin, dtype='float32'), pos)
loss_part2 = fluid.layers.elementwise_add(loss_part1, neg)
loss_part3 = fluid.layers.elementwise_max(
fluid.layers.fill_constant_batch_size_like(
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'), loss_part2)
return loss_part3
class TaskParadigm(task_paradigm): class TaskParadigm(task_paradigm):
''' '''
...@@ -26,12 +39,25 @@ class TaskParadigm(task_paradigm): ...@@ -26,12 +39,25 @@ class TaskParadigm(task_paradigm):
def __init__(self, config, phase, backbone_config=None): def __init__(self, config, phase, backbone_config=None):
self._is_training = phase == 'train' self._is_training = phase == 'train'
self._hidden_size = backbone_config['hidden_size'] self._hidden_size = backbone_config['hidden_size']
self._batch_size = config['batch_size']
self._num_classes = config.get('num_classes', 2)
if 'learning_strategy' in config:
self._learning_strategy = config['learning_strategy']
else:
self._learning_strategy = 'pointwise'
if 'margin' in config:
self._margin = config['margin']
else:
self._margin = 0.5
if 'initializer_range' in config: if 'initializer_range' in config:
self._param_initializer = config['initializer_range'] self._param_initializer = config['initializer_range']
else: else:
self._param_initializer = fluid.initializer.TruncatedNormal( self._param_initializer = fluid.initializer.TruncatedNormal(
scale=backbone_config.get('initializer_range', 0.02)) scale=backbone_config.get('initializer_range', 0.02))
if 'dropout_prob' in config: if 'dropout_prob' in config:
self._dropout_prob = config['dropout_prob'] self._dropout_prob = config['dropout_prob']
else: else:
...@@ -39,15 +65,19 @@ class TaskParadigm(task_paradigm): ...@@ -39,15 +65,19 @@ class TaskParadigm(task_paradigm):
self._pred_output_path = config.get('pred_output_path', None) self._pred_output_path = config.get('pred_output_path', None)
self._preds = [] self._preds = []
self._preds_logits = []
@property @property
def inputs_attrs(self): def inputs_attrs(self):
if self._is_training: reader = {}
reader = {"label_ids": [[-1], 'int64']}
else:
reader = {}
bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']} bb = {"sentence_pair_embedding": [[-1, self._hidden_size], 'float32']}
if self._is_training:
if self._learning_strategy == 'pointwise':
reader["label_ids"] = [[-1], 'int64']
elif self._learning_strategy == 'pairwise':
bb["sentence_pair_embedding_neg"] = [[-1, self._hidden_size], 'float32']
return {'reader': reader, 'backbone': bb} return {'reader': reader, 'backbone': bb}
@property @property
...@@ -55,52 +85,110 @@ class TaskParadigm(task_paradigm): ...@@ -55,52 +85,110 @@ class TaskParadigm(task_paradigm):
if self._is_training: if self._is_training:
return {"loss": [[1], 'float32']} return {"loss": [[1], 'float32']}
else: else:
return {"logits": [[-1, 2], 'float32']} if self._learning_strategy=='paiwise':
return {"probs": [[-1, 1], 'float32']}
else:
return {"logits": [[-1, 2], 'float32'],
"probs": [[-1, 2], 'float32']}
def build(self, inputs, scope_name=""): def build(self, inputs, scope_name=""):
if self._is_training:
labels = inputs["reader"]["label_ids"]
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
# inputs
cls_feats = inputs["backbone"]["sentence_pair_embedding"]
if self._is_training: if self._is_training:
cls_feats = fluid.layers.dropout( cls_feats = fluid.layers.dropout(
x=cls_feats, x=cls_feats,
dropout_prob=self._dropout_prob, dropout_prob=self._dropout_prob,
dropout_implementation="upscale_in_train") dropout_implementation="upscale_in_train")
if self._learning_strategy == 'pairwise':
cls_feats_neg = inputs["backbone"]["sentence_pair_embedding_neg"]
cls_feats_neg = fluid.layers.dropout(
x=cls_feats_neg,
dropout_prob=self._dropout_prob,
dropout_implementation="upscale_in_train")
elif self._learning_strategy == 'pointwise':
labels = inputs["reader"]["label_ids"]
# loss
# for pointwise
if self._learning_strategy == 'pointwise':
logits = fluid.layers.fc(
input=cls_feats,
size=self._num_classes,
param_attr=fluid.ParamAttr(
name=scope_name+"cls_out_w",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b",
initializer=fluid.initializer.Constant(0.)))
probs = fluid.layers.softmax(logits)
if self._is_training:
ce_loss = fluid.layers.cross_entropy(
input=probs, label=labels)
loss = fluid.layers.mean(x=ce_loss)
return {'loss': loss}
# for pred
else:
return {'logits': logits,
'probs': probs}
# for pairwise
elif self._learning_strategy == 'pairwise':
pos_score = fluid.layers.fc(
input=cls_feats,
size=1,
act = "sigmoid",
param_attr=fluid.ParamAttr(
name=scope_name+"cls_out_w_pr",
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b_pr",
initializer=fluid.initializer.Constant(0.)))
pos_score = fluid.layers.reshape(x=pos_score, shape=[-1, 1], inplace=True)
logits = fluid.layers.fc( if self._is_training:
input=cls_feats, neg_score = fluid.layers.fc(
size=2, input=cls_feats_neg,
param_attr=fluid.ParamAttr( size=1,
name=scope_name+"cls_out_w", act = "sigmoid",
initializer=self._param_initializer), param_attr=fluid.ParamAttr(
bias_attr=fluid.ParamAttr( name=scope_name+"cls_out_w_pr",
name=scope_name+"cls_out_b", initializer=self._param_initializer),
initializer=fluid.initializer.Constant(0.))) bias_attr=fluid.ParamAttr(
name=scope_name+"cls_out_b_pr",
initializer=fluid.initializer.Constant(0.)))
neg_score = fluid.layers.reshape(x=neg_score, shape=[-1, 1], inplace=True)
loss = fluid.layers.mean(computeHingeLoss(pos_score, neg_score, self._margin))
return {'loss': loss}
# for pred
else:
return {'probs': pos_score}
if self._is_training:
inputs = fluid.layers.softmax(logits)
ce_loss = fluid.layers.cross_entropy(
input=inputs, label=labels)
loss = fluid.layers.mean(x=ce_loss)
return {'loss': loss}
else:
return {'logits': logits}
def postprocess(self, rt_outputs): def postprocess(self, rt_outputs):
if not self._is_training: if not self._is_training:
logits = rt_outputs['logits'] probs = []
preds = np.argmax(logits, -1) logits = []
self._preds.extend(preds.tolist()) probs = rt_outputs['probs']
self._preds.extend(probs.tolist())
if self._learning_strategy == 'pointwise':
logits = rt_outputs['logits']
self._preds_logits.extend(logits.tolist())
def epoch_postprocess(self, post_inputs): def epoch_postprocess(self, post_inputs):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if not self._is_training: if not self._is_training:
if self._pred_output_path is None: if self._pred_output_path is None:
raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.') raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.')
with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer: with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer:
for p in self._preds: for i in range(len(self._preds)):
writer.write(str(p)+'\n') if self._learning_strategy == 'pointwise':
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json')) label = 0 if self._preds[i][0] > self._preds[i][1] else 1
result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]}
elif self._learning_strategy == 'pairwise':
label = 0 if self._preds[i][0] < 0.5 else 1
result = {'index': i, 'label': label, 'probs': self._preds[i][0]}
result = json.dumps(result)
writer.write(result+'\n')
print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json'))
\ No newline at end of file
...@@ -23,12 +23,14 @@ from paddle import fluid ...@@ -23,12 +23,14 @@ from paddle import fluid
def _check_and_adapt_shape_dtype(rt_val, attr, message=""): def _check_and_adapt_shape_dtype(rt_val, attr, message=""):
if not isinstance(rt_val, np.ndarray): if not isinstance(rt_val, np.ndarray):
rt_val = np.array(rt_val) rt_val = np.array(rt_val)
assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)." assert rt_val.dtype != np.dtype('O'), "yielded data is not a valid tensor(number of elements on some dimension may differ)."
if rt_val.dtype == np.dtype('float64'): if rt_val.dtype == np.dtype('float64'):
rt_val = rt_val.astype('float32') rt_val = rt_val.astype('float32')
shape, dtype = attr shape, dtype = attr
assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype)) assert rt_val.dtype == np.dtype(dtype), message+"yielded data type not consistent with attr settings. Expect: {}, receive: {}.".format(rt_val.dtype, np.dtype(dtype))
assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim) assert len(shape) == rt_val.ndim, message+"yielded data rank(ndim) not consistent with attr settings. Expect: {}, receive: {}.".format(len(shape), rt_val.ndim)
for rt, exp in zip(rt_val.shape, shape): for rt, exp in zip(rt_val.shape, shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册